@@ -104,15 +104,15 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
104
104
105
105
// -----
106
106
// CHECK-LABEL: conv3d
107
- func.func @test_conv3d (%arg0: tensor <1 x4 x8 x21 x17 xf32 >, %arg1: tensor <34 x1 x1 x1 x17 xf32 >, %arg2: tensor <21 x f32 >, %arg3: tensor <1 xf32 >, %arg4: tensor <1 xf32 >) -> tensor <1 x4 x8 x21 x34 xf32 > {
108
- %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f32 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf32 >, tensor <34 x1 x1 x1 x17 xf32 >, tensor <21 x f32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <1 x4 x8 x21 x34 xf32 >
107
+ func.func @test_conv3d (%arg0: tensor <1 x4 x8 x21 x17 xf32 >, %arg1: tensor <34 x1 x1 x1 x17 xf32 >, %arg2: tensor <34 x f32 >, %arg3: tensor <1 xf32 >, %arg4: tensor <1 xf32 >) -> tensor <1 x4 x8 x21 x34 xf32 > {
108
+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f32 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf32 >, tensor <34 x1 x1 x1 x17 xf32 >, tensor <34 x f32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <1 x4 x8 x21 x34 xf32 >
109
109
return %0 : tensor <1 x4 x8 x21 x34 xf32 >
110
110
}
111
111
112
112
// -----
113
113
// CHECK-LABEL: conv3d_with_local_bound
114
- func.func @test_conv3d_with_local_bound (%arg0: tensor <1 x4 x8 x21 x17 xf32 >, %arg1: tensor <34 x1 x1 x1 x17 xf32 >, %arg2: tensor <21 x f32 >, %arg3: tensor <1 xf32 >, %arg4: tensor <1 xf32 >) -> tensor <1 x4 x8 x21 x34 xf32 > {
115
- %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f32 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >, local_bound = true } : (tensor <1 x4 x8 x21 x17 xf32 >, tensor <34 x1 x1 x1 x17 xf32 >, tensor <21 x f32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <1 x4 x8 x21 x34 xf32 >
114
+ func.func @test_conv3d_with_local_bound (%arg0: tensor <1 x4 x8 x21 x17 xf32 >, %arg1: tensor <34 x1 x1 x1 x17 xf32 >, %arg2: tensor <34 x f32 >, %arg3: tensor <1 xf32 >, %arg4: tensor <1 xf32 >) -> tensor <1 x4 x8 x21 x34 xf32 > {
115
+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f32 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >, local_bound = true } : (tensor <1 x4 x8 x21 x17 xf32 >, tensor <34 x1 x1 x1 x17 xf32 >, tensor <34 x f32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <1 x4 x8 x21 x34 xf32 >
116
116
return %0 : tensor <1 x4 x8 x21 x34 xf32 >
117
117
}
118
118
@@ -823,8 +823,8 @@ func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1
823
823
824
824
// -----
825
825
// CHECK-LABEL: conv3d_f8E5M2
826
- func.func @test_conv3d_f8E5M2 (%arg0: tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, %arg1: tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, %arg2: tensor <21 x f16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 > {
827
- %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, tensor <21 x f16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 >
826
+ func.func @test_conv3d_f8E5M2 (%arg0: tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, %arg1: tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, %arg2: tensor <34 x f16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 > {
827
+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, tensor <34 x f16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 >
828
828
return %0 : tensor <1 x4 x8 x21 x34 xf16 >
829
829
}
830
830
@@ -968,8 +968,8 @@ func.func @test_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<8
968
968
969
969
// -----
970
970
// CHECK-LABEL: conv3d_f8E4M3FN
971
- func.func @test_conv3d_f8E4M3FN (%arg0: tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, %arg1: tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, %arg2: tensor <21 x f16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 > {
972
- %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, tensor <21 x f16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 >
971
+ func.func @test_conv3d_f8E4M3FN (%arg0: tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, %arg1: tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, %arg2: tensor <34 x f16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 > {
972
+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, tensor <34 x f16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 >
973
973
return %0 : tensor <1 x4 x8 x21 x34 xf16 >
974
974
}
975
975
0 commit comments