@@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
94
94
// CHECK: } -> tensor<f32>
95
95
%0 = tosa.add %arg0 , %arg1 : (tensor <f32 >, tensor <f32 >) -> tensor <f32 >
96
96
97
+
97
98
// CHECK: return [[RESULT]] : tensor<f32>
98
99
return %0 : tensor <f32 >
99
100
}
@@ -103,20 +104,20 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
103
104
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
104
105
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, d1)>
105
106
106
- // CHECK-LABEL: func.func @test_add_0d_broadcast (
107
+ // CHECK-LABEL: func.func @test_add_2d_broadcast (
107
108
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>,
108
- // CHECK-SAME: %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
109
- // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
109
+ // CHECK-SAME: %[[ARG1:.*]]: tensor<1x1xf32>) -> tensor<2x1xf32> {
110
110
// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
111
- // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED ]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
111
+ // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[ARG1 ]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
112
112
// CHECK: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
113
113
// CHECK: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
114
114
// CHECK: linalg.yield %[[ADD]] : f32
115
115
// CHECK: } -> tensor<2x1xf32>
116
116
// CHECK: return %[[RESULT]] : tensor<2x1xf32>
117
117
// CHECK: }
118
- func.func @test_add_0d_broadcast (%arg0: tensor <2 x1 xf32 >, %arg1: tensor <f32 >) -> tensor <2 x1 xf32 > {
119
- %0 = tosa.add %arg0 , %arg1 : (tensor <2 x1 xf32 >, tensor <f32 >) -> tensor <2 x1 xf32 >
118
+ func.func @test_add_2d_broadcast (%arg0: tensor <2 x1 xf32 >, %arg1: tensor <1 x1 xf32 >) -> tensor <2 x1 xf32 > {
119
+ // tosa element-wise operators now require operands of equal ranks
120
+ %0 = tosa.add %arg0 , %arg1 : (tensor <2 x1 xf32 >, tensor <1 x1 xf32 >) -> tensor <2 x1 xf32 >
120
121
return %0 : tensor <2 x1 xf32 >
121
122
}
122
123
@@ -383,23 +384,9 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
383
384
384
385
// -----
385
386
386
- // CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
387
- // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
388
- // CHECK-LABEL: @test_add_2d_different_ranks
389
- // CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
390
- // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
391
387
func.func @test_add_2d_different_ranks (%arg0: tensor <3 x4 xf32 >, %arg1: tensor <2 x3 x4 xf32 >) -> tensor <2 x3 x4 xf32 > {
392
-
393
- // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
394
- // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
395
- // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
396
- // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
397
- // CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
398
- // CHECK: linalg.yield %[[VAL_4]] : f32
399
- // CHECK: } -> tensor<2x3x4xf32>
400
- %0 = tosa.add %arg0 , %arg1 : (tensor <3 x4 xf32 >, tensor <2 x3 x4 xf32 >) -> tensor <2 x3 x4 xf32 >
401
-
402
- // CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
388
+ // expected-error@+1 {{'tosa.add' op operands don't have matching ranks}}
389
+ %0 = " tosa.add" (%arg0 , %arg1 ) : (tensor <3 x4 xf32 >, tensor <2 x3 x4 xf32 >) -> tensor <2 x3 x4 xf32 >
403
390
return %0 : tensor <2 x3 x4 xf32 >
404
391
}
405
392
0 commit comments