From f8bc07c5762d3d608df2ea5b81b094aa05ae7c88 Mon Sep 17 00:00:00 2001 From: Yongshan <57080732+yyswhsccc@users.noreply.github.com> Date: Mon, 18 May 2026 03:19:33 -0600 Subject: [PATCH] [bugfix] sync colocate vLLM weight load --- swift/rlhf_trainers/rollout_mixin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/swift/rlhf_trainers/rollout_mixin.py b/swift/rlhf_trainers/rollout_mixin.py index f4ea58e1d0..2d968a5bee 100644 --- a/swift/rlhf_trainers/rollout_mixin.py +++ b/swift/rlhf_trainers/rollout_mixin.py @@ -34,7 +34,7 @@ from swift.template import Template from swift.tuners import Swift from swift.utils import (get_current_device, get_logger, is_deepspeed_enabled, is_vllm_available, remove_response, - to_device) + synchronize, to_device) from .arguments import RolloutTrainerArgumentsMixin from .rlhf_mixin import RLHFTrainerMixin from .utils import (VLLM_LORA_INT_ID, VLLM_LORA_NAME, VLLM_LORA_PATH, FlattenedTensorBucket, TensorLoRARequest, @@ -570,6 +570,8 @@ def _load_state_dict_to_vllm(self, state_dict): # Patch MoE weight_loader if needed patch_vllm_moe_model_weight_loader(llm_model) llm_model.load_weights(state_dict.items()) + # Keep ZeRO-3 gathered tensors valid until vLLM finishes any queued device copies. + synchronize() del state_dict def _fix_param_name_to_vllm(self, name: str, extra_prefixes: Optional[List[str]] = None) -> str: