Skip to content

Commit 6291f08

Browse files
Tai78641Jerry-Ge
authored andcommitted
[TOSA] Change PadOp padding to tosa.shape
This patch changes PadOp's padding input to type !tosa.shape<2 * rank>, (where rank is the rank of the PadOp's input), instead of a <rank x 2> tensor. Signed-off-by: Tai Ly <[email protected]> Change-Id: I08526a699d6b8ebbaf9ee092cd37580e5d78f919
1 parent 7aec7ca commit 6291f08

17 files changed

+196
-159
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
3939
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
4040
Attribute attr);
4141

42+
bool collectShapeValue(Operation* op, llvm::SmallVector<int64_t>& newShape);
43+
4244
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
4345

4446
} // namespace tosa

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,21 +1557,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
15571557
Example:
15581558

15591559
```mlir
1560-
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
1561-
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
1560+
%0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
1561+
tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<4x9xf32>)
15621562
```
15631563

15641564
Example 2:
15651565

15661566
```mlir
1567-
%0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
1568-
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
1567+
%0 = tosa.const_shape { value = dense<[-1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
1568+
tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
15691569
```
15701570
}];
15711571

15721572
let arguments = (ins
15731573
Tosa_RankedTensor:$input1,
1574-
TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding,
1574+
Tosa_Shape:$padding,
15751575
Optional<Tosa_ScalarTensor>:$pad_const,
15761576
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
15771577
);

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
229229
return permuted;
230230
}
231231

232+
// Computes shape value using tosa const_shape op.
233+
Value getTosaConstShape(PatternRewriter& rewriter, Location loc,
234+
llvm::ArrayRef<int64_t> shape);
235+
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
236+
232237
} // namespace tosa
233238
} // namespace mlir
234239

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,16 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
306306
ConversionPatternRewriter &rewriter) const final {
307307
auto loc = padOp.getLoc();
308308
auto input = padOp.getInput1();
309-
auto padding = padOp.getPadding();
309+
310+
ElementsAttr paddingElems;
311+
if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
312+
return rewriter.notifyMatchFailure(
313+
padOp, "padding must be a static shape value");
314+
}
315+
llvm::SmallVector<int64_t> paddingVals;
316+
for (auto idx : paddingElems.getValues<IntegerAttr>()) {
317+
paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
318+
}
310319

311320
ShapedType inputTy = cast<ShapedType>(input.getType());
312321
Type elementTy = inputTy.getElementType();
@@ -345,18 +354,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
345354
highValues.reserve(rank);
346355

347356
for (int i = 0; i < rank; i++) {
348-
Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
349-
Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
350-
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
351-
loc, padding, ValueRange({lowIndex}));
352-
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
353-
loc, padding, ValueRange({highIndex}));
354-
355-
lowVal = rewriter.createOrFold<arith::IndexCastOp>(
356-
loc, rewriter.getIndexType(), lowVal);
357-
highVal = rewriter.createOrFold<arith::IndexCastOp>(
358-
loc, rewriter.getIndexType(), highVal);
359-
357+
Value lowVal = rewriter.create<arith::ConstantOp>(
358+
loc, rewriter.getIndexAttr(paddingVals[2 * i]));
359+
Value highVal = rewriter.create<arith::ConstantOp>(
360+
loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
360361
lowValues.push_back(lowVal);
361362
highValues.push_back(highVal);
362363
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,26 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
210210
}
211211
}
212212

213+
//===----------------------------------------------------------------------===//
214+
// TOSA shape inference helper
215+
//===----------------------------------------------------------------------===//
216+
bool mlir::tosa::collectShapeValue(Operation* op, llvm::SmallVector<int64_t>& newShape) {
217+
if (!op) {
218+
return false;
219+
}
220+
if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
221+
Attribute constOpAttr = constOp->getAttr("value");
222+
DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
223+
for (int i = 0; i < elementsAttr.size(); i++) {
224+
int64_t val = elementsAttr.getValues<int64_t>()[i];
225+
newShape.push_back(val);
226+
}
227+
return true;
228+
}
229+
// for undefined op, return false.
230+
return false;
231+
}
232+
213233
//===----------------------------------------------------------------------===//
214234
// TOSA Operator Verifiers.
215235
//===----------------------------------------------------------------------===//
@@ -823,51 +843,42 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
823843
PadOp::Adaptor adaptor,
824844
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
825845
ShapeAdaptor inputShape(adaptor.getInput1().getType());
826-
ShapeAdaptor paddingShape(adaptor.getPadding().getType());
846+
auto paddingRank =
847+
cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
827848
SmallVector<int64_t> outputShape;
828849

829-
// If both inputs have unknown shape, we cannot determine the shape of the
830-
// output.
831-
if (!inputShape.hasRank() && !paddingShape.hasRank()) {
832-
inferredReturnShapes.push_back(ShapedTypeComponents());
833-
return success();
834-
}
835-
836-
// If the input rank is unknown we can info the output rank using the
837-
// padding shape's first dim.
850+
// If the input rank is unknown, we can infer the output rank using the
851+
// padding shape's rank divided by 2.
838852
if (!inputShape.hasRank()) {
839-
if (paddingShape.isDynamicDim(0)) {
840-
inferredReturnShapes.push_back(ShapedTypeComponents());
841-
return success();
842-
}
843-
844-
outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
853+
outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
845854
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
846855
return success();
847856
}
848857

