Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions agent_r1/config/agent_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,24 @@ hydra:
searchpath:
- pkg://verl.trainer.config

# 如果还需要覆盖其他配置,可以直接在这里添加
# 例如:
# actor_rollout_ref:
# rollout:
# agent:
# num_workers: 16
# Trajectory-level loss reweighting is derived against seq-mean-token-sum.
actor_rollout_ref:
actor:
loss_agg_mode: seq-mean-token-sum

trainer:

# --- Mini-batch partitioning strategy (three mutually-exclusive modes) ---
# "default" : verl-native fixed-size contiguous split (no trajectory awareness)
# "trajectory" : keep every trajectory in a single PPO mini-batch and apply
# trajectory-level loss weighting
# TODO Implemented, awaiting testing
# "step" : step-index-based grouping with reverse-order update (TODO:SeeUPO: Sequence-Level Agentic-RL with Convergence Guarantees https://arxiv.org/abs/2602.06554)
minibatch_mode: "trajectory"

# TODO Implemented, awaiting testing
# --- Step-based mini-batch parameters (active when minibatch_mode == "step") (SeeUPO: Sequence-Level Agentic-RL with Convergence Guarantees https://arxiv.org/abs/2602.06554)
step_minibatch:
# Maximum number of interaction steps. null = auto-infer from data
# (max step_index + 1 in the current batch).
max_steps: null
206 changes: 203 additions & 3 deletions agent_r1/ray_agent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from tqdm import tqdm

from agent_r1.metric_utils import compute_data_metrics
from agent_r1.utils.step_batching import group_batch_by_step_index
from agent_r1.utils.traj_loss import apply_step_pool_loss_weights, apply_traj_loss_weights
from agent_r1.utils.trajectory_batching import pack_trajectories_into_minibatches
from verl import DataProto
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
from verl.protocol import pad_dataproto_to_divisor
Expand All @@ -58,6 +61,7 @@
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.debug import marked_timer
from verl.utils.metric import reduce_metrics
from verl.utils.py_functional import append_to_dict
from verl.utils.rollout_skip import RolloutSkip


Expand Down Expand Up @@ -164,9 +168,6 @@ def compute_advantage(
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
) -> DataProto:
# TODO: 重写所有 core_algos 中的 advantage 函数,适配新型的 agent flow 数据结构
# 多行 data 对应一条完整轨迹,通过 non_tensor_batch["trajectory_uids"] 来区分不同轨迹,每条轨迹包含多行 data。
# 通过 non_tensor_batch["step_indices"] 来区分同一条轨迹内的不同 step 的顺序。
"""Compute advantage estimates for policy optimization.

This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc.
Expand Down Expand Up @@ -1018,6 +1019,205 @@ def fit(self):
# The dataset may be changed after each training batch
self.train_dataset.on_batch_end(batch=batch)

def _prepare_trajectory_minibatches(self, batch: DataProto, ppo_mini_batch_size: int) -> DataProto:
"""Thin wrapper around pack_trajectories_into_minibatches.

Reorders and pads the step-level batch so that
DataProto.split(ppo_mini_batch_size) — as performed inside the verl
actor/critic workers — yields mini-batches where every trajectory is
fully contained (no cross-trajectory truncation).

Args:
batch: Full step-level DataProto after advantage computation.
ppo_mini_batch_size: The mini-batch size used by the worker's
internal split() call.

Returns:
Reordered DataProto ready for trajectory-aligned split().
"""
return pack_trajectories_into_minibatches(batch, ppo_mini_batch_size)

def _validate_effective_mini_batch_size(
self,
batch: DataProto,
ppo_mini_batch_size: int,
component: str,
) -> None:
"""Validate that PPO mini-batch size does not exceed valid train rows."""
if ppo_mini_batch_size <= 0:
raise ValueError(f"{component}: ppo_mini_batch_size must be positive, got {ppo_mini_batch_size}.")

is_pad = batch.non_tensor_batch.get("is_pad", None)
if is_pad is None:
effective_train_batch_size = len(batch)
else:
effective_train_batch_size = int((~is_pad.astype(bool)).sum())

if ppo_mini_batch_size > effective_train_batch_size:
raise ValueError(
f"{component}: ppo_mini_batch_size={ppo_mini_batch_size} exceeds the current "
f"effective train batch size ({effective_train_batch_size} valid rows). "
"Please reduce ppo_mini_batch_size or increase train_batch_size."
)

def _resolve_minibatch_mode(self) -> str:
"""Resolve the effective mini-batch partitioning mode.

On the new-engine path (``use_legacy_worker_impl == "disable"``),
``"trajectory"`` and ``"step"`` modes are not yet supported and fall
back to ``"default"`` with a warning.

Returns:
One of ``"default"``, ``"trajectory"``, ``"step"``.
"""
mode = str(self.config.trainer.get("minibatch_mode", "trajectory"))

if mode not in ("default", "trajectory", "step"):
raise ValueError(f"Invalid minibatch_mode '{mode}'. Must be one of: 'default', 'trajectory', 'step'.")

if mode in ("trajectory", "step") and self.use_legacy_worker_impl == "disable":
print(
f"[WARNING] minibatch_mode='{mode}' is not supported on the "
f"new-engine path (use_legacy_worker_impl='disable'). "
f"Falling back to 'default'."
)
mode = "default"

return mode

# ------------------------------------------------------------------
# Actor update: three-way dispatch
# ------------------------------------------------------------------

def _update_actor(self, batch: DataProto) -> DataProto:
"""Override: dispatches to the appropriate mini-batch strategy before
forwarding to the parent update.

