Skip to content

Support FIPO#9328

Merged
hjh0119 merged 8 commits into
modelscope:mainfrom
li2zhi:fipo
May 19, 2026
Merged

Support FIPO#9328
hjh0119 merged 8 commits into
modelscope:mainfrom
li2zhi:fipo

Conversation

@li2zhi
Copy link
Copy Markdown
Contributor

@li2zhi li2zhi commented May 13, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

FIPO is a value-free RL algorithm proposed in FIPO: Future-guided Importance Policy Optimization. It is designed to improve long-form reasoning training by refining token-level credit assignment in GRPO/DAPO-style optimization.

Algorithm Overview

FIPO keeps the standard DAPO scaffold, but changes how token-level updates are weighted. The local signal is the signed log-probability shift between the current and old policy:

$$ \Delta \log p_t = \log \pi_\theta(y_t \mid x, y_{1:t-1}) - \log \pi_{old}(y_t \mid x, y_{1:t-1}) $$

Positive values mean the token is being reinforced, while negative values mean it is being suppressed. Since reasoning is sequential, FIPO then accumulates this signal over the future trajectory:

$$ FutureKL_t = \sum_{k=t}^{T} M_k \cdot \gamma^{k-t} \cdot \Delta \log p_k $$

Positive FutureKL_t means the future following token t is being reinforced; negative FutureKL_t means it is being suppressed. The decay window keeps the signal local enough to stay stable, while the mask removes extreme-ratio outliers.

FIPO maps this future signal into a bounded influence weight:

$$ f_t = clip(\exp(FutureKL_t), 1-\epsilon_{f,low}, 1+\epsilon_{f,high}), \quad \tilde{A}_t = \hat{A}_t \cdot f_t $$

FutureKL_t = discounted_sum_of_future_logprob_shifts(t)
weighted_advantage_t = A_t * clip(exp(FutureKL_t), low, high)

The final token-level FIPO loss keeps the standard clipped PPO/DAPO form, but replaces the original advantage with the future-aware one:

$$ r_t = \frac{\pi_\theta(y_t \mid x, y_{1:t-1})}{\pi_{old}(y_t \mid x, y_{1:t-1})} $$

$$ L_t^{FIPO} = min(r_t \tilde{A}_t,; clip(r_t, 1-\epsilon, 1+\epsilon)\tilde{A}_t) $$

Tokens that lead into preferred futures are amplified, while tokens that lead into suppressed futures are attenuated. Clipping keeps this modulation stable. The final DAPO-style loss therefore stays clipped and simple, but the advantage term becomes future-aware rather than uniformly inherited from the final outcome.

Reference

Paper: FIPO: Future-guided Importance Policy Optimization

li2zhi added 5 commits May 9, 2026 16:56
# Conflicts:
#	swift/megatron/arguments/megatron_args.py
#	swift/megatron/trainers/grpo_trainer.py
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements FIPO (Future-KL Influenced Policy Optimization), including documentation, training scripts, and integration into Megatron and RLHF trainers. Reviewers identified potential numerical stability issues in weight calculations and a discrepancy between the safety threshold logic and its documentation. Recommendations were also made to log FIPO-specific metrics and refactor the loss function to minimize code duplication.

Comment thread swift/rlhf_trainers/grpo_trainer.py Outdated
Comment thread swift/megatron/trainers/grpo_trainer.py Outdated
Comment thread swift/rlhf_trainers/grpo_trainer.py Outdated
Comment thread swift/megatron/trainers/grpo_trainer.py Outdated
@hjh0119 hjh0119 self-assigned this May 13, 2026
@hjh0119
Copy link
Copy Markdown
Collaborator

hjh0119 commented May 18, 2026

thanks for your contribution

is this pr ready to be merged?

@li2zhi
Copy link
Copy Markdown
Contributor Author

li2zhi commented May 19, 2026

Yes. Please let me know if there are any remaining concerns or changes needed.

Comment thread docs/source/Instruction/GRPO/AdvancedResearch/FIPO.md
Comment thread examples/train/grpo/internal/fipo.sh Outdated
Comment thread docs/source/Instruction/GRPO/AdvancedResearch/FIPO.md Outdated
Comment thread swift/megatron/trainers/grpo_trainer.py Outdated
@hjh0119
Copy link
Copy Markdown
Collaborator

hjh0119 commented May 19, 2026

Thanks for your contribution! I've left a few comments

@li2zhi
Copy link
Copy Markdown
Contributor Author

li2zhi commented May 19, 2026

Thanks for your review!

I've addressed the comments and pushed the updates. Please let me know if anything else needs to be changed.

@hjh0119
Copy link
Copy Markdown
Collaborator

hjh0119 commented May 19, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces FIPO (Future-KL Influenced Policy Optimization), a value-free RL method designed for long-chain reasoning, along with comprehensive documentation and training examples. The implementation adds FIPO-specific arguments and loss logic to both the Megatron and standard RLHF trainers, including metrics tracking for future-KL and influence weights. Review feedback suggests removing an extra blank line in the English documentation for consistency and recommends refactoring the duplicated _compute_fipo_influence logic into a shared utility to improve maintainability and reduce code duplication.


## Parameters


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This extra blank line should be removed to maintain consistent formatting with the Chinese version of this document.

Comment on lines +1028 to +1075
def _compute_fipo_influence(self, log_ratio: torch.Tensor, coef_1: torch.Tensor, advantages: torch.Tensor,
completion_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Compute FIPO token-level influence weight from Future-KL divergence."""
future_kl_delta = log_ratio.masked_fill(~completion_mask, 0.0)

# Dual-Clip participation mask: high-ratio tokens do not contribute to Future-KL.
if self.args.delta is not None:
delta = torch.as_tensor(self.args.delta, dtype=log_ratio.dtype, device=log_ratio.device)
high_ratio_mask = coef_1 > delta
future_kl_delta = torch.where(high_ratio_mask, torch.zeros_like(future_kl_delta), future_kl_delta)

seq_len = future_kl_delta.shape[1]
future_kl = torch.zeros_like(future_kl_delta)
positions = torch.arange(seq_len, device=log_ratio.device).unsqueeze(1)
gamma = torch.as_tensor(self.fipo_gamma, dtype=log_ratio.dtype, device=log_ratio.device)
chunk_size = 128
for chunk_start in range(0, seq_len, chunk_size):
chunk_end = min(seq_len, chunk_start + chunk_size)
chunk_positions = torch.arange(chunk_start, chunk_end, device=log_ratio.device).unsqueeze(0)
distance = chunk_positions - positions
future_mask = distance >= 0
decay_block = torch.pow(gamma, distance.clamp(min=0)) * future_mask.to(log_ratio.dtype)
future_kl += torch.matmul(future_kl_delta[:, chunk_start:chunk_end], decay_block.t())
future_kl = future_kl.masked_fill(~completion_mask, 0.0)

influence_weight = torch.exp(future_kl)

if self.fipo_clip_range:
high = 1 + self.fipo_clip_range
low = 1.0 if self.fipo_clip_high_only else 1 - self.fipo_clip_range
influence_weight = torch.clamp(influence_weight, min=low, max=high)
influence_weight = influence_weight.detach()

# avoid amplifying negative-advantage tokens with very high IS ratios.
safety_mask = torch.ones_like(completion_mask, dtype=torch.bool)
if self.fipo_safety_threshold is not None:
negative_advantage = advantages.unsqueeze(1) < 0
high_is_ratio = coef_1 > self.fipo_safety_threshold
safety_mask = ~(negative_advantage & high_is_ratio)
influence_weight = torch.where(safety_mask, influence_weight,
torch.clamp(influence_weight, min=0.8, max=1.0))

metrics = {
'future_kl': future_kl,
'influence_weight': influence_weight,
'safety_mask': safety_mask,
}
return influence_weight, metrics
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function _compute_fipo_influence is nearly identical to the one in swift/megatron/trainers/grpo_trainer.py. To improve maintainability and avoid code duplication, consider refactoring this logic into a shared utility function that both trainers can call. This will make future updates easier and reduce the risk of inconsistencies.

@hjh0119
Copy link
Copy Markdown
Collaborator

hjh0119 commented May 19, 2026

LGTM thanks!

@hjh0119 hjh0119 merged commit dafd45c into modelscope:main May 19, 2026
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants