Skip to content

CUDA 11.8's syevd solver can cause an illegal memory access error when called through torch.linalg.eigh #655

Closed
@hjmshi

Description

@hjmshi

We are preparing a PyTorch submission for AlgoPerf that relies on the torch.linalg.eigh operator, which calls the linalg_eigh_cusolver_syevd solver from cuSOLVER. While running this operator with PyTorch 2.1.0 + CUDA 11.8, we have observed that it can create an illegal memory access error in our AlgoPerf runs. This failure is not recoverable.

Description

We have observed previous issues with CUDA 11.8 where the torch.linalg.eigh operator can create an CUDA illegal memory access error, which is unrecoverable; see, for example, pytorch/pytorch#105359 and pytorch/pytorch#94772 (comment).

We have now observed this problem arise in our experiments for AlgoPerf for the OGBG model:

W0226 19:28:24.702791 140677981660992 shampoo_preconditioner_list.py:629] Matrix inverse root computation failed for factor matrix 52.block_0.1 with exception CUDA error: an illegal memory access was encountered                                                                                                                                                                               
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.                                                                          
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.                                                                                                                                           
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.                                                                                                                              
. Using previous inv_factor_matrix and continuing...                                                                                                                                             
    timing, metrics = train_once(workload, workload_name,                                                                                                                                        
  File "submission_runner.py", line 336, in train_once                                                                                                                                           
    optimizer_state, model_params, model_state = update_params(                                                                                                                                  
  File "/algorithmic-efficiency/submissions/shampoo_submission/pytorch_shampoo.py", line 178, in update_params                                                                                   
    optimizer_state['optimizer'].step()                                                                                                                                                          
  File "/usr/local/lib/python3.8/dist-packages/torch/optim/lr_scheduler.py", line 68, in wrapper                                                                                                 
    return wrapped(*args, **kwargs)                                                                                                                                                              
  File "/usr/local/lib/python3.8/dist-packages/torch/optim/optimizer.py", line 373, in wrapper                                                                                                   
    out = func(*args, **kwargs)                                                                                                                                                                  
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context                                                                                        
    return func(*args, **kwargs)                                                                                                                                                                 
  File "/algorithmic-efficiency/submissions/shampoo_submission/optimizers/distributed_shampoo/distributed_shampoo.py", line 905, in step                                                         
    self._per_group_step(                                                                                                                                                                        
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context                                                                                        
    return func(*args, **kwargs)                                                                                                                                                                 
  File "/algorithmic-efficiency/submissions/shampoo_submission/optimizers/distributed_shampoo/distributed_shampoo.py", line 753, in _per_group_step_impl                                         
W0226 19:28:24.704183 140547811231552 matrix_functions.py:218] Failed to compute eigendecomposition in torch.float32 precision with exception cusolver error: CUSOLVER_STATUS_EXECUTION_FAILED, w
hen calling `cusolverDnXsyevd( handle, params, jobz, uplo, n, CUDA_R_32F, reinterpret_cast<void*>(A), lda, CUDA_R_32F, reinterpret_cast<void*>(W), CUDA_R_32F, reinterpret_cast<void*>(bufferOnDe
vice), workspaceInBytesOnDevice, reinterpret_cast<void*>(bufferOnHost), workspaceInBytesOnHost, info)`. This error may appear if the input matrix contains NaN. If you keep seeing this error, yo
u may use `torch.backends.cuda.preferred_linalg_library()` to try linear algebra operators with other supported backends. See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.p
referred_linalg_library! Retrying in double precision...                                                                                                                                         
    torch._foreach_mul_(state_lists[MASKED_FILTERED_GRAD_LIST], beta1)                                                                                                                           
RuntimeError: CUDA error: an illegal memory access was encountered                                                                                                                               
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.                                                                          
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.                                                                                                                                           
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.                                                                                                                              

Notice that in this case, we are aiming to bypass the error (which is caught by our script), but then consecutive CUDA kernels also lead to illegal memory access.

Consistent with the Github issues posted above, we have checked that the cuSOLVER version is /usr/local/cuda/lib64/libcusolver.so.11.4.1.48, which is the problematic solver.

Steps to Reproduce

Follow the steps in pytorch/pytorch#105359 (comment).

Source or Possible Fix

If possible, we would suggest using CUDA 12.1.1 instead of CUDA 11.8 for the contest in order to avoid these instabilities.

Is there a way that we can change the Dockerfile to use PyTorch 12.1.0 with CUDA 12.1.1? (We have tried changing the first line in the Dockerfile to docker.io/nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04, but we still observe PyTorch 12.1 + CUDA 11.8 being used when calling torch.__version__ and torch.version.cuda.)

cc @anana10c @mikerabbat @tsunghsienlee @yuchenhao @shintaro-iwasaki

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions