diff --git a/swift/callbacks/dynamic_mix.py b/swift/callbacks/dynamic_mix.py new file mode 100644 index 0000000000..71161617a5 --- /dev/null +++ b/swift/callbacks/dynamic_mix.py @@ -0,0 +1,92 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import collections + +import torch + +from swift.utils import get_logger +from .base import TrainerCallback + +logger = get_logger() + + +class DynamicMixingCallback(TrainerCallback): + """Callback that dynamically adjusts data sampling weights based on per-domain loss.""" + + def __init__(self, args, trainer): + super().__init__(args, trainer) + self.update_steps = args.dynamic_mix_update_steps + self.temperature = args.dynamic_mix_temperature + self.warmup_steps = args.dynamic_mix_warmup_steps + self._sampler = None + self._domain_names = None + self._loss_buffer = collections.defaultdict(list) + self._last_update_step = 0 + + def on_train_begin(self, args, state, control, **kwargs): + # Reuse the already-created training dataloader instead of re-calling the getter + dataloader = getattr(self.trainer, 'train_dataloader', None) + if dataloader is None: + logger.warning('DynamicMixingCallback: train_dataloader not found, dynamic mixing disabled.') + return + sampler = getattr(dataloader, 'batch_sampler', None) + from swift.dataloader import DynamicMixBatchSampler + # Unwrap SkipBatchSampler if present + if hasattr(sampler, 'batch_sampler'): + sampler = sampler.batch_sampler + if not isinstance(sampler, DynamicMixBatchSampler): + logger.warning('DynamicMixingCallback: sampler is not ' + 'DynamicMixBatchSampler, dynamic mixing disabled.') + return + self._sampler = sampler + self._domain_names = sampler.domain_names + domain_sizes = {n: len(sampler.domain_indices[n]) for n in self._domain_names} + logger.info(f'Dynamic mixing initialized. Domains: {domain_sizes}') + logger.info(f'Initial probabilities: {sampler.probabilities}') + + def on_log(self, args, state, control, logs=None, **kwargs): + if self._sampler is None or logs is None: + return + # Capture loss_{channel} values from logs + for name in self._domain_names: + # channel=None samples have log key "loss_None" + log_key = f'loss_{name}' + if log_key in logs and logs[log_key] is not None: + self._loss_buffer[name].append(logs[log_key]) + + # Check if it's time to update weights + if (state.global_step >= self.warmup_steps + and state.global_step - self._last_update_step >= self.update_steps): + self._update_probabilities(state.global_step) + + def _update_probabilities(self, global_step): + domain_losses = {} + for name in self._domain_names: + values = self._loss_buffer.get(name, []) + if values: + domain_losses[name] = sum(values) / len(values) + + if not domain_losses: + logger.info(f'Step {global_step}: no channel loss data yet, ' + 'skipping dynamic mix update.') + return + + # Use global mean for domains without loss data + mean_loss = sum(domain_losses.values()) / len(domain_losses) + for name in self._domain_names: + if name not in domain_losses: + domain_losses[name] = mean_loss + + # softmax(loss / T) + loss_tensor = torch.tensor([domain_losses[n] for n in self._domain_names]) + probs = torch.softmax(loss_tensor / self.temperature, dim=0) + probs_dict = {name: probs[i].item() for i, name in enumerate(self._domain_names)} + + self._sampler.set_probabilities(probs_dict) + self._loss_buffer.clear() + self._last_update_step = global_step + + # Log new weights to metrics (will appear in tensorboard/wandb) + for name, prob in probs_dict.items(): + self.trainer.custom_metrics['train'][f'mix_prob_{name}'].update(torch.tensor([prob])) + + logger.info(f'Step {global_step}: updated mix probabilities: {probs_dict}') diff --git a/swift/callbacks/mapping.py b/swift/callbacks/mapping.py index 3f18235e79..c22c51c275 100644 --- a/swift/callbacks/mapping.py +++ b/swift/callbacks/mapping.py @@ -2,6 +2,7 @@ from .activation_cpu_offload import ActivationCpuOffloadCallBack from .adalora import AdaloraCallback from .deepspeed_elastic import DeepspeedElasticCallback, GracefulExitCallback +from .dynamic_mix import DynamicMixingCallback from .early_stop import EarlyStopCallback from .lisa import LISACallback from .perf_log import PerfMetricsLogCallback @@ -10,6 +11,7 @@ 'activation_cpu_offload': ActivationCpuOffloadCallBack, 'adalora': AdaloraCallback, 'deepspeed_elastic': DeepspeedElasticCallback, + 'dynamic_mix': DynamicMixingCallback, 'early_stop': EarlyStopCallback, 'graceful_exit': GracefulExitCallback, 'lisa': LISACallback, diff --git a/swift/dataloader/__init__.py b/swift/dataloader/__init__.py index e0c72ff4f8..e7f5c28a46 100644 --- a/swift/dataloader/__init__.py +++ b/swift/dataloader/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .dispatcher import DataLoaderDispatcher -from .shard import BatchSamplerShard, DataLoaderShard +from .shard import BatchSamplerShard, DataLoaderShard, DynamicMixBatchSampler diff --git a/swift/dataloader/shard.py b/swift/dataloader/shard.py index 96873d2a02..1af51ef7f9 100644 --- a/swift/dataloader/shard.py +++ b/swift/dataloader/shard.py @@ -2,7 +2,7 @@ import torch import torch.distributed as dist from torch.utils.data import DataLoader -from typing import Optional +from typing import Dict, List, Optional from swift.utils import to_device @@ -94,3 +94,94 @@ def __iter__(self): if self.device: item = to_device(item, self.device) yield item + + +class DynamicMixBatchSampler: + """Batch sampler that samples indices weighted by per-domain probabilities. + + Supports runtime probability updates for dynamic data mixing. + """ + + def __init__( + self, + domain_indices: Dict[str, List[int]], + batch_size: int, + shuffle: bool, + drop_last: bool, + data_seed: Optional[int], + tp_size: int = 1, + num_batches: Optional[int] = None, + ): + self.tp_size = tp_size + self.domain_indices = domain_indices + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + self.base_seed = data_seed or 0 + self.curr_seed = self.base_seed + self.num_batches = num_batches + # Sort domain names to ensure consistent ordering across all ranks + self.domain_names = sorted(domain_indices.keys()) + # Initial weights proportional to domain sizes + total = sum(len(domain_indices[n]) for n in self.domain_names) + self.probabilities = {n: len(domain_indices[n]) / total for n in self.domain_names} + + @property + def rank(self): + return (dist.get_rank() // self.tp_size) if dist.is_initialized() else 0 + + @property + def world_size(self): + return (dist.get_world_size() // self.tp_size) if dist.is_initialized() else 1 + + def set_probabilities(self, probs): + """Update sampling probabilities (must be called with the same values on all ranks).""" + for name in self.domain_names: + if name in probs: + self.probabilities[name] = probs[name] + # Normalize + total = sum(self.probabilities[n] for n in self.domain_names) + self.probabilities = {n: self.probabilities[n] / total for n in self.domain_names} + + def __iter__(self): + generator = torch.Generator() + generator.manual_seed(self.curr_seed) + # Shuffle indices within each domain + domain_shuffled = {} + for name in self.domain_names: + indices = self.domain_indices[name] + if self.shuffle: + perm = torch.randperm(len(indices), generator=generator).tolist() + domain_shuffled[name] = [indices[p] for p in perm] + else: + domain_shuffled[name] = list(indices) + domain_cursors = {name: 0 for name in self.domain_names} + + for _ in range(self.num_batches): + # Re-read probabilities each batch so runtime updates take effect + prob_tensor = torch.tensor([self.probabilities[n] for n in self.domain_names]) + global_batch = [] + batch_total = self.batch_size * self.world_size + sampled_domains = torch.multinomial( + prob_tensor, batch_total, replacement=True, generator=generator).tolist() + for domain_idx in sampled_domains: + domain_name = self.domain_names[domain_idx] + cursor = domain_cursors[domain_name] + if cursor >= len(domain_shuffled[domain_name]): + # Domain exhausted, reshuffle and reset + indices = self.domain_indices[domain_name] + if self.shuffle: + perm = torch.randperm(len(indices), generator=generator).tolist() + domain_shuffled[domain_name] = [indices[p] for p in perm] + domain_cursors[domain_name] = 0 + cursor = 0 + global_batch.append(domain_shuffled[domain_name][cursor]) + domain_cursors[domain_name] = cursor + 1 + # Distributed sharding: each rank takes its slice + yield global_batch[self.rank::self.world_size] + + def set_epoch(self, epoch): + self.curr_seed = self.base_seed + epoch + + def __len__(self): + return self.num_batches diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index ab253ab593..f4a609e6da 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -135,6 +135,13 @@ class TrainArgumentsMixin: router_aux_loss_coef: float = 0. enable_dft_loss: bool = False # https://arxiv.org/abs/2508.05629 enable_channel_loss: bool = False + + # dynamic data mixing + dynamic_mix: bool = False + dynamic_mix_update_steps: int = 100 + dynamic_mix_temperature: float = 1.0 + dynamic_mix_warmup_steps: int = 0 + safe_serialization: bool = True max_shard_size: str = '5GB' @@ -234,6 +241,10 @@ def _init_callbacks(self): fsdp_config = getattr(self, 'fsdp_config', {}) if isinstance(fsdp_config, dict) and fsdp_config.get('activation_cpu_offload', False): self.callbacks.append('activation_cpu_offload') + if self.dynamic_mix: + self.enable_channel_loss = True + if 'dynamic_mix' not in self.callbacks: + self.callbacks.append('dynamic_mix') def __post_init__(self): if hasattr(self, 'output_dir'): @@ -245,6 +256,15 @@ def __post_init__(self): if self.optimizer is None and (self.vit_lr is not None or self.aligner_lr is not None): self.optimizer = 'multimodal' self._init_callbacks() + if self.dynamic_mix: + if getattr(self, 'streaming', False): + raise ValueError('dynamic_mix does not support streaming mode.') + if getattr(self, 'interleave_prob', None) is not None: + raise ValueError('dynamic_mix and interleave_prob are mutually exclusive.') + if getattr(self, 'packing', False): + raise ValueError('dynamic_mix is not compatible with packing mode.') + if self.group_by_length: + raise ValueError('dynamic_mix is not compatible with group_by_length.') if self.gradient_accumulation_steps is None: world_size = get_dist_setting()[2] self.gradient_accumulation_steps = max(1, math.ceil(16 / self.per_device_train_batch_size / world_size)) diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 8310e35c80..fbf5d937a6 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -37,7 +37,7 @@ from typing import Callable, Dict, List, Optional from swift.callbacks import callbacks_map -from swift.dataloader import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard +from swift.dataloader import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, DynamicMixBatchSampler from swift.hub import get_hub from swift.loss import loss_map from swift.metrics import MeanMetric, compute_acc, eval_metrics_map @@ -1221,11 +1221,26 @@ def get_train_dataloader(self, skip_batches=0): } if hasattr(train_dataset, '__len__'): - if args.group_by_length: - batch_sampler_params['group_by_length'] = args.group_by_length - batch_sampler_params['lengths'] = train_dataset['lengths'] - batch_sampler = BatchSamplerShard( - len(train_dataset), batch_size=self._train_batch_size, **batch_sampler_params) + if getattr(args, 'dynamic_mix', False): + domain_indices = self._build_domain_indices(train_dataset) + ws = (dist.get_world_size() // batch_sampler_params.get('tp_size', 1) + ) if dist.is_initialized() else 1 + total_per_rank = len(train_dataset) // ws + if args.dataloader_drop_last: + num_batches = total_per_rank // self._train_batch_size + else: + num_batches = (total_per_rank + self._train_batch_size - 1) // self._train_batch_size + batch_sampler = DynamicMixBatchSampler( + domain_indices=domain_indices, + batch_size=self._train_batch_size, + num_batches=num_batches, + **batch_sampler_params) + else: + if args.group_by_length: + batch_sampler_params['group_by_length'] = args.group_by_length + batch_sampler_params['lengths'] = train_dataset['lengths'] + batch_sampler = BatchSamplerShard( + len(train_dataset), batch_size=self._train_batch_size, **batch_sampler_params) dataloader_params['worker_init_fn'] = partial( seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index) if skip_batches > 0: @@ -1241,6 +1256,25 @@ def get_train_dataloader(self, skip_batches=0): dataloader = DataLoaderDispatcher(dataloader, self.accelerator.device, skip_batches=skip_batches) return dataloader + def _build_domain_indices(self, dataset): + """Extract channel -> index list mapping from the training dataset.""" + domain_indices = collections.defaultdict(list) + # Unwrap LazyLLMDataset to get underlying HfDataset + hf_dataset = dataset + if hasattr(hf_dataset, 'dataset'): + hf_dataset = hf_dataset.dataset + if hasattr(hf_dataset, 'features') and 'channel' in hf_dataset.features: + channels = hf_dataset['channel'] + for idx, ch in enumerate(channels): + domain_indices[ch if ch is not None else 'default'].append(idx) + else: + domain_indices['default'] = list(range(len(dataset))) + logger.warning('No "channel" column found in dataset. ' + 'All data treated as single domain "default".') + logger.info(f'Dynamic mix domains: ' + f'{{{", ".join(f"{k}: {len(v)}" for k, v in sorted(domain_indices.items()))}}}') + return dict(domain_indices) + @contextmanager def _disable_group_by_length(self): group_by_length = getattr(self.args, 'group_by_length', False) diff --git a/tests/train/test_dynamic_mix.py b/tests/train/test_dynamic_mix.py new file mode 100644 index 0000000000..25df9795cb --- /dev/null +++ b/tests/train/test_dynamic_mix.py @@ -0,0 +1,218 @@ +import collections +import unittest + +import torch + + +class TestDynamicMixBatchSampler(unittest.TestCase): + + def _make_sampler(self, domain_indices=None, batch_size=4, shuffle=True, + drop_last=False, data_seed=42, num_batches=10): + from swift.dataloader.shard import DynamicMixBatchSampler + if domain_indices is None: + domain_indices = { + 'math': list(range(0, 100)), + 'code': list(range(100, 200)), + 'general': list(range(200, 400)), + } + return DynamicMixBatchSampler( + domain_indices=domain_indices, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + data_seed=data_seed, + num_batches=num_batches, + ) + + def test_initial_probabilities_proportional_to_size(self): + sampler = self._make_sampler() + # math:100, code:100, general:200 => 0.25, 0.25, 0.5 + self.assertAlmostEqual(sampler.probabilities['math'], 0.25) + self.assertAlmostEqual(sampler.probabilities['code'], 0.25) + self.assertAlmostEqual(sampler.probabilities['general'], 0.5) + + def test_yields_correct_number_of_batches(self): + sampler = self._make_sampler(num_batches=20, batch_size=8) + batches = list(sampler) + self.assertEqual(len(batches), 20) + for batch in batches: + self.assertEqual(len(batch), 8) + + def test_len_equals_num_batches(self): + sampler = self._make_sampler(num_batches=15) + self.assertEqual(len(sampler), 15) + + def test_all_indices_belong_to_domains(self): + domain_indices = { + 'a': list(range(0, 50)), + 'b': list(range(50, 100)), + } + sampler = self._make_sampler(domain_indices=domain_indices, num_batches=30) + all_valid = set(range(100)) + for batch in sampler: + for idx in batch: + self.assertIn(idx, all_valid) + + def test_set_probabilities_changes_distribution(self): + sampler = self._make_sampler(num_batches=200, batch_size=8, data_seed=123) + # Set extreme probabilities: almost all samples from 'math' + sampler.set_probabilities({'math': 0.98, 'code': 0.01, 'general': 0.01}) + + math_indices = set(range(0, 100)) + math_count = 0 + total_count = 0 + for batch in sampler: + for idx in batch: + total_count += 1 + if idx in math_indices: + math_count += 1 + + math_ratio = math_count / total_count + # With 98% probability, math should dominate + self.assertGreater(math_ratio, 0.85) + + def test_set_probabilities_normalizes(self): + sampler = self._make_sampler() + sampler.set_probabilities({'math': 3.0, 'code': 1.0, 'general': 1.0}) + total = sum(sampler.probabilities.values()) + self.assertAlmostEqual(total, 1.0, places=6) + self.assertAlmostEqual(sampler.probabilities['math'], 0.6, places=6) + + def test_deterministic_with_same_seed(self): + """Same seed should produce identical batches (important for distributed consistency).""" + sampler1 = self._make_sampler(data_seed=99, num_batches=10, batch_size=4) + sampler2 = self._make_sampler(data_seed=99, num_batches=10, batch_size=4) + batches1 = list(sampler1) + batches2 = list(sampler2) + self.assertEqual(batches1, batches2) + + def test_different_seed_different_batches(self): + sampler1 = self._make_sampler(data_seed=1, num_batches=5) + sampler2 = self._make_sampler(data_seed=2, num_batches=5) + batches1 = list(sampler1) + batches2 = list(sampler2) + self.assertNotEqual(batches1, batches2) + + def test_set_epoch_changes_seed(self): + sampler = self._make_sampler(data_seed=42, num_batches=5) + batches_epoch0 = list(sampler) + sampler.set_epoch(1) + batches_epoch1 = list(sampler) + self.assertNotEqual(batches_epoch0, batches_epoch1) + + def test_no_shuffle_deterministic(self): + sampler = self._make_sampler(shuffle=False, num_batches=5) + batches1 = list(sampler) + batches2 = list(sampler) + self.assertEqual(batches1, batches2) + + def test_domain_exhaustion_reshuffle(self): + """When a small domain is exhausted, it should reshuffle and continue.""" + domain_indices = { + 'small': [0, 1], + 'large': list(range(2, 102)), + } + sampler = self._make_sampler( + domain_indices=domain_indices, num_batches=50, batch_size=4) + # Set probability heavily towards the small domain + sampler.set_probabilities({'small': 0.9, 'large': 0.1}) + # Should not raise, even though 'small' only has 2 samples + batches = list(sampler) + self.assertEqual(len(batches), 50) + + +class TestDynamicMixingCallback(unittest.TestCase): + + def test_update_probabilities_softmax(self): + """Verify that _update_probabilities applies softmax(loss/T) correctly.""" + from swift.callbacks.dynamic_mix import DynamicMixingCallback + + class FakeArgs: + dynamic_mix_update_steps = 10 + dynamic_mix_temperature = 1.0 + dynamic_mix_warmup_steps = 0 + + class FakeMeanMetric: + def __init__(self): + self.values = [] + + def update(self, v): + self.values.append(v) + + class FakeTrainer: + custom_metrics = { + 'train': collections.defaultdict(FakeMeanMetric), + } + + class FakeSampler: + domain_names = ['code', 'math'] + domain_indices = {'code': list(range(50)), 'math': list(range(50, 100))} + probabilities = {'code': 0.5, 'math': 0.5} + + def set_probabilities(self, probs): + self.probabilities = probs + + callback = DynamicMixingCallback(FakeArgs(), FakeTrainer()) + callback._sampler = FakeSampler() + callback._domain_names = ['code', 'math'] + + # Simulate loss values: math has higher loss + callback._loss_buffer['code'] = [1.0, 1.0] + callback._loss_buffer['math'] = [3.0, 3.0] + + callback._update_probabilities(global_step=10) + + # math (loss=3) should get higher probability than code (loss=1) + self.assertGreater( + callback._sampler.probabilities['math'], + callback._sampler.probabilities['code']) + + # With T=1: softmax([1,3]) = [exp(1)/(exp(1)+exp(3)), exp(3)/(exp(1)+exp(3))] + expected_math = torch.softmax(torch.tensor([1.0, 3.0]), dim=0)[1].item() + self.assertAlmostEqual( + callback._sampler.probabilities['math'], expected_math, places=5) + + def test_high_temperature_more_uniform(self): + """Higher temperature should produce more uniform distribution.""" + from swift.callbacks.dynamic_mix import DynamicMixingCallback + + class FakeArgs: + dynamic_mix_update_steps = 10 + dynamic_mix_temperature = 100.0 # Very high T + dynamic_mix_warmup_steps = 0 + + class FakeMeanMetric: + def update(self, v): + pass + + class FakeTrainer: + custom_metrics = { + 'train': collections.defaultdict(FakeMeanMetric), + } + + class FakeSampler: + domain_names = ['a', 'b'] + domain_indices = {'a': list(range(50)), 'b': list(range(50, 100))} + probabilities = {'a': 0.5, 'b': 0.5} + + def set_probabilities(self, probs): + self.probabilities = probs + + callback = DynamicMixingCallback(FakeArgs(), FakeTrainer()) + callback._sampler = FakeSampler() + callback._domain_names = ['a', 'b'] + + callback._loss_buffer['a'] = [1.0] + callback._loss_buffer['b'] = [10.0] + + callback._update_probabilities(global_step=10) + + # With very high T, both should be close to 0.5 + self.assertAlmostEqual( + callback._sampler.probabilities['a'], 0.5, places=1) + self.assertAlmostEqual( + callback._sampler.probabilities['b'], 0.5, places=1) + + +if __name__ == '__main__': + unittest.main()