Skip to content

Weight gradient kernels for dense and MoE models #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 14, 2025
Merged

Conversation

zheanxu
Copy link
Collaborator

@zheanxu zheanxu commented May 6, 2025

This Pull Request introduces deepgemm.wgrad_gemm_fp8_fp8_fp32_nt and k_grouped_wgrad_gemm_fp8_fp8_fp32_nt, optimized weight gradient kernels for dense and MoE models. These kernels achieve a ~20% speedup compared to the internal CUTLASS implementation.

For detailed usage, refer to the function documentation.

Weight gradient GEMMs for dense models

M N K Opti BMxBN Computation (TFLOPS) Memory Bandwidth (GB/s)
7168 2112 4096 128x152 920 507
1536 24576 4096 128x152 986 582
512 32768 4096 128x152 878 1086
16384 7168 4096 128x152 994 342
7168 4096 4096 128x152 942 411
2048 7168 4096 128x152 920 513
7168 2112 8192 128x152 1052 451
1536 24576 8192 128x152 1092 511
512 32768 8192 128x152 1014 1129
16384 7168 8192 128x152 1079 240
7168 4096 8192 128x152 1061 333
2048 7168 8192 128x152 1037 452

Grouped weight gradient GEMMs for MoE models

Groups M N K Opti BMxBN Computation (TFLOPS) Memory Bandwidth (GB/s)
4 7168 4096 4096 128x152 939 409
4 2048 7168 4096 128x152 900 502
4 7168 4096 8192 128x152 1044 328
4 2048 7168 8192 128x152 1033 450
8 7168 4096 4096 128x152 942 411
8 2048 7168 4096 128x152 902 503

@zheanxu zheanxu requested a review from LyricZhao May 6, 2025 09:19
@zheanxu zheanxu self-assigned this May 6, 2025
@LyricZhao
Copy link
Collaborator

LyricZhao commented May 6, 2025

I plan to merge it after #94, thanks!

@LyricZhao
Copy link
Collaborator

These kernels achieve a ~20% speedup compared to the internal CUTLASS implementation.

To clarify, you can refer to the profile-data repo for internal CUTLASS impl performance comparison.

@hxdtest
Copy link

hxdtest commented May 13, 2025

m = 4096 
n = 1024
k = 4096


x = torch.ones((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.rand((n, k), device='cuda', dtype=torch.bfloat16)
 
ref_out =  (x.float() @ y.float().t())
out1 = ref_out.clone() 
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = per_token_cast_to_fp8(y)
 
 
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out1)
print(out1/ref_out)  

why the difference is nearly twice ?

out1/ref_out
tensor([[1.9995, 2.0002, 2.0004,  ..., 2.0006, 1.9999, 2.0009]

--update--
it should be out1 = torch.zeros_like(ref_out)

@zheanxu
Copy link
Collaborator Author

zheanxu commented May 13, 2025

@hxdtest Thank you very much for your feedback.
During backpropagation, W needs to accumulate W_grad, so deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x, y, out) was designed to perform out += [email protected]() instead of out = [email protected](). This detail was omitted in the documentation.

@hxdtest
Copy link

hxdtest commented May 14, 2025

@hxdtest Thank you very much for your feedback. During backpropagation, W needs to accumulate W_grad, so deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x, y, out) was designed to perform out += [email protected]() instead of out = [email protected](). This detail was omitted in the documentation.

Thank you for your reply. After fix the test code, the results are close.

@LyricZhao LyricZhao merged commit 04278f6 into main May 14, 2025
@LyricZhao LyricZhao deleted the wgrad-gemm branch May 14, 2025 07:55
@hxdtest
Copy link

hxdtest commented May 14, 2025

@hxdtest Thank you very much for your feedback. During backpropagation, W needs to accumulate W_grad, so deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x, y, out) was designed to perform out += [email protected]() instead of out = [email protected](). This detail was omitted in the documentation.

Fantastic Work!I used DeepGemm and built a fp8 Linear layer to replace torch.nn.Linear and run a rl job. It seems evaluation scores with mixed fp8 precision are close to scores with mixed bf16 experiment.

@ajWithNucleus
Copy link

ajWithNucleus commented May 17, 2025

@hxdtest can you please share your Linear layer wrapper as a quick start util. It will be helpful.

@hxdtest
Copy link

hxdtest commented May 29, 2025

@hxdtest can you please share your Linear layer wrapper as a quick start util. It will be helpful.

https://github.com/hxdtest/fp8_verl/blob/add_fp8/verl/third_party/deep_gemm/fp8_linear.py

@LyricZhao
Copy link
Collaborator

Thanks for your work on the FP8 linear module. But the implements have lots of unfused kernels, e.g. per_token_cast_to_fp8 and kernels inside the call of the DeepGEMM function, which may lead to an overall lower performance.

Just a reminder if you care about the end-to-end performance :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants