[Bug fix] Avoid OOM when casting fp32 on NPU for GRPO with vLLM colocate#9335
Conversation
[Bug fix] avoid OOM when casting fp32 on NPU for GRPO with vLLM colocate
There was a problem hiding this comment.
Code Review
This pull request modifies the NPU FSDP patching logic to move modules to the CPU and clear the NPU cache before casting to FP32, which prevents OOM errors in scenarios like GRPO with vLLM colocation where the model is preloaded on the NPU. Review feedback suggests making this CPU move and cache clearing conditional on whether the parameters are actually residing on the NPU to avoid unnecessary synchronization overhead and potential issues with meta-device initialization.
| import torch_npu | ||
| module = module.cpu() | ||
| torch_npu.npu.synchronize() | ||
| torch_npu.npu.empty_cache() |
There was a problem hiding this comment.
It is recommended to only move the module to CPU and clear the NPU cache if the parameters are currently residing on the NPU. This avoids unnecessary synchronization and cache clearing overhead for standard training paths (like SFT or LoRA) where the model is already on the CPU or meta device. It also ensures better compatibility with meta-device initialization, as calling .cpu() on meta-parameters might lead to unexpected behavior depending on the PyTorch version.
| import torch_npu | |
| module = module.cpu() | |
| torch_npu.npu.synchronize() | |
| torch_npu.npu.empty_cache() | |
| if param.device.type == 'npu': | |
| import torch_npu | |
| module = module.cpu() | |
| torch_npu.npu.synchronize() | |
| torch_npu.npu.empty_cache() |
ys2025-AI
left a comment
There was a problem hiding this comment.
添加param.device.type == 'npu':判断
PR type
PR information
Problem
When running GRPO with --vllm_mode colocate on Ascend NPU (8 x A2), Accelerator.prepare() triggers _cast_module_to_fp32_for_npu_if_needed() to cast the model to fp32 before FSDP2 sharding. However, in colocate mode the model has already been preloaded onto NPU by vLLM. Casting module.to(torch.float32) on NPU temporarily duplicates the full model (bf16 + fp32), causing OOM on large models like Qwen3-30B-A3B.
Root Cause
The patch assumes the model resides on CPU/meta before prepare(). GRPO colocate breaks this assumption because vLLM initializes the model on NPU first.
Fix
When param.device.type == 'npu', move the model back to CPU, free NPU memory via empty_cache() + synchronize(), then perform the fp32 cast on CPU. FSDP2 will shard the fp32 parameters back to NPU during prepare().
Compatibility
No impact on standard SFT / LoRA / Full fine-tuning where the model stays on CPU before prepare().
Only affects the NPU colocate code path.