Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions docs/source/Instruction/GRPO/AdvancedResearch/FIPO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# FIPO: Future-KL Influenced Policy Optimization

[FIPO](https://arxiv.org/abs/2603.19835) 是一种面向长链推理的 value-free RL 方法。它保留 GRPO/DAPO 的整体训练框架,但改变 token 级策略更新的加权方式:不再让一个序列级 advantage 均匀作用到所有 token,而是用折扣累积的 Future-KL 信号判断“从当前 token 开始的后续轨迹”整体是在被增强还是被削弱。

## 核心思想

GRPO/DAPO 中,每个 response 的 token 通常共享同一个序列级 advantage:

$$
\hat{A}_{i,t} = \hat{A}_{i}
$$

这种做法稳定且简单,但 credit assignment 粒度较粗。FIPO 引入当前策略与旧策略在每个 token 上的 log-prob shift:

$$
\Delta \log p_t = \log \pi_\theta(y_t \mid x, y_{<t}) -
\log \pi_{\mathrm{old}}(y_t \mid x, y_{<t})
$$

如果 $\Delta \log p_t > 0$,说明当前训练正在提高该 token 的概率;如果小于 0,则说明该 token 正在被压低。FIPO进一步从当前位置向后折扣累积该信号:

$$
\mathrm{FutureKL}_t =
\sum_{k=t}^{T}\gamma^{k-t} M_k \Delta \log p_k
$$

其中 $M_k$ 是 completion mask,$\gamma = 2^{-1 / \text{decay\_rate}}$。`decay_rate` 越大,越远的 future token 对当前位置的影响越强;`decay_rate` 越小,Future-KL 越偏局部。然后将 Future-KL 映射为 influence weight:

$$
f_t = \mathrm{clip}(\exp(\mathrm{FutureKL}_t), 1-\epsilon_f, 1+\epsilon_f)
$$

最终把原本的 advantage 改成 future-aware advantage:

$$
\tilde{A}_{i,t} = \hat{A}_{i} \cdot f_{i,t}
$$

## 参数

| 参数 | 类型 | 默认值 | 说明 |
|---------------------------|---------|--------|----------------------------------------------------------------------------------------------------------------|
| `--loss_type` | `str` | `grpo` | 设置为`fipo` 启用 FIPO loss |
| `--delta` | `float` | `None` | 启用后会同时用于 Future-KL 高 IS ratio token 过滤和主 loss 的 dual-clip 上限,应大于 `1 + epsilon_high`,对齐FIPO 32B训练脚本建议设置为 `10.0` |
| `--fipo_decay_rate` | `float` | `32.0` | Future-KL 折扣半衰参数,实际折扣为`2 ** (-1 / fipo_decay_rate)` |
| `--fipo_clip_range` | `float` | `0.2` | influence weight 裁剪范围;`0.2` 表示默认裁剪到 `[0.8, 1.2]` |
| `--fipo_clip_high_only` | `bool` | `true` | 若为`true`,权重只裁剪到 `[1.0, 1.0 + fipo_clip_range]`,更偏向放大正 Future-KL |
| `--fipo_detach_weight` | `bool` | `true` | 是否对 influence weight 截断梯度 |
Comment thread
li2zhi marked this conversation as resolved.
Outdated
| `--fipo_safety_threshold` | `float` | `4.0` | 负 advantage 且 IS ratio 超过该阈值时,将 FIPO 权重限制到 `[0.8, 1.0]` 以避免过度惩罚 |
Comment thread
li2zhi marked this conversation as resolved.

## 训练示例

[swift](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/internal/fipo.sh)
1 change: 1 addition & 0 deletions docs/source/Instruction/GRPO/AdvancedResearch/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Advanced Research
entropy_mask.md
CISPO.md
DAPO.md
FIPO.md
deepeyes.md
GSPO.md
CHORD.md
Expand Down
54 changes: 54 additions & 0 deletions docs/source_en/Instruction/GRPO/AdvancedResearch/FIPO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# FIPO: Future-KL Influenced Policy Optimization

[FIPO](https://arxiv.org/abs/2603.19835) is a value-free RL method for eliciting longer and deeper reasoning. It keeps the GRPO/DAPO training scaffold, but changes how token-level policy updates are weighted: instead of applying one sequence-level advantage uniformly to every token, FIPO uses a discounted Future-KL signal to estimate whether the future trajectory after each token is being reinforced or suppressed.

## Core Idea

In GRPO/DAPO, tokens in the same response usually share the same sequence-level advantage:

$$
\hat{A}_{i,t} = \hat{A}_{i}
$$

This is simple and stable, but the credit assignment is coarse. FIPO starts from the signed log-probability shift between the current policy and the old policy:

$$
\Delta \log p_t = \log \pi_\theta(y_t \mid x, y_{<t}) -
\log \pi_{\mathrm{old}}(y_t \mid x, y_{<t})
$$

A positive value means the token probability is being increased by the current update, while a negative value means it is being suppressed. FIPO then accumulates this signal from the current token to the end of the response:

$$
\mathrm{FutureKL}_t =
\sum_{k=t}^{T}\gamma^{k-t} M_k \Delta \log p_k
$$

where $M_k$ is the completion mask and $\gamma = 2^{-1 / \text{decay\_rate}}$. A larger `decay_rate` gives farther future tokens more influence; a smaller value makes the signal more local. FIPO maps the Future-KL value into a bounded influence weight:

$$
f_t = \mathrm{clip}(\exp(\mathrm{FutureKL}_t), 1-\epsilon_f, 1+\epsilon_f)
$$

The original advantage is then replaced by a future-aware advantage:

$$
\tilde{A}_{i,t} = \hat{A}_{i} \cdot f_{i,t}
$$

## Parameters


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This extra blank line should be removed to maintain consistent formatting with the Chinese version of this document.

| Parameter | Type | Default | Description |
| ------------------------- | ------- | ------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `--loss_type` | `str` | `grpo` | Set to`fipo` to enable FIPO loss |
| `--delta` | `float` | `None` | When enabled, it is used for both Future-KL high-IS-ratio token filtering and the main-loss dual-clip upper bound, and should be greater than `1 + epsilon_high`. Set it to `10.0` to match the official 32B script |
| `--fipo_decay_rate` | `float` | `32.0` | Half-life parameter for Future-KL; the actual discount is`2 ** (-1 / fipo_decay_rate)` |
| `--fipo_clip_range` | `float` | `0.2` | Influence weight clipping range;`0.2` clips to `[0.8, 1.2]` |
| `--fipo_clip_high_only` | `bool` | `true` | If`true`, clips the weight to `[1.0, 1.0 + fipo_clip_range]` |
| `--fipo_detach_weight` | `bool` | `true` | Whether to stop gradients through the influence weight |
| `--fipo_safety_threshold` | `float` | `4.0` | Caps the FIPO weight to `[0.8, 1.0]` for negative-advantage tokens whose IS ratio exceeds this threshold |

## Training Example

[swift](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/internal/fipo.sh)
1 change: 1 addition & 0 deletions docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Advanced Research
entropy_mask.md
CISPO.md
DAPO.md
FIPO.md
deepeyes.md
GSPO.md
CHORD.md
Expand Down
47 changes: 47 additions & 0 deletions examples/train/grpo/internal/fipo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
CUDA_VISIBLE_DEVICES=2 \
swift rollout \
--model Qwen/Qwen2.5-1.5B-Instruct

# 2 GPUS for sequence parallel
NPROC_PER_NODE=2 \
CUDA_VISIBLE_DEVICES=0,1 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-1.5B-Instruct \
--dataset 'AI-MO/NuminaMath-TIR' \
--reward_funcs accuracy \
--use_vllm true \
--vllm_mode server \
--vllm_server_host 127.0.0.1 \
--vllm_server_port 8000 \
--tuner_type full \
--torch_dtype bfloat16 \
--load_from_cache_file true \
--max_completion_length 4096 \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--learning_rate 1e-6 \
--gradient_accumulation_steps 2 \
--save_total_limit 3 \
--save_steps 500 \
--logging_steps 1 \
--warmup_ratio 0.05 \
--dataloader_num_workers 8 \
--num_generations 8 \
--temperature 1.0 \
--system """You are a helpful math assistant. Solve the problem step by step and put your final answer within \\boxed{}.""" \
--log_completions true \
--num_iterations 3 \
--padding_free true \
--sequence_parallel_size 2 \
--attn_impl flash_attn \
--beta 0 \
--dynamic_sample true \
--loss_type fipo \
--delta 10.0 \
--epsilon_high 0.28 \
--fipo_decay_rate 32 \
--fipo_clip_range 0.2 \
--fipo_clip_high_only true \
--fipo_detach_weight true \
Comment thread
li2zhi marked this conversation as resolved.
Outdated
--fipo_safety_threshold 10.0
7 changes: 7 additions & 0 deletions swift/megatron/arguments/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ class RLHFMegatronArgumentsMixin:
# REAL https://arxiv.org/abs/2602.05630
real_tau: float = 0.5

# FIPO https://arxiv.org/abs/2603.19835
fipo_decay_rate: float = 32.0
fipo_clip_range: Optional[float] = 0.2
fipo_clip_high_only: bool = True
fipo_detach_weight: bool = True
fipo_safety_threshold: Optional[float] = 4.0

epsilon: float = 0.2
epsilon_high: Optional[float] = None
delta: Optional[float] = None
Expand Down
74 changes: 68 additions & 6 deletions swift/megatron/trainers/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def _init_grpo_params(self):
# REAL, https://arxiv.org/abs/2602.05630
self.real_tau = args.real_tau

# FIPO, https://arxiv.org/abs/2603.19835
self.fipo_gamma = 2**(-1 / args.fipo_decay_rate)
self.fipo_clip_range = args.fipo_clip_range
self.fipo_clip_high_only = args.fipo_clip_high_only
self.fipo_detach_weight = args.fipo_detach_weight
self.fipo_safety_threshold = args.fipo_safety_threshold

# DAPO, https://arxiv.org/abs/2503.14476
self.dynamic_sample = args.dynamic_sample
self.max_resample_times = args.max_resample_times
Expand Down Expand Up @@ -467,7 +474,7 @@ def _generate_and_score_completions(self, batch):
micro_batch_advantages = total_advantages[start_idx:end_idx]
micro_batch_encoded['advantages'] = micro_batch_advantages

if self.loss_type in ['cispo', 'dapo']:
if self.loss_type in ['cispo', 'dapo', 'fipo']:
# Calculate num_items_in_batch
# Count completion tokens from all mini_batch_data (this includes gathered data from rollout_group)
# Use completion_mask.sum() for both padding_free and non-padding_free modes
Expand All @@ -484,12 +491,53 @@ def _generate_and_score_completions(self, batch):
mpu.get_tensor_model_parallel_world_size() * mpu.get_pipeline_model_parallel_world_size()
* mpu.get_context_parallel_world_size())
num_items_in_batch = total_token_count_tensor / rollout_group_size
# Store num_items_in_batch in each mini_batch_data for CISPO/DAPO loss normalization
# Store num_items_in_batch in each mini_batch_data for token-normalized losses
for batch_data in mini_batch_data:
batch_data['num_items_in_batch'] = num_items_in_batch

return mini_batch_data

def _compute_fipo_influence(self, log_ratio: torch.Tensor, coef_1: torch.Tensor, advantages: torch.Tensor,
completion_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Compute FIPO token-level influence weight from discounted Future-KL."""
future_kl_delta = log_ratio.masked_fill(~completion_mask, 0.0)

if self.args.delta is not None:
delta = torch.as_tensor(self.args.delta, dtype=log_ratio.dtype, device=log_ratio.device)
high_ratio_mask = coef_1 > delta
future_kl_delta = torch.where(high_ratio_mask, torch.zeros_like(future_kl_delta), future_kl_delta)

future_kl = torch.zeros_like(future_kl_delta)
running = torch.zeros_like(future_kl_delta[:, 0])
gamma = torch.as_tensor(self.fipo_gamma, dtype=log_ratio.dtype, device=log_ratio.device)
for token_idx in range(future_kl_delta.shape[1] - 1, -1, -1):
valid_token = completion_mask[:, token_idx]
running = torch.where(valid_token, future_kl_delta[:, token_idx] + gamma * running, running)
future_kl[:, token_idx] = torch.where(valid_token, running, 0.0)
Comment thread
li2zhi marked this conversation as resolved.
Outdated

influence_source = future_kl.detach() if self.fipo_detach_weight else future_kl
influence_weight = torch.exp(influence_source)

if self.fipo_clip_range:
high = 1 + self.fipo_clip_range
low = 1.0 if self.fipo_clip_high_only else 1 - self.fipo_clip_range
influence_weight = torch.clamp(influence_weight, min=low, max=high)

safety_mask = torch.ones_like(completion_mask, dtype=torch.bool)
if self.fipo_safety_threshold is not None:
negative_advantage = advantages.unsqueeze(1) < 0
high_is_ratio = coef_1 > self.fipo_safety_threshold
safety_mask = ~(negative_advantage & high_is_ratio)
influence_weight = torch.where(safety_mask, influence_weight,
torch.clamp(influence_weight, min=0.8, max=1.0))
Comment thread
li2zhi marked this conversation as resolved.
Outdated

metrics = {
'future_kl': future_kl,
'influence_weight': influence_weight,
'safety_mask': safety_mask,
}
return influence_weight, metrics

@profiling_decorator
def _generate_completions(self, batch):
"""
Expand Down Expand Up @@ -1256,6 +1304,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):

coef_1 = torch.exp(log_importance_weights)

fipo_metrics = None
if self.loss_type == 'cispo':
clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach()
per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps
Expand All @@ -1265,14 +1314,19 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
is_positive = advantages.unsqueeze(1) > 0
soft_gate = torch.where(is_positive, gate_pos, gate_neg)
per_token_loss = -soft_gate * advantages.unsqueeze(1)
elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo']:
elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo', 'fipo']:
if self.loss_type == 'fipo':
fipo_weight, fipo_metrics = self._compute_fipo_influence(log_ratio, coef_1, advantages, completion_mask)

coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
if self.args.delta is not None:
coef_1 = torch.clamp(coef_1, max=self.args.delta)

per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.loss_type == 'fipo':
per_token_loss = per_token_loss * fipo_weight
elif self.loss_type == 'real':
per_token_loss = torch.zeros_like(per_token_logps)
else:
Expand Down Expand Up @@ -1339,8 +1393,8 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
elif self.loss_type == 'dr_grpo':
loss = (per_token_loss * completion_mask).sum() / (micro_batch_size * self.max_completion_length)
elif self.loss_type in ['cispo', 'dapo']:
# CISPO and DAPO: Normalize by total completion tokens across all processes
elif self.loss_type in ['cispo', 'dapo', 'fipo']:
# CISPO, DAPO, and FIPO: Normalize by total completion tokens across all processes
num_items_in_batch = data['num_items_in_batch']
dp_size = mpu.get_data_parallel_world_size()
normalizer = num_items_in_batch / dp_size
Expand Down Expand Up @@ -1409,6 +1463,14 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):

