diff --git a/docsrc/index.rst b/docsrc/index.rst index e7d5250e52..cf52fba2d5 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -141,6 +141,7 @@ Model Zoo * :ref:`torch_export_gpt2` * :ref:`torch_export_llama2` * :ref:`torch_export_sam2` +* :ref:`torch_export_flux_dev` * :ref:`notebooks` .. toctree:: @@ -157,6 +158,7 @@ Model Zoo tutorials/_rendered_examples/dynamo/torch_export_gpt2 tutorials/_rendered_examples/dynamo/torch_export_llama2 tutorials/_rendered_examples/dynamo/torch_export_sam2 + tutorials/_rendered_examples/dynamo/torch_export_flux_dev tutorials/notebooks Python API Documentation diff --git a/docsrc/tutorials/_rendered_examples/dog_code.png b/docsrc/tutorials/_rendered_examples/dog_code.png new file mode 100644 index 0000000000..18b272e025 Binary files /dev/null and b/docsrc/tutorials/_rendered_examples/dog_code.png differ diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index bb789e77b6..a9eab6d698 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -20,4 +20,5 @@ Model Zoo * :ref:`_torch_compile_gpt2`: Compiling a GPT2 model using ``torch.compile`` * :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`) * :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`) -* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`) \ No newline at end of file +* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`) +* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`) \ No newline at end of file diff --git a/examples/dynamo/torch_export_flux_dev.py b/examples/dynamo/torch_export_flux_dev.py new file mode 100644 index 0000000000..25b2fc6d2e --- /dev/null +++ b/examples/dynamo/torch_export_flux_dev.py @@ -0,0 +1,150 @@ +""" +.. _torch_export_flux_dev: + +Compiling FLUX.1-dev model using the Torch-TensorRT dynamo backend +=================================================================== + +This example illustrates the state of the art model `FLUX.1-dev `_ optimized using +Torch-TensorRT. + +**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications. + +Install the following dependencies before compilation + +.. code-block:: python + + pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" + +There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example, +we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency) +""" + +# %% +# Import the following libraries +# ----------------------------- +import torch +import torch_tensorrt +from diffusers import FluxPipeline +from torch.export._trace import _export + +# %% +# Define the FLUX-1.dev model +# ----------------------------- +# Load the ``FLUX-1.dev`` pretrained pipeline using ``FluxPipeline`` class. +# ``FluxPipeline`` includes different components such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler`` necessary +# to generate an image. We load the weights in ``FP16`` precision using ``torch_dtype`` argument +DEVICE = "cuda:0" +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, +) +pipe.to(DEVICE).to(torch.float16) +# Store the config and transformer backbone +config = pipe.transformer.config +backbone = pipe.transformer + + +# %% +# Export the backbone using torch.export +# -------------------------------------------------- +# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2`` +# due to `0/1 specialization `_ +batch_size = 2 +BATCH = torch.export.Dim("batch", min=1, max=2) +SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512) +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. +# To see this recommendation, you can try exporting using min=1, max=4096 +IMG_ID = torch.export.Dim("img_id", min=3586, max=4096) +dynamic_shapes = { + "hidden_states": {0: BATCH}, + "encoder_hidden_states": {0: BATCH, 1: SEQ_LEN}, + "pooled_projections": {0: BATCH}, + "timestep": {0: BATCH}, + "txt_ids": {0: SEQ_LEN}, + "img_ids": {0: IMG_ID}, + "guidance": {0: BATCH}, +} +# The guidance factor is of type torch.float32 +dummy_inputs = { + "hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to( + DEVICE + ), + "encoder_hidden_states": torch.randn( + (batch_size, 512, 4096), dtype=torch.float16 + ).to(DEVICE), + "pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to( + DEVICE + ), + "timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(DEVICE), + "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE), + "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE), + "guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE), +} +# This will create an exported program which is going to be compiled with Torch-TensorRT +ep = _export( + backbone, + args=(), + kwargs=dummy_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + allow_complex_guards_as_runtime_asserts=True, +) + +# %% +# Torch-TensorRT compilation +# --------------------------- +# .. note:: +# The compilation requires a GPU with high memory (> 80GB) since TensorRT is storing the weights in FP32 precision. This is a known issue and will be resolved in the future. +# +# +# We enable ``FP32`` matmul accumulation using ``use_fp32_acc=True`` to ensure accuracy is preserved by introducing cast to ``FP32`` nodes. +# We also enable explicit typing to ensure TensorRT respects the datatypes set by the user which is a requirement for FP32 matmul accumulation. +# Since this is a 12 billion parameter model, it takes around 20-30 min to compile on H100 GPU. The model is completely convertible and results in +# a single TensorRT engine. +trt_gm = torch_tensorrt.dynamo.compile( + ep, + inputs=dummy_inputs, + enabled_precisions={torch.float32}, + truncate_double=True, + min_block_size=1, + use_fp32_acc=True, + use_explicit_typing=True, +) + +# %% +# Post Processing +# --------------------------- +# Release the GPU memory occupied by the exported program and the pipe.transformer +# Set the transformer in the Flux pipeline to the Torch-TRT compiled model +backbone.to("cpu") +del ep +pipe.transformer = trt_gm +pipe.transformer.config = config + +# %% +# Image generation using prompt +# --------------------------- +# Provide a prompt and the file name of the image to be generated. Here we use the +# prompt ``A golden retriever holding a sign to code``. + + +# Function which generates images from the flux pipeline +def generate_image(pipe, prompt, image_name): + seed = 42 + image = pipe( + prompt, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cuda").manual_seed(seed), + ).images[0] + image.save(f"{image_name}.png") + print(f"Image generated using {image_name} model saved as {image_name}.png") + + +generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code") + +# %% +# The generated image is as shown below +# +# .. image:: dog_code.png +# diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py index 28f71f78a6..67d2ba6690 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py @@ -15,7 +15,10 @@ def remove_assert_scalar( """Remove assert_scalar ops in the graph""" count = 0 for node in gm.graph.nodes: - if node.target == torch.ops.aten._assert_scalar.default: + if ( + node.target == torch.ops.aten._assert_scalar.default + or node == torch.ops.aten._assert_tensor_metadata.default + ): gm.graph.erase_node(node) count += 1 diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 2d3cb2924d..f3d297a01c 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -243,12 +243,16 @@ def prepare_inputs( inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], disable_memory_format_check: bool = False, ) -> Any: - if isinstance(inputs, Input): + if inputs is None: + return None + + elif isinstance(inputs, Input): return inputs - elif isinstance(inputs, torch.Tensor): + elif isinstance(inputs, (torch.Tensor, int, float, bool)): return Input.from_tensor( - inputs, disable_memory_format_check=disable_memory_format_check + torch.tensor(inputs), + disable_memory_format_check=disable_memory_format_check, ) elif isinstance(inputs, (list, tuple)): @@ -395,10 +399,13 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) - """ Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64 """ - if isinstance(tensor, (torch.Tensor, FakeTensor)): - return tensor.dtype + if isinstance(tensor, (torch.Tensor, FakeTensor, int, float, bool)): + return torch.tensor(tensor).dtype elif isinstance(tensor, torch.SymInt): return torch.int64 + elif tensor is None: + # Case where we explicitly pass one of the inputs to be None (eg: FLUX.1-dev) + return None else: raise ValueError(f"Found invalid tensor type {type(tensor)}")