Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 42 additions & 25 deletions swift/model/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import torch
import torch.nn.functional as F
from contextlib import nullcontext
from packaging import version
from PIL import Image
from transformers import AutoTokenizer, BitsAndBytesConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase
Expand Down Expand Up @@ -1191,6 +1192,21 @@ def _ensure_linear_attention_kernels(mod: torch.nn.Module) -> None:
raise ImportError(_SP_LINEAR_KERNEL_IMPORT_ERROR)


def _gather_qwen3_5_conv1d_params_if_zero3(mod: torch.nn.Module):
conv1d = getattr(mod, 'conv1d', None)
if conv1d is None or not is_deepspeed_enabled():
return nullcontext()

# Qwen3.5 passes conv1d.weight directly to causal-conv kernels, bypassing
# DeepSpeed's submodule forward hooks that normally gather ZeRO-3 params.
params = [p for p in conv1d.parameters(recurse=False) if p is not None]
if not params or not any(hasattr(p, 'ds_tensor') or hasattr(p, 'ds_id') for p in params):
return nullcontext()

import deepspeed
return deepspeed.zero.GatheredParameters(params)


def _get_local_conv_weights(mod: torch.nn.Module, *, sp_rank: int, local_num_k_heads: int, local_num_v_heads: int):
conv_weight = mod.conv1d.weight.squeeze(1)
conv_bias = getattr(mod.conv1d, 'bias', None)
Expand Down Expand Up @@ -1378,31 +1394,32 @@ def sp_linear_forward(
):
from swift.sequence_parallel import sequence_parallel as sequence_parallel_context

if not _sp_is_enabled(sequence_parallel_context):
kwargs = {}
if 'cache_position' in parameters:
kwargs['cache_position'] = cache_position
return origin_forward(
mod, hidden_states, cache_params=cache_params, attention_mask=attention_mask, **kwargs)

if int(getattr(sequence_parallel_context, 'rp_world_size', 1) or 1) > 1:
requested_sp_size = int(getattr(sequence_parallel_context, 'world_size', 1) or 1)
suggested_sp_size = int(getattr(sequence_parallel_context, 'sp_world_size', 1) or 1)
raise NotImplementedError(
'Qwen3.5 linear attention sequence parallel does not support derived ring attention '
f'(sequence_parallel_size={requested_sp_size}, '
f'sp_world_size={getattr(sequence_parallel_context, "sp_world_size", None)}, '
f'rp_world_size={getattr(sequence_parallel_context, "rp_world_size", None)}). '
f'Please reduce --sequence_parallel_size to {suggested_sp_size} so that rp_world_size becomes 1.')

return _run_qwen3_5_gated_delta_net_sequence_parallel_forward(
mod,
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
sequence_parallel_context=sequence_parallel_context,
)
with _gather_qwen3_5_conv1d_params_if_zero3(mod):
if not _sp_is_enabled(sequence_parallel_context):
kwargs = {}
if 'cache_position' in parameters:
kwargs['cache_position'] = cache_position
return origin_forward(
mod, hidden_states, cache_params=cache_params, attention_mask=attention_mask, **kwargs)

if int(getattr(sequence_parallel_context, 'rp_world_size', 1) or 1) > 1:
requested_sp_size = int(getattr(sequence_parallel_context, 'world_size', 1) or 1)
suggested_sp_size = int(getattr(sequence_parallel_context, 'sp_world_size', 1) or 1)
raise NotImplementedError(
'Qwen3.5 linear attention sequence parallel does not support derived ring attention '
f'(sequence_parallel_size={requested_sp_size}, '
f'sp_world_size={getattr(sequence_parallel_context, "sp_world_size", None)}, '
f'rp_world_size={getattr(sequence_parallel_context, "rp_world_size", None)}). '
f'Please reduce --sequence_parallel_size to {suggested_sp_size} so that rp_world_size becomes 1.')

return _run_qwen3_5_gated_delta_net_sequence_parallel_forward(
mod,
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
sequence_parallel_context=sequence_parallel_context,
)

Qwen3_5GatedDeltaNet.forward = sp_linear_forward
Qwen3_5GatedDeltaNet._ms_swift_sp_linear_patched = True
Expand Down
88 changes: 88 additions & 0 deletions tests/models/test_qwen35_zero3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import sys
import torch
import types

from swift.model.models import qwen


def _install_fake_deepspeed(monkeypatch, events):

class GatheredParameters:

def __init__(self, params):
self.params = list(params)
events.append(('init', self.params))

def __enter__(self):
events.append('enter')
return self

def __exit__(self, exc_type, exc, tb):
events.append('exit')

deepspeed = types.SimpleNamespace(zero=types.SimpleNamespace(GatheredParameters=GatheredParameters))
monkeypatch.setitem(sys.modules, 'deepspeed', deepspeed)


def test_qwen35_conv1d_gather_context_collects_zero3_params(monkeypatch):
events = []
_install_fake_deepspeed(monkeypatch, events)
monkeypatch.setattr(qwen, 'is_deepspeed_enabled', lambda: True)

mod = types.SimpleNamespace(conv1d=torch.nn.Conv1d(4, 4, kernel_size=3, groups=4, bias=False))
mod.conv1d.weight.ds_tensor = object()

with qwen._gather_qwen3_5_conv1d_params_if_zero3(mod):
events.append('body')

assert events == [('init', [mod.conv1d.weight]), 'enter', 'body', 'exit']


def test_qwen35_conv1d_gather_context_ignores_non_zero3_params(monkeypatch):
events = []
_install_fake_deepspeed(monkeypatch, events)
monkeypatch.setattr(qwen, 'is_deepspeed_enabled', lambda: True)

mod = types.SimpleNamespace(conv1d=torch.nn.Conv1d(4, 4, kernel_size=3, groups=4, bias=False))

with qwen._gather_qwen3_5_conv1d_params_if_zero3(mod):
events.append('body')

assert events == ['body']


def test_qwen35_patched_forward_gathers_conv1d_params_for_origin_forward(monkeypatch):
events = []
_install_fake_deepspeed(monkeypatch, events)
monkeypatch.setattr(qwen, 'is_deepspeed_enabled', lambda: True)

qwen35_package = types.ModuleType('transformers.models.qwen3_5')
qwen35_package.__path__ = []
qwen35_modeling = types.ModuleType('transformers.models.qwen3_5.modeling_qwen3_5')

class Qwen3_5GatedDeltaNet(torch.nn.Module):

def __init__(self):
super().__init__()
self.conv1d = torch.nn.Conv1d(4, 4, kernel_size=3, groups=4, bias=False)
self.conv1d.weight.ds_tensor = object()

def forward(self, hidden_states, cache_params=None, cache_position=None, attention_mask=None):
events.append('origin')
return hidden_states + 1

qwen35_modeling.Qwen3_5GatedDeltaNet = Qwen3_5GatedDeltaNet
qwen35_package.modeling_qwen3_5 = qwen35_modeling
monkeypatch.setitem(sys.modules, 'transformers.models.qwen3_5', qwen35_package)
monkeypatch.setitem(sys.modules, 'transformers.models.qwen3_5.modeling_qwen3_5', qwen35_modeling)

qwen._patch_qwen3_5_linear_attention_sequence_parallel()

hidden_states = torch.zeros(1, 2, 4)
mod = Qwen3_5GatedDeltaNet()
output = mod(hidden_states)

assert torch.equal(output, hidden_states + 1)
assert events[0] == ('init', [mod.conv1d.weight])
assert events[1:] == ['enter', 'origin', 'exit']