Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swift/arguments/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class TeacherModelArguments:
remotely. When this is set, `teacher_model` is not required. Defaults to None.
"""
teacher_model: Optional[str] = None
teacher_model_group: List[str] = field(default_factory=list)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

use_mopd 标志在 GKDTrainer 中被引用,但未在参数定义中声明。应在此处添加以避免 AttributeError。此外,建议更新 TeacherModelArguments 的 docstring 以包含 teacher_model_groupuse_mopd 的说明。

Suggested change
teacher_model_group: List[str] = field(default_factory=list)
teacher_model_group: List[str] = field(default_factory=list)
use_mopd: bool = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider changing teacher_model to Optional[List[str]] (similar to reward_model) to avoid introducing additional parameters

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for an extra use_mopd parameter, MOPD can be determined by the number of teacher models

teacher_adapters: List[str] = field(default_factory=list)
teacher_model_type: Optional[str] = field(
default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
Expand Down
17 changes: 12 additions & 5 deletions swift/rlhf_trainers/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
self.prepare_logits_to_keep(inputs)
model_inputs['logits_to_keep'] = inputs['logits_to_keep']

teacher_model = self.choose_teacher_model()
if self.use_liger_gkd_loss:
# Liger fused JSD loss for memory efficiency
# Get base models (exclude lm_head to save memory)
Expand All @@ -307,7 +308,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
base_student = getattr(unwrapped_student, getattr(unwrapped_student, 'base_model_prefix', 'model'),
unwrapped_student)

unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model)
unwrapped_teacher = self.accelerator.unwrap_model(teacher_model)
base_teacher = getattr(unwrapped_teacher, getattr(unwrapped_teacher, 'base_model_prefix', 'model'),
unwrapped_teacher)

Expand All @@ -316,7 +317,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N

load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext()
with load_context:
with torch.no_grad(), disable_gradient_checkpointing(self.teacher_model,
with torch.no_grad(), disable_gradient_checkpointing(teacher_model,
self.args.gradient_checkpointing_kwargs):
teacher_outputs = base_teacher(**model_inputs, use_cache=False)

Expand Down Expand Up @@ -415,7 +416,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
loss = loss + self.args.sft_alpha * outputs_student.loss
# Separate teacher model provided
else:
assert self.teacher_model is not None
assert teacher_model is not None
if self.args.sft_alpha > 0:
model_inputs['labels'] = inputs['labels']
outputs_student = model(**model_inputs)
Expand All @@ -426,9 +427,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
}

load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext()
with torch.no_grad(), load_context, disable_gradient_checkpointing(self.teacher_model,
with torch.no_grad(), load_context, disable_gradient_checkpointing(teacher_model,
self.args.gradient_checkpointing_kwargs):
outputs_teacher = self.teacher_model(**t_fwd)
outputs_teacher = teacher_model(**t_fwd)

opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None
teacher_out = TeacherOutput(full_logits=outputs_teacher.logits, opsd_teacher_labels=opsd_labels)
Expand All @@ -443,6 +444,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
else:
return loss

def choose_teacher_model(self):
if not self.args.use_mopd:
return self.teacher_model
#todo 使用mopd时从教师模型组选择最佳模型
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

choose_teacher_model 的实现不完整。当启用 use_mopd 时,该函数目前隐式返回 None,这将导致 compute_loss 在预期有效模型的地方(例如第 311 行或第 419 行)发生崩溃。此外,teacher_model_group 中的模型需要在训练器初始化期间加载并准备为 module 对象,以便在此处使用。


def _prepare_batch_inputs(self, inputs: list, encode_prompt_only: bool = False) -> Dict[str, torch.Tensor]:
"""Prepare batch inputs for training.

Expand Down Expand Up @@ -788,6 +794,7 @@ def generalized_jsd_loss(
t_log_probs = F.log_softmax(t_chunk, dim=-1)
del s_chunk, t_chunk

#todo 使用mopd的计算函数,增加教师模型权重
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此 TODO 表明 MOPD 的核心逻辑(将教师权重纳入 JSD 损失计算)尚未实现。如果没有这部分实现,MOPD 功能将无法按预期工作。

if beta == 0:
jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True)
elif beta == 1:
Expand Down