Modes (controlled by ``trainer.minibatch_mode``):
- ``"default"``: verl-native contiguous split (no trajectory awareness).
- ``"trajectory"``: trajectory-aware packing plus per-trajectory loss weighting.
- ``"step"``: step-index-based pools, updated in reverse order.
"""
mode = self._resolve_minibatch_mode()

if mode == "trajectory":
return self._update_actor_trajectory(batch)
elif mode == "step":
return self._update_actor_step_based(batch)
else:
return super()._update_actor(batch)

def _update_actor_trajectory(self, batch: DataProto) -> DataProto:
"""Pack full trajectories into mini-batches, then delegate to the parent update."""
mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
self._validate_effective_mini_batch_size(batch, mini_batch_size, component="actor trajectory mode")
batch = self._prepare_trajectory_minibatches(batch, mini_batch_size)

actor_loss_agg_mode = str(self.config.actor_rollout_ref.actor.loss_agg_mode)
required_mode = "seq-mean-token-sum"
if actor_loss_agg_mode != required_mode:
raise ValueError(
"Trajectory mini-batching requires "
"actor_rollout_ref.actor.loss_agg_mode='seq-mean-token-sum'. "
f"Got '{actor_loss_agg_mode}'."
)

batch = apply_traj_loss_weights(batch, mini_batch_size)

return super()._update_actor(batch)

def _update_actor_step_based(self, batch: DataProto) -> DataProto:
"""Step-based mini-batch update: iterate step pools in reverse order.

1. ``group_batch_by_step_index(batch)`` -> ``step_pools[0..T-1]``
2. For ``t`` in ``T-1, T-2, ..., 0``:
- Send ``step_pools[t]`` to the parent ``_update_actor`` which
handles the actual worker dispatch / split / forward-backward.
3. Aggregate metrics from all step updates.
"""
step_cfg = self.config.trainer.get("step_minibatch", {})
max_steps = step_cfg.get("max_steps", None)

mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size
self._validate_effective_mini_batch_size(batch, mini_batch_size, component="actor step mode")
step_pools = group_batch_by_step_index(
batch,
max_steps=max_steps,
ppo_mini_batch_size=mini_batch_size,
)

aggregated_metrics: dict[str, list] = {}

for t in reversed(range(len(step_pools))):
pool = step_pools[t]
if len(pool) == 0:
continue

# Apply step-pool loss weights so that placeholder rows produce
# zero gradient and valid rows are uniformly weighted.
pool = apply_step_pool_loss_weights(
pool,
mini_batch_size,
mode="uniform-valid",
)

# Delegate to the verl base-class _update_actor which handles
# worker dispatch, split into mini/micro batches, forward-backward.
step_output = super()._update_actor(pool)

# Preserve verl's collected metric structure and let the caller's
# reduce_metrics() perform the final aggregation.
step_metrics = step_output.meta_info.get("metrics", {})
append_to_dict(aggregated_metrics, step_metrics)

return DataProto.from_single_dict(data={}, meta_info={"metrics": aggregated_metrics})

# ------------------------------------------------------------------
# Critic update: three-way dispatch
# ------------------------------------------------------------------

def _update_critic(self, batch: DataProto) -> DataProto:
"""Override: dispatches to the appropriate mini-batch strategy before
forwarding to the parent critic update."""
mode = self._resolve_minibatch_mode()

if mode == "trajectory":
mini_batch_size = self.config.critic.ppo_mini_batch_size
self._validate_effective_mini_batch_size(batch, mini_batch_size, component="critic trajectory mode")
batch = self._prepare_trajectory_minibatches(batch, mini_batch_size)
return super()._update_critic(batch)
elif mode == "step":
return self._update_critic_step_based(batch)
else:
return super()._update_critic(batch)

def _update_critic_step_based(self, batch: DataProto) -> DataProto:
"""Step-based critic update: iterate step pools in reverse order,
symmetric to ``_update_actor_step_based``."""
step_cfg = self.config.trainer.get("step_minibatch", {})
max_steps = step_cfg.get("max_steps", None)

mini_batch_size = self.config.critic.ppo_mini_batch_size
self._validate_effective_mini_batch_size(batch, mini_batch_size, component="critic step mode")
step_pools = group_batch_by_step_index(
batch,
max_steps=max_steps,
ppo_mini_batch_size=mini_batch_size,
)

aggregated_metrics: dict[str, list] = {}

for t in reversed(range(len(step_pools))):
pool = step_pools[t]
if len(pool) == 0:
continue
step_output = super()._update_critic(pool)

step_metrics = step_output.meta_info.get("metrics", {})
append_to_dict(aggregated_metrics, step_metrics)

return DataProto.from_single_dict(data={}, meta_info={"metrics": aggregated_metrics})

def _pad_dataproto_to_world_size(self, batch):
world_sizes = []
if self.use_critic and self.critic_wg.world_size != 0:
Expand Down
13 changes: 13 additions & 0 deletions agent_r1/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading