Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swift/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class MLLMModelType:
ovis2 = 'ovis2'
ovis2_5 = 'ovis2_5'
midashenglm = 'midashenglm'
mimo_v2 = 'mimo_v2'

chatglm4v = 'chatglm4v'
glm4v = 'glm4v'
Expand Down
9 changes: 9 additions & 0 deletions swift/model/model_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class MLLMModelArch:
keye_vl = 'keye_vl'

midashenglm = 'midashenglm'
mimo_v2 = 'mimo_v2'
step_audio2_mini = 'step_audio2_mini'
hunyuan_vl = 'hunyuan_vl'
step3_vl = 'step3_vl'
Expand Down Expand Up @@ -787,6 +788,14 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
vision_tower='model.visual',
))

register_model_arch(
MultiModelKeys(
MLLMModelArch.mimo_v2,
language_model=['model', 'lm_head'],
aligner='visual.merger',
vision_tower=['visual', 'audio_encoder', 'speech_embeddings'],
))


def get_model_arch(arch_name: Optional[str]) -> Optional[MultiModelKeys]:
return MODEL_ARCH_MAPPING.get(arch_name)
22 changes: 22 additions & 0 deletions swift/model/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,28 @@ def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
tags=['vision', 'video']))


class MiMoV2Loader(Qwen2VLLoader):

def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
model = ModelLoader.get_model(self, model_dir, config, processor, model_kwargs)
patch_get_input_embeddings(model.visual, 'patch_embed')
return model
Comment on lines +872 to +875
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of get_model for MiMoV2Loader bypasses the logic in Qwen2VLLoader.get_model to avoid the hardcoded auto_model_cls, but it also misses the necessary check for AWQ-wrapped models. When a model is quantized with AWQ, the actual model components are often nested under a .model attribute. It is safer to replicate the base_model logic to ensure patch_get_input_embeddings is applied to the correct module.

Suggested change
def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
model = ModelLoader.get_model(self, model_dir, config, processor, model_kwargs)
patch_get_input_embeddings(model.visual, 'patch_embed')
return model
def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
model = ModelLoader.get_model(self, model_dir, config, processor, model_kwargs)
base_model = model.model if 'AWQ' in model.__class__.__name__ else model
patch_get_input_embeddings(base_model.visual, 'patch_embed')
return model



register_model(
ModelMeta(
MLLMModelType.mimo_v2, [
ModelGroup([
Model('XiaomiMiMo/MiMo-V2.5', 'XiaomiMiMo/MiMo-V2.5'),
], TemplateType.mimo_v2),
],
MiMoV2Loader,
model_arch=ModelArch.mimo_v2,
architectures=['MiMoV2ForCausalLM'],
requires=['transformers>=4.49', 'qwen_vl_utils>=0.0.6', 'decord'],
tags=['vision', 'video']))


def patch_Qwen3VLMoeTextExperts_dtype():
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts
if hasattr(Qwen3VLMoeTextExperts, '_patch'):
Expand Down
1 change: 1 addition & 0 deletions swift/template/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class MLLMTemplateType:
ovis2 = 'ovis2'
ovis2_5 = 'ovis2_5'
mimo_vl = 'mimo_vl'
mimo_v2 = 'mimo_v2'
midashenglm = 'midashenglm'

llama3_1_omni = 'llama3_1_omni'
Expand Down
44 changes: 44 additions & 0 deletions swift/template/templates/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,50 @@ class Qwen2_5VLTemplate(Qwen2VLTemplate):
default_system='You are MiMo, an AI assistant developed by Xiaomi.'))


class MiMoV2Template(Qwen2_5VLTemplate):
"""Template for XiaomiMiMo/MiMo-V2.5.

Differences from Qwen2_5VLTemplate:
- MiMo-V2.5 does not use 3D rope position IDs (no get_rope_index).
- Video key is named 'video_pixel_values' instead of 'pixel_values_videos'.
- Supports thinking mode with <think>...</think> tags.
"""

def _get_position_ids(self, inputs: Dict[str, Any]):
# MiMo-V2.5 uses standard rotary position embeddings,
# not 3D rope like Qwen2VL. No special position IDs needed.
return {}

def forward_context(self, model, inputs):
# Skip Qwen2VL-specific flash attention patching
return Template.forward_context(self, model, inputs)

def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
if not self.is_training:
# During inference, rename key to match MiMo-V2.5 forward signature
if 'pixel_values_videos' in inputs:
inputs['video_pixel_values'] = inputs.pop('pixel_values_videos')
return inputs
# For training, compute embeddings manually
input_ids = inputs['input_ids']
base_model = self.get_base_model(model)
inputs_embeds = base_model.model.embed_tokens(input_ids)
inputs_embeds = self._get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config)
return {'inputs_embeds': inputs_embeds}
Comment on lines +526 to +537
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _post_encode implementation for MiMoV2Template has several improvement opportunities in the training branch:

  1. Robustness: It assumes base_model.model.embed_tokens exists. Using a check for language_model (similar to the parent Qwen2VLTemplate) makes it more resilient to different model architectures.
  2. Consistency: The inference branch returns the full inputs dictionary, while the training branch returns a new dictionary containing only inputs_embeds. While the framework might merge these, it is safer and more consistent to update the inputs dictionary in place and return it.
  3. Ambiguity: When inputs_embeds is provided, input_ids should ideally be removed from the inputs to avoid ambiguity in the model's forward call.
    def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
        if not self.is_training:
            # During inference, rename key to match MiMo-V2.5 forward signature
            if 'pixel_values_videos' in inputs:
                inputs['video_pixel_values'] = inputs.pop('pixel_values_videos')
            return inputs
        # For training, compute embeddings manually
        input_ids = inputs['input_ids']
        base_model = self.get_base_model(model)
        if hasattr(base_model.model, 'embed_tokens'):
            inputs_embeds = base_model.model.embed_tokens(input_ids)
        else:
            inputs_embeds = base_model.model.language_model.embed_tokens(input_ids)
        inputs_embeds = self._get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config)
        inputs['inputs_embeds'] = inputs_embeds
        inputs.pop('input_ids', None)
        return inputs



register_template(
QwenTemplateMeta(
MLLMTemplateType.mimo_v2,
template_cls=MiMoV2Template,
default_system='You are MiMo, a helpful AI assistant engineered by Xiaomi.',
is_thinking=True,
thinking_prefix='<think>\n',
non_thinking_prefix='<think>\n</think>\n\n',
history_thinking_prefix='<think>\n</think>\n\n',
))


class Qwen3VLTemplate(Qwen2VLTemplate):
version = 'v3'

Expand Down
Loading