# Compute clipping metrics
completion_token_count = completion_mask.sum().clamp(min=1.0)
if fipo_metrics is not None:
avg_metric['fipo/future_kl_mean'] = ((fipo_metrics['future_kl'] * completion_mask).sum()
/ completion_token_count).clone().detach()
avg_metric['fipo/influence_weight_mean'] = ((fipo_metrics['influence_weight'] * completion_mask).sum()
/ completion_token_count).clone().detach()
avg_metric['fipo/safety_keep_ratio'] = ((fipo_metrics['safety_mask'].float() * completion_mask).sum()
/ completion_token_count).clone().detach()

if self.loss_type == 'cispo':
# CISPO: Only track upper bound clipping
# coef_1 is [batch_size, max_seq_len] or [batch_size, 1] depending on importance_sampling_level
Expand All @@ -1419,7 +1481,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
elif self.loss_type in ['sapo', 'real']:
# SAPO / REAL: No hard clipping, skip clipping metrics
pass
elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo']:
elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo', 'fipo']:
# coef_1 is [batch_size, max_seq_len] or [batch_size, 1] depending on importance_sampling_level
# Use exp(log_importance_weights) to get the original ratios before clamping
coef_1_for_metrics = torch.exp(log_importance_weights)
Expand Down
16 changes: 16 additions & 0 deletions swift/rlhf_trainers/args_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,15 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin):
constraints on negative dominance. The default value is 1.05.
real_tau (float): The temperature parameter. REAL induces monotonic and bounded gradient weighting with
magnitude upper-bounded by 1/tau. The default value is 0.5.
fipo_decay_rate (float): Half-life used to derive `fipo_gamma`. Defaults to 32.0.
fipo_clip_range (Optional[float]): Clip range for the FIPO influence weight. `0.2` clips to
`[0.8, 1.2]`; `None` or `0` disables clipping. Defaults to 0.2.
fipo_clip_high_only (bool): If `True`, clips the FIPO influence weight to `[1, 1 + fipo_clip_range]`.
Defaults to True.
fipo_detach_weight (bool): If `True`, stops gradients through the Future-KL influence weight. Defaults to True.
fipo_safety_threshold (Optional[float]): Safety threshold for negative advantages. Tokens with
`advantage < 0` and importance ratio above this value have their FIPO influence weight capped to
`[0.8, 1.0]` to avoid over-penalization. Defaults to 4.0.
advantage_estimator (Literal['grpo', 'rloo', 'reinforce_plus_plus']): The advantage estimation
function to use. 'grpo' calculates the relative advantage within a group. Options are 'grpo', 'rloo',
'reinforce_plus_plus'. Defaults to 'grpo'.
Expand Down Expand Up @@ -414,6 +423,13 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin):
# REAL https://arxiv.org/abs/2602.05630
real_tau: float = 0.5

# FIPO https://arxiv.org/abs/2603.19835
fipo_decay_rate: float = 32.0
fipo_clip_range: Optional[float] = 0.2
fipo_clip_high_only: bool = True
fipo_detach_weight: bool = True
fipo_safety_threshold: Optional[float] = 4.0

num_generations_eval: Optional[int] = None

# dataset
Expand Down
Loading
Loading