Skip to content

Exploring Stable Diffusion in Torch Compile Path #2144

Closed
@gs-olive

Description

@gs-olive

When benchmarking Torch-TRT on Stable Diffusion via TorchBench (see tutorial here), and using the following command:

HUGGING_FACE_HUB_TOKEN={YOUR HF TOKEN} python run.py stable_diffusion --backend torch_trt --precision fp32 -d cuda -t eval --ir torch_compile

set_module Issue

The above fails first due to this line, since the set_module attribute for Stable Diffusion seems to not be working.

Missing ATen Operators

If we intercept the compilation above that line, the compilation succeeds, but is very slow with over 100 TRT engines generated. Below are the missing operators listed. Some of these may have since been implemented.

- torch.ops.aten.var_mean.correction
- torch.ops.aten._unsafe_view.default
- torch.ops.aten.arange.start_step
- torch.ops.aten.bmm.default
- torch.ops.aten.amax.default
- torch.ops.aten.erf.default
- torch.ops.aten.full.default
- torch.ops.aten._to_copy.default
- torch.ops.aten.permute.default
- torch.ops.aten.sum.dim_IntList
- torch.ops.aten.exp.default
- torch.ops.aten.slice.Tensor
- _operator.getitem
- torch.ops.aten.argmax.default
- torch.ops.aten.unsqueeze.default
- torch.ops.aten.index.Tensor
- torch.ops.aten.sqrt.default
- torch.ops.aten.mm.default
- torch.ops.aten.embedding.default
- torch.ops.aten.reciprocal.default
- torch.ops.aten.clone.default

Accuracy Issue

Additionally, there is an accuracy issue where the outputs of the trt_model have a very low cosine similarity score (~0.14 on one test), relative to their PyTorch counterparts. This may be contributed to by the graph segmentation.

### Tasks
- [x] _operator.getitem
- [ ] torch.ops.aten.var_mean.correction
- [ ] torch.ops.aten._unsafe_view.default
- [ ] torch.ops.aten.amax.default
- [ ] torch.ops.aten.erf.default
- [ ] torch.ops.aten.full.default
- [x] torch.ops.aten._to_copy.default
- [x] torch.ops.aten.permute.default
- [ ] https://github.com/pytorch/TensorRT/issues/2244
- [ ] torch.ops.aten.exp.default
- [x] torch.ops.aten.slice.Tensor
- [ ] https://github.com/pytorch/TensorRT/issues/2245
- [x] torch.ops.aten.unsqueeze.default
- [ ] https://github.com/pytorch/TensorRT/issues/2231
- [ ] torch.ops.aten.sqrt.default
- [x] torch.ops.aten.mm.default
- [x] torch.ops.aten.embedding.default
- [ ] torch.ops.aten.reciprocal.default
- [x] torch.ops.aten.clone.default
- [ ] https://github.com/pytorch/TensorRT/issues/2236
- [ ] https://github.com/pytorch/TensorRT/issues/2243
- [ ] https://github.com/pytorch/TensorRT/issues/1795
- [x] https://github.com/pytorch/TensorRT/issues/2346

Metadata

Metadata

Labels

No ActivityStory: Dynamo Compile ImprovementsIssues relating to improvement of the Dynamo compile pathcomponent: dynamoIssues relating to the `torch.compile` or `torch._dynamo.export` paths

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions