diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 47cda3c9f481e..4975530a9588c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -231,7 +231,7 @@ def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> { //===----------------------------------------------------------------------===// class Tosa_Op traits = []> : - Op { } @@ -241,6 +241,7 @@ class Tosa_ElementwiseOp traits = []> : ["inferReturnTypeComponents"]>, ResultsBroadcastableShape, TosaElementwiseOperator, + SameOperandsAndResultRank, Pure])> { let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp index d08d5fea66310..3c1ca6aac9096 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -303,6 +303,32 @@ void propagateShapesInRegion(Region ®ion, TypeModificationState &state) { } } +/// Recursively validate tosa ops with SameOperandsAndResultRank trait in region +/// and all nested regions +void validateSameOperandsAndResultRankTrait(Region ®ion) { + int errs = 0; + for (auto &block : region) { + for (auto &op : block) { + if (!op.getDialect() || + op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) + continue; + if (op.hasTrait()) { + if (OpTrait::impl::verifySameOperandsAndResultRank(&op).failed()) { + errs++; + } + } + WhileOp whileOp = dyn_cast(op); + IfOp ifOp = dyn_cast(op); + if (whileOp || ifOp) { + // recurse into whileOp's regions + for (auto &next : op.getRegions()) { + validateSameOperandsAndResultRankTrait(next); + } + } + } + } +} + /// Pass that performs shape propagation across TOSA operations. This includes /// migrating to within the regions of if/while operations. struct TosaInferShapes @@ -313,6 +339,8 @@ struct TosaInferShapes TypeModificationState state; propagateShapesInRegion(func.getBody(), state); state.commit(); + + validateSameOperandsAndResultRankTrait(func.getBody()); } }; } // namespace diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir index ea1b79cbd9507..75b48f2b06d89 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir @@ -45,3 +45,11 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> return %0 : tensor<13x21x3x!quant.uniform> } + +// ----- + +func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + // expected-error@+1 {{'tosa.add' op operands don't have matching ranks}} + %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + return %0 : tensor<2x3x4xf32> +} diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index e2c95ba7a0c6b..f860dca85c9e9 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -94,6 +94,7 @@ func.func @test_add_0d(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: } -> tensor %0 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: return [[RESULT]] : tensor return %0 : tensor } @@ -103,20 +104,20 @@ func.func @test_add_0d(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, d1)> -// CHECK-LABEL: func.func @test_add_0d_broadcast( +// CHECK-LABEL: func.func @test_add_2d_broadcast( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>, -// CHECK-SAME: %[[ARG1:.*]]: tensor) -> tensor<2x1xf32> { -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor into tensor<1x1xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1xf32>) -> tensor<2x1xf32> { // CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32> -// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) { +// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) { // CHECK: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): // CHECK: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32 // CHECK: linalg.yield %[[ADD]] : f32 // CHECK: } -> tensor<2x1xf32> // CHECK: return %[[RESULT]] : tensor<2x1xf32> // CHECK: } -func.func @test_add_0d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor) -> tensor<2x1xf32> { - %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor) -> tensor<2x1xf32> +func.func @test_add_2d_broadcast(%arg0: tensor<2x1xf32>, %arg1: tensor<1x1xf32>) -> tensor<2x1xf32> { + // tosa element-wise operators now require operands of equal ranks + %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1x1xf32>) -> tensor<2x1xf32> return %0 : tensor<2x1xf32> } @@ -383,28 +384,6 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor, %arg1: tensor (0, d1, d2)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: @test_add_2d_different_ranks -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: -// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: -func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { - - // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32> - // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32> - // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) { - // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32): - // CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32 - // CHECK: linalg.yield %[[VAL_4]] : f32 - // CHECK: } -> tensor<2x3x4xf32> - %0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> - - // CHECK: return %[[RESULT]] : tensor<2x3x4xf32> - return %0 : tensor<2x3x4xf32> -} - -// ----- - // CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)> // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @test_select_2d_one_dynamic diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir deleted file mode 100644 index 7613aa3b8dd03..0000000000000 --- a/mlir/test/Dialect/Tosa/broadcast.mlir +++ /dev/null @@ -1,285 +0,0 @@ -// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s - -// ----- -// CHECK-LABEL: broadcast0 -func.func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { - // CHECK-NOT: reshape - %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> - return %0 : tensor<1xf32> -} - -// ----- -// CHECK-LABEL: broadcast1 -func.func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1 - %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32> - return %0 : tensor<2x1xf32> -} - -// ----- -// CHECK-LABEL: broadcast2 -func.func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]] - %0 = tosa.add %arg0, %arg1 : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32> - return %0 : tensor<2x1xf32> -} - -// ----- -// CHECK-LABEL: broadcast3 -func.func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]] - %0 = tosa.add %arg0, %arg1 : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32> - return %0 : tensor<2x1x1x1xf32> -} - -// ----- -// CHECK-LABEL: broadcast4 -func.func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]] - %0 = tosa.add %arg0, %arg1 : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32> - return %0 : tensor<1x1x1x2xf32> -} - -// ----- -// CHECK-LABEL: broadcast5 -func.func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]] - %0 = tosa.add %arg0, %arg1 : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32> - return %0 : tensor<1x1x2x1xf32> -} - -// ----- -// CHECK-LABEL: broadcast6 -func.func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]] - %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32> - return %0 : tensor<17x16x15x14xf32> -} - -// ----- -// CHECK-LABEL: broadcast7 -func.func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]] - %0 = tosa.add %arg0, %arg1 : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32> - return %0 : tensor<17x16x1x14xf32> -} - -// ----- -// CHECK-LABEL: broadcast8 -func.func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]] - %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32> - return %0 : tensor<17x16x15x14xf32> -} - -// ----- -// CHECK-LABEL: broadcast9 -func.func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]] - %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32> - return %0 : tensor<17x16x15x14xf32> -} - -// ----- -// CHECK-LABEL: broadcast10 -func.func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]] - %0 = tosa.add %arg0, %arg1 : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32> - return %0 : tensor<17x16x15x14xf32> -} - -// ----- -// CHECK-LABEL: broadcast13 -func.func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1 - %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> - return %0 : tensor<17x16x15x14xf32> -} - -// ----- -// CHECK-LABEL: broadcast14 -func.func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1 - %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> - return %0 : tensor<17x16x1x14xf32> -} - -// ----- -// CHECK-LABEL: broadcast15 -func.func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1 - %0 = tosa.add %arg0, %arg1 : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> - return %0 : tensor<17x16x15x14xf32> -} - -// ----- -// CHECK-LABEL: broadcast16 -func.func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1 - %0 = tosa.add %arg0, %arg1 : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> - return %0 : tensor<17x16x15x14xf32> -} - -// ----- -// CHECK-LABEL: broadcast17 -func.func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1 - %0 = tosa.add %arg0, %arg1 : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> - return %0 : tensor<17x16x15x14xf32> -} - -// ----- -// CHECK-LABEL: broadcast18 -func.func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> { - // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %arg1 - %0 = tosa.add %arg0, %arg1 : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32> - return %0 : tensor<14x15xf32> -} - -// ----- -// CHECK-LABEL: broadcast19 -func.func @test_broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.sub %arg0, %[[VAR0]] - %0 = tosa.sub %arg0, %arg1 : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32> - return %0 : tensor<64x64x17xf32> -} - -// ----- -// CHECK-LABEL: broadcast20 -func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>) -> (tensor<3x3x4x5xf32> ) { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %arg0, %[[VAR0]] - %0 = tosa.add %arg0, %arg1 : (tensor<3x3x4x1xf32>, tensor<4x5xf32>) -> tensor<3x3x4x5xf32> - return %0 : tensor<3x3x4x5xf32> -} - -// ----- -// CHECK-LABEL: broadcast_mul -func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1 - %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> - return %0 : tensor<17x16x15x14xi32> -} - -// ----- -// CHECK-LABEL: broadcast_arithmetic_right_shift -func.func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.arithmetic_right_shift %[[VAR0]], %arg1 - %0 = tosa.arithmetic_right_shift %arg0, %arg1 { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> - return %0 : tensor<17x16x15x14xi32> -} - -// ----- -// CHECK-LABEL: broadcast_scalar -func.func @test_broadcast_scalar(%arg0: tensor, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { - // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK: %[[VAR1:.*]] = tosa.add %[[VAR0]], %arg1 - %0 = tosa.add %arg0, %arg1 : (tensor, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> - return %0 : tensor<17x16x15x14xi32> -} - -// ----- -// CHECK-LABEL: broadcast_select_both_input -func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor, %arg2: tensor) -> tensor<1x16x16xf32> { - // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array} - // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]] - %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x16x16xi1>, tensor, tensor) -> tensor<1x16x16xf32> - return %0 : tensor<1x16x16xf32> -} - -// ----- -// CHECK-LABEL: broadcast_select_one_input -func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor) -> tensor<17x16x15x14xf32> { - // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg2 {new_shape = array} - // CHECK: %[[VAL_1:.*]] = tosa.select %arg0, %arg1, %[[VAL_0]] - %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor) -> tensor<17x16x15x14xf32> - return %0 : tensor<17x16x15x14xf32> -} - -// ----- -// CHECK-LABEL: broadcast_select_predicate -func.func @test_broadcast_select_predicate(%arg0: tensor, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { - // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK: %[[VAL_1:.*]] = tosa.select %[[VAL_0]], %arg1, %arg2 - %0 = tosa.select %arg0, %arg1, %arg2 : (tensor, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> - return %0 : tensor<1x32x32x8xf32> -} - -// ----- -// CHECK-LABEL: broadcast_select_abc -func.func @test_broadcast_select_abc(%arg0: tensor, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { - // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2 - %0 = tosa.select %arg0, %arg1, %arg2 : (tensor, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> - return %0 : tensor<1x32x32x8xf32> -} - -// ----- -// CHECK-LABEL: broadcast_select_acb -func.func @test_broadcast_select_acb(%arg0: tensor, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> { - // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array} - // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]] - %0 = tosa.select %arg0, %arg1, %arg2 : (tensor, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32> - return %0 : tensor<1x32x32x8xf32> -} - -// ----- -// CHECK-LABEL: broadcast_select_bac -func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { - // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %[[VAL_1]], %arg2 - %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> - return %0 : tensor<1x32x32x8xf32> -} - -// ----- -// CHECK-LABEL: broadcast_select_bca -func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor) -> tensor<1x32x32x8xf32> { - // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg0 {new_shape = array} - // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array} - // CHECK: %[[VAL_2:.*]] = tosa.select %[[VAL_0]], %arg1, %[[VAL_1]] - %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor) -> tensor<1x32x32x8xf32> - return %0 : tensor<1x32x32x8xf32> -} - -// ----- -// CHECK-LABEL: broadcast_select_cab -func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> { - // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array} - // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]] - %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor, tensor<32x8xf32>) -> tensor<1x32x32x8xf32> - return %0 : tensor<1x32x32x8xf32> -} - -// ----- -// CHECK-LABEL: broadcast_select_cba -func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor) -> tensor<1x32x32x8xf32> { - // CHECK-DAG: %[[VAL_0:.*]] = tosa.reshape %arg1 {new_shape = array} - // CHECK-DAG: %[[VAL_1:.*]] = tosa.reshape %arg2 {new_shape = array} - // CHECK: %[[VAL_2:.*]] = tosa.select %arg0, %[[VAL_0]], %[[VAL_1]] - %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor) -> tensor<1x32x32x8xf32> - return %0 : tensor<1x32x32x8xf32> -} diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir index 869bce08a8c72..3ff3121348fca 100644 --- a/mlir/test/Dialect/Tosa/constant_folding.mlir +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -15,9 +15,9 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> { } // CHECK-LABEL: func @try_fold_equal_with_unranked_tensor -func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor) { +func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) { // CHECK: tosa.equal // CHECK-NEXT: return - %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<*xi1> + %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> return } diff --git a/mlir/test/Dialect/Tosa/inlining.mlir b/mlir/test/Dialect/Tosa/inlining.mlir index d57b5cbcf475c..e892fdaa27750 100644 --- a/mlir/test/Dialect/Tosa/inlining.mlir +++ b/mlir/test/Dialect/Tosa/inlining.mlir @@ -47,7 +47,8 @@ func.func @inlined_while_fn(%arg0: tensor, %arg1: tensor, %arg2: tenso } func.func private @while_body_50(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) { %1 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - %2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor) -> tensor<10xi32> + %3 = "tosa.reshape"(%1) {new_shape = array} : (tensor) -> tensor<1xi32> + %2 = "tosa.add"(%arg3, %3) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32> return %1, %arg1, %arg2, %2: tensor, tensor, tensor, tensor<10xi32> } func.func private @while_cond_40(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> tensor { diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index cc7fd009f01fa..deaa8e2442337 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1018,3 +1018,19 @@ func.func @test_const_shape_value() -> !tosa.shape<5> { %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<5> return %cst : !tosa.shape<5> } + +// ----- + +func.func @test_sub_with_unequal_operand_ranks(%arg0: tensor<1x21x3xf32>, %arg1: tensor<1x13x21x3xf32>) -> tensor<1x13x21x3xf32> { + // expected-error@+1 {{'tosa.sub' op operands don't have matching ranks}} + %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<1x13x21x3xf32>) -> tensor<1x13x21x3xf32> + return %0 : tensor<1x13x21x3xf32> +} + +// ----- + +func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> { + // expected-error@+1 {{'tosa.sub' op result type has different rank than operands}} + %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> + return %0 : tensor<1x13x21x3xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index f4da66ef561b2..8ab7284019f96 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -11,15 +11,15 @@ func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: @test_multiple -func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor) -> tensor<*xf32> { +func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<1xf32>) -> tensor<*xf32> { // CHECK: [[ADD:%.+]] = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> %0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: [[LOG:%.+]] = tosa.log %0 : (tensor<4xf32>) -> tensor<4xf32> %1 = tosa.log %0 : (tensor<*xf32>) -> tensor<*xf32> - // CHECK: [[SUB:%.+]] = tosa.sub %0, %arg2 : (tensor<4xf32>, tensor) -> tensor<4xf32> - %2 = tosa.sub %0, %arg2 : (tensor<*xf32>, tensor) -> tensor<*xf32> + // CHECK: [[SUB:%.+]] = tosa.sub %0, %arg2 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %2 = tosa.sub %0, %arg2 : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -104,33 +104,33 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () { // ----- // CHECK-LABEL: @test_binary_scalar_f32 -func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor) -> () { - // CHECK: tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<4xf32> - %0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<*xf32> +func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) -> () { + // CHECK: tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<4xf32> - %1 = tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<*xf32> + // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %1 = tosa.maximum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<4xf32> - %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<*xf32> + // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor) -> tensor<4xf32> - %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor) -> tensor<*xf32> + // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<4xf32> - %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<*xf32> + // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<4xf32> - %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<*xf32> + // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<4xi1> - %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<*xi1> + // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> - // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<4xi1> - %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<*xi1> + // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> - // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<4xi1> - %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor) -> tensor<*xi1> + // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> return } @@ -172,48 +172,48 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32 // ----- // CHECK-LABEL: @test_binary_i32 -func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor) -> () { - // CHECK: tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<4xi32> - %0 = tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<*xi32> +func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<1xi32>) -> () { + // CHECK: tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %0 = tosa.add %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: tosa.bitwise_and %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<4xi32> - %1 = tosa.bitwise_and %arg0, %arg1: (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: tosa.bitwise_and %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %1 = tosa.bitwise_and %arg0, %arg1: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<4xi32> - %2 = tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %2 = tosa.bitwise_or %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: tosa.bitwise_xor %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<4xi32> - %3 = tosa.bitwise_xor %arg0, %arg1: (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: tosa.bitwise_xor %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %3 = tosa.bitwise_xor %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<4xi1> - %4 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<*xi1> + // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1> + %4 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> - // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<4xi1> - %5 = tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<*xi1> + // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1> + %5 = tosa.greater %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> - // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<4xi1> - %6 = tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<*xi1> + // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi1> + %6 = tosa.greater_equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> - // CHECK: tosa.logical_left_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> - %7 = tosa.logical_left_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: tosa.logical_left_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %7 = tosa.logical_left_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: tosa.logical_right_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> - %8 = tosa.logical_right_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: tosa.logical_right_shift %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %8 = tosa.logical_right_shift %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<4xi32> - %9 = tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %9 = tosa.maximum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<4xi32> - %10 = tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %10 = tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xi32>, tensor) -> tensor<4xi32> - %11 = tosa.mul %arg0, %arg1 { shift = 0 : i8 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %11 = tosa.mul %arg0, %arg1 { shift = 0 : i8 }: (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<4xi32> - %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> - // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<4xi32> - %13 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor) -> tensor<*xi32> + // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + %13 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> return } @@ -221,15 +221,15 @@ func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor) -> () { // ----- // CHECK-LABEL: @test_binary_i1 -func.func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor) -> () { - // CHECK: tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor) -> tensor<4xi1> - %0 = tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor) -> tensor<*xi1> +func.func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor<1xi1>) -> () { + // CHECK: tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<4xi1> + %0 = tosa.logical_and %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<*xi1> - // CHECK: tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor) -> tensor<4xi1> - %1 = tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor) -> tensor<*xi1> + // CHECK: tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<4xi1> + %1 = tosa.logical_or %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<*xi1> - // CHECK: tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor) -> tensor<4xi1> - %2 = tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor) -> tensor<*xi1> + // CHECK: tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<4xi1> + %2 = tosa.logical_xor %arg0, %arg1 : (tensor<4xi1>, tensor<1xi1>) -> tensor<*xi1> return } @@ -237,9 +237,9 @@ func.func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor) -> () { // ----- // CHECK-LABEL: @test_select_i32 -func.func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor, %arg2 : tensor<4xi32>) -> () { - // CHECK: tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<4xi32> - %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<*xi32> +func.func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor<1xi32>, %arg2 : tensor<4xi32>) -> () { + // CHECK: tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor<1xi32>, tensor<4xi32>) -> tensor<4xi32> + %0 = tosa.select %arg0, %arg1, %arg2 : (tensor<4xi1>, tensor<1xi32>, tensor<4xi32>) -> tensor<*xi32> return }