Skip to content

NCCL hangs during model training when CUDA graphs are enabled within NeMo's docker container (ver 25.02.01) #13140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jiuqiant opened this issue Apr 21, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@jiuqiant
Copy link

jiuqiant commented Apr 21, 2025

Describe the bug

NCCL hangs during model training when CUDA graphs are enabled within NeMo's docker container (ver 25.02.01). On a single node with 8 H100 80GB GPUs, enabling the CUDA graph feature in NeMo resulted in NCCL hanging and core dump.

Steps/Code to reproduce bug

Please list minimal steps or code snippet for us to be able to reproduce the bug.

The issue can be reproduced by running llama3_1_8b.py (see below) by the following command:

docker run --gpus all -it --rm -v  /home/test_run:/workspace/test_run --shm-size=8g --ulimit memlock=-1 --ulimit stack=67108864 -e CUDA_DEVICE_MAX_CONNECTIONS=1  nvcr.io/nvidia/nemo:25.02.01  /bin/bash -c "python test_run/llama3_1_8b.py"

llama3_1_8b.py:

"""Llama 3.1 8B training recipe."""

import os

from lightning.pytorch.loggers import TensorBoardLogger
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.llm.gpt.model.llama import Llama31Config8B
from nemo.collections.llm.gpt.model.llama import LlamaModel
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.pytorch.optim import MegatronOptimizerModule
from nemo.utils.exp_manager import TimingCallback
import torch


def main():
  data = llm.MockDataModule(
      num_train_samples=1_000_000,
      seq_length=8192,
      global_batch_size=128,
      micro_batch_size=1,
  )

  model_config = Llama31Config8B()
  model_config.enable_cuda_graph = True
  model = LlamaModel(model_config)

  strategy = nl.MegatronStrategy(
      tensor_model_parallel_size=8,
      pipeline_model_parallel_size=1,
      pipeline_dtype=torch.bfloat16,
      virtual_pipeline_model_parallel_size=None,
      context_parallel_size=1,
      expert_model_parallel_size=1,
      sequence_parallel=True,
      account_for_embedding_in_pipeline_split=True,
      account_for_loss_in_pipeline_split=True,
      gradient_as_bucket_view=True,
      ckpt_async_save=True,
      ckpt_parallel_save=True,
      ckpt_parallel_load=True,
      ckpt_parallel_save_optim=True,
      ckpt_load_strictness="log_all",
      ddp=DistributedDataParallelConfig(
          check_for_nan_in_grad=True,
          grad_reduce_in_fp32=True,
          overlap_grad_reduce=True,
          overlap_param_gather=True,
          average_in_collective=True,
      ),
      use_te_rng_tracker=True,
  )

  # Combine to the trainer
  trainer = nl.Trainer(
      accelerator="gpu",
      devices=8,
      num_nodes=1,
      max_steps=10,
      limit_val_batches=1,
      val_check_interval=5,
      log_every_n_steps=1,
      strategy=strategy,
      # Will let nemo tune automatically
      accumulate_grad_batches=1,
      # Will use nemo's sampler
      use_distributed_sampler=False,
      plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
      # Will let NeMoLogger to setup checkpoint
      enable_checkpointing=False,
      callbacks=[TimingCallback()],
  )

  # Config the optimizer
  opt_config = OptimizerConfig(
      optimizer="adam",
      lr=3e-4,
      weight_decay=0.1,
      bf16=True,
      fp16=False,
      adam_beta1=0.9,
      adam_beta2=0.95,
      adam_eps=1e-5,
      use_distributed_optimizer=True,
      clip_grad=1.0,
  )
  lr_scheduler = CosineAnnealingScheduler(
      warmup_steps=2000,
      constant_steps=0,
      min_lr=3e-5,
  )
  opt = MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler)

  # Setup checkpoint and tensorboard for logger
  ckpt = nl.ModelCheckpoint(
      save_top_k=1,
      # Generate a *-last ckpt copy (link) whenever a ckpt is saved.
      # This is required when using auto resume.
      save_last=True,
      # Set to True if the final ckpt will be used by auto resume
      save_optim_on_train_end=False,
      filename="{val_loss:.2f}-{step}-{consumed_samples}",
  )
  tb = TensorBoardLogger(
      save_dir="tensorboard",  # The name of tfevents folder
      name="",  # No need further subfolder
  )
  logger = nl.NeMoLogger(
      # The centralized dir for loggings, tensorboard, checkpoints
      explicit_log_dir="/logs",
      log_global_rank_0_only=True,
      update_logger_directory=True,
      # Remove this argument to disable checkpointing
      ckpt=ckpt,
      tensorboard=tb,
  )

  # Config auto resume
  resume = nl.AutoResume(
      # Force the training to resume from the last ckpt in log_dir if exists
      resume_if_exists=True,
      # Do not raise error if ckpt does not exist
      resume_ignore_no_checkpoint=True,
  )

  # Call nl.trainer.fit
  llm.pretrain(
      model=model,
      data=data,
      trainer=trainer,
      log=logger,
      resume=resume,
      optim=opt,
  )


