Skip to content

Commit d344168

Browse files
committed
chore: move functions to organize code better
1 parent b4fce3e commit d344168

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,40 @@ def aten_ops_tile(
800800
)
801801

802802

803+
def zero_output_validator(node: Node) -> bool:
804+
if 0 in node.args[1]:
805+
_LOGGER.debug(
806+
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
807+
)
808+
return False
809+
else:
810+
return True
811+
812+
813+
@dynamo_tensorrt_converter(
814+
torch.ops.aten.as_strided.default,
815+
capability_validator=zero_output_validator,
816+
)
817+
@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default)
818+
def aten_ops_as_strided(
819+
ctx: ConversionContext,
820+
target: Target,
821+
args: Tuple[Argument, ...],
822+
kwargs: Dict[str, Argument],
823+
name: str,
824+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
825+
return impl.slice.as_strided(
826+
ctx,
827+
target,
828+
source_ir=SourceIR.ATEN,
829+
name=name,
830+
input=args[0],
831+
size=args[1],
832+
stride=args[2],
833+
storage_offset=args_bounds_check(args, 3, None),
834+
)
835+
836+
803837
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
804838
@enforce_tensor_types(
805839
{
@@ -2185,7 +2219,6 @@ def aten_ops_linear(
21852219
bias=args_bounds_check(args, 2, None),
21862220
)
21872221

2188-
21892222
@dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default)
21902223
def aten_ops_cdist_forward(
21912224
ctx: ConversionContext,
@@ -2206,39 +2239,6 @@ def aten_ops_cdist_forward(
22062239
)
22072240

22082241

2209-
def zero_output_validator(node: Node) -> bool:
2210-
if 0 in node.args[1]:
2211-
_LOGGER.debug(
2212-
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
2213-
)
2214-
return False
2215-
else:
2216-
return True
2217-
2218-
2219-
@dynamo_tensorrt_converter(
2220-
torch.ops.aten.as_strided.default,
2221-
capability_validator=zero_output_validator,
2222-
)
2223-
def aten_ops_as_strided(
2224-
ctx: ConversionContext,
2225-
target: Target,
2226-
args: Tuple[Argument, ...],
2227-
kwargs: Dict[str, Argument],
2228-
name: str,
2229-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2230-
return impl.slice.as_strided(
2231-
ctx,
2232-
target,
2233-
source_ir=SourceIR.ATEN,
2234-
name=name,
2235-
input=args[0],
2236-
size=args[1],
2237-
stride=args[2],
2238-
storage_offset=args_bounds_check(args, 3, None),
2239-
)
2240-
2241-
22422242
def avg_pool_param_validator(pool_node: Node) -> bool:
22432243
ceil_mode = args_bounds_check(pool_node.args, 4, False)
22442244
divisor_override = args_bounds_check(pool_node.args, 6)

0 commit comments

Comments
 (0)