-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Eval bug: b5335 break flash attention on 4070 #13430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Does this patch also fix the issue? diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 64fb4ff4c..fc231e97e 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -216,7 +216,7 @@ typedef float2 dfloat2;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
-#define CP_ASYNC_AVAILABLE
+// #define CP_ASYNC_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
@@ -258,7 +258,7 @@ static bool new_mma_available(const int cc) {
}
static bool cp_async_available(const int cc) {
- return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+ return false;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() { |
Look OK with that patch too (my patch first removed) |
I was unable to reproduce the issue but is it fixed by #13438 ? |
Look solid now. Most likley my 4070 wasn't getting done in time compared to the CPU thread that was spawning the async copy. Nice find and thanks again for your fast responses! |
This is not a race condition condition between CPU and GPU, it's a race condition between threads on one of the streaming multiprocessors on the GPU. Whether that race condition actually manifests as a bug is not guaranteed though (I was using a 4090). |
OK. Might be driver version + hardware related hard to say. I guess I just drew the short straw and my hardware + driver combo happened to triggered the bug. |
@JohannesGaessler Unfortunately there is still a problem with flash attention on even after this fix. Generation is fine for awhile, then if too many tokens get generated it will start to generate gibberish : b5342 with FA on generation: .
b5342 with FA off generation (no problems) .
Year 11: Elsie is 30 years old The model is Lllama 4 Scout, moe experts fully offloaded to CPU and rest of tensors including moe shared on GPU. b5279 doesn't show this problem but I do not know the commit that suddently causes the FA gibberish problem to start. |
Please tell me the exact model and command then that you're using. |
It fails with the Q3_K_H model that you downloaded the same way (after a certain number of tokens it starts generating garbage, always at the same point). This prompt can be used: Count from 1 to 200. Spell out numbers on each new line. with FA on, garbage after 159: One Hundred Fifty-Six With FA off, it will go to 200 correctly. The problem I am having is that it only does this when I am speculating it with Lllama 3.2 1B, and upstream llama.cpp will not speculate this model since it doesn't support translating between vocabularies of speculator and target which my downstream does. The unique thing happening during speculation is that during generation the target is not evaluating with batch size 1, it is evaluating with batch size 4 to 5. So something in the short batch sizes seems to have got wiped out with recent changes. It works fine with FA off, or with b5279. I don't know the commit that broke it but I am suspect the same one that caused all the async bugs to start happening. I will see if I can find a way to trigger the bug with upstream. If there is a way to force it to do decodes with short batch sizes it might help expose the problem. UPDATE: Problem does not appear to be related to async copy. I applied the patch above to global turn off async copy on release b5347 and identical problem happens. Identical failure in identical place also happens with KV set to F16 or q8_0. So something else got changed since b5279 which is causing short block size decodes to fail consistently after a certain number of tokens get evaluated. UPDATE 2: Problem does not happen with Qwen3-30B-A3B speculated with Qwen3-0.6B with similar config (experts offloaded to CPU, rest of tensors on GPU). This model does not have shared experts like Lllama 4 Scout. |
I ran up to 1000 question from MMLU, MMLU-Pro, GPQA, and GSM8K on a variety of models, once on a commit prior to my changes to MMQ and FlashAttention and once after those changes. Before: 45150/103440 correct answers. I included the LLaMA 4 quant you reported issues with, all other models are at FP16 precision. Before, LLaMA 4 q3_K_H only: 3729/6688 This does look like there could be a problem. |
I am thinking some interaction with SWA mask and FA on Lllama 4 Scout. Qwen 3 30 moe is rock solid with identical CPU/GPU offload config, Llama 4 is crapping out after a number of tokens go by which points to something going on with SWA + FA now in the newer code (which recently started to give gibberish after a certain number of tokens, b5279 never gave the gibberish, but I think it was still running degraded similar to your findings) |
I noticed that the scores for MMLU-Pro in particular got worse. My suspicion was that models would produce comparatively longer answers so that the bug only appears if the context is sufficiently long. On my RTX 3090 I get bad perplexity results with
The PR that introduced the problem is #13438 . The problem is that the |
Yeah that is what I found with the latest code. It was doing OK until the gens got too long then went into complete meltdown with gibbersh output, but I haven't seen that complete meltdown until recently. Maybe it was degrading with the longer gens though. Soft degradation is much harder to detect give statistical nature of underlying models.
Sounds great look forward to testing it tomorrow. Strange that the hard fail point in my tests in latest code was always a precise certain number of tokens in, 740 good gen tokens then gibberish. With prompt length of 26 KV would be at 766 total tokens at the fail point in my test. |
Name and Version
b5335 server
Operating systems
Linux
GGML backends
CUDA
Hardware
4070
Models
any (tested with Qwen3 8B)
Problem description & steps to reproduce
gibberish is generation when FA is turned on.
The problem goes away if making the following change in the cuda source file :
fattn-mma-f16.cuh
line 550 at b5335
//constexpr bool use_cp_async = nstages == 1;
constexpr bool use_cp_async = 0;
First Bad Commit
Unknown
Relevant log output
The text was updated successfully, but these errors were encountered: