diff --git a/swift/rlhf_trainers/rollout_mixin.py b/swift/rlhf_trainers/rollout_mixin.py index f4ea58e1d0..2d968a5bee 100644 --- a/swift/rlhf_trainers/rollout_mixin.py +++ b/swift/rlhf_trainers/rollout_mixin.py @@ -34,7 +34,7 @@ from swift.template import Template from swift.tuners import Swift from swift.utils import (get_current_device, get_logger, is_deepspeed_enabled, is_vllm_available, remove_response, - to_device) + synchronize, to_device) from .arguments import RolloutTrainerArgumentsMixin from .rlhf_mixin import RLHFTrainerMixin from .utils import (VLLM_LORA_INT_ID, VLLM_LORA_NAME, VLLM_LORA_PATH, FlattenedTensorBucket, TensorLoRARequest, @@ -570,6 +570,8 @@ def _load_state_dict_to_vllm(self, state_dict): # Patch MoE weight_loader if needed patch_vllm_moe_model_weight_loader(llm_model) llm_model.load_weights(state_dict.items()) + # Keep ZeRO-3 gathered tensors valid until vLLM finishes any queued device copies. + synchronize() del state_dict def _fix_param_name_to_vllm(self, name: str, extra_prefixes: Optional[List[str]] = None) -> str: