Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions swift/callbacks/dynamic_mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import collections

import torch

from swift.utils import get_logger
from .base import TrainerCallback

logger = get_logger()


class DynamicMixingCallback(TrainerCallback):
"""Callback that dynamically adjusts data sampling weights based on per-domain loss."""

def __init__(self, args, trainer):
super().__init__(args, trainer)
self.update_steps = args.dynamic_mix_update_steps
self.temperature = args.dynamic_mix_temperature
self.warmup_steps = args.dynamic_mix_warmup_steps
self._sampler = None
self._domain_names = None
self._loss_buffer = collections.defaultdict(list)
self._last_update_step = 0

def on_train_begin(self, args, state, control, **kwargs):
# Reuse the already-created training dataloader instead of re-calling the getter
dataloader = getattr(self.trainer, 'train_dataloader', None)
if dataloader is None:
logger.warning('DynamicMixingCallback: train_dataloader not found, dynamic mixing disabled.')
return
sampler = getattr(dataloader, 'batch_sampler', None)
from swift.dataloader import DynamicMixBatchSampler
# Unwrap SkipBatchSampler if present
if hasattr(sampler, 'batch_sampler'):
sampler = sampler.batch_sampler
if not isinstance(sampler, DynamicMixBatchSampler):
logger.warning('DynamicMixingCallback: sampler is not '
'DynamicMixBatchSampler, dynamic mixing disabled.')
return
self._sampler = sampler
self._domain_names = sampler.domain_names
domain_sizes = {n: len(sampler.domain_indices[n]) for n in self._domain_names}
logger.info(f'Dynamic mixing initialized. Domains: {domain_sizes}')
logger.info(f'Initial probabilities: {sampler.probabilities}')

def on_log(self, args, state, control, logs=None, **kwargs):
if self._sampler is None or logs is None:
return
# Capture loss_{channel} values from logs
for name in self._domain_names:
# channel=None samples have log key "loss_None"
log_key = f'loss_{name}'
if log_key in logs and logs[log_key] is not None:
self._loss_buffer[name].append(logs[log_key])

# Check if it's time to update weights
if (state.global_step >= self.warmup_steps
and state.global_step - self._last_update_step >= self.update_steps):
self._update_probabilities(state.global_step)

def _update_probabilities(self, global_step):
domain_losses = {}
for name in self._domain_names:
values = self._loss_buffer.get(name, [])
if values:
domain_losses[name] = sum(values) / len(values)

if not domain_losses:
logger.info(f'Step {global_step}: no channel loss data yet, '
'skipping dynamic mix update.')
return

# Use global mean for domains without loss data
mean_loss = sum(domain_losses.values()) / len(domain_losses)
for name in self._domain_names:
if name not in domain_losses:
domain_losses[name] = mean_loss

# softmax(loss / T)
loss_tensor = torch.tensor([domain_losses[n] for n in self._domain_names])
probs = torch.softmax(loss_tensor / self.temperature, dim=0)
probs_dict = {name: probs[i].item() for i, name in enumerate(self._domain_names)}

self._sampler.set_probabilities(probs_dict)
self._loss_buffer.clear()
self._last_update_step = global_step

# Log new weights to metrics (will appear in tensorboard/wandb)
for name, prob in probs_dict.items():
self.trainer.custom_metrics['train'][f'mix_prob_{name}'].update(torch.tensor([prob]))

logger.info(f'Step {global_step}: updated mix probabilities: {probs_dict}')
2 changes: 2 additions & 0 deletions swift/callbacks/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .activation_cpu_offload import ActivationCpuOffloadCallBack
from .adalora import AdaloraCallback
from .deepspeed_elastic import DeepspeedElasticCallback, GracefulExitCallback
from .dynamic_mix import DynamicMixingCallback
from .early_stop import EarlyStopCallback
from .lisa import LISACallback
from .perf_log import PerfMetricsLogCallback
Expand All @@ -10,6 +11,7 @@
'activation_cpu_offload': ActivationCpuOffloadCallBack,
'adalora': AdaloraCallback,
'deepspeed_elastic': DeepspeedElasticCallback,
'dynamic_mix': DynamicMixingCallback,
'early_stop': EarlyStopCallback,
'graceful_exit': GracefulExitCallback,
'lisa': LISACallback,
Expand Down
2 changes: 1 addition & 1 deletion swift/dataloader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .dispatcher import DataLoaderDispatcher
from .shard import BatchSamplerShard, DataLoaderShard
from .shard import BatchSamplerShard, DataLoaderShard, DynamicMixBatchSampler
93 changes: 92 additions & 1 deletion swift/dataloader/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from typing import Optional
from typing import Dict, List, Optional

