diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index f0662778f7..7ad0db22aa 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -620,7 +620,11 @@ reward模型参数将在PPO、GRPO中使用。 - reward_model_plugin: 奖励模型逻辑,默认为orm逻辑, 详细见[自定义奖励模型](./GRPO/DeveloperGuide/reward_model.md#自定义奖励模型)。 - dataset_shuffle: 是否对dataset进行随机操作,默认为True。 - truncation_strategy: 用于处理输入长度超过 max_length 的样本,支持 delete 和 left 两种策略,分别表示删除该样本和从左侧裁剪。默认值为 left。若使用 delete 策略,被删除的超长样本或编码失败的样本将在原数据集中通过重采样进行替换。 -- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo', 'dapo', 'cispo', 'sapo', 'real'], 默认为'grpo', 具体参考[文档](./GRPO/DeveloperGuide/loss_types.md) +- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo', 'dapo', 'cispo', 'sapo', 'real', 'fipo'], 默认为'grpo', 具体参考[文档](./GRPO/DeveloperGuide/loss_types.md) +- fipo_decay_rate: FIPO Future-KL 折扣半衰参数,实际折扣为`2 ** (-1 / fipo_decay_rate)`,默认值为32.0。 +- fipo_clip_range: FIPO influence weight 裁剪范围,默认值为0.2;设置为None或0时不裁剪。 +- fipo_clip_high_only: 是否只将FIPO influence weight裁剪到`[1.0, 1.0 + fipo_clip_range]`,默认值为True。 +- fipo_safety_threshold: 当负advantage token的IS ratio超过该阈值时,将FIPO influence weight限制到`[0.8, 1.0]`,默认值为4.0。 - log_completions: 是否记录训练中的模型生成内容,搭配 `--report_to wandb/swanlab` 使用。默认为False。 - 提示:若没有设置`--report_to wandb/swanlab`,则会在checkpoint中创建`completions.jsonl`来存储生成内容。 - use_vllm: 是否使用 vLLM 作为 GRPO 生成的 infer_backend,默认为False。 diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/FIPO.md b/docs/source/Instruction/GRPO/AdvancedResearch/FIPO.md new file mode 100644 index 0000000000..4d54b357a3 --- /dev/null +++ b/docs/source/Instruction/GRPO/AdvancedResearch/FIPO.md @@ -0,0 +1,52 @@ +# FIPO: Future-KL Influenced Policy Optimization + +[FIPO](https://arxiv.org/abs/2603.19835) 是一种面向长链推理的 value-free RL 方法。它保留 GRPO/DAPO 的整体训练框架,但改变 token 级策略更新的加权方式:不再让一个序列级 advantage 均匀作用到所有 token,而是用折扣累积的 Future-KL 信号判断“从当前 token 开始的后续轨迹”整体是在被增强还是被削弱。 + +## 核心思想 + +GRPO/DAPO 中,每个 response 的 token 通常共享同一个序列级 advantage: + +$$ +\hat{A}_{i,t} = \hat{A}_{i} +$$ + +这种做法稳定且简单,但 credit assignment 粒度较粗。FIPO 引入当前策略与旧策略在每个 token 上的 log-prob shift: + +$$ +\Delta \log p_t = \log \pi_\theta(y_t \mid x, y_{ 0$,说明当前训练正在提高该 token 的概率;如果小于 0,则说明该 token 正在被压低。FIPO进一步从当前位置向后折扣累积该信号: + +$$ +\mathrm{FutureKL}_t = +\sum_{k=t}^{T}\gamma^{k-t} M_k \Delta \log p_k +$$ + +其中 $M_k$ 是 completion mask,$\gamma = 2^{-1 / \text{decay\_rate}}$。`decay_rate` 越大,越远的 future token 对当前位置的影响越强;`decay_rate` 越小,Future-KL 越偏局部。然后将 Future-KL 映射为 influence weight: + +$$ +f_t = \mathrm{clip}(\exp(\mathrm{FutureKL}_t), 1-\epsilon_f, 1+\epsilon_f) +$$ + +最终把原本的 advantage 改成 future-aware advantage: + +$$ +\tilde{A}_{i,t} = \hat{A}_{i} \cdot f_{i,t} +$$ + +## 参数 + +| 参数 | 类型 | 默认值 | 说明 | +|---------------------------|---------|--------|----------------------------------------------------------------------------------------------------------------| +| `--loss_type` | `str` | `grpo` | 设置为`fipo` 启用 FIPO loss | +| `--delta` | `float` | `None` | 启用后会同时用于 Future-KL 高 IS ratio token 过滤和主 loss 的 dual-clip 上限,应大于 `1 + epsilon_high`,对齐FIPO 32B训练脚本建议设置为 `10.0` | +| `--fipo_decay_rate` | `float` | `32.0` | Future-KL 折扣半衰参数,实际折扣为`2 ** (-1 / fipo_decay_rate)` | +| `--fipo_clip_range` | `float` | `0.2` | influence weight 裁剪范围;`0.2` 表示默认裁剪到 `[0.8, 1.2]` | +| `--fipo_clip_high_only` | `bool` | `true` | 若为`true`,权重只裁剪到 `[1.0, 1.0 + fipo_clip_range]`,更偏向放大正 Future-KL | +| `--fipo_safety_threshold` | `float` | `4.0` | 负 advantage 且 IS ratio 超过该阈值时,将 FIPO 权重限制到 `[0.8, 1.0]` 以避免过度惩罚 | + +## 训练示例 + +[swift](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/internal/fipo.sh) diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/index.rst b/docs/source/Instruction/GRPO/AdvancedResearch/index.rst index f270b14c09..69073d5e10 100644 --- a/docs/source/Instruction/GRPO/AdvancedResearch/index.rst +++ b/docs/source/Instruction/GRPO/AdvancedResearch/index.rst @@ -6,6 +6,7 @@ Advanced Research entropy_mask.md CISPO.md DAPO.md + FIPO.md deepeyes.md GSPO.md CHORD.md diff --git a/docs/source/Instruction/GRPO/DeveloperGuide/loss_types.md b/docs/source/Instruction/GRPO/DeveloperGuide/loss_types.md index a1bc9e1bbf..fed681229a 100644 --- a/docs/source/Instruction/GRPO/DeveloperGuide/loss_types.md +++ b/docs/source/Instruction/GRPO/DeveloperGuide/loss_types.md @@ -108,6 +108,20 @@ $$\mathcal{L}_{\text{DAPO}} = \frac{\sum_{i=1}^{N} \sum_{t=1}^{T_i} \mathcal{L}_ **归一化维度:** 全局token维度(跨所有进程的completion token总数) +## FIPO + +`--loss_type fipo` + +FIPO 在 DAPO/GRPO 的 clipped policy loss 上引入 Future-KL influence weight。每个 token 的序列级 advantage 会乘以从当前位置到后续 token 的折扣累积 KL 位移得到的权重: + +$$f_{i,t} = \text{clip}\left(\exp\left(\sum_{k=t}^{T_i} \gamma^{k-t} M_{i,k} \Delta \log p_{i,k}\right), 1-\epsilon_f, 1+\epsilon_f\right)$$ + +$$\mathcal{L}_{i,t}^{\text{FIPO}} = f_{i,t} \cdot \mathcal{L}_{i,t}$$ + +FIPO 的 influence weight 默认不参与梯度计算,并使用与 DAPO 相同的全局 token 归一化。 + +**归一化维度:** 全局 token 维度(所有进程的 completion token 总数) + ## SAPO `--loss_type sapo` diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 98266d7596..c0e95bd215 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -635,7 +635,11 @@ The meanings of the following parameters can be referenced [here](https://huggin - reward_model_plugin: The logic for the reward model, which defaults to ORM logic. For more information, please refer to [Customized Reward Models](./GRPO/DeveloperGuide/reward_model.md#custom-reward-model). - dataset_shuffle: Whether to shuffle the dataset randomly. Default is True. - truncation_strategy: The method to handle inputs exceeding `max_length`. Supported values are `delete` and `left`, representing deletion and left-side truncation respectively. The default is `left`. With the delete strategy, over-long or encoding-failed samples are discarded, and new samples are resampled from the original dataset to maintain the intended batch size. -- loss_type: The type of loss normalization. Options are ['grpo', 'bnpo', 'dr_grpo', 'dapo', 'cispo', 'sapo', 'real'], default is 'grpo'. For details, refer to this [doc](./GRPO/DeveloperGuide/loss_types.md) +- loss_type: The type of loss normalization. Options are ['grpo', 'bnpo', 'dr_grpo', 'dapo', 'cispo', 'sapo', 'real', 'fipo'], default is 'grpo'. For details, refer to this [doc](./GRPO/DeveloperGuide/loss_types.md) +- fipo_decay_rate: Half-life parameter for FIPO Future-KL. The actual discount is `2 ** (-1 / fipo_decay_rate)`. Default is 32.0. +- fipo_clip_range: Clipping range for the FIPO influence weight. Default is 0.2; set to None or 0 to disable clipping. +- fipo_clip_high_only: Whether to clip the FIPO influence weight to `[1.0, 1.0 + fipo_clip_range]` only. Default is True. +- fipo_safety_threshold: Caps the FIPO influence weight to `[0.8, 1.0]` for negative-advantage tokens whose IS ratio exceeds this threshold. Default is 4.0. - log_completions: Whether to log the model-generated content during training, to be used in conjunction with `--report_to wandb/swanlab`, default is False. - Note: If `--report_to wandb/swanlab` is not set, a `completions.jsonl` will be created in the checkpoint to store the generated content. - use_vllm: Whether to use vLLM as the infer_backend for GRPO generation, default is False. diff --git a/docs/source_en/Instruction/GRPO/AdvancedResearch/FIPO.md b/docs/source_en/Instruction/GRPO/AdvancedResearch/FIPO.md new file mode 100644 index 0000000000..9469a9603d --- /dev/null +++ b/docs/source_en/Instruction/GRPO/AdvancedResearch/FIPO.md @@ -0,0 +1,53 @@ +# FIPO: Future-KL Influenced Policy Optimization + +[FIPO](https://arxiv.org/abs/2603.19835) is a value-free RL method for eliciting longer and deeper reasoning. It keeps the GRPO/DAPO training scaffold, but changes how token-level policy updates are weighted: instead of applying one sequence-level advantage uniformly to every token, FIPO uses a discounted Future-KL signal to estimate whether the future trajectory after each token is being reinforced or suppressed. + +## Core Idea + +In GRPO/DAPO, tokens in the same response usually share the same sequence-level advantage: + +$$ +\hat{A}_{i,t} = \hat{A}_{i} +$$ + +This is simple and stable, but the credit assignment is coarse. FIPO starts from the signed log-probability shift between the current policy and the old policy: + +$$ +\Delta \log p_t = \log \pi_\theta(y_t \mid x, y_{ Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """Compute FIPO token-level influence weight from discounted Future-KL.""" + future_kl_delta = log_ratio.masked_fill(~completion_mask, 0.0) + + if self.args.delta is not None: + delta = torch.as_tensor(self.args.delta, dtype=log_ratio.dtype, device=log_ratio.device) + high_ratio_mask = coef_1 > delta + future_kl_delta = torch.where(high_ratio_mask, torch.zeros_like(future_kl_delta), future_kl_delta) + + seq_len = future_kl_delta.shape[1] + future_kl = torch.zeros_like(future_kl_delta) + positions = torch.arange(seq_len, device=log_ratio.device).unsqueeze(1) + gamma = torch.as_tensor(self.fipo_gamma, dtype=log_ratio.dtype, device=log_ratio.device) + chunk_size = 128 + for chunk_start in range(0, seq_len, chunk_size): + chunk_end = min(seq_len, chunk_start + chunk_size) + chunk_positions = torch.arange(chunk_start, chunk_end, device=log_ratio.device).unsqueeze(0) + distance = chunk_positions - positions + future_mask = distance >= 0 + decay_block = torch.pow(gamma, distance.clamp(min=0)) * future_mask.to(log_ratio.dtype) + future_kl += torch.matmul(future_kl_delta[:, chunk_start:chunk_end], decay_block.t()) + future_kl = future_kl.masked_fill(~completion_mask, 0.0) + + influence_weight = torch.exp(future_kl) + + if self.fipo_clip_range: + high = 1 + self.fipo_clip_range + low = 1.0 if self.fipo_clip_high_only else 1 - self.fipo_clip_range + influence_weight = torch.clamp(influence_weight, min=low, max=high) + influence_weight = influence_weight.detach() + + safety_mask = torch.ones_like(completion_mask, dtype=torch.bool) + if self.fipo_safety_threshold is not None: + negative_advantage = advantages.unsqueeze(1) < 0 + high_is_ratio = coef_1 > self.fipo_safety_threshold + safety_mask = ~(negative_advantage & high_is_ratio) + influence_weight = torch.where(safety_mask, influence_weight, + torch.clamp(influence_weight, min=0.8, max=1.0)) + + metrics = { + 'future_kl': future_kl, + 'influence_weight': influence_weight, + 'safety_mask': safety_mask, + } + return influence_weight, metrics + @profiling_decorator def _generate_completions(self, batch): """ @@ -1256,6 +1309,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): coef_1 = torch.exp(log_importance_weights) + fipo_metrics = None if self.loss_type == 'cispo': clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps @@ -1265,7 +1319,10 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): is_positive = advantages.unsqueeze(1) > 0 soft_gate = torch.where(is_positive, gate_pos, gate_neg) per_token_loss = -soft_gate * advantages.unsqueeze(1) - elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo']: + elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo', 'fipo']: + if self.loss_type == 'fipo': + fipo_weight, fipo_metrics = self._compute_fipo_influence(log_ratio, coef_1, advantages, completion_mask) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) if self.args.delta is not None: coef_1 = torch.clamp(coef_1, max=self.args.delta) @@ -1273,6 +1330,8 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): per_token_loss1 = coef_1 * advantages.unsqueeze(1) per_token_loss2 = coef_2 * advantages.unsqueeze(1) per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if self.loss_type == 'fipo': + per_token_loss = per_token_loss * fipo_weight elif self.loss_type == 'real': per_token_loss = torch.zeros_like(per_token_logps) else: @@ -1339,8 +1398,8 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif self.loss_type == 'dr_grpo': loss = (per_token_loss * completion_mask).sum() / (micro_batch_size * self.max_completion_length) - elif self.loss_type in ['cispo', 'dapo']: - # CISPO and DAPO: Normalize by total completion tokens across all processes + elif self.loss_type in ['cispo', 'dapo', 'fipo']: + # CISPO, DAPO, and FIPO: Normalize by total completion tokens across all processes num_items_in_batch = data['num_items_in_batch'] dp_size = mpu.get_data_parallel_world_size() normalizer = num_items_in_batch / dp_size @@ -1409,6 +1468,14 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): # Compute clipping metrics completion_token_count = completion_mask.sum().clamp(min=1.0) + if fipo_metrics is not None: + avg_metric['fipo/future_kl_mean'] = ((fipo_metrics['future_kl'] * completion_mask).sum() + / completion_token_count).clone().detach() + avg_metric['fipo/influence_weight_mean'] = ((fipo_metrics['influence_weight'] * completion_mask).sum() + / completion_token_count).clone().detach() + avg_metric['fipo/safety_keep_ratio'] = ((fipo_metrics['safety_mask'].float() * completion_mask).sum() + / completion_token_count).clone().detach() + if self.loss_type == 'cispo': # CISPO: Only track upper bound clipping # coef_1 is [batch_size, max_seq_len] or [batch_size, 1] depending on importance_sampling_level @@ -1419,7 +1486,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): elif self.loss_type in ['sapo', 'real']: # SAPO / REAL: No hard clipping, skip clipping metrics pass - elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo']: + elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo', 'fipo']: # coef_1 is [batch_size, max_seq_len] or [batch_size, 1] depending on importance_sampling_level # Use exp(log_importance_weights) to get the original ratios before clamping coef_1_for_metrics = torch.exp(log_importance_weights) diff --git a/swift/rlhf_trainers/args_mixin.py b/swift/rlhf_trainers/args_mixin.py index d5aa8f1bc8..52e5b0e0e8 100644 --- a/swift/rlhf_trainers/args_mixin.py +++ b/swift/rlhf_trainers/args_mixin.py @@ -322,6 +322,14 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin): constraints on negative dominance. The default value is 1.05. real_tau (float): The temperature parameter. REAL induces monotonic and bounded gradient weighting with magnitude upper-bounded by 1/tau. The default value is 0.5. + fipo_decay_rate (float): Half-life used to derive `fipo_gamma`. Defaults to 32.0. + fipo_clip_range (Optional[float]): Clip range for the FIPO influence weight. `0.2` clips to + `[0.8, 1.2]`; `None` or `0` disables clipping. Defaults to 0.2. + fipo_clip_high_only (bool): If `True`, clips the FIPO influence weight to `[1, 1 + fipo_clip_range]`. + Defaults to True. + fipo_safety_threshold (Optional[float]): Safety threshold for negative advantages. Tokens with + `advantage < 0` and importance ratio above this value have their FIPO influence weight capped to + `[0.8, 1.0]` to avoid over-penalization. Defaults to 4.0. advantage_estimator (Literal['grpo', 'rloo', 'reinforce_plus_plus']): The advantage estimation function to use. 'grpo' calculates the relative advantage within a group. Options are 'grpo', 'rloo', 'reinforce_plus_plus'. Defaults to 'grpo'. @@ -414,6 +422,12 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin): # REAL https://arxiv.org/abs/2602.05630 real_tau: float = 0.5 + # FIPO https://arxiv.org/abs/2603.19835 + fipo_decay_rate: float = 32.0 + fipo_clip_range: Optional[float] = 0.2 + fipo_clip_high_only: bool = True + fipo_safety_threshold: Optional[float] = 4.0 + num_generations_eval: Optional[int] = None # dataset diff --git a/swift/rlhf_trainers/grpo_trainer.py b/swift/rlhf_trainers/grpo_trainer.py index 76da8e452b..c81008f0e3 100644 --- a/swift/rlhf_trainers/grpo_trainer.py +++ b/swift/rlhf_trainers/grpo_trainer.py @@ -1025,6 +1025,55 @@ def _compute_loss_single(self, model, inputs): self._update_metrics(metrics_data) return loss + def _compute_fipo_influence(self, log_ratio: torch.Tensor, coef_1: torch.Tensor, advantages: torch.Tensor, + completion_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """Compute FIPO token-level influence weight from Future-KL divergence.""" + future_kl_delta = log_ratio.masked_fill(~completion_mask, 0.0) + + # Dual-Clip participation mask: high-ratio tokens do not contribute to Future-KL. + if self.args.delta is not None: + delta = torch.as_tensor(self.args.delta, dtype=log_ratio.dtype, device=log_ratio.device) + high_ratio_mask = coef_1 > delta + future_kl_delta = torch.where(high_ratio_mask, torch.zeros_like(future_kl_delta), future_kl_delta) + + seq_len = future_kl_delta.shape[1] + future_kl = torch.zeros_like(future_kl_delta) + positions = torch.arange(seq_len, device=log_ratio.device).unsqueeze(1) + gamma = torch.as_tensor(self.fipo_gamma, dtype=log_ratio.dtype, device=log_ratio.device) + chunk_size = 128 + for chunk_start in range(0, seq_len, chunk_size): + chunk_end = min(seq_len, chunk_start + chunk_size) + chunk_positions = torch.arange(chunk_start, chunk_end, device=log_ratio.device).unsqueeze(0) + distance = chunk_positions - positions + future_mask = distance >= 0 + decay_block = torch.pow(gamma, distance.clamp(min=0)) * future_mask.to(log_ratio.dtype) + future_kl += torch.matmul(future_kl_delta[:, chunk_start:chunk_end], decay_block.t()) + future_kl = future_kl.masked_fill(~completion_mask, 0.0) + + influence_weight = torch.exp(future_kl) + + if self.fipo_clip_range: + high = 1 + self.fipo_clip_range + low = 1.0 if self.fipo_clip_high_only else 1 - self.fipo_clip_range + influence_weight = torch.clamp(influence_weight, min=low, max=high) + influence_weight = influence_weight.detach() + + # avoid amplifying negative-advantage tokens with very high IS ratios. + safety_mask = torch.ones_like(completion_mask, dtype=torch.bool) + if self.fipo_safety_threshold is not None: + negative_advantage = advantages.unsqueeze(1) < 0 + high_is_ratio = coef_1 > self.fipo_safety_threshold + safety_mask = ~(negative_advantage & high_is_ratio) + influence_weight = torch.where(safety_mask, influence_weight, + torch.clamp(influence_weight, min=0.8, max=1.0)) + + metrics = { + 'future_kl': future_kl, + 'influence_weight': influence_weight, + 'safety_mask': safety_mask, + } + return influence_weight, metrics + def _compute_loss_and_metrics(self, model, inputs): """Core loss computation without metrics recording.""" mode = 'train' if self.model.training else 'eval' @@ -1126,6 +1175,7 @@ def _compute_loss_and_metrics(self, model, inputs): coef_1 = torch.exp(log_importance_weights) + fipo_metrics = None if self.loss_type == 'cispo': clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps @@ -1139,7 +1189,10 @@ def _compute_loss_and_metrics(self, model, inputs): per_token_loss = -soft_gate * advantages_expanded elif self.loss_type == 'real': per_token_loss = torch.zeros_like(per_token_logps) - elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo']: + elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo', 'fipo']: + if self.loss_type == 'fipo': + fipo_weight, fipo_metrics = self._compute_fipo_influence(log_ratio, coef_1, advantages, completion_mask) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) if self.args.delta is not None: coef_1 = torch.clamp(coef_1, max=self.args.delta) @@ -1147,6 +1200,8 @@ def _compute_loss_and_metrics(self, model, inputs): per_token_loss1 = coef_1 * advantages.unsqueeze(1) per_token_loss2 = coef_2 * advantages.unsqueeze(1) per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if self.loss_type == 'fipo': + per_token_loss = per_token_loss * fipo_weight if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask if per_token_kl is not None: @@ -1210,8 +1265,8 @@ def _compute_loss_and_metrics(self, model, inputs): if self.beta != 0.0: kl_loss = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) loss = loss + kl_loss * self.beta - elif self.loss_type in ['cispo', 'dapo']: - # CISPO and DAPO: Normalize by total completion tokens across all processes + elif self.loss_type in ['cispo', 'dapo', 'fipo']: + # CISPO, DAPO, and FIPO: Normalize by total completion tokens across all processes normalizer = inputs['num_items_in_batch'] / self.accelerator.num_processes loss = (per_token_loss * completion_mask).sum() / normalizer else: @@ -1234,6 +1289,16 @@ def masked_batch_mean(x): 'completion_token_count': completion_token_count, } + if fipo_metrics is not None: + fipo_future_kl = masked_batch_mean(fipo_metrics['future_kl']) + fipo_influence_weight = masked_batch_mean(fipo_metrics['influence_weight']) + fipo_safety_keep = masked_batch_mean(fipo_metrics['safety_mask'].float()) + metrics_data['fipo'] = { + 'future_kl_mean': self.accelerator.gather_for_metrics(fipo_future_kl).nanmean().item(), + 'influence_weight_mean': self.accelerator.gather_for_metrics(fipo_influence_weight).nanmean().item(), + 'safety_keep_ratio': self.accelerator.gather_for_metrics(fipo_safety_keep).nanmean().item(), + } + if per_token_kl is not None: mean_kl = masked_batch_mean(per_token_kl) metrics_data['kl'] = self.accelerator.gather_for_metrics(mean_kl).nanmean().item() @@ -1301,6 +1366,11 @@ def _update_metrics(self, metrics_data): for key, value in rollout_metrics.items(): self._metrics[mode][f'rollout_correction/{key}'].append(value) + # Update FIPO metrics + if 'fipo' in metrics_data: + for key, value in metrics_data['fipo'].items(): + self._metrics[mode][f'fipo/{key}'].append(value) + # Update clipping metrics if 'clipping' in metrics_data: clipping = metrics_data['clipping'] @@ -1382,6 +1452,7 @@ def _aggregate_and_update_metrics(self, all_metrics_data, mode): clip_values = {'low': [], 'high': [], 'region': [], 'low_min': [], 'high_max': []} cispo_clip_values = [] entropy_thresholds = [] + fipo_values = {} for chunk_metrics, chunk_weight in all_metrics_data: chunk_tokens = chunk_metrics['completion_token_count'] @@ -1403,6 +1474,12 @@ def _aggregate_and_update_metrics(self, all_metrics_data, mode): if 'kl' in chunk_metrics: kl_values.append(chunk_metrics['kl']) + # Collect FIPO metrics (weighted by tokens) + if 'fipo' in chunk_metrics: + weight = chunk_tokens.item() if hasattr(chunk_tokens, 'item') else chunk_tokens + for key, value in chunk_metrics['fipo'].items(): + fipo_values.setdefault(key, []).append((value, weight)) + # Collect clipping metrics (weighted by tokens) if 'clipping' in chunk_metrics: clipping = chunk_metrics['clipping'] @@ -1452,6 +1529,9 @@ def weighted_avg(values): 'region_clip_mean': weighted_avg(clip_values['region']) } + if fipo_values: + aggregated_metrics['fipo'] = {key: weighted_avg(values) for key, values in fipo_values.items()} + # Update metrics self._update_metrics(aggregated_metrics) @@ -2174,6 +2254,12 @@ def _prepare_algorithm_params(self): # REAL, https://arxiv.org/abs/2602.05630 self.real_tau = args.real_tau + # FIPO, https://arxiv.org/abs/2603.19835 + self.fipo_gamma = 2**(-1 / args.fipo_decay_rate) + self.fipo_clip_range = args.fipo_clip_range + self.fipo_clip_high_only = args.fipo_clip_high_only + self.fipo_safety_threshold = args.fipo_safety_threshold + # RLOO, self.advantage_estimator = args.advantage_estimator self.kl_in_reward = args.kl_in_reward