Skip to content

Conversation

@Jayce-Ping
Copy link

Description

This PR introduces a Reinforcement Learning (RL) training pipeline for flow-matching models, implementing the algorithms proposed in Flow-GRPO (arXiv:2505.05470) and Dance-GRPO (arXiv:2505.07818).

Traditional flow-matching models (like FLUX or Stable Diffusion 3) rely on Supervised Fine-Tuning (SFT) using paired data. This PR enables Online RL directly on the flow-matching vector field, allowing the model to optimize towards non-differentiable objectives (such as human preference, aesthetic scores, or structural constraints) in a data-efficient manner without requiring ground-truth target images.

Key technical contributions include:

  1. ODE-to-SDE Transformation: Converts the deterministic ODE sampling process into a stochastic SDE during training to facilitate the exploration required for RL.
  2. Group Relative Policy Optimization (GRPO): Adapts the GRPO algorithm (popularized by LLMs like DeepSeek-R1) for continuous flow-matching policies.
  3. Efficiency: Implements the "fewer steps for training" strategy mentioned in the Flow-GRPO paper, significantly reducing the computational cost of online sampling.

Features

  • Data-Efficient Fine-Tuning: Enables model alignment using only prompts and a reward function (no paired image data required).

  • Flow-GRPO Algorithm: Full implementation of the GRPO loss specifically derived for flow-matching vector fields.

  • Stochastic Scheduler for Exploration: Added a new scheduler that supports ODE-SDE mixed sampling to inject noise for exploration during the RL rollout phase.

  • RL Trainer Wrapper: A modular FlowGRPOTrainer that integrates seamlessly with the existing DiffSynth-Studio training loop.

  • Reward Model Registry:

    • Integrated common reward backends (e.g., Aesthetic Scorer, HPSv2).
    • Added a base class for custom, user-defined reward functions.
  • GRPO Sampler: A DistributedSampler object designed for diverse and evenly distributed RL sampling.

TODO

  • New Scheduler: Added scheduler for RL sampling.
  • New Distributed DataLoader: Implement a new distributed sampler and dataloader for even, group-wise sampling.
  • New Module for Reward Models: Add a new folder containing Reward Model classes.
  • RL Wrapper: Implement the new pipeline/wrapper object.
  • Documentation: Update the docs with a guide on "Reinforcement Learning for Flow Models".
  • Model Compatibility: Verify compatibility with Flux.1 and Qwen-Image backbones.

Type of Change

  • New feature (non-breaking change which adds functionality)
  • Bug fix (non-breaking change which fixes an issue)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation Update

References

Relevant Issues

#1111 #1110

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Jayce-Ping, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates an online Reinforcement Learning (RL) training pipeline into the system, specifically supporting Flow-GRPO and Dance-GRPO algorithms for flow-matching models. This enhancement allows models to be fine-tuned using non-differentiable objectives like human preferences, eliminating the need for traditional supervised fine-tuning with paired image data and significantly improving data efficiency.

Highlights

  • Online Reinforcement Learning for Flow-Matching: Introduces an RL training pipeline for flow-matching models, specifically implementing Flow-GRPO and Dance-GRPO, enabling optimization towards non-differentiable objectives without requiring paired data.
  • Stochastic Scheduler for Exploration: Adds a new FlowMatchSDEScheduler that supports ODE-SDE mixed sampling to inject noise, crucial for exploration during RL rollouts.
  • GRPO Algorithm Adaptation: Adapts the Group Relative Policy Optimization (GRPO) algorithm for continuous flow-matching policies, a technique previously popularized in LLMs.
  • Efficiency Improvements: Incorporates a "fewer steps for training" strategy to significantly reduce the computational cost of online sampling.
  • Modular RL Trainer and Reward System: Provides a modular FlowGRPOTrainer and a Reward Model Registry with integrated common reward backends and support for custom user-defined reward functions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a FlowMatchSDEScheduler to support Reinforcement Learning based training for flow-matching models, as described in the Flow-GRPO and Dance-GRPO papers. The new scheduler implements stochastic sampling to enable exploration. My review of the new scheduler implementation has identified several critical issues related to correctness and numerical stability, including a type error that would cause a crash, potential division-by-zero errors, and an incorrect log-probability calculation. I have provided detailed comments and code suggestions to address these problems.

Comment on lines +194 to +197
if noise_window is None:
self.noise_window = list(range(self.num_train_timesteps))
else:
self.noise_window = list(noise_window)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

self.noise_window is initialized as a list. However, in the current_noise_steps property (line 208), it is indexed with a torch.Tensor on line 212 (self.noise_window[selected_indices]), which will raise a TypeError. To fix this, self.noise_window should be a torch.Tensor. This change will also make the return type of current_noise_steps consistent with its type hint (torch.Tensor).

Suggested change
if noise_window is None:
self.noise_window = list(range(self.num_train_timesteps))
else:
self.noise_window = list(noise_window)
if noise_window is None:
self.noise_window = torch.arange(self.num_train_timesteps)
else:
self.noise_window = torch.tensor(noise_window)

elif self.sde_type == 'Dance-SDE':
pred_original_sample = sample - sigma * model_output
std_dev_t = current_noise_level * torch.sqrt(-1 * dt)
log_term = 0.5 * current_noise_level**2 * (sample - pred_original_sample * (1 - sigma)) / sigma**2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line involves a division by sigma**2. Some scheduler configurations, like set_timesteps_wan, can produce sigma values of 0. This will lead to a division-by-zero error. Please add a small epsilon to the denominator for numerical stability.

Suggested change
log_term = 0.5 * current_noise_level**2 * (sample - pred_original_sample * (1 - sigma)) / sigma**2
log_term = 0.5 * current_noise_level**2 * (sample - pred_original_sample * (1 - sigma)) / (sigma**2 + 1e-9)

Comment on lines +284 to +286
(-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)))
- math.log(std_dev_t)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are two issues in this block:

  1. math.log is used on a tensor std_dev_t (line 285), which will cause a TypeError. You should use torch.log instead.
  2. If dt is zero, std_dev_t will be zero, leading to division by zero in the first term and log(0) in the second term. Adding a small epsilon will improve numerical stability.
Suggested change
(-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)))
- math.log(std_dev_t)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
(-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2) + 1e-9))
- torch.log(std_dev_t + 1e-9)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))

Comment on lines +263 to +268
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
- torch.log(std_dev_t * torch.sqrt(-1 * dt))
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The log probability calculation for Flow-SDE can be unstable if dt is zero, which would cause division by zero or log(0). It's safer to add a small epsilon for numerical stability. Additionally, torch.sqrt(-1 * dt) is computed multiple times; pre-calculating it can improve clarity and slightly improve efficiency.

Suggested change
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
- torch.log(std_dev_t * torch.sqrt(-1 * dt))
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
sqrt_neg_dt = torch.sqrt(-dt)
variance = (std_dev_t * sqrt_neg_dt) ** 2
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * variance + 1e-9)
- torch.log(std_dev_t * sqrt_neg_dt + 1e-9)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))

Comment on lines +307 to +308
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation for log_prob for the CPS SDE type appears to be incomplete. It's currently calculating only the negative squared error. A proper log probability for a Gaussian distribution should also include terms for the variance and the normalization constant, following the formula: log P(x) = - (x - μ)² / (2σ²) - log(σ) - 0.5 * log(2π). Here, the variance is std_dev_t**2.

Suggested change
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
variance = std_dev_t**2
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * variance + 1e-9)
- torch.log(std_dev_t + 1e-9)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant