@@ -979,28 +979,28 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
979
979
980
980
981
981
// ----
982
- // CHECK-LABEL: func @aten_matmul_broadcast_matrix (
982
+ // CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix (
983
983
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
984
984
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>)
985
985
// CHECK-SAME: -> !torch.tensor {
986
- // CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
987
- // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
988
- // CHECK: return %[[CAST]] : !torch.tensor
989
- func @aten_matmul_broadcast_matrix (%arg0: !torch.vtensor <[?,?,?,?,?],f32 >, %arg1: !torch.vtensor <[?,?,?],f32 >) -> !torch.tensor {
986
+ // CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
987
+ // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
988
+ // CHECK: return %[[CAST]] : !torch.tensor
989
+ func @torch.aten.Matmul.Broadcast.Matrix (%arg0: !torch.vtensor <[?,?,?,?,?],f32 >, %arg1: !torch.vtensor <[?,?,?],f32 >) -> !torch.tensor {
990
990
%0 = torch.aten.matmul %arg0 , %arg1 : !torch.vtensor <[?,?,?,?,?],f32 >, !torch.vtensor <[?,?,?],f32 > -> !torch.tensor
991
991
return %0 : !torch.tensor
992
992
}
993
993
994
994
995
995
// ----
996
- // CHECK-LABEL: func @aten_matmul_broadcast_vector (
996
+ // CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector (
997
997
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
998
998
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>)
999
999
// CHECK-SAME: -> !torch.tensor {
1000
- // CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
1001
- // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
1002
- // CHECK: return %[[CAST]] : !torch.tensor
1003
- func @aten_matmul_broadcast_vector (%arg0: !torch.vtensor <[?,?,?,?,?],f32 >, %arg1: !torch.vtensor <[?],f32 >) -> !torch.tensor {
1000
+ // CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
1001
+ // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
1002
+ // CHECK: return %[[CAST]] : !torch.tensor
1003
+ func @torch.aten.Matmul.Broadcast.Vector (%arg0: !torch.vtensor <[?,?,?,?,?],f32 >, %arg1: !torch.vtensor <[?],f32 >) -> !torch.tensor {
1004
1004
%0 = torch.aten.matmul %arg0 , %arg1 : !torch.vtensor <[?,?,?,?,?],f32 >, !torch.vtensor <[?],f32 > -> !torch.tensor
1005
1005
return %0 : !torch.tensor
1006
1006
}
@@ -1022,3 +1022,16 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
1022
1022
%0 = torch.aten.to.dtype %arg0 , %int4 , %false , %false , %none : !torch.tensor <[?,?],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none -> !torch.tensor
1023
1023
return %0 : !torch.tensor
1024
1024
}
1025
+
1026
+ // ----
1027
+ // CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
1028
+ // CHECK-SAME: %[[SELF:.*]]: !torch.int)
1029
+ // CHECK-SAME: -> !torch.tensor {
1030
+ // CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
1031
+ // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor
1032
+ // CHECK: return %[[CAST]] : !torch.tensor
1033
+
1034
+ func @torch.prim.NumToTensor.Scalar (%arg0: !torch.int ) -> !torch.tensor {
1035
+ %0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
1036
+ return %0: !torch.tensor
1037
+ }
0 commit comments