@@ -800,6 +800,40 @@ def aten_ops_tile(
800
800
)
801
801
802
802
803
+ def zero_output_validator (node : Node ) -> bool :
804
+ if 0 in node .args [1 ]:
805
+ _LOGGER .debug (
806
+ f"We do not support output tensor { node .args [1 ]} tensors with zero-sized dimensions for this operation."
807
+ )
808
+ return False
809
+ else :
810
+ return True
811
+
812
+
813
+ @dynamo_tensorrt_converter (
814
+ torch .ops .aten .as_strided .default ,
815
+ capability_validator = zero_output_validator ,
816
+ )
817
+ @dynamo_tensorrt_converter (torch .ops .aten .as_strided .default )
818
+ def aten_ops_as_strided (
819
+ ctx : ConversionContext ,
820
+ target : Target ,
821
+ args : Tuple [Argument , ...],
822
+ kwargs : Dict [str , Argument ],
823
+ name : str ,
824
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
825
+ return impl .slice .as_strided (
826
+ ctx ,
827
+ target ,
828
+ source_ir = SourceIR .ATEN ,
829
+ name = name ,
830
+ input = args [0 ],
831
+ size = args [1 ],
832
+ stride = args [2 ],
833
+ storage_offset = args_bounds_check (args , 3 , None ),
834
+ )
835
+
836
+
803
837
@dynamo_tensorrt_converter (torch .ops .aten .permute .default )
804
838
@enforce_tensor_types (
805
839
{
@@ -2185,7 +2219,6 @@ def aten_ops_linear(
2185
2219
bias = args_bounds_check (args , 2 , None ),
2186
2220
)
2187
2221
2188
-
2189
2222
@dynamo_tensorrt_converter (torch .ops .aten ._cdist_forward .default )
2190
2223
def aten_ops_cdist_forward (
2191
2224
ctx : ConversionContext ,
@@ -2206,39 +2239,6 @@ def aten_ops_cdist_forward(
2206
2239
)
2207
2240
2208
2241
2209
- def zero_output_validator (node : Node ) -> bool :
2210
- if 0 in node .args [1 ]:
2211
- _LOGGER .debug (
2212
- f"We do not support output tensor { node .args [1 ]} tensors with zero-sized dimensions for this operation."
2213
- )
2214
- return False
2215
- else :
2216
- return True
2217
-
2218
-
2219
- @dynamo_tensorrt_converter (
2220
- torch .ops .aten .as_strided .default ,
2221
- capability_validator = zero_output_validator ,
2222
- )
2223
- def aten_ops_as_strided (
2224
- ctx : ConversionContext ,
2225
- target : Target ,
2226
- args : Tuple [Argument , ...],
2227
- kwargs : Dict [str , Argument ],
2228
- name : str ,
2229
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
2230
- return impl .slice .as_strided (
2231
- ctx ,
2232
- target ,
2233
- source_ir = SourceIR .ATEN ,
2234
- name = name ,
2235
- input = args [0 ],
2236
- size = args [1 ],
2237
- stride = args [2 ],
2238
- storage_offset = args_bounds_check (args , 3 , None ),
2239
- )
2240
-
2241
-
2242
2242
def avg_pool_param_validator (pool_node : Node ) -> bool :
2243
2243
ceil_mode = args_bounds_check (pool_node .args , 4 , False )
2244
2244
divisor_override = args_bounds_check (pool_node .args , 6 )
0 commit comments