Skip to content

Commit 7e622b6

Browse files
Jerry-GeTai78641
andauthored
[TOSA] Change PadOp padding to tosa.shape (#123133)
This patch changes PadOp's padding input to type !tosa.shape<2 * rank>, (where rank is the rank of the PadOp's input), instead of a <rank x 2> tensor. This patch is also a part of TOSA v1.0 effort: https://discourse.llvm.org/t/rfc-tosa-dialect-increment-to-v1-0/83708 This patch updates the PadOp to match all against the TOSA v1.0 form. Original Authors include: @Tai78641 @wonjeon Co-authored-by: Tai Ly <[email protected]>
1 parent 3057d0f commit 7e622b6

16 files changed

+196
-159
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,21 +1557,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
15571557
Example:
15581558

15591559
```mlir
1560-
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
1561-
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
1560+
%0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
1561+
tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<4x9xf32>)
15621562
```
15631563

15641564
Example 2:
15651565

15661566
```mlir
1567-
%0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
1568-
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
1567+
%0 = tosa.const_shape { value = dense<[-1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
1568+
tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
15691569
```
15701570
}];
15711571

15721572
let arguments = (ins
15731573
Tosa_RankedTensor:$input1,
1574-
TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding,
1574+
Tosa_Shape:$padding,
15751575
Optional<Tosa_ScalarTensor>:$pad_const,
15761576
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
15771577
);

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,14 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
229229
return permuted;
230230
}
231231

232+
// Computes shape value using tosa const_shape op.
233+
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
234+
llvm::ArrayRef<int64_t> shape);
235+
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
236+
237+
bool getConstShapeValue(Operation *op,
238+
llvm::SmallVector<int64_t> &result_shape);
239+
232240
} // namespace tosa
233241
} // namespace mlir
234242

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,16 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
306306
ConversionPatternRewriter &rewriter) const final {
307307
auto loc = padOp.getLoc();
308308
auto input = padOp.getInput1();
309-
auto padding = padOp.getPadding();
309+
310+
ElementsAttr paddingElems;
311+
if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
312+
return rewriter.notifyMatchFailure(
313+
padOp, "padding must be a static shape value");
314+
}
315+
llvm::SmallVector<int64_t> paddingVals;
316+
for (auto idx : paddingElems.getValues<IntegerAttr>()) {
317+
paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
318+
}
310319

311320
ShapedType inputTy = cast<ShapedType>(input.getType());
312321
Type elementTy = inputTy.getElementType();
@@ -345,18 +354,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
345354
highValues.reserve(rank);
346355

347356
for (int i = 0; i < rank; i++) {
348-
Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
349-
Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
350-
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
351-
loc, padding, ValueRange({lowIndex}));
352-
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
353-
loc, padding, ValueRange({highIndex}));
354-
355-
lowVal = rewriter.createOrFold<arith::IndexCastOp>(
356-
loc, rewriter.getIndexType(), lowVal);
357-
highVal = rewriter.createOrFold<arith::IndexCastOp>(
358-
loc, rewriter.getIndexType(), highVal);
359-
357+
Value lowVal = rewriter.create<arith::ConstantOp>(
358+
loc, rewriter.getIndexAttr(paddingVals[2 * i]));
359+
Value highVal = rewriter.create<arith::ConstantOp>(
360+
loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
360361
lowValues.push_back(lowVal);
361362
highValues.push_back(highVal);
362363
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ using namespace mlir;
3636
using namespace mlir::tosa;
3737

3838
#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
39+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
3940

4041
//===----------------------------------------------------------------------===//
4142
// Tosa dialect interface includes.
@@ -822,51 +823,42 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
822823
PadOp::Adaptor adaptor,
823824
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
824825
ShapeAdaptor inputShape(adaptor.getInput1().getType());
825-
ShapeAdaptor paddingShape(adaptor.getPadding().getType());
826+
auto paddingRank =
827+
cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
826828
SmallVector<int64_t> outputShape;
827829

828-
// If both inputs have unknown shape, we cannot determine the shape of the
829-
// output.
830-
if (!inputShape.hasRank() && !paddingShape.hasRank()) {
831-
inferredReturnShapes.push_back(ShapedTypeComponents());
832-
return success();
833-
}
834-
835-
// If the input rank is unknown we can info the output rank using the
836-
// padding shape's first dim.
830+
// If the input rank is unknown, we can infer the output rank using the
831+
// padding shape's rank divided by 2.
837832
if (!inputShape.hasRank()) {
838-
if (paddingShape.isDynamicDim(0)) {
839-
inferredReturnShapes.push_back(ShapedTypeComponents());
840-
return success();
841-
}
842-
843-
outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
833+
outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
844834
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
845835
return success();
846836
}
847837

848-
DenseIntElementsAttr paddings;
838+
SmallVector<int64_t> paddingValues;
849839
// If the paddings value is not a constant, all dimensions must be dynamic.
850-
if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
840+
if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(),
841+
paddingValues)) {
851842
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
852843
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
853844
return success();
854845
}
855846

