Skip to content

Commit 4e68f23

Browse files
committed
[TOSA] Add SameOperandsAndResultRank to TOSA Ops
This patch adds SameOperandsAndResultRank trait to TOSA operators with ResultsBroadcastableShape trait. SameOperandsAndResultRank trait requiring that all operands and results have matching ranks unless the operand/result is unranked. This also renders the TosaMakeBroadcastable pass unnecessary - but this pass is left in for now just in case it is still used in some flows. The lit test, broadcast.mlir, is removed. This also adds verify of the SameOperandsAndResultRank trait in the TosaInferShapes pass to validate inferred shapes. Signed-off-by: Tai Ly <[email protected]> Change-Id: I27bf16b31f15aa92d42ad5376b8791cf74e4f6ac
1 parent 2adc012 commit 4e68f23

File tree

7 files changed

+97
-365
lines changed

7 files changed

+97
-365
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
219219
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
220220
["inferReturnTypeComponents"]>,
221221
ResultsBroadcastableShape,
222+
SameOperandsAndResultRank,
222223
Pure])> {
223224
let assemblyFormat =
224225
"operands attr-dict `:` functional-type(operands, results)";

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,32 @@ void propagateShapesInRegion(Region &region, TypeModificationState &state) {
285285
}
286286
}
287287

288+
/// recursively validate tosa ops with SameOperandsAndResultRank trait in region
289+
/// and all nested regions
290+
void validateSameOperandsAndResultRankTrait(Region &region) {
291+
int errs = 0;
292+
for (auto &block : region) {
293+
for (auto &op : block) {
294+
if (!op.getDialect() ||
295+
op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
296+
continue;
297+
if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
298+
if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
299+
errs++;
300+
}
301+
}
302+
WhileOp whileOp = dyn_cast<WhileOp>(op);
303+
IfOp ifOp = dyn_cast<IfOp>(op);
304+
if (whileOp || ifOp) {
305+
// recurse into whileOp's regions
306+
for (auto &next : op.getRegions()) {
307+
validateSameOperandsAndResultRankTrait(next);
308+
}
309+
}
310+
}
311+
}
312+
}
313+
288314
/// Pass that performs shape propagation across TOSA operations. This includes
289315
/// migrating to within the regions of if/while operations.
290316
struct TosaInferShapes
@@ -295,6 +321,8 @@ struct TosaInferShapes
295321
TypeModificationState state;
296322
propagateShapesInRegion(func.getBody(), state);
297323
state.commit();
324+
325+
validateSameOperandsAndResultRankTrait(func.getBody());
298326
}
299327
};
300328
} // namespace

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
9494
// CHECK: } -> tensor<f32>
9595
%0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
9696

97+
9798
// CHECK: return [[RESULT]] : tensor<f32>
9899
return %0 : tensor<f32>
99100
}
@@ -341,23 +342,9 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
341342

342343
// -----
343344

344-
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
345-
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
346-
// CHECK-LABEL: @test_add_2d_different_ranks
347-
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
348-
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
349345
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
350-
351-
// CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
352-
// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
353-
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
354-
// CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
355-
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
356-
// CHECK: linalg.yield %[[VAL_4]] : f32
357-
// CHECK: } -> tensor<2x3x4xf32>
358-
%0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
359-
360-
// CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
346+
// expected-error@+1 {{'tosa.add' op operands don't have matching ranks}}
347+
%0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
361348
return %0 : tensor<2x3x4xf32>
362349
}
363350

mlir/test/Dialect/Tosa/broadcast.mlir

Lines changed: 0 additions & 285 deletions
This file was deleted.

mlir/test/Dialect/Tosa/constant_folding.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
1515
}
1616

1717
// CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
18-
func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<i32>) {
18+
func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
1919
// CHECK: tosa.equal
2020
// CHECK-NEXT: return
21-
%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
21+
%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
2222
return
2323
}

mlir/test/Dialect/Tosa/inlining.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ func.func @inlined_while_fn(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tenso
4747
}
4848
func.func private @while_body_50(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>) {
4949
%1 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
50-
%2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor<i32>) -> tensor<10xi32>
50+
%3 = "tosa.reshape"(%1) {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
51+
%2 = "tosa.add"(%arg3, %3) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
5152
return %1, %arg1, %arg2, %2: tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>
5253
}
5354
func.func private @while_cond_40(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> tensor<i1> {

0 commit comments

Comments
 (0)