Skip to content

add feature gate for tensorrt plugin #3518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions py/torch_tensorrt/_features.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os
import sys
from collections import namedtuple
Expand All @@ -15,6 +16,7 @@
"dynamo_frontend",
"fx_frontend",
"refit",
"qdp_plugin",
],
)

Expand All @@ -39,14 +41,24 @@
_FX_FE_AVAIL = True
_REFIT_AVAIL = True

if importlib.util.find_spec("tensorrt.plugin"):
_QDP_PLUGIN_AVAIL = True
else:
_QDP_PLUGIN_AVAIL = False

ENABLED_FEATURES = FeatureSet(
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL
_TS_FE_AVAIL,
_TORCHTRT_RT_AVAIL,
_DYNAMO_FE_AVAIL,
_FX_FE_AVAIL,
_REFIT_AVAIL,
_QDP_PLUGIN_AVAIL,
)


def _enabled_features_str() -> str:
enabled = lambda x: "ENABLED" if x else "DISABLED"
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n" # type: ignore[no-untyped-call]
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)}\n" # type: ignore[no-untyped-call]
return out_str


Expand All @@ -64,6 +76,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
return wrapper


def needs_qdp_plugin(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.qdp_plugin:
return f(*args, **kwargs)
else:

def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
raise NotImplementedError(
"TensorRT QDP(Quick Deploy Plugins) not available, requires TensorRT 10.7.0 or higher"
)

return not_implemented(*args, **kwargs)

return wrapper


def needs_refit(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.refit:
Expand Down
81 changes: 80 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from typing import List, Optional, Sequence
import logging
from typing import List, Optional, Sequence, cast

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_positive_dim,
get_trt_tensor,
set_layer_name,
)
from torch_tensorrt.dynamo.types import TRTTensor

logger = logging.getLogger(__name__)


def unsqueeze(
ctx: ConversionContext,
Expand All @@ -18,12 +22,87 @@ def unsqueeze(
input: TRTTensor,
dim: int,
) -> TRTTensor:
from importlib.metadata import version

if version("tensorrt") < "10.7.0":
logger.warning(
f"IUnsqueezeLayer is supported starting from TensorRT 10.7.0, using the old unsqueeze implementation in the current TensorRT version: {version('tensorrt')}"
)
return unsqueeze_old(ctx, target, source_ir, name, input, dim)
axes = get_trt_tensor(ctx, dim, f"{name}_axes")
layer = ctx.net.add_unsqueeze(input, axes)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


# old implementation for jetson due to IUnsqueezeLayer was not supported prior to 10.7.0
def unsqueeze_old(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
) -> TRTTensor:
input_val = get_trt_tensor(ctx, input, f"{name}_input")
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"unsqueeze received input {input_val} that is not part "
"of the TensorRT region!"
)

dim = cast(int, dim)

input_shape_size = len(input_val.shape)
dim = get_positive_dim(dim, input_shape_size + 1)

intermediate_dim = 0
dynamic_shape_cnt = 0
# if unsqueeze the last dimensions, we can directly append to the shape
if dim == input_shape_size:
intermediate_dim = dim
else:
# since maximum of one dimension is permitted to be specified as -1
# find the intermediate_dim which has only 1 dynamic_shape_cnt
# and then we can add a transpose after reshape if it is not the final shape we want
for i, s in reversed(list(enumerate(input_val.shape))):
if i >= dim:
if s == -1:
dynamic_shape_cnt += 1
if dynamic_shape_cnt > 1:
intermediate_dim = i + 1
break
if i == dim:
intermediate_dim = i
break
# calculate the new_shape for the shuffle layer's reshape_dims
new_shape = list(
tuple(input_val.shape)[:intermediate_dim]
+ (1,)
+ tuple(input_val.shape)[intermediate_dim:]
)
for i, s in enumerate(new_shape):
if i < intermediate_dim and s == -1:
new_shape[i] = 0
layer = ctx.net.add_shuffle(input_val)
layer.reshape_dims = tuple(new_shape)
# if the intermediate_dim is not the final dim we want to unsqueeze, add a second_transpose after reshape
if intermediate_dim != dim:
# calculate the second_transpose for the shuffle layer
permutation = [*range(0, len(new_shape))]
# for example: if the reshape_dims is (3, 3, 5, 1, 5) and the final shape we want is (3, 1, 3, 5, 5)
# here intermediate_dim=3, dim=1, we need to move intermediate_dim before [dim: intermediate_dim)
new_permutation = (
tuple(permutation[:dim])
+ (intermediate_dim,)
+ tuple(permutation[dim:intermediate_dim])
+ tuple(permutation[intermediate_dim + 1 :])
)
layer.second_transpose = new_permutation
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def broadcast_in_dim(
ctx: ConversionContext,
target: Target,
Expand Down
10 changes: 9 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from types import FunctionType
from typing import Any, Callable, Tuple

import tensorrt.plugin as trtp
import torch
from sympy import lambdify
from torch._dynamo.source import LocalSource
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
from torch_tensorrt._features import needs_qdp_plugin

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand All @@ -28,6 +28,13 @@ def mksym(


def _generate_plugin(plugin_name: str) -> None:
try:
import tensorrt.plugin as trtp
except ImportError as e:
raise RuntimeError(
"Unable to import TensorRT plugin. TensorRT version must be 10.7.0 or higher to support for Triton based TensorRT plugins"
)

namespace, name = plugin_name.split("::")

# retrieve the corresponding torch operation using the passed in string
Expand Down Expand Up @@ -211,6 +218,7 @@ def _generic_plugin_impl(
trtp.impl(plugin_name)(plugin_impl)


@needs_qdp_plugin
def generate_plugin(plugin_name: str) -> None:
"""
Generate the Plugin using external kernels and TensorRT Quick Deployable Plugin APIs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@

import numpy as np
import tensorrt as trt

# Seems like a bug in TensorRT
import tensorrt.plugin as trtp
import torch
from tensorrt.plugin._lib import QDP_REGISTRY
from torch.fx.node import Argument, Node, Target
from torch_tensorrt._features import needs_qdp_plugin
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand All @@ -32,6 +29,15 @@ def _generate_plugin_converter(
supports_dynamic_shapes: bool = False,
requires_output_allocator: bool = False,
) -> DynamoConverterImplSignature:
try:
import tensorrt.plugin as trtp

except ImportError as e:
raise RuntimeError(
"Unable to import TensorRT plugin. TensorRT version must be 10.7.0 or higher to support for Triton based TensorRT plugins"
)
from tensorrt.plugin._lib import QDP_REGISTRY

torch_target = getattr(getattr(torch.ops, namespace), op_name)
overload_str = overload if overload else ""
overload_name = overload_str if overload else "default"
Expand Down Expand Up @@ -101,6 +107,7 @@ def custom_kernel_converter(
return custom_kernel_converter


@needs_qdp_plugin
def generate_plugin_converter(
plugin_id: str,
capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None,
Expand Down
Loading