-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[New Feature] MOPD #9035
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[New Feature] MOPD #9035
Changes from 1 commit
e9e0f38
a542b7a
54822f2
664a4e1
dc4be10
a1fcda7
8923a54
0bd9b55
7c70a3a
d8e95c9
5667586
a227202
4f75226
e1e58af
b53a0f3
1d2cd01
479ddf7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,7 @@ class TeacherModelArguments: | |
| remotely. When this is set, `teacher_model` is not required. Defaults to None. | ||
| """ | ||
| teacher_model: Optional[str] = None | ||
| teacher_model_group: List[str] = field(default_factory=list) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider changing teacher_model to Optional[List[str]] (similar to reward_model) to avoid introducing additional parameters
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for an extra use_mopd parameter, MOPD can be determined by the number of teacher models |
||
| teacher_adapters: List[str] = field(default_factory=list) | ||
| teacher_model_type: Optional[str] = field( | ||
| default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'}) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -298,6 +298,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |
| self.prepare_logits_to_keep(inputs) | ||
| model_inputs['logits_to_keep'] = inputs['logits_to_keep'] | ||
|
|
||
| teacher_model = self.choose_teacher_model() | ||
| if self.use_liger_gkd_loss: | ||
| # Liger fused JSD loss for memory efficiency | ||
| # Get base models (exclude lm_head to save memory) | ||
|
|
@@ -307,7 +308,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |
| base_student = getattr(unwrapped_student, getattr(unwrapped_student, 'base_model_prefix', 'model'), | ||
| unwrapped_student) | ||
|
|
||
| unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) | ||
| unwrapped_teacher = self.accelerator.unwrap_model(teacher_model) | ||
| base_teacher = getattr(unwrapped_teacher, getattr(unwrapped_teacher, 'base_model_prefix', 'model'), | ||
| unwrapped_teacher) | ||
|
|
||
|
|
@@ -316,7 +317,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |
|
|
||
| load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() | ||
| with load_context: | ||
| with torch.no_grad(), disable_gradient_checkpointing(self.teacher_model, | ||
| with torch.no_grad(), disable_gradient_checkpointing(teacher_model, | ||
| self.args.gradient_checkpointing_kwargs): | ||
| teacher_outputs = base_teacher(**model_inputs, use_cache=False) | ||
|
|
||
|
|
@@ -415,7 +416,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |
| loss = loss + self.args.sft_alpha * outputs_student.loss | ||
| # Separate teacher model provided | ||
| else: | ||
| assert self.teacher_model is not None | ||
| assert teacher_model is not None | ||
| if self.args.sft_alpha > 0: | ||
| model_inputs['labels'] = inputs['labels'] | ||
| outputs_student = model(**model_inputs) | ||
|
|
@@ -426,9 +427,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |
| } | ||
|
|
||
| load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() | ||
| with torch.no_grad(), load_context, disable_gradient_checkpointing(self.teacher_model, | ||
| with torch.no_grad(), load_context, disable_gradient_checkpointing(teacher_model, | ||
| self.args.gradient_checkpointing_kwargs): | ||
| outputs_teacher = self.teacher_model(**t_fwd) | ||
| outputs_teacher = teacher_model(**t_fwd) | ||
|
|
||
| opsd_labels = opsd_teacher_inputs.get('labels') if opsd_teacher_inputs is not None else None | ||
| teacher_out = TeacherOutput(full_logits=outputs_teacher.logits, opsd_teacher_labels=opsd_labels) | ||
|
|
@@ -443,6 +444,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |
| else: | ||
| return loss | ||
|
|
||
| def choose_teacher_model(self): | ||
| if not self.args.use_mopd: | ||
| return self.teacher_model | ||
| #todo 使用mopd时从教师模型组选择最佳模型 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def _prepare_batch_inputs(self, inputs: list, encode_prompt_only: bool = False) -> Dict[str, torch.Tensor]: | ||
| """Prepare batch inputs for training. | ||
|
|
||
|
|
@@ -788,6 +794,7 @@ def generalized_jsd_loss( | |
| t_log_probs = F.log_softmax(t_chunk, dim=-1) | ||
| del s_chunk, t_chunk | ||
|
|
||
| #todo 使用mopd的计算函数,增加教师模型权重 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| if beta == 0: | ||
| jsd_chunk = F.kl_div(s_log_probs, t_log_probs, reduction='none', log_target=True) | ||
| elif beta == 1: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_mopd标志在GKDTrainer中被引用,但未在参数定义中声明。应在此处添加以避免AttributeError。此外,建议更新TeacherModelArguments的 docstring 以包含teacher_model_group和use_mopd的说明。