Skip to content

[TOSA] Add SameOperandsAndResultRank to TOSA Ops #104501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2025

Conversation

Tai78641
Copy link
Contributor

@Tai78641 Tai78641 commented Aug 15, 2024

[note: this is blocked by: https://github.com/tensorflow/tensorflow/pull/73891 otherwise tensorflow may have lit test failures]

This patch adds SameOperandsAndResultRank trait to TOSA operators with ResultsBroadcastableShape trait. SameOperandsAndResultRank trait requiring that all operands and results have matching ranks unless the operand/result is unranked.

This also renders the TosaMakeBroadcastable pass unnecessary - but this pass is left in for now just in case it is still used in some flows. The lit test, broadcast.mlir, is removed.

This also adds verify of the SameOperandsAndResultRank trait in the TosaInferShapes pass to validate inferred shapes.

@llvmbot
Copy link
Member

llvmbot commented Aug 15, 2024

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This patch adds SameOperandsAndResultRank trait to TOSA operators with ResultsBroadcastableShape trait. SameOperandsAndResultRank trait requiring that all operands and results have matching ranks unless the operand/result is unranked.

This also renders the TosaMakeBroadcastable pass unnecessary - but this pass is left in for now just in case it is still used in some flows. The lit test, broadcast.mlir, is removed.

This also adds verify of the SameOperandsAndResultRank trait in the TosaInferShapes pass to validate inferred shapes.


Patch is 32.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104501.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (+28)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+3-16)
  • (removed) mlir/test/Dialect/Tosa/broadcast.mlir (-285)
  • (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/inlining.mlir (+2-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+61-61)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..0542319c96f889 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -219,6 +219,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
               DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                                         ["inferReturnTypeComponents"]>,
               ResultsBroadcastableShape,
+              SameOperandsAndResultRank,
               Pure])> {
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index b1d5720541846f..816c8aebc4a644 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -285,6 +285,32 @@ void propagateShapesInRegion(Region &region, TypeModificationState &state) {
   }
 }
 
+/// recursively validate tosa ops with SameOperandsAndResultRank trait in region
+/// and all nested regions
+void validateSameOperandsAndResultRankTrait(Region &region) {
+  int errs = 0;
+  for (auto &block : region) {
+    for (auto &op : block) {
+      if (!op.getDialect() ||
+          op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+        continue;
+      if (op.hasTrait<OpTrait::SameOperandsAndResultRank>()) {
+        if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) {
+          errs++;
+        }
+      }
+      WhileOp whileOp = dyn_cast<WhileOp>(op);
+      IfOp ifOp = dyn_cast<IfOp>(op);
+      if (whileOp || ifOp) {
+        // recurse into whileOp's regions
+        for (auto &next : op.getRegions()) {
+          validateSameOperandsAndResultRankTrait(next);
+        }
+      }
+    }
+  }
+}
+
 /// Pass that performs shape propagation across TOSA operations. This includes
 /// migrating to within the regions of if/while operations.
 struct TosaInferShapes
@@ -295,6 +321,8 @@ struct TosaInferShapes
     TypeModificationState state;
     propagateShapesInRegion(func.getBody(), state);
     state.commit();
+
+    validateSameOperandsAndResultRankTrait(func.getBody());
   }
 };
 } // namespace
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 0e35f8ea9d0cd1..23b35f11aa02f7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
   // CHECK: } -> tensor<f32>
   %0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
 
+
   // CHECK: return [[RESULT]] : tensor<f32>
   return %0 : tensor<f32>
 }
@@ -341,23 +342,9 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
 
 // -----
 
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: @test_add_2d_different_ranks
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
 func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
-
-  // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
-  // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
-  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
-  // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
-  // CHECK:   %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
-  // CHECK:   linalg.yield %[[VAL_4]] : f32
-  // CHECK: } -> tensor<2x3x4xf32>
-  %0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
-
-  // CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
+  // expected-error@+1 {{'tosa.add' op operands don't have matching ranks}}
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
   return %0 : tensor<2x3x4xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
