Skip to content

[New Feature] MOPD#9035

Open
doctorMcy wants to merge 17 commits into
modelscope:mainfrom
doctorMcy:feature_MOPD
Open

[New Feature] MOPD#9035
doctorMcy wants to merge 17 commits into
modelscope:mainfrom
doctorMcy:feature_MOPD

Conversation

@doctorMcy
Copy link
Copy Markdown

PR type

  • Bug Fix
  • [ √ ] New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

https://github.com/XiaomiMiMo/MiMo-V2-Flash/blob/main/paper.pdf
小米团队推出的MiMo-V2中提出了一种新的后训练范式Multi-Teacher On-Policy Distillation (MOPD),用于解决能力不平衡的问题(例如:提升了数学,代码能力下降)。
Write the detail information belongs to this PR.

Experiment results

Paste your experiment result here(if needed).

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the teacher_model_group argument and refactors the GKDTrainer to use a choose_teacher_model method for selecting teacher models during loss computation. Feedback highlights that the use_mopd flag is missing from the argument definitions, which will result in an AttributeError. Additionally, the choose_teacher_model implementation is currently incomplete and will cause crashes when MOPD is enabled, and the core logic for incorporating teacher weights into the JSD loss calculation is still pending.

remotely. When this is set, `teacher_model` is not required. Defaults to None.
"""
teacher_model: Optional[str] = None
teacher_model_group: List[str] = field(default_factory=list)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

use_mopd 标志在 GKDTrainer 中被引用,但未在参数定义中声明。应在此处添加以避免 AttributeError。此外,建议更新 TeacherModelArguments 的 docstring 以包含 teacher_model_groupuse_mopd 的说明。

Suggested change
teacher_model_group: List[str] = field(default_factory=list)
teacher_model_group: List[str] = field(default_factory=list)
use_mopd: bool = False

Comment thread swift/rlhf_trainers/gkd_trainer.py Outdated
Comment on lines +447 to +450
def choose_teacher_model(self):
if not self.args.use_mopd:
return self.teacher_model
#todo 使用mopd时从教师模型组选择最佳模型
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

choose_teacher_model 的实现不完整。当启用 use_mopd 时,该函数目前隐式返回 None,这将导致 compute_loss 在预期有效模型的地方(例如第 311 行或第 419 行)发生崩溃。此外,teacher_model_group 中的模型需要在训练器初始化期间加载并准备为 module 对象,以便在此处使用。

Comment thread swift/rlhf_trainers/gkd_trainer.py Outdated
t_log_probs = F.log_softmax(t_chunk, dim=-1)
del s_chunk, t_chunk

#todo 使用mopd的计算函数,增加教师模型权重
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此 TODO 表明 MOPD 的核心逻辑(将教师权重纳入 JSD 损失计算)尚未实现。如果没有这部分实现,MOPD 功能将无法按预期工作。

Comment thread swift/rlhf_trainers/gkd_trainer.py Outdated
unwrapped_info = f"Failed to unwrap model: {e}"

raise ValueError(f"Cannot determine teacher model path for tokenizer initialization. "
f"Model info: {model_info}. Unwrapped info: {unwrapped_info}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一段可以抽象一个函数

@Tohrusky
Copy link
Copy Markdown
Contributor

很乐意帮忙测试 有合并的计划吗~ @Jintao-Huang

@Tohrusky
Copy link
Copy Markdown
Contributor

image

I think it would be better if adding a shell in the examples directory to show how to start training.

Specifically, it'll show how to launch the vLLM rollout with multiple teachers, start MOPD training, and whether use reward functions(ORM) as a part of A in the paper.

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

cc @hjh0119

很乐意帮忙测试 有合并的计划吗~ @Jintao-Huang

@hjh0119
Copy link
Copy Markdown
Collaborator

hjh0119 commented May 18, 2026

Thanks for the contribution, feel free to let me know when it's ready to merge

remotely. When this is set, `teacher_model` is not required. Defaults to None.
"""
teacher_model: Optional[str] = None
teacher_model_group: List[str] = field(default_factory=list)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider changing teacher_model to Optional[List[str]] (similar to reward_model) to avoid introducing additional parameters

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for an extra use_mopd parameter, MOPD can be determined by the number of teacher models

Comment on lines +136 to +151
self.teacher_model_group_models = None
if args.rlhf_type == 'gkd' and hasattr(args, 'teacher_model_group') and args.teacher_model_group:
logger.info(f'Loading teacher_model_group with {len(args.teacher_model_group)} models')
self.teacher_model_group_models = []
for idx, teacher_model_path in enumerate(args.teacher_model_group):
logger.info(f'Loading teacher model group [{idx}]: {teacher_model_path}')
# Use teacher_model_type and teacher_model_revision if available, otherwise infer
model_type = getattr(args, 'teacher_model_type', None)
model_revision = getattr(args, 'teacher_model_revision', None)

result = self._prepare_single_model_for_teacher_group(teacher_model_path, model_type, model_revision)
if result is not None:
model, _ = result
self.teacher_model_group_models.append(model)
logger.info(f'Successfully loaded teacher model group [{idx}]: {model}')
logger.info(f'Total teacher_model_group_models loaded: {len(self.teacher_model_group_models)}')
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, just like the handling of reward_model, without introducing additional logic

Comment on lines +187 to +223
def _prepare_single_model_for_teacher_group(self, model_id_or_path, model_type, model_revision):
"""Prepare a single model for teacher_model_group."""
args = self.args

if model_type is None:
model_info, _ = get_model_info_meta(model_id_or_path)
model_type = model_info.model_type

model_dir = safe_snapshot_download(
model_id_or_path=model_id_or_path,
revision=model_revision,
download_model=False,
use_hf=args.use_hf,
hub_token=args.hub_token,
)
task_type, num_labels = self._get_model_task_type(model_dir)

context = nullcontext()
if args.teacher_deepspeed:
if args.teacher_deepspeed.get('zero_optimization', {}).get('stage') != 3:
context = disable_deepspeed_zero3()
with context:
model, processor = args.get_model_processor(
model=model_id_or_path,
model_type=model_type,
revision=model_revision,
task_type=task_type,
num_labels=num_labels)

# For teacher models, set to eval mode and disable gradients
if self.args.sequence_parallel_size > 1:
sequence_parallel.prepare(
self.args.sequence_parallel_size, model, processor, padding_free=args.padding_free)
model.requires_grad_(False).eval()

HfConfigFactory.set_config_attr(model.config, 'use_cache', False)
return model, processor
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Comment on lines +297 to +306
# Pass pre-loaded teacher_model_group_models if available, otherwise pass the string list
if hasattr(self, 'teacher_model_group_models') and self.teacher_model_group_models:
trainer_kwargs['teacher_model_group_models'] = self.teacher_model_group_models
else:
trainer_kwargs['teacher_model_group'] = self.args.teacher_model_group
trainer_kwargs['teacher_use_disable_adapter'] = getattr(self.args, '_teacher_use_disable_adapter', False)
if self.args.use_mopd:
trainer_kwargs['use_mopd'] = self.args.use_mopd
# todo
# trainer_kwargs['mopd_config'] = self.args.mopd_config
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Comment on lines +180 to +192
self.gold_adapter_group[adapter_key] = GOLDLossAdapter(
config={
'use_uld_loss': True,
'use_extended_uld': True,
'uld_use_hybrid_loss': True,
'uld_crossentropy_weight': 0.0,
'uld_distillation_weight': 1.0,
'uld_student_temperature': 1.0,
'uld_teacher_temperature': 1.0,
},
student_tokenizer=self.student_tokenizer,
teacher_tokenizer=teacher_tokenizer,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, without considering the cross-tokenizer case, don't introduce GOLDLoss as it adds unnecessary complexity

Comment on lines +535 to +599
elif self.use_mopd:
num_teacher_models = len(self.teacher_model_group)
if num_teacher_models == 0:
raise ValueError("teacher_model_group cannot be empty")
loss = torch.tensor(0.0, device=model.device)
prompt_texts = inputs['prompt_text']
completion_texts = inputs['completion_texts']
(
student_input_ids,
student_labels,
student_attention_mask,
student_prompt_length,
) = self.build_inputs_from_texts(
self.student_tokenizer,
prompt_texts,
completion_texts
)
# Student model forward pass (WITH gradients for student parameters)
outputs_student = model(
input_ids=student_input_ids,
attention_mask=student_attention_mask,
)
for teacher_model in teacher_model_group:
print('-------self.use_generalized_jsd_loss')
teacher_tokenizer = self.teacher_tokenizer_group[id(teacher_model)]
# Add teacher model memory management like in liger branch
load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext()

# Get adapter for current teacher tokenizer
adapter_key = id(teacher_tokenizer)
gold_adapter = self.gold_adapter_group[adapter_key]
with load_context:
with torch.no_grad(), disable_gradient_checkpointing(teacher_model,
self.args.gradient_checkpointing_kwargs):
(
teacher_input_ids,
teacher_labels,
teacher_attention_mask,
teacher_prompt_length,
) = self.build_inputs_from_texts(
teacher_tokenizer,
prompt_texts,
completion_texts
)

# Teacher model forward pass (NO gradients)
outputs_teacher = teacher_model(
input_ids=teacher_input_ids,
attention_mask=teacher_attention_mask,
)
# Ensure teacher_logits has gradient info but teacher model params don't participate
teacher_logits = outputs_teacher.logits.detach().requires_grad_(True)

# Release intermediate tensors to free memory
del teacher_attention_mask
loss_total = gold_adapter(
student_logits=outputs_student.logits,
teacher_logits=teacher_logits,
student_labels=student_labels,
teacher_labels=teacher_labels,
student_input_ids=student_input_ids,
teacher_input_ids=teacher_input_ids,
)
loss += loss_total / len(teacher_model_group)
# Separate teacher model provided
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, please update the multi-teacher logic based on the existing single-teacher logic

Comment on lines +535 to +599
elif self.use_mopd:
num_teacher_models = len(self.teacher_model_group)
if num_teacher_models == 0:
raise ValueError("teacher_model_group cannot be empty")
loss = torch.tensor(0.0, device=model.device)
prompt_texts = inputs['prompt_text']
completion_texts = inputs['completion_texts']
(
student_input_ids,
student_labels,
student_attention_mask,
student_prompt_length,
) = self.build_inputs_from_texts(
self.student_tokenizer,
prompt_texts,
completion_texts
)
# Student model forward pass (WITH gradients for student parameters)
outputs_student = model(
input_ids=student_input_ids,
attention_mask=student_attention_mask,
)
for teacher_model in teacher_model_group:
print('-------self.use_generalized_jsd_loss')
teacher_tokenizer = self.teacher_tokenizer_group[id(teacher_model)]
# Add teacher model memory management like in liger branch
load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext()

# Get adapter for current teacher tokenizer
adapter_key = id(teacher_tokenizer)
gold_adapter = self.gold_adapter_group[adapter_key]
with load_context:
with torch.no_grad(), disable_gradient_checkpointing(teacher_model,
self.args.gradient_checkpointing_kwargs):
(
teacher_input_ids,
teacher_labels,
teacher_attention_mask,
teacher_prompt_length,
) = self.build_inputs_from_texts(
teacher_tokenizer,
prompt_texts,
completion_texts
)

# Teacher model forward pass (NO gradients)
outputs_teacher = teacher_model(
input_ids=teacher_input_ids,
attention_mask=teacher_attention_mask,
)
# Ensure teacher_logits has gradient info but teacher model params don't participate
teacher_logits = outputs_teacher.logits.detach().requires_grad_(True)

# Release intermediate tensors to free memory
del teacher_attention_mask
loss_total = gold_adapter(
student_logits=outputs_student.logits,
teacher_logits=teacher_logits,
student_labels=student_labels,
teacher_labels=teacher_labels,
student_input_ids=student_input_ids,
teacher_input_ids=teacher_input_ids,
)
loss += loss_total / len(teacher_model_group)
# Separate teacher model provided
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no return value here

Comment on lines +629 to +730
def build_inputs_from_texts(
self,
tokenizer: PreTrainedTokenizerBase,
prompt_texts: list[str],
completion_texts: list[str],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
"""Tokenize teacher prompts/completions and produce tensors ready for GOLD loss."""

pad_token_id = tokenizer.pad_token_id
eos_token_id = tokenizer.eos_token_id

prompt_token_ids = tokenizer(prompt_texts, add_special_tokens=True)["input_ids"]
completion_token_ids = tokenizer(completion_texts, add_special_tokens=False)["input_ids"]

sequences: list[torch.Tensor] = []
attention_masks: list[torch.Tensor] = []
labels_list: list[torch.Tensor] = []
prompt_lengths: list[int] = []
# Get device using reliable detection method
device = None
try:
# First try to get device from model parameters
if hasattr(self, 'model') and self.model is not None:
device = next(self.model.parameters()).device
elif hasattr(self, 'teacher_model') and self.teacher_model is not None:
device = next(self.teacher_model.parameters()).device
except (AttributeError, StopIteration):
pass

# Fallback to default device detection
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for prompt_ids, completion_ids in zip(prompt_token_ids, completion_token_ids, strict=True):
# Remove trailing EOS from prompt so completions can extend cleanly
if eos_token_id is not None and prompt_ids and prompt_ids[-1] == eos_token_id:
prompt_ids = prompt_ids[:-1]

prompt_lengths.append(len(prompt_ids))
sequence = list(prompt_ids)
sequence.extend(completion_ids)
if eos_token_id is not None:
sequence.append(eos_token_id)

seq_tensor = torch.tensor(sequence, dtype=torch.long, device=device)
sequences.append(seq_tensor)
attention_masks.append(torch.ones_like(seq_tensor))
labels = seq_tensor.clone()
labels[: len(prompt_ids)] = -100
if pad_token_id is not None:
labels[labels == pad_token_id] = -100
labels_list.append(labels)

teacher_input_ids = pad(
sequences,
padding_side="right",
padding_value=pad_token_id if pad_token_id is not None else 0,
)
teacher_attention_mask = pad(attention_masks, padding_side="right", padding_value=0).bool()
teacher_labels = pad(labels_list, padding_side="right", padding_value=-100)

if eos_token_id is not None:
for row in range(teacher_attention_mask.size(0)):
valid = (
teacher_input_ids[row] != pad_token_id
if pad_token_id is not None
else teacher_attention_mask[row].bool()
)
if valid.any():
last_idx = valid.nonzero(as_tuple=True)[0][-1]
teacher_attention_mask[row, last_idx + 1:] = False

teacher_prompt_length = max(prompt_lengths) if prompt_lengths else 0

return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length

def get_model_path(self, model):
model_path = getattr(model, 'name_or_path', None)
if model_path is None:
# Try to get path from config
if hasattr(model, 'config') and hasattr(model.config, '_name_or_path'):
model_path = model.config._name_or_path
if model_path is None:
# If still None, try to get from model's base model
unwrapped_student = self.accelerator.unwrap_model(model)
if hasattr(unwrapped_student, 'base_model_prefix'):
base_model = getattr(unwrapped_student, unwrapped_student.base_model_prefix, unwrapped_student)
if hasattr(base_model, 'config') and hasattr(base_model.config, '_name_or_path'):
model_path = base_model.config._name_or_path
# Additional fallback: try to get from model's config name_or_path attribute
if model_path is None:
if hasattr(model, 'config') and hasattr(model.config, 'name_or_path'):
model_path = model.config.name_or_path
# Additional fallback: try to get from unwrapped model's name_or_path
if model_path is None:
unwrapped_student = self.accelerator.unwrap_model(model)
model_path = getattr(unwrapped_student, 'name_or_path', None)
# Additional fallback: try to get from unwrapped model's config name_or_path
if model_path is None:
unwrapped_student = self.accelerator.unwrap_model(model)
if hasattr(unwrapped_student, 'config') and hasattr(unwrapped_student.config, 'name_or_path'):
model_path = unwrapped_student.config.name_or_path
return model_path
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove, reuse origin logic

Comment on lines +1169 to +1178
self,
student_logits,
teacher_logits=None,
labels=None,
beta=0.5,
temperature=1.0,
chunk_size=512,
topk=None,
teacher_topk_logprobs=None,
teacher_topk_indices=None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restore it

Comment on lines +1 to +524
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedTokenizerBase


class GOLDLossAdapter(nn.Module):
"""
- GOLD (General Online Logit Distillation) 损失函数适配器
支持:
1. ULD损失 (Universal Logit Distillation)
2. 扩展ULD (跨tokenizer对齐)
3. 混合损失 (Hybrid ULD + JSD)

使用示例:
adapter = GOLDLossAdapter(
config={
'use_uld_loss': True,
'use_extended_uld': True,
'uld_use_hybrid_loss': False,
'uld_crossentropy_weight': 0.0,
'uld_distillation_weight': 1.0,
'uld_student_temperature': 1.0,
'uld_teacher_temperature': 1.0,
},
student_tokenizer=student_tok,
teacher_tokenizer=teacher_tok,
)

loss = adapter(
student_logits=student_outputs.logits,
teacher_logits=teacher_outputs.logits,
student_labels=student_labels,
teacher_labels=teacher_labels,
student_input_ids=student_input_ids,
teacher_input_ids=teacher_input_ids,
)
"""

def __init__(
self,
config: dict,
student_tokenizer: Optional[PreTrainedTokenizerBase] = None,
teacher_tokenizer: Optional[PreTrainedTokenizerBase] = None,
device: Optional[torch.device] = None,
):
super().__init__()
self.device = device

# 基础配置
self.use_uld_loss = config.get('use_uld_loss', True) # 是否开启通用蒸馏
self.crossentropy_weight = config.get('uld_crossentropy_weight', 0.0)
self.distillation_weight = config.get('uld_distillation_weight', 1.0)
self.student_temperature = config.get('uld_student_temperature', 0.9)
self.teacher_temperature = config.get('uld_teacher_temperature', 0.9)
self.skip_student_eos = config.get('uld_skip_student_eos', True)
self.skip_teacher_eos = config.get('uld_skip_teacher_eos', True)
self.use_extended_uld = config.get('use_extended_uld', True)
self.ignore_index = -100

# Tokenizers
self.student_tokenizer = student_tokenizer
self.teacher_tokenizer = teacher_tokenizer

# Hybrid ULD配置
self.use_hybrid_loss = config.get('uld_use_hybrid_loss', True) # 是否对完全匹配的词汇进行匹配,开启提高稳定性
self.hybrid_matched_weight = config.get('uld_hybrid_matched_weight', None)
self.hybrid_unmatched_weight = config.get('uld_hybrid_unmatched_weight', None)
self.beta = config.get('beta', 1.0)

# 初始化词汇映射(用于hybrid loss)
self._vocab_mapping = None
self._teacher_matched_ids = None
self._student_matched_ids = None
self.mapping_tensor = None

if self.use_hybrid_loss and student_tokenizer and teacher_tokenizer:
self._initialize_vocabulary_mapping()

# 用于logging
self.last_matched_loss = None
self.last_unmatched_loss = None

def _initialize_vocabulary_mapping(self):
"""初始化学生-教师tokenizer的词汇映射"""
student_vocab = self.student_tokenizer.get_vocab()
teacher_vocab = self.teacher_tokenizer.get_vocab()

student_token_to_id = dict(student_vocab.items())

vocab_mapping = {}
teacher_matched_ids = set()
student_matched_ids = set()

for token_str, teacher_id in teacher_vocab.items():
if token_str in student_token_to_id:
student_id = student_token_to_id[token_str]
vocab_mapping[teacher_id] = student_id
teacher_matched_ids.add(teacher_id)
student_matched_ids.add(student_id)

self._vocab_mapping = vocab_mapping
self._teacher_matched_ids = teacher_matched_ids
self._student_matched_ids = student_matched_ids

if self._vocab_mapping:
max_matched_teacher_id = max(self._vocab_mapping.keys())
self.mapping_tensor = torch.full(
(max_matched_teacher_id + 1,), -1, dtype=torch.long
)
for k, v in self._vocab_mapping.items():
self.mapping_tensor[k] = v
if self.device:
self.mapping_tensor = self.mapping_tensor.to(self.device)

def forward(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
student_labels: torch.Tensor,
teacher_labels: torch.Tensor,
student_input_ids: torch.Tensor,
teacher_input_ids: torch.Tensor,
) -> torch.Tensor:
"""
计算GOLD/ULD损失

Args:
student_logits: [batch_size, seq_len, student_vocab_size]
teacher_logits: [batch_size, seq_len, teacher_vocab_size]
student_labels: [batch_size, seq_len], -100表示忽略
teacher_labels: [batch_size, seq_len], -100表示忽略
student_input_ids: [batch_size, seq_len]
teacher_input_ids: [batch_size, seq_len]

Returns:
loss: scalar tensor
"""

if not self.use_uld_loss:
return torch.tensor(0.0, device=student_logits.device, requires_grad=True)

# 1. Cross-entropy loss (可选,通过crossentropy_weight设置权重)
crossentropy_loss = self._compute_cross_entropy(student_logits, student_labels)

# 2. Distillation loss (ULD)
distillation_loss = self._compute_distillation_loss(
student_logits, teacher_logits,
student_labels, teacher_labels,
student_input_ids, teacher_input_ids
)
return crossentropy_loss + distillation_loss

def _compute_cross_entropy(
self,
student_logits: torch.Tensor,
student_labels: torch.Tensor
) -> torch.Tensor:
"""计算cross-entropy loss"""
if self.crossentropy_weight <= 0:
return torch.tensor(0.0, device=student_logits.device, requires_grad=True)

shift_logits = student_logits[..., :-1, :].contiguous()
shift_labels = student_labels[..., 1:].contiguous()

loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
ce_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
return self.crossentropy_weight * ce_loss

def _compute_distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
student_labels: torch.Tensor,
teacher_labels: torch.Tensor,
student_input_ids: torch.Tensor,
teacher_input_ids: torch.Tensor,
) -> torch.Tensor:
"""计算ULD蒸馏损失"""
# 获取答案区域
student_answer_idx, student_answer_size = self._get_answer_regions(student_labels)
teacher_answer_idx, teacher_answer_size = self._get_answer_regions(teacher_labels)

if self.skip_student_eos:
student_answer_size = [s - 1 for s in student_answer_size]
if self.skip_teacher_eos:
teacher_answer_size = [t - 1 for t in teacher_answer_size]

# 边界检查
if not student_answer_size or not teacher_answer_size:
return torch.zeros(1, device=student_logits.device, requires_grad=True) * 1e-8

batch_size = student_logits.size(0)
distillation_losses = []

for i in range(batch_size):
s_start = student_answer_idx[i]
s_size = student_answer_size[i]
t_start = teacher_answer_idx[i]
t_size = teacher_answer_size[i]

if s_size <= 0 or t_size <= 0:
loss_i = student_logits[i].sum() * 0.0
# Ensure the loss tensor requires gradients
loss_i = loss_i.detach().requires_grad_(True)
distillation_losses.append(loss_i)
continue

# 提取答案logits
student_ans_logits = student_logits[i, s_start:s_start + s_size]
teacher_ans_logits = teacher_logits[i, t_start:t_start + t_size]

# 转换为概率
student_probs = F.softmax(student_ans_logits / self.student_temperature, dim=-1)
teacher_probs = F.softmax(teacher_ans_logits / self.teacher_temperature, dim=-1)

student_token_ids = student_input_ids[i, s_start:s_start + s_size].tolist()
teacher_token_ids = teacher_input_ids[i, t_start:t_start + t_size].tolist()

# Token对齐
if self.use_extended_uld:
student_groups, teacher_groups = self._build_alignment_groups_from_ids(
student_token_ids, teacher_token_ids
)

student_aligned = self._merge_probabilities_with_groups(
student_probs, student_groups, student_token_ids
)
teacher_aligned = self._merge_probabilities_with_groups(
teacher_probs, teacher_groups, teacher_token_ids
)

else:
min_len = min(len(student_token_ids), len(teacher_token_ids))
student_aligned = student_probs[:min_len]
teacher_aligned = teacher_probs[:min_len]

# 计算损失
if self.use_hybrid_loss and self._vocab_mapping:
aligned_loss = self._compute_hybrid_uld_loss(student_aligned, teacher_aligned)
else:
aligned_loss = self._compute_basic_uld_loss(student_aligned, teacher_aligned)

distillation_losses.append(aligned_loss)
distillation_loss = torch.stack(distillation_losses).mean()
return self.distillation_weight * distillation_loss

def _get_answer_regions(self, labels: torch.Tensor) -> Tuple[List[int], List[int]]:
"""获取答案区域的起始位置和大小"""
indices = []
sizes = []

for label in labels:
mask = label.ne(self.ignore_index)
if not mask.any():
indices.append(0)
sizes.append(0)
continue

valid_indices = mask.nonzero(as_tuple=True)[0]
indices.append(int(valid_indices[0].item()))
sizes.append(int(mask.sum().item()))

return indices, sizes

def _build_alignment_groups_from_ids(
self,
student_token_ids: List[int],
teacher_token_ids: List[int]
) -> Tuple[List[List[int]], List[List[int]]]:
"""
基于文本内容构建对齐组
使用贪心子串匹配算法
"""

def decode_tokens(tokenizer, token_ids):
pieces = []
prev = ""
for k in range(len(token_ids)):
cur = tokenizer.decode(token_ids[:k + 1], skip_special_tokens=False)
pieces.append(cur[len(prev):])
prev = cur
return pieces

student_pieces = decode_tokens(self.student_tokenizer, student_token_ids)
teacher_pieces = decode_tokens(self.teacher_tokenizer, teacher_token_ids)

# 贪心匹配算法
student_groups = []
teacher_groups = []
s_idx = 0
t_idx = 0

while s_idx < len(student_pieces) and t_idx < len(teacher_pieces):
student_text = ""
teacher_text = ""
student_group = []
teacher_group = []

# 尝试找到最短的连续匹配序列
while s_idx < len(student_pieces) and t_idx < len(teacher_pieces):
if not student_group:
student_group.append(s_idx)
student_text += student_pieces[s_idx]
s_idx += 1

if not teacher_group:
teacher_group.append(t_idx)
teacher_text += teacher_pieces[t_idx]
t_idx += 1

# 检查是否匹配
if student_text == teacher_text:
student_groups.append(student_group)
teacher_groups.append(teacher_group)
break
elif len(student_text) < len(teacher_text):
if s_idx < len(student_pieces):
student_group.append(s_idx)
student_text += student_pieces[s_idx]
s_idx += 1
else:
break
else:
if t_idx < len(teacher_pieces):
teacher_group.append(t_idx)
teacher_text += teacher_pieces[t_idx]
t_idx += 1
else:
break
else:
# 未完全匹配,添加剩余部分
if student_group and teacher_group:
student_groups.append(student_group)
teacher_groups.append(teacher_group)

return student_groups, teacher_groups

def _merge_probabilities_with_groups(
self,
probs: torch.Tensor,
alignment_groups: List[List[int]],
token_ids: List[int],
) -> torch.Tensor:
"""
根据对齐组合并概率分布
使用链式法则: P_merged = P(y|x_0) * P(x_1|x_0) * P(x_2|x_0,x_1) * ...
"""
aligned_probs = []

for group in alignment_groups:
if len(group) > 1:
# 第一个token的边际概率
marginal_probs = probs[group[0]] # [vocab_size]

# 后续token的条件概率(标量)
conditional_product = 1.0
for k in range(1, len(group)):
cond_prob = probs[group[k], token_ids[group[k - 1]]]
conditional_product *= cond_prob

merged_probs = marginal_probs * conditional_product
aligned_probs.append(merged_probs)
elif len(group) == 1:
aligned_probs.append(probs[group[0]])

if aligned_probs:
return torch.stack(aligned_probs)
else:
# 返回一个空的但需要梯度的张量
empty_tensor = probs[:0].detach().requires_grad_(True)
return empty_tensor

def _compute_basic_uld_loss(
self,
student_aligned: torch.Tensor,
teacher_aligned: torch.Tensor,
) -> torch.Tensor:
"""基础ULD损失:排序后的L1距离"""
student_sorted = student_aligned.sort(dim=-1, descending=True).values
teacher_sorted = teacher_aligned.sort(dim=-1, descending=True).values

# Padding到相同vocab size
s_vocab = student_sorted.size(-1)
t_vocab = teacher_sorted.size(-1)
max_vocab = max(s_vocab, t_vocab)

if s_vocab < max_vocab:
student_sorted = F.pad(student_sorted, (0, max_vocab - s_vocab))
if t_vocab < max_vocab:
teacher_sorted = F.pad(teacher_sorted, (0, max_vocab - t_vocab))

loss = F.l1_loss(student_sorted, teacher_sorted, reduction="sum")
loss /= student_aligned.size(0)

return loss

def _compute_hybrid_uld_loss(
self,
student_aligned: torch.Tensor,
teacher_aligned: torch.Tensor,
) -> torch.Tensor:
"""混合ULD损失:matched用JSD,unmatched用排序L1"""
device = student_aligned.device
s_vocab = student_aligned.size(-1)
t_vocab = teacher_aligned.size(-1)

# 创建matched/unmatched masks
if self._teacher_matched_ids:
teacher_matched_idx = torch.tensor(
sorted(self._teacher_matched_ids), dtype=torch.long, device=device
)
student_matched_idx = self.mapping_tensor[teacher_matched_idx]
else:
teacher_matched_idx = torch.tensor([], dtype=torch.long, device=device)
student_matched_idx = torch.tensor([], dtype=torch.long, device=device)

teacher_matched_mask = torch.zeros(t_vocab, dtype=torch.bool, device=device)
student_matched_mask = torch.zeros(s_vocab, dtype=torch.bool, device=device)

if len(teacher_matched_idx) > 0:
teacher_matched_mask[teacher_matched_idx] = True
student_matched_mask[student_matched_idx] = True

# 1. Matched tokens的JSD损失
matched_loss = torch.tensor(0.0, device=device, requires_grad=True)
matched_count = 0

if len(teacher_matched_idx) > 0:
teacher_matched_probs = teacher_aligned[:, teacher_matched_idx]
student_matched_probs = student_aligned[:, student_matched_idx]
matched_count = teacher_matched_probs.size(-1)

matched_loss = self._compute_jsd_for_matched(
student_matched_probs, teacher_matched_probs
)
# 2. Unmatched tokens的排序L1损失
teacher_unmatched = teacher_aligned[:, ~teacher_matched_mask]
student_unmatched = student_aligned[:, ~student_matched_mask]

unmatched_loss = torch.tensor(0.0, device=device, requires_grad=True)
if teacher_unmatched.size(-1) > 0 and student_unmatched.size(-1) > 0:
teacher_sorted = teacher_unmatched.sort(dim=-1, descending=True).values
student_sorted = student_unmatched.sort(dim=-1, descending=True).values

t_size = teacher_sorted.size(-1)
s_size = student_sorted.size(-1)
max_size = max(t_size, s_size)

if t_size < max_size:
teacher_sorted = F.pad(teacher_sorted, (0, max_size - t_size))
if s_size < max_size:
student_sorted = F.pad(student_sorted, (0, max_size - s_size))

unmatched_loss = F.l1_loss(student_sorted, teacher_sorted, reduction="sum")
unmatched_loss /= student_aligned.size(0)

# 3. 加权组合
if self.hybrid_matched_weight is None:
w_matched = matched_count / max(1, t_vocab)
w_unmatched = 1.0 - w_matched
else:
w_matched = self.hybrid_matched_weight
w_unmatched = self.hybrid_unmatched_weight

total_loss = w_matched * matched_loss + w_unmatched * unmatched_loss

# 保存用于logging
self.last_matched_loss = matched_loss
self.last_unmatched_loss = unmatched_loss

return total_loss

def _compute_jsd_for_matched(
self,
student_probs: torch.Tensor,
teacher_probs: torch.Tensor,
epsilon: float = 1e-8
) -> torch.Tensor:
"""计算matched tokens的JSD损失,添加数值稳定性处理"""
batch_seq_len, num_matched = student_probs.shape

# 检查输入概率分布是否有效
if torch.isnan(student_probs).any() or torch.isnan(teacher_probs).any():
return torch.tensor(0.0, device=student_probs.device, requires_grad=True)

# 添加epsilon防止数值下溢和log(0)
student_probs = student_probs.clamp(min=epsilon)
teacher_probs = teacher_probs.clamp(min=epsilon)

# 重新归一化概率分布
student_probs = student_probs / student_probs.sum(dim=-1, keepdim=True)
teacher_probs = teacher_probs / teacher_probs.sum(dim=-1, keepdim=True)

student_flat = student_probs.view(-1, num_matched)
teacher_flat = teacher_probs.view(-1, num_matched)

# JSD = 0.5 * KL(P||M) + 0.5 * KL(Q||M), where M = 0.5*(P+Q)
m = 0.5 * (student_flat + teacher_flat)

# 添加epsilon到中间分布
m = m.clamp(min=epsilon)
m = m / m.sum(dim=-1, keepdim=True)

# 直接对概率分布取对数,添加epsilon防止数值问题
log_m = torch.log(m + epsilon)
log_student = torch.log(student_flat + epsilon)
log_teacher = torch.log(teacher_flat + epsilon)

# 使用log_target=True,传入log概率
kl_p_m = F.kl_div(log_m, log_student, reduction='batchmean', log_target=True)
kl_q_m = F.kl_div(log_m, log_teacher, reduction='batchmean', log_target=True)
jsd = 0.5 * (kl_p_m + kl_q_m)

# 检查结果是否有效
if torch.isnan(jsd) or torch.isinf(jsd):
return torch.tensor(0.0, device=student_probs.device, requires_grad=True)

return jsd
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, without considering GOLDLoss, limit MOPD to the premise of using the same tokenizer

@hjh0119
Copy link
Copy Markdown
Collaborator

hjh0119 commented May 19, 2026

Please keep the PR clean. For MOPD, I think we can reuse the existing GKD pipeline and simply extend the original teacher_model related logic to support multiple teachers

@Kagura-0001
Copy link
Copy Markdown
Contributor

加油

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants