Skip to content

R3 PR: Rollout Routing Replay#1273

Open
erictang000 wants to merge 27 commits intomainfrom
r3
Open

R3 PR: Rollout Routing Replay#1273
erictang000 wants to merge 27 commits intomainfrom
r3

Conversation

@erictang000
Copy link
Collaborator

@erictang000 erictang000 commented Mar 4, 2026

Overview

This PR adds support for Rollout Routing Replay (R3) from (See Paper).

See #815 for tracking of future tasks to fully support routing replay in all settings.

We add the following flags to enable R3:

cfg.generator.inference_engine.enable_return_routed_experts=True
cfg.trainer.policy.megatron_config.moe_enable_routing_replay=True

cfg.generator.inference_engine.enable_return_routed_experts=True is a pass through argument to vLLM, which records expert router indices (returning a list of dimension (batch_size, seq_len, num_layers, top_k).

We then pass this list rollout_expert_indices list through to Megatron's native RouterReplay feature (link).

When cfg.trainer.policy.megatron_config.moe_enable_routing_replay is set to true, Megatron initializes an instance of RouterReplay on each training worker rank. RouterReplay.set_replay_data(per_layer_data) can be used to set router decisions, and RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) and RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) can be used to set the routing mode to be forward or backward.

Supported Settings

Router Replay is supported for the following settings:

Generator Settings

  • use_conversation_multi_turn=True and use_conversation_multi_turn=False
  • batched=False and batched=True
  • async_engine=True and async_engine=False
  • NOT retokenize_chat_history mode - i.e. self.use_conversation_multi_turn and self.custom_chat_template
  • NOT self.generator_cfg.step_wise_trajectories - this should be possible to add but is not currently added/tested.
  • fully_async training - technically should work but not tested in this PR. Tracking in [skyrl-train] Enable routing replay in SkyRL #815
  • NOT step_wise_training

Inference Engine Settings

Trainer Settings

  • TP, CP, EP, DP are all supported. PP will be added in a follow up PR.

Custom Generator support

  • Custom generators using SkyRL's inference engine should just plumb through

Rollout Routing Replay

image

Relevant resources:
vLLM PR: vllm-project/vllm#28284
Verl PR: verl-project/verl#4101
Mindlab blog: https://macaron.im/mindlab/research/router-replay-r3-why-it-failed-and-how-we-fixed-it
Megatron-LM API guide: https://github.com/NVIDIA/Megatron-LM/blob/main/docs/api-guide/router_replay.md


Open with Devin

Co-authored-by: Dev Patel <dev.patel@berkeley.edu>
gemini-code-assist[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

@SumanthRH SumanthRH self-assigned this Mar 4, 2026
@erictang000
Copy link
Collaborator Author

Forward pass with router replay showing lower logprob diff!

image

devin-ai-integration[bot]

This comment was marked as resolved.

@erictang000
Copy link
Collaborator Author

erictang000 commented Mar 7, 2026

Current State:

For small scale tests, routing replay seems to be working as shown above - tested only for TP=8 serving and TP=4 and EP=8 training.
image

What's not working:

  • Deepseek style models (like moonlight-16b-a3b) - the average logprobs for both with and without routing replay seem to be way off from what is expected from the inference engine (like ~8 average logprobs vs 0.2 average logprobs) - maybe this is specific to moonlight + some megatron config? not sure. Reproducible by running the test_logprobs tests
  • Running a real training batch runs into hanging for all models - see [Bug]: Generation hangs until RAY_CGRAPH_get_timeout (300s) with Ray compiled DAG executor vllm-project/vllm#36237 for the relevant stack trace. This seems to happen pretty much only when router replay is enabled on the vllm end. We could try upgrading vllm to 0.17.0 to see if anything has been pushed that could help fix this. Reproducible by running a large enough training batch with router replay on the test_logprobs tests

TODOs:

  • fix above bugs
  • clean up code
  • add checks for settings where r3 won't be supported for now (async RL, batched generator, use_conversation_multi_turn=false)
  • run full end to end test showing r3 minimizes drift on DAPO (with at least one model family - others just need to pass tests.

@erictang000
Copy link
Collaborator Author

Deepseek style models (like moonlight-16b-a3b) - the average logprobs for both with and without routing replay seem to be way off from what is expected from the inference engine (like ~8 average logprobs vs 0.2 average logprobs) - maybe this is specific to moonlight + some megatron config? not sure. Reproducible by running the test_logprobs tests

Solved! the issue was that in our test we were setting cfg.trainer.use_sample_packing=False, and also setting NVTE_FUSED_ATTN=0. This was causing flash-attn to be used without sample packing for the moonlight forward pass - flash attn produces incorrect logprobs that differ greatly from vLLM:

vLLM logprobs     - mean: -2.564655
Megatron (replay) - mean: -9.231623
Megatron (no rep) - mean: -9.647593

after setting cfg.trainer.use_sample_packing=True and NVTE_FUSED_ATTN=1 (to allow setting transformer_config_kwargs["attention_backend"] to "fused" correctly

vLLM logprobs     - mean: -0.223607, std: 0.674102
Megatron (replay) - mean: -0.223626, std: 0.674850
Megatron (no rep) - mean: -0.224379, std: 0.677036
With replay    - logprob diff mean: 0.006648, std: 0.021737
Without replay - logprob diff mean: 0.011115, std: 0.035957

devin-ai-integration[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

@erictang000
Copy link
Collaborator Author

Running a real training batch runs into hanging for all models - see vllm-project/vllm#36237 for the relevant stack trace. This seems to happen pretty much only when router replay is enabled on the vllm end. We could try upgrading vllm to 0.17.0 to see if anything has been pushed that could help fix this. Reproducible by running a large enough training batch with router replay on the test_logprobs tests

verified that cherry picking the changes from #1300 to use the mp backend allow us to work around the compiled graph timeout

image

devin-ai-integration[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants