-
Notifications
You must be signed in to change notification settings - Fork 607
Weight gradient kernels for dense and MoE models #95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I plan to merge it after #94, thanks! |
To clarify, you can refer to the profile-data repo for internal CUTLASS impl performance comparison. |
why the difference is nearly twice ?
--update-- |
@hxdtest Thank you very much for your feedback. |
Thank you for your reply. After fix the test code, the results are close. |
Fantastic Work!I used DeepGemm and built a fp8 Linear layer to replace |
@hxdtest can you please share your Linear layer wrapper as a quick start util. It will be helpful. |
https://github.com/hxdtest/fp8_verl/blob/add_fp8/verl/third_party/deep_gemm/fp8_linear.py |
Thanks for your work on the FP8 linear module. But the implements have lots of unfused kernels, e.g. Just a reminder if you care about the end-to-end performance :) |
This Pull Request introduces
deepgemm.wgrad_gemm_fp8_fp8_fp32_nt
andk_grouped_wgrad_gemm_fp8_fp8_fp32_nt
, optimized weight gradient kernels for dense and MoE models. These kernels achieve a ~20% speedup compared to the internal CUTLASS implementation.For detailed usage, refer to the function documentation.
Weight gradient GEMMs for dense models
Grouped weight gradient GEMMs for MoE models