Skip to content

🐛 [Bug] torch.arange causes SpecViolationError during torch_tensorrt.save #3189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Qi-Zha0 opened this issue Sep 26, 2024 · 2 comments · Fixed by #3194
Closed

🐛 [Bug] torch.arange causes SpecViolationError during torch_tensorrt.save #3189

Qi-Zha0 opened this issue Sep 26, 2024 · 2 comments · Fixed by #3194
Assignees
Labels
bug Something isn't working

Comments

@Qi-Zha0
Copy link

Qi-Zha0 commented Sep 26, 2024

Bug Description

See example below

To Reproduce

Minimum example:

import torch
import torch_tensorrt


class Mod(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x_embed = torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
        return x_embed


ep = torch_tensorrt.compile(Mod(), ir="dynamo", inputs=(torch.randn(1, 1, 128, 128)))
torch_tensorrt.save(ep, "test.ep", inputs=(torch.randn(1, 1, 128, 128)))

Error:

WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache.bin')

WARNING:torch_tensorrt.dynamo._compiler:0 supported operations detected in subgraph containing 0 computational nodes. Skipping this subgraph, since min_block_size was detected to be 5
Traceback (most recent call last):
  File "/home/user/project/project_subdirectory/scripts/debug.py", line 16, in <module>
    torch_tensorrt.save(ep, "test.ep", inputs=(torch.randn(1, 1, 128, 128)))
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 461, in save
    exp_program = export(module, inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py", line 33, in export
    exp_program = create_trt_exp_program(patched_module)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py", line 328, in create_trt_exp_program
    trt_exp_program = ExportedProgram(
                      ^^^^^^^^^^^^^^^^
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 246, in __init__
    self.verifier().check(self)
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch/_export/verifier.py", line 155, in check
    _verify_exported_program_signature(ep)
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch/_export/verifier.py", line 421, in _verify_exported_program_signature
    raise SpecViolationError(
torch._export.verifier.SpecViolationError: User output _frozen_param0_1 is not in the correct order or is not found in the exported program's user_output list: ('_frozen_param0',). 
WARNING:py.warnings:/usr/lib/python3.11/tempfile.py:1073: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmpn6njzoc7'>
  _warnings.warn(warn_message, ResourceWarning)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.4
  • PyTorch Version (e.g. 1.0): 2.4
  • CPU Architecture:
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.11
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@Qi-Zha0 Qi-Zha0 added the bug Something isn't working label Sep 26, 2024
@peri044
Copy link
Collaborator

peri044 commented Sep 27, 2024

Here's what is happening

DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=0] = placeholder[target=x]
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (1, 129), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
    return (arange,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=0] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    return (_frozen_param0,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=0] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    return (_frozen_param0,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=0] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    return (_frozen_param0,)

Since there are static inputs here, the arange graph gets constant folded and the output is a constant list of [0, 1, ... 128]. This is registered as a getattr node with name _frozen_param0 in the graph (which are lifted as inputs and converted into placeholder node with name _frozen_param0_1). However, since this is also an output, the output name was not being updated from _frozen_param0 to _frozen_param0_1 in the graph_signature.output_specs.

#3194 fixes this. The example provided by @Qi-Zha0 now passes.

@Qi-Zha0
Copy link
Author

Qi-Zha0 commented Sep 27, 2024

@peri044 Thank you! Tested it on my end. It fixes the issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants