diff --git a/swift/model/npu_patch/fsdp.py b/swift/model/npu_patch/fsdp.py index bb43a05704..79e0f00164 100644 --- a/swift/model/npu_patch/fsdp.py +++ b/swift/model/npu_patch/fsdp.py @@ -26,7 +26,17 @@ def _cast_module_to_fp32_for_npu_if_needed(module: torch.nn.Module, accelerator: # entering that path with bf16/fp16 parameters can fail before mixed # precision policy has a chance to manage runtime compute dtype. Cast early # while parameters are still on CPU or meta, so only dtype changes here. + + # GRPO with vLLM colocate mode may preload the model onto NPU before + # Accelerator.prepare() is called. In that case, casting fp32 on NPU + # would temporarily duplicate the full model (bf16 + fp32), causing OOM. + # We move the model back to CPU first to free NPU memory, then cast. try: + if param.device.type == 'npu': + import torch_npu + module = module.cpu() + torch_npu.npu.synchronize() + torch_npu.npu.empty_cache() return module.to(torch.float32) except Exception as exc: raise NPUCastError(f'Failed to cast {module.__class__.__name__} to fp32.') from exc