Skip to content

[TOSA] Change PadOp padding to tosa.shape #123133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1557,21 +1557,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
Example:

```mlir
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
%0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<4x9xf32>)
```

Example 2:

```mlir
%0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
%0 = tosa.const_shape { value = dense<[-1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
```
}];

let arguments = (ins
Tosa_RankedTensor:$input1,
TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding,
Tosa_Shape:$padding,
Optional<Tosa_ScalarTensor>:$pad_const,
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
);
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
return permuted;
}

// Computes shape value using tosa const_shape op.
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape);
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);

bool getConstShapeValue(Operation *op,
llvm::SmallVector<int64_t> &result_shape);

} // namespace tosa
} // namespace mlir

Expand Down
27 changes: 14 additions & 13 deletions mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,16 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
ConversionPatternRewriter &rewriter) const final {
auto loc = padOp.getLoc();
auto input = padOp.getInput1();
auto padding = padOp.getPadding();

ElementsAttr paddingElems;
if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit surprised this works, maybe I'm missing something. Should we extract padding in a similar way as below?

if (!tosa::ExtractConstShapeValue(adaptor.getPadding().getDefiningOp(),
                                    paddingValues))

From: https://github.com/llvm/llvm-project/pull/123133/files#diff-90956ba24a2a97cc56a9a3659c7e46e56f1bd791a869246c6a758f9c93f1434fR841

Copy link
Contributor

@Tai78641 Tai78641 Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape constant value is intended to work properly with matchPattern

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, in that case it’s a non-blocking comment, just a suggestion for consistency

return rewriter.notifyMatchFailure(
padOp, "padding must be a static shape value");
}
llvm::SmallVector<int64_t> paddingVals;
for (auto idx : paddingElems.getValues<IntegerAttr>()) {
paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
}

ShapedType inputTy = cast<ShapedType>(input.getType());
Type elementTy = inputTy.getElementType();
Expand Down Expand Up @@ -345,18 +354,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
highValues.reserve(rank);

for (int i = 0; i < rank; i++) {
Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({lowIndex}));
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({highIndex}));

lowVal = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), lowVal);
highVal = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), highVal);

Value lowVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(paddingVals[2 * i]));
Value highVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
lowValues.push_back(lowVal);
highValues.push_back(highVal);
}
Expand Down
50 changes: 20 additions & 30 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ using namespace mlir;
using namespace mlir::tosa;

#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"

//===----------------------------------------------------------------------===//
// Tosa dialect interface includes.
Expand Down Expand Up @@ -823,51 +824,42 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
PadOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
ShapeAdaptor paddingShape(adaptor.getPadding().getType());
auto paddingRank =
cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
SmallVector<int64_t> outputShape;

// If both inputs have unknown shape, we cannot determine the shape of the
// output.
if (!inputShape.hasRank() && !paddingShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
return success();
}

// If the input rank is unknown we can info the output rank using the
// padding shape's first dim.
// If the input rank is unknown, we can infer the output rank using the
// padding shape's rank divided by 2.
if (!inputShape.hasRank()) {
if (paddingShape.isDynamicDim(0)) {
inferredReturnShapes.push_back(ShapedTypeComponents());
return success();
}

outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}

DenseIntElementsAttr paddings;
SmallVector<int64_t> paddingValues;
// If the paddings value is not a constant, all dimensions must be dynamic.
if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(),
paddingValues)) {
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}

SmallVector<int64_t> paddingValues;
for (auto val : paddings) {
paddingValues.push_back(val.getSExtValue());
}

outputShape.reserve(inputShape.getRank());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
if (inputShape.isDynamicDim(i)) {
outputShape.push_back(ShapedType::kDynamic);
continue;
}
auto padFront = paddingValues[i * 2];
auto padBack = paddingValues[i * 2 + 1];
if (padFront < 0 || padBack < 0) {
// if either padding for dim i is -1, output dim is unknown
outputShape.push_back(ShapedType::kDynamic);
continue;
}

outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
paddingValues[i * 2 + 1]);
outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
}

inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
Expand All @@ -877,17 +869,15 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
LogicalResult tosa::PadOp::verify() {
RankedTensorType inputType = getInput1().getType();
RankedTensorType outputType = getOutput().getType();
RankedTensorType paddingType = getPadding().getType();
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();

if (inputType.getRank() != outputType.getRank())
return emitOpError() << "expect same input and output tensor rank.";

if (!paddingType.isDynamicDim(0) &&
paddingType.getDimSize(0) != inputType.getRank() * 2)
if (paddingRank != inputType.getRank() * 2)
return emitOpError() << "expected padding tensor dim 0 to have size "
<< inputType.getRank() * 2
<< " (2*rank(shape1)) but got size "
<< paddingType.getDimSize(0);
<< " (2*rank(shape1)) but got size " << paddingRank;

return success();
}
Expand Down
6 changes: 1 addition & 5 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
}
}

auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
auto padSize =
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
Value padSizeVal =
rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);

auto padTy = RankedTensorType::get({}, inputETy);
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
Expand Down
6 changes: 1 addition & 5 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
}
}

auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
auto padSize =
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
Value padSizeVal =
rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);

auto padTy = RankedTensorType::get({}, inputETy);
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
Expand Down
29 changes: 11 additions & 18 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,14 @@ class TransposeConvStridedConverter
int64_t inputChannels = weightTy.getDimSize(3);

// Pad the weight so that it is modulo of the striding.
llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
weightPadding[3] =
(weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
weightPadding[5] =
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;

Value weightPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), weightPadding);

if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
Expand Down Expand Up @@ -197,17 +196,14 @@ class TransposeConvStridedConverter
/* axis = */ rewriter.getI32IntegerAttr(2));

// We need to pad the input far enough that we can pull all values.
llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;

DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);

Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
Value inputPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), inputPadding);

if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
Expand Down Expand Up @@ -310,17 +306,14 @@ class TransposeConvStridedConverter
rewriter.getDenseI64ArrayAttr(sliceSize))
.getResult();

llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
resultPadding[2] = resultPadTop;
resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
resultPadding[4] = resultPadLeft;
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];

DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);

Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
Value resultPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), resultPadding);

Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), slice,
Expand Down
33 changes: 33 additions & 0 deletions mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,36 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,

return success();
}

Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape) {
auto attr = rewriter.getIndexTensorAttr(shape);
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
mlir::Operation *mlir_op =
rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
return mlir_op->getResult(0);
}

SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return ShapedType::isDynamic(dim) ? -1 : dim;
}));
}

bool mlir::tosa::getConstShapeValue(Operation *op,
llvm::SmallVector<int64_t> &result_shape) {
if (!op) {
return false;
}
if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
Attribute constOpAttr = constOp->getAttr("value");
DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
for (int i = 0; i < elementsAttr.size(); i++) {
int64_t val = elementsAttr.getValues<int64_t>()[i];
result_shape.push_back(val);
}
return true;
}
// for undefined op, return false.
return false;
}
Loading
Loading