if __name__ == "__main__":
  main()

The output log:

Training epoch 0, iteration 7/9 | lr: 1.199e-06 | global_batch_size: 128 | global_step: 7 | reduced_train_loss: 11.66 | train_step_timing in s: 19.74 | consumed_samples: 1024
[NeMo I 2025-04-21 20:55:08 nemo_logging:393] Async finalization time took 0.000 s
Training epoch 0, iteration 8/9 | lr: 1.349e-06 | global_batch_size: 128 | global_step: 8 | reduced_train_loss: 11.66 | train_step_timing in s: 19.74 | consumed_samples: 1152
[NeMo I 2025-04-21 20:55:28 nemo_logging:393] Async finalization time took 0.001 s
Training epoch 0, iteration 9/9 | lr: 1.499e-06 | global_batch_size: 128 | global_step: 9 | reduced_train_loss: 11.66 | train_step_timing in s: 19.71 | consumed_samples: 1280
[NeMo I 2025-04-21 20:55:48 nemo_logging:393] Successfully saved checkpoint from iteration       5 to /logs/checkpoints/val_loss=0.00-step=4-consumed_samples=640.0-last.ckpt
[NeMo I 2025-04-21 20:55:48 nemo_logging:393] Async checkpoint save for step 5 (/logs/checkpoints/val_loss=0.00-step=4-consumed_samples=640.0-last.ckpt) finalized successfully.
[NeMo I 2025-04-21 20:55:48 nemo_logging:393] Async finalization time took 0.135 s
[INFO     | pytorch_lightning.utilities.rank_zero]: Epoch 0, global step 9: 'val_loss' was not in top 1
[NeMo I 2025-04-21 20:55:49 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 10 : Start time: 1745268948.958s : Save duration: 1.016s
[NeMo I 2025-04-21 20:55:50 nemo_logging:393] Scheduled async checkpoint save for /logs/checkpoints/val_loss=0.00-step=9-consumed_samples=1280.0-last.ckpt
[NeMo I 2025-04-21 20:55:50 nemo_logging:393] Async finalization time took 0.001 s
[INFO     | pytorch_lightning.utilities.rank_zero]: `Trainer.fit` stopped: `max_steps=10` reached.
[NeMo I 2025-04-21 20:55:50 nemo_logging:393] Pending async checkpoint saves. Finalizing them synchronously now
[NeMo I 2025-04-21 20:56:03 nemo_logging:393] Successfully saved checkpoint from iteration      10 to /logs/checkpoints/val_loss=0.00-step=9-consumed_samples=1280.0-last.ckpt
[NeMo I 2025-04-21 20:56:03 nemo_logging:393] Async checkpoint save for step 10 (/logs/checkpoints/val_loss=0.00-step=9-consumed_samples=1280.0-last.ckpt) finalized successfully.
[NeMo I 2025-04-21 20:56:14 nemo_logging:393] Async finalization time took 24.807 s
[rank3]:[E421 21:07:05.801837010 ProcessGroupNCCL.cpp:1663] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 3] ProcessGroupNCCL's watchdog got stuck for 480 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. 
[rank4]:[E421 21:07:06.826057483 ProcessGroupNCCL.cpp:1663] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 4] ProcessGroupNCCL's watchdog got stuck for 480 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. 
[rank5]:[E421 21:07:07.257568212 ProcessGroupNCCL.cpp:1663] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 5] ProcessGroupNCCL's watchdog got stuck for 480 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. 
[rank7]:[E421 21:07:07.453226545 ProcessGroupNCCL.cpp:1663] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 7] ProcessGroupNCCL's watchdog got stuck for 480 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. 
[rank2]:[E421 21:07:07.664917138 ProcessGroupNCCL.cpp:1663] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 2] ProcessGroupNCCL's watchdog got stuck for 480 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. 
[rank1]:[E421 21:07:07.774540179 ProcessGroupNCCL.cpp:1663] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 1] ProcessGroupNCCL's watchdog got stuck for 480 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. 
[rank0]:[E421 21:07:07.791114957 ProcessGroupNCCL.cpp:1663] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 0] ProcessGroupNCCL's watchdog got stuck for 480 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. 
[rank6]:[E421 21:07:07.791466567 ProcessGroupNCCL.cpp:1663] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 6] ProcessGroupNCCL's watchdog got stuck for 480 seconds without making progress in monitoring enqueued collectives. This typically indicates a NCCL/CUDA API (e.g., CudaEventDestroy) hang blocking the watchdog, and could be triggered by another thread holding the GIL inside a CUDA api (for example, CudaEventDestroy), or other deadlock-prone behaviors.If you suspect the watchdog is not actually stuck and a longer timeout would help, you can either increase the timeout (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value or disable the heartbeat monitor (TORCH_NCCL_ENABLE_MONITORING=0).If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout or false positive abort; otherwise, please attempt to debug the hang. 
[rank3]:[F421 21:15:05.802702510 ProcessGroupNCCL.cpp:1485] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 3] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 3] Terminating the process after attempting to dump debug info, due to ProcessGroupNCCL watchdog hang.
[rank4]:[F421 21:15:06.826913006 ProcessGroupNCCL.cpp:1485] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 4] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 4] Terminating the process after attempting to dump debug info, due to ProcessGroupNCCL watchdog hang.
[rank5]:[F421 21:15:07.258583592 ProcessGroupNCCL.cpp:1485] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 5] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 5] Terminating the process after attempting to dump debug info, due to ProcessGroupNCCL watchdog hang.
[rank7]:[F421 21:15:07.454088414 ProcessGroupNCCL.cpp:1485] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 7] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 7] Terminating the process after attempting to dump debug info, due to ProcessGroupNCCL watchdog hang.
[rank2]:[F421 21:15:07.665376648 ProcessGroupNCCL.cpp:1485] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 2] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 2] Terminating the process after attempting to dump debug info, due to ProcessGroupNCCL watchdog hang.
[rank1]:[F421 21:15:07.775612890 ProcessGroupNCCL.cpp:1485] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 1] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 1] Terminating the process after attempting to dump debug info, due to ProcessGroupNCCL watchdog hang.
[rank0]:[F421 21:15:07.791741488 ProcessGroupNCCL.cpp:1485] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 0] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 0] Terminating the process after attempting to dump debug info, due to ProcessGroupNCCL watchdog hang.
[rank6]:[F421 21:15:07.792264960 ProcessGroupNCCL.cpp:1485] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 6] [PG ID 5 PG GUID 42(TENSOR_MODEL_PARALLEL_GROUP) Rank 6] Terminating the process after attempting to dump debug info, due to ProcessGroupNCCL watchdog hang.

Expected behavior

Training completes successfully when CUDA graphs are not disabled (by commenting out model_config.enable_cuda_graph = True and use_te_rng_tracker=True in the python file llama3_1_8b.py). However, we expect that enabling CUDA graphs should lead to the same successful outcome, which is not currently happening.

Environment overview (please complete the following information)

Running in a NeMo's docker container (ver 25.02.01) on GCP A3 High.

@jiuqiant jiuqiant added the bug Something isn't working label Apr 21, 2025
@jiuqiant
Copy link
Author

Providing more details:

  • NVIDIA-SMI 550.90.07
  • Driver Version: 550.90.07
  • CUDA Version: 12.8
  • Docker image: nvcr.io/nvidia/nemo:25.04.rc2

This hanging issue could lead to a core dump and the program can crash, potentially causing subsequent job failures. While fixing it would be beneficial, it's not a top priority.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants