[New Feature] MOPD#9035
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the teacher_model_group argument and refactors the GKDTrainer to use a choose_teacher_model method for selecting teacher models during loss computation. Feedback highlights that the use_mopd flag is missing from the argument definitions, which will result in an AttributeError. Additionally, the choose_teacher_model implementation is currently incomplete and will cause crashes when MOPD is enabled, and the core logic for incorporating teacher weights into the JSD loss calculation is still pending.
| remotely. When this is set, `teacher_model` is not required. Defaults to None. | ||
| """ | ||
| teacher_model: Optional[str] = None | ||
| teacher_model_group: List[str] = field(default_factory=list) |
There was a problem hiding this comment.
use_mopd 标志在 GKDTrainer 中被引用,但未在参数定义中声明。应在此处添加以避免 AttributeError。此外,建议更新 TeacherModelArguments 的 docstring 以包含 teacher_model_group 和 use_mopd 的说明。
| teacher_model_group: List[str] = field(default_factory=list) | |
| teacher_model_group: List[str] = field(default_factory=list) | |
| use_mopd: bool = False |
| def choose_teacher_model(self): | ||
| if not self.args.use_mopd: | ||
| return self.teacher_model | ||
| #todo 使用mopd时从教师模型组选择最佳模型 |
| t_log_probs = F.log_softmax(t_chunk, dim=-1) | ||
| del s_chunk, t_chunk | ||
|
|
||
| #todo 使用mopd的计算函数,增加教师模型权重 |
| unwrapped_info = f"Failed to unwrap model: {e}" | ||
|
|
||
| raise ValueError(f"Cannot determine teacher model path for tokenizer initialization. " | ||
| f"Model info: {model_info}. Unwrapped info: {unwrapped_info}") |
|
很乐意帮忙测试 有合并的计划吗~ @Jintao-Huang |
|
cc @hjh0119
|
# Conflicts: # swift/rlhf_trainers/gkd_trainer.py
|
Thanks for the contribution, feel free to let me know when it's ready to merge |
| remotely. When this is set, `teacher_model` is not required. Defaults to None. | ||
| """ | ||
| teacher_model: Optional[str] = None | ||
| teacher_model_group: List[str] = field(default_factory=list) |
There was a problem hiding this comment.
Consider changing teacher_model to Optional[List[str]] (similar to reward_model) to avoid introducing additional parameters
There was a problem hiding this comment.
No need for an extra use_mopd parameter, MOPD can be determined by the number of teacher models
| self.teacher_model_group_models = None | ||
| if args.rlhf_type == 'gkd' and hasattr(args, 'teacher_model_group') and args.teacher_model_group: | ||
| logger.info(f'Loading teacher_model_group with {len(args.teacher_model_group)} models') | ||
| self.teacher_model_group_models = [] | ||
| for idx, teacher_model_path in enumerate(args.teacher_model_group): | ||
| logger.info(f'Loading teacher model group [{idx}]: {teacher_model_path}') | ||
| # Use teacher_model_type and teacher_model_revision if available, otherwise infer | ||
| model_type = getattr(args, 'teacher_model_type', None) | ||
| model_revision = getattr(args, 'teacher_model_revision', None) | ||
|
|
||
| result = self._prepare_single_model_for_teacher_group(teacher_model_path, model_type, model_revision) | ||
| if result is not None: | ||
| model, _ = result | ||
| self.teacher_model_group_models.append(model) | ||
| logger.info(f'Successfully loaded teacher model group [{idx}]: {model}') | ||
| logger.info(f'Total teacher_model_group_models loaded: {len(self.teacher_model_group_models)}') |
There was a problem hiding this comment.
Similarly, just like the handling of reward_model, without introducing additional logic
| def _prepare_single_model_for_teacher_group(self, model_id_or_path, model_type, model_revision): | ||
| """Prepare a single model for teacher_model_group.""" | ||
| args = self.args | ||
|
|
||
| if model_type is None: | ||
| model_info, _ = get_model_info_meta(model_id_or_path) | ||
| model_type = model_info.model_type | ||
|
|
||
| model_dir = safe_snapshot_download( | ||
| model_id_or_path=model_id_or_path, | ||
| revision=model_revision, | ||
| download_model=False, | ||
| use_hf=args.use_hf, | ||
| hub_token=args.hub_token, | ||
| ) | ||
| task_type, num_labels = self._get_model_task_type(model_dir) | ||
|
|
||
| context = nullcontext() | ||
| if args.teacher_deepspeed: | ||
| if args.teacher_deepspeed.get('zero_optimization', {}).get('stage') != 3: | ||
| context = disable_deepspeed_zero3() | ||
| with context: | ||
| model, processor = args.get_model_processor( | ||
| model=model_id_or_path, | ||
| model_type=model_type, | ||
| revision=model_revision, | ||
| task_type=task_type, | ||
| num_labels=num_labels) | ||
|
|
||
| # For teacher models, set to eval mode and disable gradients | ||
| if self.args.sequence_parallel_size > 1: | ||
| sequence_parallel.prepare( | ||
| self.args.sequence_parallel_size, model, processor, padding_free=args.padding_free) | ||
| model.requires_grad_(False).eval() | ||
|
|
||
| HfConfigFactory.set_config_attr(model.config, 'use_cache', False) | ||
| return model, processor |
| # Pass pre-loaded teacher_model_group_models if available, otherwise pass the string list | ||
| if hasattr(self, 'teacher_model_group_models') and self.teacher_model_group_models: | ||
| trainer_kwargs['teacher_model_group_models'] = self.teacher_model_group_models | ||
| else: | ||
| trainer_kwargs['teacher_model_group'] = self.args.teacher_model_group | ||
| trainer_kwargs['teacher_use_disable_adapter'] = getattr(self.args, '_teacher_use_disable_adapter', False) | ||
| if self.args.use_mopd: | ||
| trainer_kwargs['use_mopd'] = self.args.use_mopd | ||
| # todo | ||
| # trainer_kwargs['mopd_config'] = self.args.mopd_config |
| self.gold_adapter_group[adapter_key] = GOLDLossAdapter( | ||
| config={ | ||
| 'use_uld_loss': True, | ||
| 'use_extended_uld': True, | ||
| 'uld_use_hybrid_loss': True, | ||
| 'uld_crossentropy_weight': 0.0, | ||
| 'uld_distillation_weight': 1.0, | ||
| 'uld_student_temperature': 1.0, | ||
| 'uld_teacher_temperature': 1.0, | ||
| }, | ||
| student_tokenizer=self.student_tokenizer, | ||
| teacher_tokenizer=teacher_tokenizer, | ||
| ) |
There was a problem hiding this comment.
For now, without considering the cross-tokenizer case, don't introduce GOLDLoss as it adds unnecessary complexity
| elif self.use_mopd: | ||
| num_teacher_models = len(self.teacher_model_group) | ||
| if num_teacher_models == 0: | ||
| raise ValueError("teacher_model_group cannot be empty") | ||
| loss = torch.tensor(0.0, device=model.device) | ||
| prompt_texts = inputs['prompt_text'] | ||
| completion_texts = inputs['completion_texts'] | ||
| ( | ||
| student_input_ids, | ||
| student_labels, | ||
| student_attention_mask, | ||
| student_prompt_length, | ||
| ) = self.build_inputs_from_texts( | ||
| self.student_tokenizer, | ||
| prompt_texts, | ||
| completion_texts | ||
| ) | ||
| # Student model forward pass (WITH gradients for student parameters) | ||
| outputs_student = model( | ||
| input_ids=student_input_ids, | ||
| attention_mask=student_attention_mask, | ||
| ) | ||
| for teacher_model in teacher_model_group: | ||
| print('-------self.use_generalized_jsd_loss') | ||
| teacher_tokenizer = self.teacher_tokenizer_group[id(teacher_model)] | ||
| # Add teacher model memory management like in liger branch | ||
| load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() | ||
|
|
||
| # Get adapter for current teacher tokenizer | ||
| adapter_key = id(teacher_tokenizer) | ||
| gold_adapter = self.gold_adapter_group[adapter_key] | ||
| with load_context: | ||
| with torch.no_grad(), disable_gradient_checkpointing(teacher_model, | ||
| self.args.gradient_checkpointing_kwargs): | ||
| ( | ||
| teacher_input_ids, | ||
| teacher_labels, | ||
| teacher_attention_mask, | ||
| teacher_prompt_length, | ||
| ) = self.build_inputs_from_texts( | ||
| teacher_tokenizer, | ||
| prompt_texts, | ||
| completion_texts | ||
| ) | ||
|
|
||
| # Teacher model forward pass (NO gradients) | ||
| outputs_teacher = teacher_model( | ||
| input_ids=teacher_input_ids, | ||
| attention_mask=teacher_attention_mask, | ||
| ) | ||
| # Ensure teacher_logits has gradient info but teacher model params don't participate | ||
| teacher_logits = outputs_teacher.logits.detach().requires_grad_(True) | ||
|
|
||
| # Release intermediate tensors to free memory | ||
| del teacher_attention_mask | ||
| loss_total = gold_adapter( | ||
| student_logits=outputs_student.logits, | ||
| teacher_logits=teacher_logits, | ||
| student_labels=student_labels, | ||
| teacher_labels=teacher_labels, | ||
| student_input_ids=student_input_ids, | ||
| teacher_input_ids=teacher_input_ids, | ||
| ) | ||
| loss += loss_total / len(teacher_model_group) | ||
| # Separate teacher model provided |
There was a problem hiding this comment.
Same as above, please update the multi-teacher logic based on the existing single-teacher logic
| elif self.use_mopd: | ||
| num_teacher_models = len(self.teacher_model_group) | ||
| if num_teacher_models == 0: | ||
| raise ValueError("teacher_model_group cannot be empty") | ||
| loss = torch.tensor(0.0, device=model.device) | ||
| prompt_texts = inputs['prompt_text'] | ||
| completion_texts = inputs['completion_texts'] | ||
| ( | ||
| student_input_ids, | ||
| student_labels, | ||
| student_attention_mask, | ||
| student_prompt_length, | ||
| ) = self.build_inputs_from_texts( | ||
| self.student_tokenizer, | ||
| prompt_texts, | ||
| completion_texts | ||
| ) | ||
| # Student model forward pass (WITH gradients for student parameters) | ||
| outputs_student = model( | ||
| input_ids=student_input_ids, | ||
| attention_mask=student_attention_mask, | ||
| ) | ||
| for teacher_model in teacher_model_group: | ||
| print('-------self.use_generalized_jsd_loss') | ||
| teacher_tokenizer = self.teacher_tokenizer_group[id(teacher_model)] | ||
| # Add teacher model memory management like in liger branch | ||
| load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() | ||
|
|
||
| # Get adapter for current teacher tokenizer | ||
| adapter_key = id(teacher_tokenizer) | ||
| gold_adapter = self.gold_adapter_group[adapter_key] | ||
| with load_context: | ||
| with torch.no_grad(), disable_gradient_checkpointing(teacher_model, | ||
| self.args.gradient_checkpointing_kwargs): | ||
| ( | ||
| teacher_input_ids, | ||
| teacher_labels, | ||
| teacher_attention_mask, | ||
| teacher_prompt_length, | ||
| ) = self.build_inputs_from_texts( | ||
| teacher_tokenizer, | ||
| prompt_texts, | ||
| completion_texts | ||
| ) | ||
|
|
||
| # Teacher model forward pass (NO gradients) | ||
| outputs_teacher = teacher_model( | ||
| input_ids=teacher_input_ids, | ||
| attention_mask=teacher_attention_mask, | ||
| ) | ||
| # Ensure teacher_logits has gradient info but teacher model params don't participate | ||
| teacher_logits = outputs_teacher.logits.detach().requires_grad_(True) | ||
|
|
||
| # Release intermediate tensors to free memory | ||
| del teacher_attention_mask | ||
| loss_total = gold_adapter( | ||
| student_logits=outputs_student.logits, | ||
| teacher_logits=teacher_logits, | ||
| student_labels=student_labels, | ||
| teacher_labels=teacher_labels, | ||
| student_input_ids=student_input_ids, | ||
| teacher_input_ids=teacher_input_ids, | ||
| ) | ||
| loss += loss_total / len(teacher_model_group) | ||
| # Separate teacher model provided |
| def build_inputs_from_texts( | ||
| self, | ||
| tokenizer: PreTrainedTokenizerBase, | ||
| prompt_texts: list[str], | ||
| completion_texts: list[str], | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: | ||
| """Tokenize teacher prompts/completions and produce tensors ready for GOLD loss.""" | ||
|
|
||
| pad_token_id = tokenizer.pad_token_id | ||
| eos_token_id = tokenizer.eos_token_id | ||
|
|
||
| prompt_token_ids = tokenizer(prompt_texts, add_special_tokens=True)["input_ids"] | ||
| completion_token_ids = tokenizer(completion_texts, add_special_tokens=False)["input_ids"] | ||
|
|
||
| sequences: list[torch.Tensor] = [] | ||
| attention_masks: list[torch.Tensor] = [] | ||
| labels_list: list[torch.Tensor] = [] | ||
| prompt_lengths: list[int] = [] | ||
| # Get device using reliable detection method | ||
| device = None | ||
| try: | ||
| # First try to get device from model parameters | ||
| if hasattr(self, 'model') and self.model is not None: | ||
| device = next(self.model.parameters()).device | ||
| elif hasattr(self, 'teacher_model') and self.teacher_model is not None: | ||
| device = next(self.teacher_model.parameters()).device | ||
| except (AttributeError, StopIteration): | ||
| pass | ||
|
|
||
| # Fallback to default device detection | ||
| if device is None: | ||
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
| for prompt_ids, completion_ids in zip(prompt_token_ids, completion_token_ids, strict=True): | ||
| # Remove trailing EOS from prompt so completions can extend cleanly | ||
| if eos_token_id is not None and prompt_ids and prompt_ids[-1] == eos_token_id: | ||
| prompt_ids = prompt_ids[:-1] | ||
|
|
||
| prompt_lengths.append(len(prompt_ids)) | ||
| sequence = list(prompt_ids) | ||
| sequence.extend(completion_ids) | ||
| if eos_token_id is not None: | ||
| sequence.append(eos_token_id) | ||
|
|
||
| seq_tensor = torch.tensor(sequence, dtype=torch.long, device=device) | ||
| sequences.append(seq_tensor) | ||
| attention_masks.append(torch.ones_like(seq_tensor)) | ||
| labels = seq_tensor.clone() | ||
| labels[: len(prompt_ids)] = -100 | ||
| if pad_token_id is not None: | ||
| labels[labels == pad_token_id] = -100 | ||
| labels_list.append(labels) | ||
|
|
||
| teacher_input_ids = pad( | ||
| sequences, | ||
| padding_side="right", | ||
| padding_value=pad_token_id if pad_token_id is not None else 0, | ||
| ) | ||
| teacher_attention_mask = pad(attention_masks, padding_side="right", padding_value=0).bool() | ||
| teacher_labels = pad(labels_list, padding_side="right", padding_value=-100) | ||
|
|
||
| if eos_token_id is not None: | ||
| for row in range(teacher_attention_mask.size(0)): | ||
| valid = ( | ||
| teacher_input_ids[row] != pad_token_id | ||
| if pad_token_id is not None | ||
| else teacher_attention_mask[row].bool() | ||
| ) | ||
| if valid.any(): | ||
| last_idx = valid.nonzero(as_tuple=True)[0][-1] | ||
| teacher_attention_mask[row, last_idx + 1:] = False | ||
|
|
||
| teacher_prompt_length = max(prompt_lengths) if prompt_lengths else 0 | ||
|
|
||
| return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length | ||
|
|
||
| def get_model_path(self, model): | ||
| model_path = getattr(model, 'name_or_path', None) | ||
| if model_path is None: | ||
| # Try to get path from config | ||
| if hasattr(model, 'config') and hasattr(model.config, '_name_or_path'): | ||
| model_path = model.config._name_or_path | ||
| if model_path is None: | ||
| # If still None, try to get from model's base model | ||
| unwrapped_student = self.accelerator.unwrap_model(model) | ||
| if hasattr(unwrapped_student, 'base_model_prefix'): | ||
| base_model = getattr(unwrapped_student, unwrapped_student.base_model_prefix, unwrapped_student) | ||
| if hasattr(base_model, 'config') and hasattr(base_model.config, '_name_or_path'): | ||
| model_path = base_model.config._name_or_path | ||
| # Additional fallback: try to get from model's config name_or_path attribute | ||
| if model_path is None: | ||
| if hasattr(model, 'config') and hasattr(model.config, 'name_or_path'): | ||
| model_path = model.config.name_or_path | ||
| # Additional fallback: try to get from unwrapped model's name_or_path | ||
| if model_path is None: | ||
| unwrapped_student = self.accelerator.unwrap_model(model) | ||
| model_path = getattr(unwrapped_student, 'name_or_path', None) | ||
| # Additional fallback: try to get from unwrapped model's config name_or_path | ||
| if model_path is None: | ||
| unwrapped_student = self.accelerator.unwrap_model(model) | ||
| if hasattr(unwrapped_student, 'config') and hasattr(unwrapped_student.config, 'name_or_path'): | ||
| model_path = unwrapped_student.config.name_or_path | ||
| return model_path |
There was a problem hiding this comment.
remove, reuse origin logic
| self, | ||
| student_logits, | ||
| teacher_logits=None, | ||
| labels=None, | ||
| beta=0.5, | ||
| temperature=1.0, | ||
| chunk_size=512, | ||
| topk=None, | ||
| teacher_topk_logprobs=None, | ||
| teacher_topk_indices=None, |
| from typing import List, Optional, Tuple | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from transformers import PreTrainedTokenizerBase | ||
|
|
||
|
|
||
| class GOLDLossAdapter(nn.Module): | ||
| """ | ||
| - GOLD (General Online Logit Distillation) 损失函数适配器 | ||
| 支持: | ||
| 1. ULD损失 (Universal Logit Distillation) | ||
| 2. 扩展ULD (跨tokenizer对齐) | ||
| 3. 混合损失 (Hybrid ULD + JSD) | ||
|
|
||
| 使用示例: | ||
| adapter = GOLDLossAdapter( | ||
| config={ | ||
| 'use_uld_loss': True, | ||
| 'use_extended_uld': True, | ||
| 'uld_use_hybrid_loss': False, | ||
| 'uld_crossentropy_weight': 0.0, | ||
| 'uld_distillation_weight': 1.0, | ||
| 'uld_student_temperature': 1.0, | ||
| 'uld_teacher_temperature': 1.0, | ||
| }, | ||
| student_tokenizer=student_tok, | ||
| teacher_tokenizer=teacher_tok, | ||
| ) | ||
|
|
||
| loss = adapter( | ||
| student_logits=student_outputs.logits, | ||
| teacher_logits=teacher_outputs.logits, | ||
| student_labels=student_labels, | ||
| teacher_labels=teacher_labels, | ||
| student_input_ids=student_input_ids, | ||
| teacher_input_ids=teacher_input_ids, | ||
| ) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: dict, | ||
| student_tokenizer: Optional[PreTrainedTokenizerBase] = None, | ||
| teacher_tokenizer: Optional[PreTrainedTokenizerBase] = None, | ||
| device: Optional[torch.device] = None, | ||
| ): | ||
| super().__init__() | ||
| self.device = device | ||
|
|
||
| # 基础配置 | ||
| self.use_uld_loss = config.get('use_uld_loss', True) # 是否开启通用蒸馏 | ||
| self.crossentropy_weight = config.get('uld_crossentropy_weight', 0.0) | ||
| self.distillation_weight = config.get('uld_distillation_weight', 1.0) | ||
| self.student_temperature = config.get('uld_student_temperature', 0.9) | ||
| self.teacher_temperature = config.get('uld_teacher_temperature', 0.9) | ||
| self.skip_student_eos = config.get('uld_skip_student_eos', True) | ||
| self.skip_teacher_eos = config.get('uld_skip_teacher_eos', True) | ||
| self.use_extended_uld = config.get('use_extended_uld', True) | ||
| self.ignore_index = -100 | ||
|
|
||
| # Tokenizers | ||
| self.student_tokenizer = student_tokenizer | ||
| self.teacher_tokenizer = teacher_tokenizer | ||
|
|
||
| # Hybrid ULD配置 | ||
| self.use_hybrid_loss = config.get('uld_use_hybrid_loss', True) # 是否对完全匹配的词汇进行匹配,开启提高稳定性 | ||
| self.hybrid_matched_weight = config.get('uld_hybrid_matched_weight', None) | ||
| self.hybrid_unmatched_weight = config.get('uld_hybrid_unmatched_weight', None) | ||
| self.beta = config.get('beta', 1.0) | ||
|
|
||
| # 初始化词汇映射(用于hybrid loss) | ||
| self._vocab_mapping = None | ||
| self._teacher_matched_ids = None | ||
| self._student_matched_ids = None | ||
| self.mapping_tensor = None | ||
|
|
||
| if self.use_hybrid_loss and student_tokenizer and teacher_tokenizer: | ||
| self._initialize_vocabulary_mapping() | ||
|
|
||
| # 用于logging | ||
| self.last_matched_loss = None | ||
| self.last_unmatched_loss = None | ||
|
|
||
| def _initialize_vocabulary_mapping(self): | ||
| """初始化学生-教师tokenizer的词汇映射""" | ||
| student_vocab = self.student_tokenizer.get_vocab() | ||
| teacher_vocab = self.teacher_tokenizer.get_vocab() | ||
|
|
||
| student_token_to_id = dict(student_vocab.items()) | ||
|
|
||
| vocab_mapping = {} | ||
| teacher_matched_ids = set() | ||
| student_matched_ids = set() | ||
|
|
||
| for token_str, teacher_id in teacher_vocab.items(): | ||
| if token_str in student_token_to_id: | ||
| student_id = student_token_to_id[token_str] | ||
| vocab_mapping[teacher_id] = student_id | ||
| teacher_matched_ids.add(teacher_id) | ||
| student_matched_ids.add(student_id) | ||
|
|
||
| self._vocab_mapping = vocab_mapping | ||
| self._teacher_matched_ids = teacher_matched_ids | ||
| self._student_matched_ids = student_matched_ids | ||
|
|
||
| if self._vocab_mapping: | ||
| max_matched_teacher_id = max(self._vocab_mapping.keys()) | ||
| self.mapping_tensor = torch.full( | ||
| (max_matched_teacher_id + 1,), -1, dtype=torch.long | ||
| ) | ||
| for k, v in self._vocab_mapping.items(): | ||
| self.mapping_tensor[k] = v | ||
| if self.device: | ||
| self.mapping_tensor = self.mapping_tensor.to(self.device) | ||
|
|
||
| def forward( | ||
| self, | ||
| student_logits: torch.Tensor, | ||
| teacher_logits: torch.Tensor, | ||
| student_labels: torch.Tensor, | ||
| teacher_labels: torch.Tensor, | ||
| student_input_ids: torch.Tensor, | ||
| teacher_input_ids: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| 计算GOLD/ULD损失 | ||
|
|
||
| Args: | ||
| student_logits: [batch_size, seq_len, student_vocab_size] | ||
| teacher_logits: [batch_size, seq_len, teacher_vocab_size] | ||
| student_labels: [batch_size, seq_len], -100表示忽略 | ||
| teacher_labels: [batch_size, seq_len], -100表示忽略 | ||
| student_input_ids: [batch_size, seq_len] | ||
| teacher_input_ids: [batch_size, seq_len] | ||
|
|
||
| Returns: | ||
| loss: scalar tensor | ||
| """ | ||
|
|
||
| if not self.use_uld_loss: | ||
| return torch.tensor(0.0, device=student_logits.device, requires_grad=True) | ||
|
|
||
| # 1. Cross-entropy loss (可选,通过crossentropy_weight设置权重) | ||
| crossentropy_loss = self._compute_cross_entropy(student_logits, student_labels) | ||
|
|
||
| # 2. Distillation loss (ULD) | ||
| distillation_loss = self._compute_distillation_loss( | ||
| student_logits, teacher_logits, | ||
| student_labels, teacher_labels, | ||
| student_input_ids, teacher_input_ids | ||
| ) | ||
| return crossentropy_loss + distillation_loss | ||
|
|
||
| def _compute_cross_entropy( | ||
| self, | ||
| student_logits: torch.Tensor, | ||
| student_labels: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| """计算cross-entropy loss""" | ||
| if self.crossentropy_weight <= 0: | ||
| return torch.tensor(0.0, device=student_logits.device, requires_grad=True) | ||
|
|
||
| shift_logits = student_logits[..., :-1, :].contiguous() | ||
| shift_labels = student_labels[..., 1:].contiguous() | ||
|
|
||
| loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) | ||
| ce_loss = loss_fct( | ||
| shift_logits.view(-1, shift_logits.size(-1)), | ||
| shift_labels.view(-1) | ||
| ) | ||
| return self.crossentropy_weight * ce_loss | ||
|
|
||
| def _compute_distillation_loss( | ||
| self, | ||
| student_logits: torch.Tensor, | ||
| teacher_logits: torch.Tensor, | ||
| student_labels: torch.Tensor, | ||
| teacher_labels: torch.Tensor, | ||
| student_input_ids: torch.Tensor, | ||
| teacher_input_ids: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """计算ULD蒸馏损失""" | ||
| # 获取答案区域 | ||
| student_answer_idx, student_answer_size = self._get_answer_regions(student_labels) | ||
| teacher_answer_idx, teacher_answer_size = self._get_answer_regions(teacher_labels) | ||
|
|
||
| if self.skip_student_eos: | ||
| student_answer_size = [s - 1 for s in student_answer_size] | ||
| if self.skip_teacher_eos: | ||
| teacher_answer_size = [t - 1 for t in teacher_answer_size] | ||
|
|
||
| # 边界检查 | ||
| if not student_answer_size or not teacher_answer_size: | ||
| return torch.zeros(1, device=student_logits.device, requires_grad=True) * 1e-8 | ||
|
|
||
| batch_size = student_logits.size(0) | ||
| distillation_losses = [] | ||
|
|
||
| for i in range(batch_size): | ||
| s_start = student_answer_idx[i] | ||
| s_size = student_answer_size[i] | ||
| t_start = teacher_answer_idx[i] | ||
| t_size = teacher_answer_size[i] | ||
|
|
||
| if s_size <= 0 or t_size <= 0: | ||
| loss_i = student_logits[i].sum() * 0.0 | ||
| # Ensure the loss tensor requires gradients | ||
| loss_i = loss_i.detach().requires_grad_(True) | ||
| distillation_losses.append(loss_i) | ||
| continue | ||
|
|
||
| # 提取答案logits | ||
| student_ans_logits = student_logits[i, s_start:s_start + s_size] | ||
| teacher_ans_logits = teacher_logits[i, t_start:t_start + t_size] | ||
|
|
||
| # 转换为概率 | ||
| student_probs = F.softmax(student_ans_logits / self.student_temperature, dim=-1) | ||
| teacher_probs = F.softmax(teacher_ans_logits / self.teacher_temperature, dim=-1) | ||
|
|
||
| student_token_ids = student_input_ids[i, s_start:s_start + s_size].tolist() | ||
| teacher_token_ids = teacher_input_ids[i, t_start:t_start + t_size].tolist() | ||
|
|
||
| # Token对齐 | ||
| if self.use_extended_uld: | ||
| student_groups, teacher_groups = self._build_alignment_groups_from_ids( | ||
| student_token_ids, teacher_token_ids | ||
| ) | ||
|
|
||
| student_aligned = self._merge_probabilities_with_groups( | ||
| student_probs, student_groups, student_token_ids | ||
| ) | ||
| teacher_aligned = self._merge_probabilities_with_groups( | ||
| teacher_probs, teacher_groups, teacher_token_ids | ||
| ) | ||
|
|
||
| else: | ||
| min_len = min(len(student_token_ids), len(teacher_token_ids)) | ||
| student_aligned = student_probs[:min_len] | ||
| teacher_aligned = teacher_probs[:min_len] | ||
|
|
||
| # 计算损失 | ||
| if self.use_hybrid_loss and self._vocab_mapping: | ||
| aligned_loss = self._compute_hybrid_uld_loss(student_aligned, teacher_aligned) | ||
| else: | ||
| aligned_loss = self._compute_basic_uld_loss(student_aligned, teacher_aligned) | ||
|
|
||
| distillation_losses.append(aligned_loss) | ||
| distillation_loss = torch.stack(distillation_losses).mean() | ||
| return self.distillation_weight * distillation_loss | ||
|
|
||
| def _get_answer_regions(self, labels: torch.Tensor) -> Tuple[List[int], List[int]]: | ||
| """获取答案区域的起始位置和大小""" | ||
| indices = [] | ||
| sizes = [] | ||
|
|
||
| for label in labels: | ||
| mask = label.ne(self.ignore_index) | ||
| if not mask.any(): | ||
| indices.append(0) | ||
| sizes.append(0) | ||
| continue | ||
|
|
||
| valid_indices = mask.nonzero(as_tuple=True)[0] | ||
| indices.append(int(valid_indices[0].item())) | ||
| sizes.append(int(mask.sum().item())) | ||
|
|
||
| return indices, sizes | ||
|
|
||
| def _build_alignment_groups_from_ids( | ||
| self, | ||
| student_token_ids: List[int], | ||
| teacher_token_ids: List[int] | ||
| ) -> Tuple[List[List[int]], List[List[int]]]: | ||
| """ | ||
| 基于文本内容构建对齐组 | ||
| 使用贪心子串匹配算法 | ||
| """ | ||
|
|
||
| def decode_tokens(tokenizer, token_ids): | ||
| pieces = [] | ||
| prev = "" | ||
| for k in range(len(token_ids)): | ||
| cur = tokenizer.decode(token_ids[:k + 1], skip_special_tokens=False) | ||
| pieces.append(cur[len(prev):]) | ||
| prev = cur | ||
| return pieces | ||
|
|
||
| student_pieces = decode_tokens(self.student_tokenizer, student_token_ids) | ||
| teacher_pieces = decode_tokens(self.teacher_tokenizer, teacher_token_ids) | ||
|
|
||
| # 贪心匹配算法 | ||
| student_groups = [] | ||
| teacher_groups = [] | ||
| s_idx = 0 | ||
| t_idx = 0 | ||
|
|
||
| while s_idx < len(student_pieces) and t_idx < len(teacher_pieces): | ||
| student_text = "" | ||
| teacher_text = "" | ||
| student_group = [] | ||
| teacher_group = [] | ||
|
|
||
| # 尝试找到最短的连续匹配序列 | ||
| while s_idx < len(student_pieces) and t_idx < len(teacher_pieces): | ||
| if not student_group: | ||
| student_group.append(s_idx) | ||
| student_text += student_pieces[s_idx] | ||
| s_idx += 1 | ||
|
|
||
| if not teacher_group: | ||
| teacher_group.append(t_idx) | ||
| teacher_text += teacher_pieces[t_idx] | ||
| t_idx += 1 | ||
|
|
||
| # 检查是否匹配 | ||
| if student_text == teacher_text: | ||
| student_groups.append(student_group) | ||
| teacher_groups.append(teacher_group) | ||
| break | ||
| elif len(student_text) < len(teacher_text): | ||
| if s_idx < len(student_pieces): | ||
| student_group.append(s_idx) | ||
| student_text += student_pieces[s_idx] | ||
| s_idx += 1 | ||
| else: | ||
| break | ||
| else: | ||
| if t_idx < len(teacher_pieces): | ||
| teacher_group.append(t_idx) | ||
| teacher_text += teacher_pieces[t_idx] | ||
| t_idx += 1 | ||
| else: | ||
| break | ||
| else: | ||
| # 未完全匹配,添加剩余部分 | ||
| if student_group and teacher_group: | ||
| student_groups.append(student_group) | ||
| teacher_groups.append(teacher_group) | ||
|
|
||
| return student_groups, teacher_groups | ||
|
|
||
| def _merge_probabilities_with_groups( | ||
| self, | ||
| probs: torch.Tensor, | ||
| alignment_groups: List[List[int]], | ||
| token_ids: List[int], | ||
| ) -> torch.Tensor: | ||
| """ | ||
| 根据对齐组合并概率分布 | ||
| 使用链式法则: P_merged = P(y|x_0) * P(x_1|x_0) * P(x_2|x_0,x_1) * ... | ||
| """ | ||
| aligned_probs = [] | ||
|
|
||
| for group in alignment_groups: | ||
| if len(group) > 1: | ||
| # 第一个token的边际概率 | ||
| marginal_probs = probs[group[0]] # [vocab_size] | ||
|
|
||
| # 后续token的条件概率(标量) | ||
| conditional_product = 1.0 | ||
| for k in range(1, len(group)): | ||
| cond_prob = probs[group[k], token_ids[group[k - 1]]] | ||
| conditional_product *= cond_prob | ||
|
|
||
| merged_probs = marginal_probs * conditional_product | ||
| aligned_probs.append(merged_probs) | ||
| elif len(group) == 1: | ||
| aligned_probs.append(probs[group[0]]) | ||
|
|
||
| if aligned_probs: | ||
| return torch.stack(aligned_probs) | ||
| else: | ||
| # 返回一个空的但需要梯度的张量 | ||
| empty_tensor = probs[:0].detach().requires_grad_(True) | ||
| return empty_tensor | ||
|
|
||
| def _compute_basic_uld_loss( | ||
| self, | ||
| student_aligned: torch.Tensor, | ||
| teacher_aligned: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """基础ULD损失:排序后的L1距离""" | ||
| student_sorted = student_aligned.sort(dim=-1, descending=True).values | ||
| teacher_sorted = teacher_aligned.sort(dim=-1, descending=True).values | ||
|
|
||
| # Padding到相同vocab size | ||
| s_vocab = student_sorted.size(-1) | ||
| t_vocab = teacher_sorted.size(-1) | ||
| max_vocab = max(s_vocab, t_vocab) | ||
|
|
||
| if s_vocab < max_vocab: | ||
| student_sorted = F.pad(student_sorted, (0, max_vocab - s_vocab)) | ||
| if t_vocab < max_vocab: | ||
| teacher_sorted = F.pad(teacher_sorted, (0, max_vocab - t_vocab)) | ||
|
|
||
| loss = F.l1_loss(student_sorted, teacher_sorted, reduction="sum") | ||
| loss /= student_aligned.size(0) | ||
|
|
||
| return loss | ||
|
|
||
| def _compute_hybrid_uld_loss( | ||
| self, | ||
| student_aligned: torch.Tensor, | ||
| teacher_aligned: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """混合ULD损失:matched用JSD,unmatched用排序L1""" | ||
| device = student_aligned.device | ||
| s_vocab = student_aligned.size(-1) | ||
| t_vocab = teacher_aligned.size(-1) | ||
|
|
||
| # 创建matched/unmatched masks | ||
| if self._teacher_matched_ids: | ||
| teacher_matched_idx = torch.tensor( | ||
| sorted(self._teacher_matched_ids), dtype=torch.long, device=device | ||
| ) | ||
| student_matched_idx = self.mapping_tensor[teacher_matched_idx] | ||
| else: | ||
| teacher_matched_idx = torch.tensor([], dtype=torch.long, device=device) | ||
| student_matched_idx = torch.tensor([], dtype=torch.long, device=device) | ||
|
|
||
| teacher_matched_mask = torch.zeros(t_vocab, dtype=torch.bool, device=device) | ||
| student_matched_mask = torch.zeros(s_vocab, dtype=torch.bool, device=device) | ||
|
|
||
| if len(teacher_matched_idx) > 0: | ||
| teacher_matched_mask[teacher_matched_idx] = True | ||
| student_matched_mask[student_matched_idx] = True | ||
|
|
||
| # 1. Matched tokens的JSD损失 | ||
| matched_loss = torch.tensor(0.0, device=device, requires_grad=True) | ||
| matched_count = 0 | ||
|
|
||
| if len(teacher_matched_idx) > 0: | ||
| teacher_matched_probs = teacher_aligned[:, teacher_matched_idx] | ||
| student_matched_probs = student_aligned[:, student_matched_idx] | ||
| matched_count = teacher_matched_probs.size(-1) | ||
|
|
||
| matched_loss = self._compute_jsd_for_matched( | ||
| student_matched_probs, teacher_matched_probs | ||
| ) | ||
| # 2. Unmatched tokens的排序L1损失 | ||
| teacher_unmatched = teacher_aligned[:, ~teacher_matched_mask] | ||
| student_unmatched = student_aligned[:, ~student_matched_mask] | ||
|
|
||
| unmatched_loss = torch.tensor(0.0, device=device, requires_grad=True) | ||
| if teacher_unmatched.size(-1) > 0 and student_unmatched.size(-1) > 0: | ||
| teacher_sorted = teacher_unmatched.sort(dim=-1, descending=True).values | ||
| student_sorted = student_unmatched.sort(dim=-1, descending=True).values | ||
|
|
||
| t_size = teacher_sorted.size(-1) | ||
| s_size = student_sorted.size(-1) | ||
| max_size = max(t_size, s_size) | ||
|
|
||
| if t_size < max_size: | ||
| teacher_sorted = F.pad(teacher_sorted, (0, max_size - t_size)) | ||
| if s_size < max_size: | ||
| student_sorted = F.pad(student_sorted, (0, max_size - s_size)) | ||
|
|
||
| unmatched_loss = F.l1_loss(student_sorted, teacher_sorted, reduction="sum") | ||
| unmatched_loss /= student_aligned.size(0) | ||
|
|
||
| # 3. 加权组合 | ||
| if self.hybrid_matched_weight is None: | ||
| w_matched = matched_count / max(1, t_vocab) | ||
| w_unmatched = 1.0 - w_matched | ||
| else: | ||
| w_matched = self.hybrid_matched_weight | ||
| w_unmatched = self.hybrid_unmatched_weight | ||
|
|
||
| total_loss = w_matched * matched_loss + w_unmatched * unmatched_loss | ||
|
|
||
| # 保存用于logging | ||
| self.last_matched_loss = matched_loss | ||
| self.last_unmatched_loss = unmatched_loss | ||
|
|
||
| return total_loss | ||
|
|
||
| def _compute_jsd_for_matched( | ||
| self, | ||
| student_probs: torch.Tensor, | ||
| teacher_probs: torch.Tensor, | ||
| epsilon: float = 1e-8 | ||
| ) -> torch.Tensor: | ||
| """计算matched tokens的JSD损失,添加数值稳定性处理""" | ||
| batch_seq_len, num_matched = student_probs.shape | ||
|
|
||
| # 检查输入概率分布是否有效 | ||
| if torch.isnan(student_probs).any() or torch.isnan(teacher_probs).any(): | ||
| return torch.tensor(0.0, device=student_probs.device, requires_grad=True) | ||
|
|
||
| # 添加epsilon防止数值下溢和log(0) | ||
| student_probs = student_probs.clamp(min=epsilon) | ||
| teacher_probs = teacher_probs.clamp(min=epsilon) | ||
|
|
||
| # 重新归一化概率分布 | ||
| student_probs = student_probs / student_probs.sum(dim=-1, keepdim=True) | ||
| teacher_probs = teacher_probs / teacher_probs.sum(dim=-1, keepdim=True) | ||
|
|
||
| student_flat = student_probs.view(-1, num_matched) | ||
| teacher_flat = teacher_probs.view(-1, num_matched) | ||
|
|
||
| # JSD = 0.5 * KL(P||M) + 0.5 * KL(Q||M), where M = 0.5*(P+Q) | ||
| m = 0.5 * (student_flat + teacher_flat) | ||
|
|
||
| # 添加epsilon到中间分布 | ||
| m = m.clamp(min=epsilon) | ||
| m = m / m.sum(dim=-1, keepdim=True) | ||
|
|
||
| # 直接对概率分布取对数,添加epsilon防止数值问题 | ||
| log_m = torch.log(m + epsilon) | ||
| log_student = torch.log(student_flat + epsilon) | ||
| log_teacher = torch.log(teacher_flat + epsilon) | ||
|
|
||
| # 使用log_target=True,传入log概率 | ||
| kl_p_m = F.kl_div(log_m, log_student, reduction='batchmean', log_target=True) | ||
| kl_q_m = F.kl_div(log_m, log_teacher, reduction='batchmean', log_target=True) | ||
| jsd = 0.5 * (kl_p_m + kl_q_m) | ||
|
|
||
| # 检查结果是否有效 | ||
| if torch.isnan(jsd) or torch.isinf(jsd): | ||
| return torch.tensor(0.0, device=student_probs.device, requires_grad=True) | ||
|
|
||
| return jsd |
There was a problem hiding this comment.
Same as above, without considering GOLDLoss, limit MOPD to the premise of using the same tokenizer
|
Please keep the PR clean. For MOPD, I think we can reuse the existing GKD pipeline and simply extend the original teacher_model related logic to support multiple teachers |
|
加油 |

PR type
PR information
https://github.com/XiaomiMiMo/MiMo-V2-Flash/blob/main/paper.pdf
小米团队推出的MiMo-V2中提出了一种新的后训练范式Multi-Teacher On-Policy Distillation (MOPD),用于解决能力不平衡的问题(例如:提升了数学,代码能力下降)。
Write the detail information belongs to this PR.
Experiment results
Paste your experiment result here(if needed).