diff --git a/examples/dynamo/auto_generate_plugin.py b/examples/dynamo/auto_generate_plugin.py index 4290745bfd..fc31df8c58 100644 --- a/examples/dynamo/auto_generate_plugin.py +++ b/examples/dynamo/auto_generate_plugin.py @@ -108,14 +108,17 @@ def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Ten # # ------------------------------------------------------------------- # # Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation. # # As long as the namespace and names match, the following function will automatically generate the converter for the operation. +# # If plugins require an output allocator to dynamically allocate output buffers, like data dependent operators, please set requires_output_allocator to True. torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( - "torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True + "torchtrt_ex::elementwise_scale_mul", + supports_dynamic_shapes=True, + requires_output_allocator=False, ) # # %% # # Above two commands can be replaced with the following single one line: -# torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True) +# torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True, requires_output_allocator=False) # %% diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py b/py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py index ef5ed59a56..c936308fd5 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py @@ -14,6 +14,7 @@ def custom_op( capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, supports_dynamic_shapes: bool = False, + requires_output_allocator: bool = False, ) -> None: """ Generate the Plugin and corresponding Plugin Converter using external kernels and TensorRT Quick Deployable Plugin APIs. @@ -26,8 +27,13 @@ def custom_op( partitioner will make sure this Node is run in PyTorch in the compiled graph. priority: Allows developers to override existing converters in the converter registry supports_dynamic_shapes: if dynamic shape is supported + requires_output_allocator: if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) """ generate_plugin(op_name) generate_plugin_converter( - op_name, capability_validator, priority, supports_dynamic_shapes + op_name, + capability_validator, + priority, + supports_dynamic_shapes, + requires_output_allocator, ) diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index f6343fdb34..8b0e60881a 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -29,6 +29,7 @@ def _generate_plugin_converter( capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, supports_dynamic_shapes: bool = False, + requires_output_allocator: bool = False, ) -> DynamoConverterImplSignature: torch_target = getattr(getattr(torch.ops, namespace), op_name) overload_str = overload if overload else "" @@ -87,6 +88,7 @@ def custom_kernel_converter( capability_validator=capability_validator, priority=priority, supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, )(custom_kernel_converter) assert ( torch_overload in DYNAMO_CONVERTERS @@ -99,6 +101,7 @@ def generate_plugin_converter( capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, supports_dynamic_shapes: bool = False, + requires_output_allocator: bool = False, ) -> DynamoConverterImplSignature: plugin_ns, plugin_name = plugin_id.split("::") return _generate_plugin_converter( @@ -107,4 +110,5 @@ def generate_plugin_converter( capability_validator=capability_validator, priority=priority, supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, )