from swift.utils import to_device

Expand Down Expand Up @@ -94,3 +94,94 @@ def __iter__(self):
if self.device:
item = to_device(item, self.device)
yield item


class DynamicMixBatchSampler:
"""Batch sampler that samples indices weighted by per-domain probabilities.

Supports runtime probability updates for dynamic data mixing.
"""

def __init__(
self,
domain_indices: Dict[str, List[int]],
batch_size: int,
shuffle: bool,
drop_last: bool,
data_seed: Optional[int],
tp_size: int = 1,
num_batches: Optional[int] = None,
):
self.tp_size = tp_size
self.domain_indices = domain_indices
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.base_seed = data_seed or 0
self.curr_seed = self.base_seed
self.num_batches = num_batches
# Sort domain names to ensure consistent ordering across all ranks
self.domain_names = sorted(domain_indices.keys())
# Initial weights proportional to domain sizes
total = sum(len(domain_indices[n]) for n in self.domain_names)
self.probabilities = {n: len(domain_indices[n]) / total for n in self.domain_names}

@property
def rank(self):
return (dist.get_rank() // self.tp_size) if dist.is_initialized() else 0

@property
def world_size(self):
return (dist.get_world_size() // self.tp_size) if dist.is_initialized() else 1

def set_probabilities(self, probs):
"""Update sampling probabilities (must be called with the same values on all ranks)."""
for name in self.domain_names:
if name in probs:
self.probabilities[name] = probs[name]
# Normalize
total = sum(self.probabilities[n] for n in self.domain_names)
self.probabilities = {n: self.probabilities[n] / total for n in self.domain_names}

def __iter__(self):
generator = torch.Generator()
generator.manual_seed(self.curr_seed)
# Shuffle indices within each domain
domain_shuffled = {}
for name in self.domain_names:
indices = self.domain_indices[name]
if self.shuffle:
perm = torch.randperm(len(indices), generator=generator).tolist()
domain_shuffled[name] = [indices[p] for p in perm]
else:
domain_shuffled[name] = list(indices)
domain_cursors = {name: 0 for name in self.domain_names}

for _ in range(self.num_batches):
# Re-read probabilities each batch so runtime updates take effect
prob_tensor = torch.tensor([self.probabilities[n] for n in self.domain_names])
global_batch = []
batch_total = self.batch_size * self.world_size
sampled_domains = torch.multinomial(
prob_tensor, batch_total, replacement=True, generator=generator).tolist()
for domain_idx in sampled_domains:
domain_name = self.domain_names[domain_idx]
cursor = domain_cursors[domain_name]
if cursor >= len(domain_shuffled[domain_name]):
# Domain exhausted, reshuffle and reset
indices = self.domain_indices[domain_name]
if self.shuffle:
perm = torch.randperm(len(indices), generator=generator).tolist()
domain_shuffled[domain_name] = [indices[p] for p in perm]
domain_cursors[domain_name] = 0
cursor = 0
global_batch.append(domain_shuffled[domain_name][cursor])
domain_cursors[domain_name] = cursor + 1
# Distributed sharding: each rank takes its slice
yield global_batch[self.rank::self.world_size]

def set_epoch(self, epoch):
self.curr_seed = self.base_seed + epoch

def __len__(self):
return self.num_batches
20 changes: 20 additions & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ class TrainArgumentsMixin:
router_aux_loss_coef: float = 0.
enable_dft_loss: bool = False # https://arxiv.org/abs/2508.05629
enable_channel_loss: bool = False

# dynamic data mixing
dynamic_mix: bool = False
dynamic_mix_update_steps: int = 100
dynamic_mix_temperature: float = 1.0
dynamic_mix_warmup_steps: int = 0

safe_serialization: bool = True
max_shard_size: str = '5GB'

Expand Down Expand Up @@ -234,6 +241,10 @@ def _init_callbacks(self):
fsdp_config = getattr(self, 'fsdp_config', {})
if isinstance(fsdp_config, dict) and fsdp_config.get('activation_cpu_offload', False):
self.callbacks.append('activation_cpu_offload')
if self.dynamic_mix:
self.enable_channel_loss = True
if 'dynamic_mix' not in self.callbacks:
self.callbacks.append('dynamic_mix')

def __post_init__(self):
if hasattr(self, 'output_dir'):
Expand All @@ -245,6 +256,15 @@ def __post_init__(self):
if self.optimizer is None and (self.vit_lr is not None or self.aligner_lr is not None):
self.optimizer = 'multimodal'
self._init_callbacks()
if self.dynamic_mix:
if getattr(self, 'streaming', False):
raise ValueError('dynamic_mix does not support streaming mode.')
if getattr(self, 'interleave_prob', None) is not None:
raise ValueError('dynamic_mix and interleave_prob are mutually exclusive.')
if getattr(self, 'packing', False):
raise ValueError('dynamic_mix is not compatible with packing mode.')
if self.group_by_length:
raise ValueError('dynamic_mix is not compatible with group_by_length.')
if self.gradient_accumulation_steps is None:
world_size = get_dist_setting()[2]
self.gradient_accumulation_steps = max(1, math.ceil(16 / self.per_device_train_batch_size / world_size))
Expand Down
46 changes: 40 additions & 6 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from typing import Callable, Dict, List, Optional

from swift.callbacks import callbacks_map
from swift.dataloader import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard
from swift.dataloader import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, DynamicMixBatchSampler
from swift.hub import get_hub
from swift.loss import loss_map
from swift.metrics import MeanMetric, compute_acc, eval_metrics_map
Expand Down Expand Up @@ -1221,11 +1221,26 @@ def get_train_dataloader(self, skip_batches=0):
}

if hasattr(train_dataset, '__len__'):
if args.group_by_length:
batch_sampler_params['group_by_length'] = args.group_by_length
batch_sampler_params['lengths'] = train_dataset['lengths']
batch_sampler = BatchSamplerShard(
len(train_dataset), batch_size=self._train_batch_size, **batch_sampler_params)
if getattr(args, 'dynamic_mix', False):
domain_indices = self._build_domain_indices(train_dataset)
ws = (dist.get_world_size() // batch_sampler_params.get('tp_size', 1)
) if dist.is_initialized() else 1
total_per_rank = len(train_dataset) // ws
if args.dataloader_drop_last:
num_batches = total_per_rank // self._train_batch_size
else:
num_batches = (total_per_rank + self._train_batch_size - 1) // self._train_batch_size
batch_sampler = DynamicMixBatchSampler(
domain_indices=domain_indices,
batch_size=self._train_batch_size,
num_batches=num_batches,
**batch_sampler_params)
else:
if args.group_by_length:
batch_sampler_params['group_by_length'] = args.group_by_length
batch_sampler_params['lengths'] = train_dataset['lengths']
batch_sampler = BatchSamplerShard(
len(train_dataset), batch_size=self._train_batch_size, **batch_sampler_params)
dataloader_params['worker_init_fn'] = partial(
seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index)
if skip_batches > 0:
Expand All @@ -1241,6 +1256,25 @@ def get_train_dataloader(self, skip_batches=0):
dataloader = DataLoaderDispatcher(dataloader, self.accelerator.device, skip_batches=skip_batches)
return dataloader

def _build_domain_indices(self, dataset):
"""Extract channel -> index list mapping from the training dataset."""
domain_indices = collections.defaultdict(list)
# Unwrap LazyLLMDataset to get underlying HfDataset
hf_dataset = dataset
if hasattr(hf_dataset, 'dataset'):
hf_dataset = hf_dataset.dataset
if hasattr(hf_dataset, 'features') and 'channel' in hf_dataset.features:
channels = hf_dataset['channel']
for idx, ch in enumerate(channels):
domain_indices[ch if ch is not None else 'default'].append(idx)
else:
domain_indices['default'] = list(range(len(dataset)))
logger.warning('No "channel" column found in dataset. '
'All data treated as single domain "default".')
logger.info(f'Dynamic mix domains: '
f'{{{", ".join(f"{k}: {len(v)}" for k, v in sorted(domain_indices.items()))}}}')
return dict(domain_indices)

@contextmanager
def _disable_group_by_length(self):
group_by_length = getattr(self.args, 'group_by_length', False)
Expand Down
Loading