-
Notifications
You must be signed in to change notification settings - Fork 363
feat: Implement FP32 accumulation for matmul #3110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
81 commits
Select commit
Hold shift + click to select a range
2ea181a
chore: add gpt2 example
peri044 37b65a5
chore: add llama2 example
peri044 bd12b12
Merge branch 'main' into llm_examples_main
peri044 4a9f73e
Merge branch 'main' into llm_examples_main
peri044 0387d0b
Merge branch 'main' into llm_examples_main
peri044 6193939
chore: updates
peri044 9d3296e
Merge branch 'main' into llm_examples_main
peri044 84fc49c
Merge branch 'main' into llm_examples_main
peri044 ff17d91
chore: rebase
peri044 8e6ba26
Merge branch 'llm_examples_main' of github.com:pytorch/TensorRT into …
peri044 67ec408
Merge branch 'main' into llm_examples_main
peri044 9af8e39
chore: remove aten.full decomposition
peri044 50d4096
chore: fix expand DS support
peri044 59febf5
chore: minor fix
peri044 c3e4382
chore: updates
peri044 0673db4
chore: add testcase
peri044 0b62f8f
Merge branch 'main' into full
peri044 54f6410
Merge branch 'full' into fix_expand_ds
peri044 ae3d6b2
Merge branch 'fix_expand_ds' into llm_examples_main
peri044 4464fd5
chore: updates
peri044 63b13cf
chore: updates
peri044 3d10b92
Merge branch 'main' into llm_examples_main
peri044 e97a94f
chore: updates
peri044 4f503a8
chore: updates
peri044 5ecf63e
chore: rebase
peri044 0d00d8c
chore: updates
peri044 8099003
chore: updates
peri044 457f706
chore: updates
peri044 ce3b2f8
chore: updates
peri044 d8acadc
chore: updates
peri044 262c87d
chore: updates
peri044 bb94dfd
chore: rebase
peri044 736b839
chore: updates
peri044 313380e
chore: bug fixes
peri044 1057d83
chore: updates
peri044 bfd0cf2
chore: fixes
peri044 17ddb31
chore: updates
peri044 88be4fa
chore: add torch compile gpt2 example
peri044 df825ab
chore: updates
peri044 ff07295
chore: add timing calculation
peri044 857b0aa
Merge branch 'main' into llm_examples_main
peri044 8fae56b
Merge branch 'main' into llm_examples_main
peri044 d483718
chore: rebase
peri044 397e4bc
Merge branch 'main' into llm_examples_main
peri044 6c9b9fe
chore: updates
peri044 6313b1c
chore: updates
peri044 d608cc5
chore: rebase
peri044 1327782
chore: rebase fixes
peri044 0980778
chore: updates
peri044 94b2ba1
chore: updates
peri044 2b1db29
chore: updates
peri044 9f606fc
chore: updates
peri044 0cf23be
Merge branch 'main' into llm_examples_main
peri044 3228c57
chore: Update perf tooling with support for HF models (#3034)
peri044 6786f0e
chore: updates
e4873d0
chore: updates
peri044 a725ce0
Merge branch 'main' into llm_examples_main
peri044 bb10de4
feat: lowering replace aten.full_like with aten.full
chohk88 1527aa0
chore: minor linting
chohk88 67e33c3
chore: updates
peri044 5627c1a
Merge branch 'llm_examples_main' of github.com:pytorch/TensorRT into …
peri044 7be8604
chore: updates
peri044 4d75a2e
Merge branch 'main' into llm_examples_main
peri044 0ab0dbf
feat: add fp32 accumulation option for matmul layer
peri044 3c815f8
chore: updates
5617c0a
chore: Bump TRT version to 10.3.0.26 (#3071)
zewenli98 213526e
chore: updates
peri044 c193593
chore : updates
peri044 0de0b16
chore: updates
peri044 a90191d
chore: rebase with main
peri044 71e33cb
chore: updates
peri044 4257b1e
chore: updates
peri044 619a39a
chore: updates
peri044 8c0b9c6
chore: trunc_fiv fix
peri044 b6261f9
chore: update result
peri044 ebdfe8f
fix: add model.half() for llama2
peri044 61ec948
chore: address review comments
peri044 dd27a54
chore: address review comments
peri044 b2e5244
chore: add docs
peri044 7ddd637
chore: updates
peri044 4529717
chore: sign bug fix
peri044 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
.. _mixed_precision: | ||
|
||
Compile Mixed Precision models with Torch-TensorRT | ||
==================================== | ||
.. currentmodule:: torch_tensorrt.dynamo | ||
|
||
.. automodule:: torch_tensorrt.dynamo | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
Consider the following Pytorch model which explicitly casts intermediate layer to run in FP16. | ||
|
||
.. code-block:: python | ||
|
||
class MyModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(10,10) | ||
self.linear2 = torch.nn.Linear(10,30).half() | ||
self.linear3 = torch.nn.Linear(30,40) | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = x.to(torch.float16) | ||
x = self.linear2(x) | ||
x = x.to(torch.float32) | ||
x = self.linear3(x) | ||
return x | ||
|
||
|
||
If we compile the above model using Torch-TensorRT, layer profiling logs indicate that all the layers are | ||
run in FP32. This is because TensorRT picks the kernels for layers which result in the best performance. | ||
|
||
.. code-block:: python | ||
|
||
inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()] | ||
mod = MyModule().eval().cuda() | ||
ep = torch.export.export(mod, tuple(inputs)) | ||
with torch_tensorrt.logging.debug(): | ||
trt_gm = torch_tensorrt.dynamo.compile(ep, | ||
inputs=inputs, | ||
debug=True) | ||
|
||
# Debug log info | ||
# Layers: | ||
# Name: __myl_MulSum_myl0_0, LayerType: kgen, Inputs: [ { Name: __mye116_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Float }], TacticName: __myl_MulSum_0xfa6c1858aea1b13b03f90165d7149ec6, StreamId: 0, Metadata: | ||
# Name: __myl_AddResMulSum_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye131_dconst, Dimensions: [10,30], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Float }, { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_AddResMulSum_0xb3915d7ebfe48be45b6d49083479e12f, StreamId: 0, Metadata: | ||
# Name: __myl_AddResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye146_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_AddResMulSumAdd_0xcdd0085ad25f5f45ac5fafb72acbffd6, StreamId: 0, Metadata: | ||
|
||
|
||
In order to respect the types specified by the user in the model (eg: in this case, ``linear2`` layer to run in FP16), users can enable | ||
the compilation setting ``use_explicit_typing=True``. Compiling with this option results in the following TensorRT logs | ||
|
||
.. note:: If you enable ``use_explicit_typing=True``, only torch.float32 is supported in the enabled_precisions. | ||
|
||
.. code-block:: python | ||
|
||
inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()] | ||
mod = MyModule().eval().cuda() | ||
ep = torch.export.export(mod, tuple(inputs)) | ||
with torch_tensorrt.logging.debug(): | ||
trt_gm = torch_tensorrt.dynamo.compile(ep, | ||
inputs=inputs, | ||
use_explicit_typing=True | ||
debug=True) | ||
|
||
# Debug log info | ||
# Layers: | ||
# Name: __myl_MulSumAddCas_myl0_0, LayerType: kgen, Inputs: [ { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }, { Name: __mye112_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], TacticName: __myl_MulSumAddCas_0xacf8f5dd9be2f3e7bb09cdddeac6c936, StreamId: 0, Metadata: | ||
# Name: __myl_ResMulSumAddCas_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye127_dconst, Dimensions: [10,30], Format/Datatype: Half }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantHalf, Dimensions: [1,30], Format/Datatype: Half }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_ResMulSumAddCas_0x5a3b318b5a1c97b7d5110c0291481337, StreamId: 0, Metadata: | ||
# Name: __myl_ResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye142_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_ResMulSumAdd_0x3fad91127c640fd6db771aa9cde67db0, StreamId: 0, Metadata: | ||
|
||
Now the ``linear2`` layer runs in FP16 as shown in the above logs. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,36 @@ | ||
.. _torch_compile: | ||
|
||
Dynamo / ``torch.compile`` | ||
---------------------------- | ||
Torch-TensorRT Examples | ||
==================================== | ||
|
||
Torch-TensorRT provides a backend for the new ``torch.compile`` API released in PyTorch 2.0. In the following examples we describe | ||
a number of ways you can leverage this backend to accelerate inference. | ||
Please refer to the following examples which demonstrate the usage of different features of Torch-TensorRT. We also provide | ||
examples of Torch-TensorRT compilation of select computer vision and language models. | ||
|
||
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` | ||
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` | ||
Dependencies | ||
------------------------------------ | ||
|
||
Please install the following external dependencies (assuming you already have correct `torch`, `torch_tensorrt` and `tensorrt` libraries installed (`dependencies <https://github.com/pytorch/TensorRT?tab=readme-ov-file#dependencies>`_)) | ||
|
||
.. code-block:: python | ||
|
||
pip install -r requirements.txt | ||
|
||
|
||
Compiler Features | ||
------------------------------------ | ||
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API | ||
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` | ||
* :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"` | ||
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines | ||
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights | ||
* :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule | ||
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile`` | ||
* :ref:`engine_caching_example`: Utilizing engine caching to speed up compilation times | ||
* :ref:`engine_caching_bert_example`: Demonstrating engine caching on BERT | ||
|
||
Model Zoo | ||
------------------------------------ | ||
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` | ||
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` | ||
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` | ||
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`) | ||
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
cupy==13.1.0 | ||
torch>=2.4.0.dev20240503+cu121 | ||
torch-tensorrt>=2.4.0.dev20240503+cu121 | ||
triton==2.3.0 | ||
diffusers==0.30.3 | ||
transformers==4.44.2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.