Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions swift/megatron/trainers/rollout_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
check_vllm_version_ge, expand_vllm_param_name_aliases, patch_vllm_load_adapter,
patch_vllm_moe_model_weight_loader, profiling_context, profiling_decorator,
set_expandable_segments, vllm_supports_lora_load_inplace)
from swift.utils import (get_current_device, get_logger, is_last_rank, is_vllm_available, remove_response, synchronize,
to_device)
from swift.utils import (configure_vllm_allreduce_env, get_current_device, get_logger, is_last_rank, is_vllm_available,
remove_response, synchronize, to_device)
from .utils import (gather_object, load_megatron_model_to_gpu, load_megatron_optimizer, offload_megatron_model_to_cpu,
offload_megatron_optimizer)

Expand Down Expand Up @@ -246,6 +246,7 @@ def _init_rollout_engine(self):

def _prepare_vllm_engine(self):
"""Create and configure vLLM engine for colocate mode."""
configure_vllm_allreduce_env(self.vllm_tensor_parallel_size)
from vllm.distributed import parallel_state as vllm_ps

from swift.infer_engine import GRPOVllmEngine
Expand Down
4 changes: 3 additions & 1 deletion swift/pipelines/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from swift.dataset import DatasetLoader, load_dataset, sample_dataset
from swift.infer_engine import AdapterRequest, InferRequest, RequestConfig, TransformersEngine
from swift.metrics import InferStats, MeanMetric, compute_rouge_bleu
from swift.utils import JsonlWriter, get_dist_setting, get_logger, is_dist, is_master, read_from_jsonl
from swift.utils import (JsonlWriter, configure_vllm_allreduce_env, get_dist_setting, get_logger, is_dist, is_master,
read_from_jsonl)
from ..base import SwiftPipeline
from ..export import merge_lora
from ..utils import get_cached_dataset, prepare_model_template
Expand Down Expand Up @@ -65,6 +66,7 @@ def get_infer_engine(args: InferArguments, template=None, **extra_kwargs):
if hasattr(args, 'max_batch_size'):
kwargs.update({'max_batch_size': args.max_batch_size})
elif infer_backend == 'vllm':
configure_vllm_allreduce_env(args.vllm_tensor_parallel_size)
from swift.infer_engine import VllmEngine
infer_engine_cls = VllmEngine
kwargs.update(args.get_vllm_engine_kwargs())
Expand Down
5 changes: 3 additions & 2 deletions swift/rlhf_trainers/rollout_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from swift.sequence_parallel import sequence_parallel
from swift.template import Template
from swift.tuners import Swift
from swift.utils import (get_current_device, get_logger, is_deepspeed_enabled, is_vllm_available, remove_response,
to_device)
from swift.utils import (configure_vllm_allreduce_env, get_current_device, get_logger, is_deepspeed_enabled,
is_vllm_available, remove_response, to_device)
from .arguments import RolloutTrainerArgumentsMixin
from .rlhf_mixin import RLHFTrainerMixin
from .utils import (VLLM_LORA_INT_ID, VLLM_LORA_NAME, VLLM_LORA_PATH, FlattenedTensorBucket, TensorLoRARequest,
Expand Down Expand Up @@ -257,6 +257,7 @@ def _prepare_vllm(self):

def _prepare_vllm_engine(self):
"""Create and configure vLLM engine for colocate mode"""
configure_vllm_allreduce_env(self.vllm_tensor_parallel_size)
from swift.infer_engine import GRPOVllmEngine
args = self.args
model = self.model
Expand Down
5 changes: 3 additions & 2 deletions swift/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) ModelScope Contributors. All rights reserved.

from .env import (get_dist_setting, get_hf_endpoint, get_node_setting, get_pai_tensorboard_dir, is_deepspeed_enabled,
is_dist, is_last_rank, is_local_master, is_master, is_mp, is_mp_ddp, is_pai_training_job, use_hf_hub)
from .env import (configure_vllm_allreduce_env, get_dist_setting, get_hf_endpoint, get_node_setting,
get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist, is_last_rank, is_local_master, is_master,
is_mp, is_mp_ddp, is_pai_training_job, use_hf_hub)
from .hf_config import HfConfigFactory
from .hub_utils import download_ms_file, git_clone_github, safe_snapshot_download
from .import_utils import (is_flash_attn_2_available, is_flash_attn_3_available, is_liger_available,
Expand Down
9 changes: 9 additions & 0 deletions swift/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ def is_mp_ddp() -> bool:
return False


def configure_vllm_allreduce_env(tensor_parallel_size: Optional[int]) -> None:
if not tensor_parallel_size or tensor_parallel_size <= 1 or 'VLLM_ALLREDUCE_USE_SYMM_MEM' in os.environ:
return

os.environ['VLLM_ALLREDUCE_USE_SYMM_MEM'] = '0'
logger.info_once('Setting VLLM_ALLREDUCE_USE_SYMM_MEM=0 for vLLM tensor-parallel execution. '
'Set the environment variable explicitly to override this stability default.')


def is_pai_training_job() -> bool:
return 'PAI_TRAINING_JOB_ID' in os.environ

Expand Down
35 changes: 35 additions & 0 deletions tests/utils/test_vllm_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

from swift.utils.env import configure_vllm_allreduce_env


def test_configure_vllm_allreduce_env_sets_default_for_tensor_parallel(monkeypatch):
monkeypatch.delenv('VLLM_ALLREDUCE_USE_SYMM_MEM', raising=False)

configure_vllm_allreduce_env(2)

assert os.environ['VLLM_ALLREDUCE_USE_SYMM_MEM'] == '0'


def test_configure_vllm_allreduce_env_preserves_explicit_value(monkeypatch):
monkeypatch.setenv('VLLM_ALLREDUCE_USE_SYMM_MEM', '1')

configure_vllm_allreduce_env(2)

assert os.environ['VLLM_ALLREDUCE_USE_SYMM_MEM'] == '1'


def test_configure_vllm_allreduce_env_skips_single_tensor_parallel(monkeypatch):
monkeypatch.delenv('VLLM_ALLREDUCE_USE_SYMM_MEM', raising=False)

configure_vllm_allreduce_env(1)

assert 'VLLM_ALLREDUCE_USE_SYMM_MEM' not in os.environ


def test_configure_vllm_allreduce_env_skips_missing_tensor_parallel(monkeypatch):
monkeypatch.delenv('VLLM_ALLREDUCE_USE_SYMM_MEM', raising=False)

configure_vllm_allreduce_env(None)

assert 'VLLM_ALLREDUCE_USE_SYMM_MEM' not in os.environ
Loading