849-
DenseIntElementsAttr paddings;
858+
SmallVector<int64_t> paddingValues;
850859
// If the paddings value is not a constant, all dimensions must be dynamic.
851-
if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
860+
if (!tosa::collectShapeValue(adaptor.getPadding().getDefiningOp(),
861+
paddingValues)) {
852862
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
853863
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
854864
return success();
855865
}
856866

857-
SmallVector<int64_t> paddingValues;
858-
for (auto val : paddings) {
859-
paddingValues.push_back(val.getSExtValue());
860-
}
861-
862867
outputShape.reserve(inputShape.getRank());
863868
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
864869
if (inputShape.isDynamicDim(i)) {
865870
outputShape.push_back(ShapedType::kDynamic);
866871
continue;
867872
}
873+
auto padFront = paddingValues[i * 2];
874+
auto padBack = paddingValues[i * 2 + 1];
875+
if (padFront < 0 || padBack < 0) {
876+
// if either padding for dim i is -1, output dim is unknown
877+
outputShape.push_back(ShapedType::kDynamic);
878+
continue;
879+
}
868880

869-
outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
870-
paddingValues[i * 2 + 1]);
881+
outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
871882
}
872883

873884
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -877,17 +888,16 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
877888
LogicalResult tosa::PadOp::verify() {
878889
RankedTensorType inputType = getInput1().getType();
879890
RankedTensorType outputType = getOutput().getType();
880-
RankedTensorType paddingType = getPadding().getType();
891+
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
881892

882893
if (inputType.getRank() != outputType.getRank())
883894
return emitOpError() << "expect same input and output tensor rank.";
884-
885-
if (!paddingType.isDynamicDim(0) &&
886-
paddingType.getDimSize(0) != inputType.getRank() * 2)
895+
896+
if (paddingRank != inputType.getRank() * 2)
887897
return emitOpError() << "expected padding tensor dim 0 to have size "
888898
<< inputType.getRank() * 2
889899
<< " (2*rank(shape1)) but got size "
890-
<< paddingType.getDimSize(0);
900+
<< paddingRank;
891901

892902
return success();
893903
}

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
8181
}
8282
}
8383

84-
auto padSizeTy = RankedTensorType::get({8}, rewriter.getI64Type());
85-
auto padSize =
86-
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
87-
Value padSizeVal =
88-
rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
84+
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
8985

9086
auto padTy = RankedTensorType::get({}, inputETy);
9187
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
108108
}
109109
}
110110

111-
auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
112-
auto padSize =
113-
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
114-
Value padSizeVal =
115-
rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
111+
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
116112

117113
auto padTy = RankedTensorType::get({}, inputETy);
118114
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,14 @@ class TransposeConvStridedConverter
135135
int64_t inputChannels = weightTy.getDimSize(3);
136136

137137
// Pad the weight so that it is modulo of the striding.
138-
llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
138+
llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
139139
weightPadding[3] =
140140
(weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
141141
weightPadding[5] =
142-
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
143-
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
144-
RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
145-
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
146-
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
142+
weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
143+
144+
Value weightPaddingVal =
145+
getTosaConstShape(rewriter, op->getLoc(), weightPadding);
147146

148147
if (op.getQuantizationInfo().has_value()) {
149148
auto quantInfo = op.getQuantizationInfo().value();
@@ -197,17 +196,14 @@ class TransposeConvStridedConverter
197196
/* axis = */ rewriter.getI32IntegerAttr(2));
198197

199198
// We need to pad the input far enough that we can pull all values.
200-
llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
199+
llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
201200
inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
202201
inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
203202
inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
204203
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
205204

206-
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
207-
RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);
208-
209-
Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
210-
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
205+
Value inputPaddingVal =
206+
getTosaConstShape(rewriter, op->getLoc(), inputPadding);
211207

212208
if (op.getQuantizationInfo().has_value()) {
213209
auto quantInfo = op.getQuantizationInfo().value();
@@ -310,17 +306,14 @@ class TransposeConvStridedConverter
310306
rewriter.getDenseI64ArrayAttr(sliceSize))
311307
.getResult();
312308

313-
llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
309+
llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
314310
resultPadding[2] = resultPadTop;
315311
resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
316312
resultPadding[4] = resultPadLeft;
317313
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
318314

319-
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
320-
RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);
321-
322-
Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
323-
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
315+
Value resultPaddingVal =
316+
getTosaConstShape(rewriter, op->getLoc(), resultPadding);
324317

325318
Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
326319
rewriter, loc, UnrankedTensorType::get(resultETy), slice,

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,17 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
160160

161161
return success();
162162
}
163+
164+
Value mlir::tosa::getTosaConstShape(PatternRewriter& rewriter, Location loc,
165+
llvm::ArrayRef<int64_t> shape) {
166+
auto attr = rewriter.getIndexTensorAttr(shape);
167+
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
168+
mlir::Operation *mlir_op = rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
169+
return mlir_op->getResult(0);
170+
}
171+
172+
SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
173+
return to_vector(llvm::map_range(shape, [](int64_t dim) {
174+
return ShapedType::isDynamic(dim) ? -1 : dim;
175+
}));
176+
}

0 commit comments

Comments
 (0)