diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 10cf90aed6..8cc7c6c180 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -706,10 +706,6 @@ def __post_init__(self): if self.tuner_type == 'lora_llm': if not self.is_multimodal: raise ValueError('`tuner_type="lora_llm"` is only supported for multimodal models.') - if not self.merge_lora: - raise ValueError('`merge_lora` must be True when using `--tuner_type lora_llm`') - if not self.no_save_optim: - raise ValueError('`no_save_optim` must be True when using `--tuner_type lora_llm`') if self.adapters or self.ref_adapters or self.mcore_adapter or self.mcore_ref_adapter: if self.tuner_type == 'full': self.tuner_type = 'lora' diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 25400c2d48..e0530e5a5f 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -195,7 +195,7 @@ def _prepare_peft_model(self, models): if args.mcore_model is None: self.bridge.load_weights(models, args.model_dir) peft_models = [prepare_mcore_model(args, model) for model in models] - if args.tuner_type == 'lora' and args.adapters and args.mcore_adapter is None: + if args.tuner_type in {'lora', 'lora_llm'} and args.adapters and args.mcore_adapter is None: assert len(args.adapters) == 1, 'Currently only support one adapter.' self.bridge.load_weights(models, args.adapters[0], peft_format=True, adapter_name='default') return peft_models @@ -727,7 +727,7 @@ def save_checkpoint(self): os.makedirs(output_dir, exist_ok=True) args_path = os.path.join(os.path.dirname(output_dir), 'args.json') self.copy_path(args_path, os.path.join(output_dir, 'args.json')) - save_peft_format = args.tuner_type == 'lora' and not args.merge_lora + save_peft_format = args.tuner_type in {'lora', 'lora_llm'} and not args.merge_lora if args.save_safetensors and args.no_save_optim: model = [] else: @@ -739,7 +739,7 @@ def save_checkpoint(self): self.optimizer, self.opt_param_scheduler, iteration=iteration, - peft_format=args.tuner_type == 'lora', + peft_format=args.tuner_type in {'lora', 'lora_llm'}, output_dir=output_dir) state.last_model_checkpoint = output_dir if state.best_global_step is not None: