Description
We are preparing a PyTorch submission for AlgoPerf that relies on the torch.linalg.eigh
operator, which calls the linalg_eigh_cusolver_syevd
solver from cuSOLVER. While running this operator with PyTorch 2.1.0 + CUDA 11.8, we have observed that it can create an illegal memory access error in our AlgoPerf runs. This failure is not recoverable.
Description
We have observed previous issues with CUDA 11.8 where the torch.linalg.eigh
operator can create an CUDA illegal memory access error, which is unrecoverable; see, for example, pytorch/pytorch#105359 and pytorch/pytorch#94772 (comment).
We have now observed this problem arise in our experiments for AlgoPerf for the OGBG model:
W0226 19:28:24.702791 140677981660992 shampoo_preconditioner_list.py:629] Matrix inverse root computation failed for factor matrix 52.block_0.1 with exception CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
. Using previous inv_factor_matrix and continuing...
timing, metrics = train_once(workload, workload_name,
File "submission_runner.py", line 336, in train_once
optimizer_state, model_params, model_state = update_params(
File "/algorithmic-efficiency/submissions/shampoo_submission/pytorch_shampoo.py", line 178, in update_params
optimizer_state['optimizer'].step()
File "/usr/local/lib/python3.8/dist-packages/torch/optim/lr_scheduler.py", line 68, in wrapper
return wrapped(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/optim/optimizer.py", line 373, in wrapper
out = func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/algorithmic-efficiency/submissions/shampoo_submission/optimizers/distributed_shampoo/distributed_shampoo.py", line 905, in step
self._per_group_step(
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/algorithmic-efficiency/submissions/shampoo_submission/optimizers/distributed_shampoo/distributed_shampoo.py", line 753, in _per_group_step_impl
W0226 19:28:24.704183 140547811231552 matrix_functions.py:218] Failed to compute eigendecomposition in torch.float32 precision with exception cusolver error: CUSOLVER_STATUS_EXECUTION_FAILED, w
hen calling `cusolverDnXsyevd( handle, params, jobz, uplo, n, CUDA_R_32F, reinterpret_cast<void*>(A), lda, CUDA_R_32F, reinterpret_cast<void*>(W), CUDA_R_32F, reinterpret_cast<void*>(bufferOnDe
vice), workspaceInBytesOnDevice, reinterpret_cast<void*>(bufferOnHost), workspaceInBytesOnHost, info)`. This error may appear if the input matrix contains NaN. If you keep seeing this error, yo
u may use `torch.backends.cuda.preferred_linalg_library()` to try linear algebra operators with other supported backends. See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.p
referred_linalg_library! Retrying in double precision...
torch._foreach_mul_(state_lists[MASKED_FILTERED_GRAD_LIST], beta1)
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Notice that in this case, we are aiming to bypass the error (which is caught by our script), but then consecutive CUDA kernels also lead to illegal memory access.
Consistent with the Github issues posted above, we have checked that the cuSOLVER version is /usr/local/cuda/lib64/libcusolver.so.11.4.1.48
, which is the problematic solver.
Steps to Reproduce
Follow the steps in pytorch/pytorch#105359 (comment).
Source or Possible Fix
If possible, we would suggest using CUDA 12.1.1 instead of CUDA 11.8 for the contest in order to avoid these instabilities.
Is there a way that we can change the Dockerfile
to use PyTorch 12.1.0 with CUDA 12.1.1? (We have tried changing the first line in the Dockerfile
to docker.io/nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04
, but we still observe PyTorch 12.1 + CUDA 11.8 being used when calling torch.__version__
and torch.version.cuda
.)
cc @anana10c @mikerabbat @tsunghsienlee @yuchenhao @shintaro-iwasaki