Skip to content

Commit 729f958

Browse files
authored
[TOSA] Add SameOperandsAndResultRank to TOSA Ops (#104501)
[note: this is blocked by: tensorflow/tensorflow#73891 otherwise tensorflow may have lit test failures] This patch adds SameOperandsAndResultRank trait to TOSA operators with ResultsBroadcastableShape trait. SameOperandsAndResultRank trait requiring that all operands and results have matching ranks unless the operand/result is unranked. This also renders the TosaMakeBroadcastable pass unnecessary - but this pass is left in for now just in case it is still used in some flows. The lit test, broadcast.mlir, is removed. This also adds verify of the SameOperandsAndResultRank trait in the TosaInferShapes pass to validate inferred shapes. Signed-off-by: Tai Ly <[email protected]>
1 parent bd56950 commit 729f958

File tree

9 files changed

+126
-378
lines changed

9 files changed

+126
-378
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> {
231231
//===----------------------------------------------------------------------===//
232232

233233
class Tosa_Op<string mnemonic, list<Trait> traits = []> :
234-
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
234+
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
235235
TosaResolvableShapeOperands])> {
236236
}
237237

@@ -241,6 +241,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
241241
["inferReturnTypeComponents"]>,
242242
ResultsBroadcastableShape,
243243
TosaElementwiseOperator,
244+
SameOperandsAndResultRank,
244245
Pure])> {
245246
let assemblyFormat =
246247
"operands attr-dict `:` functional-type(operands, results)";

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,32 @@ void propagateShapesInRegion(Region &region, TypeModificationState &state) {
303303
}
304304
}
305305

306+
/// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
307+
/// and all nested regions
308+
void validateSameOperandsAndResultRankTrait(Region &region) {
309+
int errs = 0;
310+
for (auto &block : region) {
311+
for (auto &op : block) {
312+
if (!op.getDialect() ||
313+
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
314+
continue;
315+
if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
316+
if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
317+
errs++;
318+
}
319+
}
320+
WhileOp whileOp = dyn_cast<WhileOp>(op);
321+
IfOp ifOp = dyn_cast<IfOp>(op);
322+
if (whileOp || ifOp) {
323+
// recurse into whileOp's regions
324+
for (auto &next : op.getRegions()) {
325+
validateSameOperandsAndResultRankTrait(next);
326+
}
327+
}
328+
}
329+
}
330+
}
331+
306332
/// Pass that performs shape propagation across TOSA operations. This includes
307333
/// migrating to within the regions of if/while operations.
308334
struct TosaInferShapes
@@ -313,6 +339,8 @@ struct TosaInferShapes
313339
TypeModificationState state;
314340
propagateShapesInRegion(func.getBody(), state);
315341
state.commit();
342+
343+
validateSameOperandsAndResultRankTrait(func.getBody());
316344
}
317345
};
318346
} // namespace

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,11 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32,
4545
%0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
4646
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
4747
}
48+
49+
// -----
50+
51+
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
52+
// expected-error@+1 {{'tosa.add' op operands don't have matching ranks}}
53+
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
54+
return %0 : tensor<2x3x4xf32>
55+
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
9494
// CHECK: } -> tensor<f32>
9595
%0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
9696

97+
9798
// CHECK: return [[RESULT]] : tensor<f32>
9899
return %0 : tensor<f32>
99100
}
@@ -103,20 +104,20 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
103104
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
104105
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, d1)>
105106

106-
// CHECK-LABEL: func.func @test_add_0d_broadcast(
107+
// CHECK-LABEL: func.func @test_add_2d_broadcast(
107108
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>,
108-
// CHECK-SAME: %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
109-
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
109+
// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1xf32>) -> tensor<2x1xf32> {
110110
// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
111-
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
111+
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
112112
// CHECK: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
113113
// CHECK: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
114114
// CHECK: linalg.yield %[[ADD]] : f32
115115
// CHECK: } -> tensor<2x1xf32>
116116
// CHECK: return %[[RESULT]] : tensor<2x1xf32>
117117
// CHECK: }
118-
func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<f32>) -> tensor<2x1xf32> {
119-
%0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<f32>) -> tensor<2x1xf32>
118+
func.func @test_add_2d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<1x1xf32>) -> tensor<2x1xf32> {
119+
// tosa element-wise operators now require operands of equal ranks
120+
%0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1x1xf32>) -> tensor<2x1xf32>
120121
return %0 : tensor<2x1xf32>
121122
}
122123

@@ -383,28 +384,6 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
383384

384385
// -----
385386

386-
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
387-
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
388-
// CHECK-LABEL: @test_add_2d_different_ranks
389-
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
390-
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
391-
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
392-
393-
// CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
394-
// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
395-
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
396-
// CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
397-
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
398-
// CHECK: linalg.yield %[[VAL_4]] : f32
399-
// CHECK: } -> tensor<2x3x4xf32>
400-
%0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
401-
402-
// CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
403-
return %0 : tensor<2x3x4xf32>
404-
}
405-
406-
// -----
407-
408387
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
409388
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
410389
// CHECK-LABEL: @test_select_2d_one_dynamic

0 commit comments

Comments
 (0)