From 93620751cbf84dc2b4704aee5f2ec838cc54e65c Mon Sep 17 00:00:00 2001 From: Yongshan <57080732+yyswhsccc@users.noreply.github.com> Date: Sun, 17 May 2026 22:01:15 -0600 Subject: [PATCH] Fix Qwen3.5 conv1d zero3 gather --- swift/model/models/qwen.py | 67 ++++++++++++++--------- tests/models/test_qwen35_zero3.py | 88 +++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 25 deletions(-) create mode 100644 tests/models/test_qwen35_zero3.py diff --git a/swift/model/models/qwen.py b/swift/model/models/qwen.py index 4a6d30f43b..aeca4becd5 100644 --- a/swift/model/models/qwen.py +++ b/swift/model/models/qwen.py @@ -4,6 +4,7 @@ import os import torch import torch.nn.functional as F +from contextlib import nullcontext from packaging import version from PIL import Image from transformers import AutoTokenizer, BitsAndBytesConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase @@ -1191,6 +1192,21 @@ def _ensure_linear_attention_kernels(mod: torch.nn.Module) -> None: raise ImportError(_SP_LINEAR_KERNEL_IMPORT_ERROR) +def _gather_qwen3_5_conv1d_params_if_zero3(mod: torch.nn.Module): + conv1d = getattr(mod, 'conv1d', None) + if conv1d is None or not is_deepspeed_enabled(): + return nullcontext() + + # Qwen3.5 passes conv1d.weight directly to causal-conv kernels, bypassing + # DeepSpeed's submodule forward hooks that normally gather ZeRO-3 params. + params = [p for p in conv1d.parameters(recurse=False) if p is not None] + if not params or not any(hasattr(p, 'ds_tensor') or hasattr(p, 'ds_id') for p in params): + return nullcontext() + + import deepspeed + return deepspeed.zero.GatheredParameters(params) + + def _get_local_conv_weights(mod: torch.nn.Module, *, sp_rank: int, local_num_k_heads: int, local_num_v_heads: int): conv_weight = mod.conv1d.weight.squeeze(1) conv_bias = getattr(mod.conv1d, 'bias', None) @@ -1378,31 +1394,32 @@ def sp_linear_forward( ): from swift.sequence_parallel import sequence_parallel as sequence_parallel_context - if not _sp_is_enabled(sequence_parallel_context): - kwargs = {} - if 'cache_position' in parameters: - kwargs['cache_position'] = cache_position - return origin_forward( - mod, hidden_states, cache_params=cache_params, attention_mask=attention_mask, **kwargs) - - if int(getattr(sequence_parallel_context, 'rp_world_size', 1) or 1) > 1: - requested_sp_size = int(getattr(sequence_parallel_context, 'world_size', 1) or 1) - suggested_sp_size = int(getattr(sequence_parallel_context, 'sp_world_size', 1) or 1) - raise NotImplementedError( - 'Qwen3.5 linear attention sequence parallel does not support derived ring attention ' - f'(sequence_parallel_size={requested_sp_size}, ' - f'sp_world_size={getattr(sequence_parallel_context, "sp_world_size", None)}, ' - f'rp_world_size={getattr(sequence_parallel_context, "rp_world_size", None)}). ' - f'Please reduce --sequence_parallel_size to {suggested_sp_size} so that rp_world_size becomes 1.') - - return _run_qwen3_5_gated_delta_net_sequence_parallel_forward( - mod, - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - sequence_parallel_context=sequence_parallel_context, - ) + with _gather_qwen3_5_conv1d_params_if_zero3(mod): + if not _sp_is_enabled(sequence_parallel_context): + kwargs = {} + if 'cache_position' in parameters: + kwargs['cache_position'] = cache_position + return origin_forward( + mod, hidden_states, cache_params=cache_params, attention_mask=attention_mask, **kwargs) + + if int(getattr(sequence_parallel_context, 'rp_world_size', 1) or 1) > 1: + requested_sp_size = int(getattr(sequence_parallel_context, 'world_size', 1) or 1) + suggested_sp_size = int(getattr(sequence_parallel_context, 'sp_world_size', 1) or 1) + raise NotImplementedError( + 'Qwen3.5 linear attention sequence parallel does not support derived ring attention ' + f'(sequence_parallel_size={requested_sp_size}, ' + f'sp_world_size={getattr(sequence_parallel_context, "sp_world_size", None)}, ' + f'rp_world_size={getattr(sequence_parallel_context, "rp_world_size", None)}). ' + f'Please reduce --sequence_parallel_size to {suggested_sp_size} so that rp_world_size becomes 1.') + + return _run_qwen3_5_gated_delta_net_sequence_parallel_forward( + mod, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + sequence_parallel_context=sequence_parallel_context, + ) Qwen3_5GatedDeltaNet.forward = sp_linear_forward Qwen3_5GatedDeltaNet._ms_swift_sp_linear_patched = True diff --git a/tests/models/test_qwen35_zero3.py b/tests/models/test_qwen35_zero3.py new file mode 100644 index 0000000000..d41a7e62db --- /dev/null +++ b/tests/models/test_qwen35_zero3.py @@ -0,0 +1,88 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import sys +import torch +import types + +from swift.model.models import qwen + + +def _install_fake_deepspeed(monkeypatch, events): + + class GatheredParameters: + + def __init__(self, params): + self.params = list(params) + events.append(('init', self.params)) + + def __enter__(self): + events.append('enter') + return self + + def __exit__(self, exc_type, exc, tb): + events.append('exit') + + deepspeed = types.SimpleNamespace(zero=types.SimpleNamespace(GatheredParameters=GatheredParameters)) + monkeypatch.setitem(sys.modules, 'deepspeed', deepspeed) + + +def test_qwen35_conv1d_gather_context_collects_zero3_params(monkeypatch): + events = [] + _install_fake_deepspeed(monkeypatch, events) + monkeypatch.setattr(qwen, 'is_deepspeed_enabled', lambda: True) + + mod = types.SimpleNamespace(conv1d=torch.nn.Conv1d(4, 4, kernel_size=3, groups=4, bias=False)) + mod.conv1d.weight.ds_tensor = object() + + with qwen._gather_qwen3_5_conv1d_params_if_zero3(mod): + events.append('body') + + assert events == [('init', [mod.conv1d.weight]), 'enter', 'body', 'exit'] + + +def test_qwen35_conv1d_gather_context_ignores_non_zero3_params(monkeypatch): + events = [] + _install_fake_deepspeed(monkeypatch, events) + monkeypatch.setattr(qwen, 'is_deepspeed_enabled', lambda: True) + + mod = types.SimpleNamespace(conv1d=torch.nn.Conv1d(4, 4, kernel_size=3, groups=4, bias=False)) + + with qwen._gather_qwen3_5_conv1d_params_if_zero3(mod): + events.append('body') + + assert events == ['body'] + + +def test_qwen35_patched_forward_gathers_conv1d_params_for_origin_forward(monkeypatch): + events = [] + _install_fake_deepspeed(monkeypatch, events) + monkeypatch.setattr(qwen, 'is_deepspeed_enabled', lambda: True) + + qwen35_package = types.ModuleType('transformers.models.qwen3_5') + qwen35_package.__path__ = [] + qwen35_modeling = types.ModuleType('transformers.models.qwen3_5.modeling_qwen3_5') + + class Qwen3_5GatedDeltaNet(torch.nn.Module): + + def __init__(self): + super().__init__() + self.conv1d = torch.nn.Conv1d(4, 4, kernel_size=3, groups=4, bias=False) + self.conv1d.weight.ds_tensor = object() + + def forward(self, hidden_states, cache_params=None, cache_position=None, attention_mask=None): + events.append('origin') + return hidden_states + 1 + + qwen35_modeling.Qwen3_5GatedDeltaNet = Qwen3_5GatedDeltaNet + qwen35_package.modeling_qwen3_5 = qwen35_modeling + monkeypatch.setitem(sys.modules, 'transformers.models.qwen3_5', qwen35_package) + monkeypatch.setitem(sys.modules, 'transformers.models.qwen3_5.modeling_qwen3_5', qwen35_modeling) + + qwen._patch_qwen3_5_linear_attention_sequence_parallel() + + hidden_states = torch.zeros(1, 2, 4) + mod = Qwen3_5GatedDeltaNet() + output = mod(hidden_states) + + assert torch.equal(output, hidden_states + 1) + assert events[0] == ('init', [mod.conv1d.weight]) + assert events[1:] == ['enter', 'origin', 'exit']