Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions swift/megatron/trainers/rollout_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
check_vllm_version_ge, expand_vllm_param_name_aliases, patch_vllm_load_adapter,
patch_vllm_moe_model_weight_loader, profiling_context, profiling_decorator,
set_expandable_segments, vllm_supports_lora_load_inplace)
from swift.utils import (get_current_device, get_logger, is_last_rank, is_vllm_available, remove_response, synchronize,
to_device)
from swift.utils import (configure_vllm_allreduce_env, get_current_device, get_logger, is_last_rank, is_vllm_available,
remove_response, synchronize, to_device)
from .utils import (gather_object, load_megatron_model_to_gpu, load_megatron_optimizer, offload_megatron_model_to_cpu,
offload_megatron_optimizer)

Expand Down Expand Up @@ -246,6 +246,7 @@ def _init_rollout_engine(self):

def _prepare_vllm_engine(self):
"""Create and configure vLLM engine for colocate mode."""
configure_vllm_allreduce_env(self.vllm_tensor_parallel_size)
from vllm.distributed import parallel_state as vllm_ps

from swift.infer_engine import GRPOVllmEngine
Expand Down
5 changes: 3 additions & 2 deletions swift/rlhf_trainers/rollout_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from swift.sequence_parallel import sequence_parallel
from swift.template import Template
from swift.tuners import Swift
from swift.utils import (get_current_device, get_logger, is_deepspeed_enabled, is_vllm_available, remove_response,
to_device)
from swift.utils import (configure_vllm_allreduce_env, get_current_device, get_logger, is_deepspeed_enabled,
is_vllm_available, remove_response, to_device)
from .arguments import RolloutTrainerArgumentsMixin
from .rlhf_mixin import RLHFTrainerMixin
from .utils import (VLLM_LORA_INT_ID, VLLM_LORA_NAME, VLLM_LORA_PATH, FlattenedTensorBucket, TensorLoRARequest,
Expand Down Expand Up @@ -257,6 +257,7 @@ def _prepare_vllm(self):

def _prepare_vllm_engine(self):
"""Create and configure vLLM engine for colocate mode"""
configure_vllm_allreduce_env(self.vllm_tensor_parallel_size)
from swift.infer_engine import GRPOVllmEngine
args = self.args
model = self.model
Expand Down
5 changes: 3 additions & 2 deletions swift/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) ModelScope Contributors. All rights reserved.

from .env import (get_dist_setting, get_hf_endpoint, get_node_setting, get_pai_tensorboard_dir, is_deepspeed_enabled,
is_dist, is_last_rank, is_local_master, is_master, is_mp, is_mp_ddp, is_pai_training_job, use_hf_hub)
from .env import (configure_vllm_allreduce_env, get_dist_setting, get_hf_endpoint, get_node_setting,
get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist, is_last_rank, is_local_master, is_master,
is_mp, is_mp_ddp, is_pai_training_job, use_hf_hub)
from .hf_config import HfConfigFactory
from .hub_utils import download_ms_file, git_clone_github, safe_snapshot_download
from .import_utils import (is_flash_attn_2_available, is_flash_attn_3_available, is_liger_available,
Expand Down
9 changes: 9 additions & 0 deletions swift/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ def is_mp_ddp() -> bool:
return False


def configure_vllm_allreduce_env(tensor_parallel_size: int) -> None:
if tensor_parallel_size <= 1 or 'VLLM_ALLREDUCE_USE_SYMM_MEM' in os.environ:
Comment thread
yyswhsccc marked this conversation as resolved.
Outdated
return

os.environ['VLLM_ALLREDUCE_USE_SYMM_MEM'] = '0'
logger.info_once('Setting VLLM_ALLREDUCE_USE_SYMM_MEM=0 for vLLM tensor-parallel rollout. '
'Set the environment variable explicitly to override this stability default.')


def is_pai_training_job() -> bool:
return 'PAI_TRAINING_JOB_ID' in os.environ

Expand Down
27 changes: 27 additions & 0 deletions tests/utils/test_vllm_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

from swift.utils.env import configure_vllm_allreduce_env


def test_configure_vllm_allreduce_env_sets_default_for_tensor_parallel(monkeypatch):
monkeypatch.delenv('VLLM_ALLREDUCE_USE_SYMM_MEM', raising=False)

configure_vllm_allreduce_env(2)

assert os.environ['VLLM_ALLREDUCE_USE_SYMM_MEM'] == '0'


def test_configure_vllm_allreduce_env_preserves_explicit_value(monkeypatch):
monkeypatch.setenv('VLLM_ALLREDUCE_USE_SYMM_MEM', '1')

configure_vllm_allreduce_env(2)

assert os.environ['VLLM_ALLREDUCE_USE_SYMM_MEM'] == '1'


def test_configure_vllm_allreduce_env_skips_single_tensor_parallel(monkeypatch):
monkeypatch.delenv('VLLM_ALLREDUCE_USE_SYMM_MEM', raising=False)

configure_vllm_allreduce_env(1)

assert 'VLLM_ALLREDUCE_USE_SYMM_MEM' not in os.environ