856-
SmallVector<int64_t> paddingValues;
857-
for (auto val : paddings) {
858-
paddingValues.push_back(val.getSExtValue());
859-
}
860-
861847
outputShape.reserve(inputShape.getRank());
862848
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
863849
if (inputShape.isDynamicDim(i)) {
864850
outputShape.push_back(ShapedType::kDynamic);
865851
continue;
866852
}
853+
auto padFront = paddingValues[i * 2];
854+
auto padBack = paddingValues[i * 2 + 1];
855+
if (padFront < 0 || padBack < 0) {
856+
// if either padding for dim i is -1, output dim is unknown
857+
outputShape.push_back(ShapedType::kDynamic);
858+
continue;
859+
}
867860

868-
outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
869-
paddingValues[i * 2 + 1]);
861+
outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
870862
}
871863

872864
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -876,17 +868,15 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
876868
LogicalResult tosa::PadOp::verify() {
877869
RankedTensorType inputType = getInput1().getType();
878870
RankedTensorType outputType = getOutput().getType();
879-
RankedTensorType paddingType = getPadding().getType();
871+
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
880872

881873
if (inputType.getRank() != outputType.getRank())
882874
return emitOpError() << "expect same input and output tensor rank.";
883875

884-
if (!paddingType.isDynamicDim(0) &&
885-
paddingType.getDimSize(0) != inputType.getRank() * 2)
876+
if (paddingRank != inputType.getRank() * 2)
886877
return emitOpError() << "expected padding tensor dim 0 to have size "
887878
<< inputType.getRank() * 2
888-
<< " (2*rank(shape1)) but got size "
889-
<< paddingType.getDimSize(0);
879+
<< " (2*rank(shape1)) but got size " << paddingRank;
890880

891881
return success();
892882
}

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
8181
}
8282
}
8383

84-
auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
85-
auto padSize =
86-
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
87-
Value padSizeVal =
88-
rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
84+
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
8985

9086
auto padTy = RankedTensorType::get({}, inputETy);
9187
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
108108
}
109109
}
110110

111-
auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
112-
auto padSize =
113-
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
114-
Value padSizeVal =
115-
rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
111+
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
116112

117113
auto padTy = RankedTensorType::get({}, inputETy);
118114
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,14 @@ class TransposeConvStridedConverter
135135
int64_t inputChannels = weightTy.getDimSize(3);
136136

137137
// Pad the weight so that it is modulo of the striding.
138-
llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
138+
llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
139139
weightPadding[3] =
140140
(weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
141141
weightPadding[5] =
142-
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
143-
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
144-
RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
145-
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
146-
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
142+
weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
143+
144+
Value weightPaddingVal =
145+
getTosaConstShape(rewriter, op->getLoc(), weightPadding);
147146

148147
if (op.getQuantizationInfo().has_value()) {
149148
auto quantInfo = op.getQuantizationInfo().value();
@@ -197,17 +196,14 @@ class TransposeConvStridedConverter
197196
/* axis = */ rewriter.getI32IntegerAttr(2));
198197

199198
// We need to pad the input far enough that we can pull all values.
200-
llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
199+
llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
201200
inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
202201
inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
203202
inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
204203
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
205204

206-
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
207-
RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);
208-
209-
Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
210-
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
205+
Value inputPaddingVal =
206+
getTosaConstShape(rewriter, op->getLoc(), inputPadding);
211207

212208
if (op.getQuantizationInfo().has_value()) {
213209
auto quantInfo = op.getQuantizationInfo().value();
@@ -310,17 +306,14 @@ class TransposeConvStridedConverter
310306
rewriter.getDenseI64ArrayAttr(sliceSize))
311307
.getResult();
312308

313-
llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
309+
llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
314310
resultPadding[2] = resultPadTop;
315311
resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
316312
resultPadding[4] = resultPadLeft;
317313
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
318314

319-
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
320-
RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);
321-
322-
Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
323-
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
315+
Value resultPaddingVal =
316+
getTosaConstShape(rewriter, op->getLoc(), resultPadding);
324317

325318
Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
326319
rewriter, loc, UnrankedTensorType::get(resultETy), slice,

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,36 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
160160

161161
return success();
162162
}
163+
164+
Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
165+
llvm::ArrayRef<int64_t> shape) {
166+
auto attr = rewriter.getIndexTensorAttr(shape);
167+
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
168+
mlir::Operation *mlir_op =
169+
rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
170+
return mlir_op->getResult(0);
171+
}
172+
173+
SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
174+
return to_vector(llvm::map_range(shape, [](int64_t dim) {
175+
return ShapedType::isDynamic(dim) ? -1 : dim;
176+
}));
177+
}
178+
179+
bool mlir::tosa::getConstShapeValue(Operation *op,
180+
llvm::SmallVector<int64_t> &result_shape) {
181+
if (!op) {
182+
return false;
183+
}
184+
if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
185+
Attribute constOpAttr = constOp->getAttr("value");
186+
DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
187+
for (int i = 0; i < elementsAttr.size(); i++) {
188+
int64_t val = elementsAttr.getValues<int64_t>()[i];
189+
result_shape.push_back(val);
190+
}
191+
return true;
192+
}
193+
// for undefined op, return false.
194+
return false;
195+
}

0 commit comments

Comments
 (0)