diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 0846dec144..6247373b1f 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -46,6 +46,8 @@ def compile( torch_executed_modules=[], **kwargs, ): + if debug: + logger.setLevel(logging.DEBUG) logger.warn( "The Dynamo backend is an experimental feature, for which only the " diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 962cbe8eba..4c2c5fdcc4 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -52,6 +52,7 @@ def aot_torch_tensorrt_aten_backend( ) +@fake_tensor_unsupported def _pretraced_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], @@ -120,9 +121,7 @@ def _compile_module( trt_mod = convert_module( submodule, submodule_inputs, - debug=settings.debug, - workspace_size=settings.workspace_size, - precision=settings.precision, + settings=settings, ) # Replace FX Module with TRT Module diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py index 4f495dad4b..1644dea547 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/backend/conversion.py @@ -2,11 +2,11 @@ import torch from torch_tensorrt.fx.trt_module import TRTModule from torch_tensorrt import TRTModuleNext +from torch_tensorrt.dynamo.backend._settings import CompilationSettings from torch_tensorrt.fx.fx2trt import ( InputTensorSpec, TRTInterpreter, ) -from torch_tensorrt.fx.utils import LowerPrecision import tensorrt as trt @@ -14,17 +14,13 @@ def convert_module( module: torch.fx.GraphModule, inputs: Sequence[torch.Tensor], - debug: bool = False, - workspace_size: int = 20 << 30, - precision: LowerPrecision = LowerPrecision.FP32, + settings: CompilationSettings = CompilationSettings(), ) -> Union[TRTModuleNext, TRTModule]: """Convert an FX module to a TRT module Args: module: FX GraphModule to convert inputs: Sequence of Tensors representing inputs to the module - debug: Whether to print out verbose debugging information - workspace_size: Maximum workspace TRT is allowed to use for the module - precision: Model Layer precision + settings: Compilation settings Returns: TRTModule or TRTModuleNext """ @@ -32,15 +28,15 @@ def convert_module( module, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True, - logger_level=(trt.Logger.VERBOSE if debug else trt.Logger.WARNING), + logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), ) r = interp.run( - max_workspace_size=workspace_size, - lower_precision=precision, + max_workspace_size=settings.workspace_size, + lower_precision=settings.precision, profiling_verbosity=( trt.ProfilingVerbosity.VERBOSE - if debug + if settings.debug else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ), ) diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index b4d1b18db9..5cd83d768c 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -136,15 +136,18 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None): f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}" ) - logger.debug("\nSupported Nodes:") + # Reformat support messages for debugger to print node overview as a single string + supported_nodes_str = "\nSupported Nodes:\n" for node_name in self.supported_operators: - logger.debug("-", node_name) + supported_nodes_str += f"- {node_name}\n" + + logger.debug(supported_nodes_str) if len(self.unsupported_operators) != 0: - logger.debug("\nUnsupported or Excluded Nodes:") + unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n" for node_name in self.unsupported_operators: - logger.debug("-", node_name) - logger.debug("\n") + unsupported_nodes_str += f"- {node_name}\n" + logger.debug(unsupported_nodes_str) else: logger.debug("\nAll Nodes Supported\n")