Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swift/arguments/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class TeacherModelArguments:
remotely. When this is set, `teacher_model` is not required. Defaults to None.
"""
teacher_model: Optional[str] = None
teacher_model_group: List[str] = field(default_factory=list)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

use_mopd 标志在 GKDTrainer 中被引用,但未在参数定义中声明。应在此处添加以避免 AttributeError。此外,建议更新 TeacherModelArguments 的 docstring 以包含 teacher_model_groupuse_mopd 的说明。

Suggested change
teacher_model_group: List[str] = field(default_factory=list)
teacher_model_group: List[str] = field(default_factory=list)
use_mopd: bool = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider changing teacher_model to Optional[List[str]] (similar to reward_model) to avoid introducing additional parameters

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for an extra use_mopd parameter, MOPD can be determined by the number of teacher models

teacher_adapters: List[str] = field(default_factory=list)
teacher_model_type: Optional[str] = field(
default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
Expand Down
61 changes: 61 additions & 0 deletions swift/pipelines/train/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,24 @@ def _prepare_model_tokenizer(self):
model, _ = result
setattr(self, f'{key}_model', model)

# Handle teacher_model_group for GKD
self.teacher_model_group_models = None
if args.rlhf_type == 'gkd' and hasattr(args, 'teacher_model_group') and args.teacher_model_group:
logger.info(f'Loading teacher_model_group with {len(args.teacher_model_group)} models')
self.teacher_model_group_models = []
for idx, teacher_model_path in enumerate(args.teacher_model_group):
logger.info(f'Loading teacher model group [{idx}]: {teacher_model_path}')
# Use teacher_model_type and teacher_model_revision if available, otherwise infer
model_type = getattr(args, 'teacher_model_type', None)
model_revision = getattr(args, 'teacher_model_revision', None)

result = self._prepare_single_model_for_teacher_group(teacher_model_path, model_type, model_revision)
if result is not None:
model, _ = result
self.teacher_model_group_models.append(model)
logger.info(f'Successfully loaded teacher model group [{idx}]: {model}')
logger.info(f'Total teacher_model_group_models loaded: {len(self.teacher_model_group_models)}')
Comment on lines +136 to +151
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, just like the handling of reward_model, without introducing additional logic


# Handle reward model(s)
self.reward_model = None
if hasattr(args, 'reward_model') and args.reward_model is not None:
Expand Down Expand Up @@ -166,6 +184,44 @@ def _prepare_model_tokenizer(self):

super()._prepare_model_tokenizer()

def _prepare_single_model_for_teacher_group(self, model_id_or_path, model_type, model_revision):
"""Prepare a single model for teacher_model_group."""
args = self.args

if model_type is None:
model_info, _ = get_model_info_meta(model_id_or_path)
model_type = model_info.model_type

model_dir = safe_snapshot_download(
model_id_or_path=model_id_or_path,
revision=model_revision,
download_model=False,
use_hf=args.use_hf,
hub_token=args.hub_token,
)
task_type, num_labels = self._get_model_task_type(model_dir)

context = nullcontext()
if args.teacher_deepspeed:
if args.teacher_deepspeed.get('zero_optimization', {}).get('stage') != 3:
context = disable_deepspeed_zero3()
with context:
model, processor = args.get_model_processor(
model=model_id_or_path,
model_type=model_type,
revision=model_revision,
task_type=task_type,
num_labels=num_labels)

# For teacher models, set to eval mode and disable gradients
if self.args.sequence_parallel_size > 1:
sequence_parallel.prepare(
self.args.sequence_parallel_size, model, processor, padding_free=args.padding_free)
model.requires_grad_(False).eval()

HfConfigFactory.set_config_attr(model.config, 'use_cache', False)
return model, processor
Comment on lines +187 to +223
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above


@classmethod
def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_type=None):
model = super().prepare_model(args, model, template=template, train_dataset=train_dataset, task_type=task_type)
Expand Down Expand Up @@ -238,6 +294,11 @@ def _get_trainer_kwargs(self):
trainer_kwargs['gkd_logits_topk'] = self.args.gkd_logits_topk
if self.args.teacher_model_server:
trainer_kwargs['teacher_model_server'] = self.args.teacher_model_server
# Pass pre-loaded teacher_model_group_models if available, otherwise pass the string list
if hasattr(self, 'teacher_model_group_models') and self.teacher_model_group_models:
trainer_kwargs['teacher_model_group_models'] = self.teacher_model_group_models
else:
trainer_kwargs['teacher_model_group'] = self.args.teacher_model_group
trainer_kwargs['teacher_use_disable_adapter'] = getattr(self.args, '_teacher_use_disable_adapter', False)
return trainer_kwargs

Expand Down
Loading