deleted file mode 100644
index 7613aa3b8dd03d..00000000000000
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ /dev/null
@@ -1,285 +0,0 @@
-// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s
-
-// -----
-// CHECK-LABEL: broadcast0
-func.func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
-  //  CHECK-NOT: reshape
-  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
-  return %0 : tensor<1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast1
-func.func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32>
-  return %0 : tensor<2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast2
-func.func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
-  return %0 : tensor<2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast3
-func.func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32>
-  return %0 : tensor<2x1x1x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast4
-func.func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32>
-  return %0 : tensor<1x1x1x2xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast5
-func.func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32>
-  return %0 : tensor<1x1x2x1xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast6
-func.func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast7
-func.func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32>
-  return %0 : tensor<17x16x1x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast8
-func.func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast9
-func.func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 15, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast10
-func.func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast13
-func.func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast14
-func.func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32>
-  return %0 : tensor<17x16x1x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast15
-func.func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast16
-func.func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast17
-func.func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast18
-func.func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> {
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32>
-  return %0 : tensor<14x15xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast19
-func.func @test_broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 17>}
-  // CHECK: %[[VAR1:.*]] = tosa.sub %arg0, %[[VAR0]]
-  %0 = tosa.sub %arg0, %arg1 : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32>
-  return %0 : tensor<64x64x17xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast20
-func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>) -> (tensor<3x3x4x5xf32> ) {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 4, 5>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]]
-  %0 = tosa.add %arg0, %arg1 : (tensor<3x3x4x1xf32>, tensor<4x5xf32>) -> tensor<3x3x4x5xf32>
-  return %0 : tensor<3x3x4x5xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_mul
-func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_arithmetic_right_shift
-func.func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.arithmetic_right_shift %[[VAR0]], %arg1
-  %0 = tosa.arithmetic_right_shift %arg0, %arg1 { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_scalar
-func.func @test_broadcast_scalar(%arg0: tensor<i32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1
-  %0 = tosa.add %arg0, %arg1 : (tensor<i32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_both_input
-func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<1x16x16xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x16x16xi1>, tensor<f32>, tensor<f32>) -> tensor<1x16x16xf32>
-  return %0 : tensor<1x16x16xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_one_input
-func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor<f32>) -> tensor<17x16x15x14xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_1:.*]] = tosa.select %arg0, %arg1, %[[VAL_0]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor<f32>) -> tensor<17x16x15x14xf32>
-  return %0 : tensor<17x16x15x14xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_predicate
-func.func @test_broadcast_select_predicate(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_1:.*]] = tosa.select %[[VAL_0]], %arg1, %arg2
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_abc
-func.func @test_broadcast_select_abc(%arg0: tensor<i1>, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_acb
-func.func @test_broadcast_select_acb(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_bac
-func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor<f32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_bca
-func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_cab
-func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor<f32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
-
-// -----
-// CHECK-LABEL: broadcast_select_cba
-func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
-  // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 32, 8>}
-  // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
-  // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]]
-  %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
-  return %0 : tensor<1x32x32x8xf32>
-}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 869bce08a8c729..3ff3121348fcad 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -15,9 +15,9 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
 }
 
 // CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
-func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<i32>) {
+func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
   // CHECK: tosa.equal
   // CHECK-NEXT: return
-  %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+  %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
   return
 }
diff --git a/mlir/test/Dialect/Tosa/inlining.mlir b/mlir/test/Dialect/Tosa/inlining.mlir
index d57b5cbcf475c8..e892fdaa277505 100644
--- a/mlir/test/Dialect/Tosa/inlining.mlir
+++ b/mlir/test/Dialect/Tosa/inlining.mlir
@@ -47,7 +47,8 @@ func.func @inlined_while_fn(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tenso
 }
 func.func pri...
[truncated]

@@ -1,285 +0,0 @@
// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the TosaMakeBroadcastable pass still needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pass will be an no-op because all operands will have same ranks already,
pass has not be removed in case it is still used in some flows

@@ -303,6 +303,32 @@ void propagateShapesInRegion(Region &region, TypeModificationState &state) {
}
}

/// recursively validate tosa ops with SameOperandsAndResultRank trait in region
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Start comment with capital

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

%0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>

// CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
// expected-error@+1 {{'tosa.add' op operands don't have matching ranks}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need to move to tosa-to-linalg-invalid ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to tosa-to-linalg-invalid

This patch adds SameOperandsAndResultRank trait to TOSA operators
with ResultsBroadcastableShape trait. SameOperandsAndResultRank trait
requiring that all operands and results have matching ranks unless
the operand/result is unranked.

This also renders the TosaMakeBroadcastable pass unnecessary - but
this pass is left in for now just in case it is still used in some
flows. The lit test, broadcast.mlir, is removed.

This also adds verify of the SameOperandsAndResultRank trait in the
TosaInferShapes pass to validate inferred shapes.

Signed-off-by: Tai Ly <[email protected]>
Change-Id: I27bf16b31f15aa92d42ad5376b8791cf74e4f6ac
@GeorgeARM GeorgeARM merged commit 729f958 into llvm:main Jan 22, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants