From 21499c2066be18ac6299e9ee2e13115137ce4f07 Mon Sep 17 00:00:00 2001 From: lqzxt <997260326@qq.com> Date: Tue, 21 Apr 2026 20:50:19 +0800 Subject: [PATCH 1/2] feat: add trajectory-level batching, loss aggregation, and dynamic batching support --- agent_r1/config/agent_ppo_trainer.yaml | 26 ++- agent_r1/ray_agent_trainer.py | 207 ++++++++++++++++- agent_r1/utils/__init__.py | 13 ++ agent_r1/utils/step_batching.py | 201 ++++++++++++++++ agent_r1/utils/traj_loss.py | 278 +++++++++++++++++++++++ agent_r1/utils/trajectory_batching.py | 303 +++++++++++++++++++++++++ 6 files changed, 1019 insertions(+), 9 deletions(-) create mode 100644 agent_r1/utils/__init__.py create mode 100644 agent_r1/utils/step_batching.py create mode 100644 agent_r1/utils/traj_loss.py create mode 100644 agent_r1/utils/trajectory_batching.py diff --git a/agent_r1/config/agent_ppo_trainer.yaml b/agent_r1/config/agent_ppo_trainer.yaml index 1579c18..ed68ceb 100644 --- a/agent_r1/config/agent_ppo_trainer.yaml +++ b/agent_r1/config/agent_ppo_trainer.yaml @@ -17,10 +17,24 @@ hydra: searchpath: - pkg://verl.trainer.config -# 如果还需要覆盖其他配置,可以直接在这里添加 -# 例如: -# actor_rollout_ref: -# rollout: -# agent: -# num_workers: 16 +# Trajectory-level loss reweighting is derived against seq-mean-token-sum. +actor_rollout_ref: + actor: + loss_agg_mode: seq-mean-token-sum +trainer: + + # --- Mini-batch partitioning strategy (three mutually-exclusive modes) --- + # "default" : verl-native fixed-size contiguous split (no trajectory awareness) + # "trajectory" : keep every trajectory in a single PPO mini-batch and apply + # trajectory-level loss weighting + # TODO Implemented, awaiting testing + # "step" : step-index-based grouping with reverse-order update (TODO:SeeUPO: Sequence-Level Agentic-RL with Convergence Guarantees https://arxiv.org/abs/2602.06554) + minibatch_mode: "trajectory" + + # TODO Implemented, awaiting testing + # --- Step-based mini-batch parameters (active when minibatch_mode == "step") (SeeUPO: Sequence-Level Agentic-RL with Convergence Guarantees https://arxiv.org/abs/2602.06554) + step_minibatch: + # Maximum number of interaction steps. null = auto-infer from data + # (max step_index + 1 in the current batch). + max_steps: null diff --git a/agent_r1/ray_agent_trainer.py b/agent_r1/ray_agent_trainer.py index ab0b639..c751954 100644 --- a/agent_r1/ray_agent_trainer.py +++ b/agent_r1/ray_agent_trainer.py @@ -35,6 +35,9 @@ from tqdm import tqdm from agent_r1.metric_utils import compute_data_metrics +from agent_r1.utils.step_batching import group_batch_by_step_index +from agent_r1.utils.traj_loss import apply_step_pool_loss_weights, apply_traj_loss_weights +from agent_r1.utils.trajectory_batching import pack_trajectories_into_minibatches from verl import DataProto from verl.experimental.dataset.sampler import AbstractCurriculumSampler from verl.protocol import pad_dataproto_to_divisor @@ -58,6 +61,7 @@ from verl.utils.config import omega_conf_to_dataclass from verl.utils.debug import marked_timer from verl.utils.metric import reduce_metrics +from verl.utils.py_functional import append_to_dict from verl.utils.rollout_skip import RolloutSkip @@ -164,9 +168,6 @@ def compute_advantage( norm_adv_by_std_in_grpo: bool = True, config: Optional[AlgoConfig] = None, ) -> DataProto: - # TODO: 重写所有 core_algos 中的 advantage 函数,适配新型的 agent flow 数据结构 - # 多行 data 对应一条完整轨迹,通过 non_tensor_batch["trajectory_uids"] 来区分不同轨迹,每条轨迹包含多行 data。 - # 通过 non_tensor_batch["step_indices"] 来区分同一条轨迹内的不同 step 的顺序。 """Compute advantage estimates for policy optimization. This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. @@ -1018,6 +1019,206 @@ def fit(self): # The dataset may be changed after each training batch self.train_dataset.on_batch_end(batch=batch) + def _prepare_trajectory_minibatches(self, batch: DataProto, ppo_mini_batch_size: int) -> DataProto: + """Thin wrapper around pack_trajectories_into_minibatches. + + Reorders and pads the step-level batch so that + DataProto.split(ppo_mini_batch_size) — as performed inside the verl + actor/critic workers — yields mini-batches where every trajectory is + fully contained (no cross-trajectory truncation). + + Args: + batch: Full step-level DataProto after advantage computation. + ppo_mini_batch_size: The mini-batch size used by the worker's + internal split() call. + + Returns: + Reordered DataProto ready for trajectory-aligned split(). + """ + return pack_trajectories_into_minibatches(batch, ppo_mini_batch_size) + + def _validate_effective_mini_batch_size( + self, + batch: DataProto, + ppo_mini_batch_size: int, + component: str, + ) -> None: + """Validate that PPO mini-batch size does not exceed valid train rows.""" + if ppo_mini_batch_size <= 0: + raise ValueError(f"{component}: ppo_mini_batch_size must be positive, got {ppo_mini_batch_size}.") + + is_pad = batch.non_tensor_batch.get("is_pad", None) + if is_pad is None: + effective_train_batch_size = len(batch) + else: + effective_train_batch_size = int((~is_pad.astype(bool)).sum()) + + if ppo_mini_batch_size > effective_train_batch_size: + raise ValueError( + f"{component}: ppo_mini_batch_size={ppo_mini_batch_size} exceeds the current " + f"effective train batch size ({effective_train_batch_size} valid rows). " + "Please reduce ppo_mini_batch_size or increase train_batch_size." + ) + + def _resolve_minibatch_mode(self) -> str: + """Resolve the effective mini-batch partitioning mode. + + On the new-engine path (``use_legacy_worker_impl == "disable"``), + ``"trajectory"`` and ``"step"`` modes are not yet supported and fall + back to ``"default"`` with a warning. + + Returns: + One of ``"default"``, ``"trajectory"``, ``"step"``. + """ + mode = str(self.config.trainer.get("minibatch_mode", "trajectory")) + + if mode not in ("default", "trajectory", "step"): + raise ValueError( + f"Invalid minibatch_mode '{mode}'. " + f"Must be one of: 'default', 'trajectory', 'step'." + ) + + if mode in ("trajectory", "step") and self.use_legacy_worker_impl == "disable": + print( + f"[WARNING] minibatch_mode='{mode}' is not supported on the " + f"new-engine path (use_legacy_worker_impl='disable'). " + f"Falling back to 'default'." + ) + mode = "default" + + return mode + + # ------------------------------------------------------------------ + # Actor update: three-way dispatch + # ------------------------------------------------------------------ + + def _update_actor(self, batch: DataProto) -> DataProto: + """Override: dispatches to the appropriate mini-batch strategy before + forwarding to the parent update. + + Modes (controlled by ``trainer.minibatch_mode``): + - ``"default"``: verl-native contiguous split (no trajectory awareness). + - ``"trajectory"``: trajectory-aware packing plus per-trajectory loss weighting. + - ``"step"``: step-index-based pools, updated in reverse order. + """ + mode = self._resolve_minibatch_mode() + + if mode == "trajectory": + return self._update_actor_trajectory(batch) + elif mode == "step": + return self._update_actor_step_based(batch) + else: + return super()._update_actor(batch) + + def _update_actor_trajectory(self, batch: DataProto) -> DataProto: + """Pack full trajectories into mini-batches, then delegate to the parent update.""" + mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size + self._validate_effective_mini_batch_size(batch, mini_batch_size, component="actor trajectory mode") + batch = self._prepare_trajectory_minibatches(batch, mini_batch_size) + + actor_loss_agg_mode = str(self.config.actor_rollout_ref.actor.loss_agg_mode) + required_mode = "seq-mean-token-sum" + if actor_loss_agg_mode != required_mode: + raise ValueError( + "Trajectory mini-batching requires " + "actor_rollout_ref.actor.loss_agg_mode='seq-mean-token-sum'. " + f"Got '{actor_loss_agg_mode}'." + ) + + batch = apply_traj_loss_weights(batch, mini_batch_size) + + return super()._update_actor(batch) + + def _update_actor_step_based(self, batch: DataProto) -> DataProto: + """Step-based mini-batch update: iterate step pools in reverse order. + + 1. ``group_batch_by_step_index(batch)`` -> ``step_pools[0..T-1]`` + 2. For ``t`` in ``T-1, T-2, ..., 0``: + - Send ``step_pools[t]`` to the parent ``_update_actor`` which + handles the actual worker dispatch / split / forward-backward. + 3. Aggregate metrics from all step updates. + """ + step_cfg = self.config.trainer.get("step_minibatch", {}) + max_steps = step_cfg.get("max_steps", None) + + mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size + self._validate_effective_mini_batch_size(batch, mini_batch_size, component="actor step mode") + step_pools = group_batch_by_step_index( + batch, max_steps=max_steps, ppo_mini_batch_size=mini_batch_size, + ) + + aggregated_metrics: dict[str, list] = {} + + for t in reversed(range(len(step_pools))): + pool = step_pools[t] + if len(pool) == 0: + continue + + # Apply step-pool loss weights so that placeholder rows produce + # zero gradient and valid rows are uniformly weighted. + pool = apply_step_pool_loss_weights( + pool, mini_batch_size, mode="uniform-valid", + ) + + # Delegate to the verl base-class _update_actor which handles + # worker dispatch, split into mini/micro batches, forward-backward. + step_output = super()._update_actor(pool) + + # Preserve verl's collected metric structure and let the caller's + # reduce_metrics() perform the final aggregation. + step_metrics = step_output.meta_info.get("metrics", {}) + append_to_dict(aggregated_metrics, step_metrics) + + return DataProto.from_single_dict( + data={}, meta_info={"metrics": aggregated_metrics} + ) + + # ------------------------------------------------------------------ + # Critic update: three-way dispatch + # ------------------------------------------------------------------ + + def _update_critic(self, batch: DataProto) -> DataProto: + """Override: dispatches to the appropriate mini-batch strategy before + forwarding to the parent critic update.""" + mode = self._resolve_minibatch_mode() + + if mode == "trajectory": + mini_batch_size = self.config.critic.ppo_mini_batch_size + self._validate_effective_mini_batch_size(batch, mini_batch_size, component="critic trajectory mode") + batch = self._prepare_trajectory_minibatches(batch, mini_batch_size) + return super()._update_critic(batch) + elif mode == "step": + return self._update_critic_step_based(batch) + else: + return super()._update_critic(batch) + + def _update_critic_step_based(self, batch: DataProto) -> DataProto: + """Step-based critic update: iterate step pools in reverse order, + symmetric to ``_update_actor_step_based``.""" + step_cfg = self.config.trainer.get("step_minibatch", {}) + max_steps = step_cfg.get("max_steps", None) + + mini_batch_size = self.config.critic.ppo_mini_batch_size + self._validate_effective_mini_batch_size(batch, mini_batch_size, component="critic step mode") + step_pools = group_batch_by_step_index( + batch, max_steps=max_steps, ppo_mini_batch_size=mini_batch_size, + ) + + aggregated_metrics: dict[str, list] = {} + + for t in reversed(range(len(step_pools))): + pool = step_pools[t] + if len(pool) == 0: + continue + step_output = super()._update_critic(pool) + + step_metrics = step_output.meta_info.get("metrics", {}) + append_to_dict(aggregated_metrics, step_metrics) + + return DataProto.from_single_dict( + data={}, meta_info={"metrics": aggregated_metrics} + ) + def _pad_dataproto_to_world_size(self, batch): world_sizes = [] if self.use_critic and self.critic_wg.world_size != 0: diff --git a/agent_r1/utils/__init__.py b/agent_r1/utils/__init__.py new file mode 100644 index 0000000..07dbea5 --- /dev/null +++ b/agent_r1/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/agent_r1/utils/step_batching.py b/agent_r1/utils/step_batching.py new file mode 100644 index 0000000..aaed2c4 --- /dev/null +++ b/agent_r1/utils/step_batching.py @@ -0,0 +1,201 @@ +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Step-based mini-batch utilities. + +Provides functions for reorganising a step-level DataProto batch (where each +row = one agent step) into per-step-index *step pools*, so that the PPO update +loop can iterate over step indices in reverse order (T -> T-1 -> ... -> 0). + +Design constraints +------------------ +- Each step pool ``D_t`` contains **all** rows whose ``step_indices == t``, + plus placeholder rows for trajectories shorter than ``t + 1`` steps. +- Placeholder rows have ``response_mask == 0``, ``attention_mask == 0``, and + ``advantages == 0`` so they never contribute to policy loss, advantage + normalisation, or transformer attention. +- ``trajectory_uids`` is preserved on every row (including placeholders) so + that downstream trajectory-level credit / metrics remain valid. +- Pure functions only: no trainer state, no global side effects. +""" + +from __future__ import annotations + +import numpy as np +import torch + +from verl import DataProto +from verl.protocol import pad_dataproto_to_divisor + +_MASK_FIELDS_TO_ZERO = ("response_mask", "attention_mask") +_ADVANTAGE_FIELDS_TO_ZERO = ("advantages",) + + +def group_batch_by_step_index( + batch: DataProto, + max_steps: int | None = None, + ppo_mini_batch_size: int | None = None, +) -> list[DataProto]: + """Partition a step-level batch into per-step-index pools. + + Args: + batch: Full rollout batch where each row is one agent step. + Must contain ``non_tensor_batch["step_indices"]`` and + ``non_tensor_batch["trajectory_uids"]``. + max_steps: Number of step pools to create (0 .. max_steps-1). + ``None`` auto-infers from the data (``max(step_indices) + 1``). + ppo_mini_batch_size: If provided, each step pool is padded to be + divisible by this value (needed by the downstream worker split). + + Returns: + A list of ``max_steps`` :class:`DataProto` objects. ``step_pools[t]`` + contains all rows with ``step_indices == t`` plus placeholder rows for + trajectories that have fewer than ``t + 1`` steps. Placeholder rows + are marked with ``non_tensor_batch["is_placeholder"] == True``. + + Raises: + ValueError: If required non-tensor fields are missing. + """ + if "step_indices" not in batch.non_tensor_batch: + raise ValueError("batch.non_tensor_batch must contain 'step_indices'") + if "trajectory_uids" not in batch.non_tensor_batch: + raise ValueError("batch.non_tensor_batch must contain 'trajectory_uids'") + + step_indices = batch.non_tensor_batch["step_indices"].astype(np.int32) + traj_uids = batch.non_tensor_batch["trajectory_uids"] + is_pad = batch.non_tensor_batch.get( + "is_pad", np.zeros(len(batch), dtype=bool) + ).astype(bool) + + if max_steps is None: + valid_steps = step_indices[~is_pad] + max_steps = int(valid_steps.max()) + 1 if len(valid_steps) > 0 else 1 + + # Build trajectory -> max_step_index mapping (valid rows only). + traj_max_step: dict[object, int] = {} + for i in range(len(batch)): + if is_pad[i]: + continue + uid = traj_uids[i] + s = int(step_indices[i]) + if uid not in traj_max_step or s > traj_max_step[uid]: + traj_max_step[uid] = s + + all_traj_uids = list(traj_max_step.keys()) + + # Pick a template row for placeholder creation (prefer an existing pad row). + pad_indices = np.where(is_pad)[0] + if len(pad_indices) > 0: + template_idx = int(pad_indices[0]) + else: + template_idx = 0 + + step_pools: list[DataProto] = [] + + for t in range(max_steps): + # Collect global indices of valid rows at step t. + real_indices: list[int] = [] + present_uids: set[object] = set() + for i in range(len(batch)): + if is_pad[i]: + continue + if int(step_indices[i]) == t: + real_indices.append(i) + present_uids.add(traj_uids[i]) + + # Trajectories that need a placeholder at step t: those with + # max_step < t (they don't have a real row for this step index). + need_placeholder_uids = [ + uid for uid in all_traj_uids + if uid not in present_uids and traj_max_step[uid] < t + ] + + n_real = len(real_indices) + n_placeholder = len(need_placeholder_uids) + pool_indices = real_indices + [template_idx] * n_placeholder + is_placeholder_flags = [False] * n_real + [True] * n_placeholder + + pool = batch.select_idxs(np.array(pool_indices, dtype=np.int64)) + + # Stamp metadata on the pool. + pool.non_tensor_batch["is_placeholder"] = np.array( + is_placeholder_flags, dtype=bool + ) + # Overwrite trajectory_uids on placeholder rows so they still link + # back to their originating trajectory. + for k, uid in enumerate(need_placeholder_uids): + pool.non_tensor_batch["trajectory_uids"][n_real + k] = uid + # Overwrite step_indices on placeholder rows. + for k in range(n_placeholder): + pool.non_tensor_batch["step_indices"][n_real + k] = np.int32(t) + + # Zero out masks and advantages on placeholder rows. + if pool.batch is not None and n_placeholder > 0: + ph_mask = torch.zeros(len(pool), dtype=torch.bool) + ph_mask[n_real:] = True + for field in _MASK_FIELDS_TO_ZERO: + if field in pool.batch.keys(): + pool.batch[field][ph_mask] = 0 + for field in _ADVANTAGE_FIELDS_TO_ZERO: + if field in pool.batch.keys(): + pool.batch[field][ph_mask] = 0 + + # Pad to be divisible by ppo_mini_batch_size if requested. + if ppo_mini_batch_size is not None and len(pool) > 0: + pool = _pad_step_pool(pool, ppo_mini_batch_size) + + step_pools.append(pool) + + return step_pools + + +def _pad_step_pool(pool: DataProto, ppo_mini_batch_size: int) -> DataProto: + """Pad a step pool so its length is divisible by ``ppo_mini_batch_size``. + + Uses the same pattern as trajectory_batching: copy a template row and zero + out masks. Padding rows are marked ``is_placeholder = True``. + """ + remainder = len(pool) % ppo_mini_batch_size + if remainder == 0: + return pool + + n_pad = ppo_mini_batch_size - remainder + original_len = len(pool) + + pool, _ = pad_dataproto_to_divisor(pool, ppo_mini_batch_size) + + # Extend is_placeholder array. + old_flags = pool.non_tensor_batch.get( + "is_placeholder", + np.zeros(original_len, dtype=bool), + ) + new_flags = np.concatenate([ + old_flags[:original_len], + np.ones(n_pad, dtype=bool), + ]) + pool.non_tensor_batch["is_placeholder"] = new_flags + + # Zero out mask / advantage fields on the newly padded rows. + if pool.batch is not None and n_pad > 0: + pad_mask = torch.zeros(len(pool), dtype=torch.bool) + pad_mask[original_len:] = True + for field in _MASK_FIELDS_TO_ZERO: + if field in pool.batch.keys(): + pool.batch[field][pad_mask] = 0 + for field in _ADVANTAGE_FIELDS_TO_ZERO: + if field in pool.batch.keys(): + pool.batch[field][pad_mask] = 0 + + return pool + diff --git a/agent_r1/utils/traj_loss.py b/agent_r1/utils/traj_loss.py new file mode 100644 index 0000000..9214611 --- /dev/null +++ b/agent_r1/utils/traj_loss.py @@ -0,0 +1,278 @@ +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Loss weighting helpers for trajectory-aware and step-aware mini-batching. + +For trajectory mode we pre-scale per-step advantages so verl's +``seq-mean-token-sum`` aggregation gives every trajectory equal weight: + + L = (1 / N_traj) * sum_T (1 / |T|) * sum_{i in T} token_sum(i) + +instead of the default per-step mean: + + L = (1 / ppo_mini_batch_size) * sum_i token_sum(i) + +The input batch must already be packed so each PPO mini-batch contains whole +trajectories. +""" + +from __future__ import annotations + +from collections import Counter + +import numpy as np +import torch + +from verl import DataProto + +_STEP_POOL_SUPPORTED_MODES = frozenset(["uniform-valid"]) + + +def compute_traj_step_weights( + trajectory_uids: np.ndarray, + is_pad: np.ndarray, + ppo_mini_batch_size: int, +) -> torch.Tensor: + """Compute per-step advantage scaling weights for trajectory-level loss. + + For each valid step i in trajectory T (with |T| steps) inside a mini-batch + that contains N_traj distinct trajectories: + + w_i = ppo_mini_batch_size / (N_traj * |T_i|) + + Pad rows receive weight 0. + + The weights satisfy ``sum_i w_i = ppo_mini_batch_size`` within every + mini-batch chunk, which is the same total weight that verl's unweighted + ``seq-mean-token-sum`` would apply. The gradient contribution of each + trajectory is therefore ``ppo_mini_batch_size / N_traj`` regardless of + how many steps it contains. + + Args: + trajectory_uids: Array of trajectory uid strings/objects, shape (bsz,). + Must come from ``batch.non_tensor_batch["trajectory_uids"]`` after + trajectory-aware packing. + is_pad: Boolean array, shape (bsz,). True for padding rows. + ppo_mini_batch_size: Actor's ``ppo_mini_batch_size`` config value. + The packed batch length must be divisible by this. + + Returns: + Float tensor of shape (bsz,) with per-step weights. + + Raises: + ValueError: If the batch length is not divisible by + ``ppo_mini_batch_size``. + ValueError: If any valid mini-batch chunk contains zero valid rows + (degenerate packing). + """ + bsz = len(trajectory_uids) + if bsz % ppo_mini_batch_size != 0: + raise ValueError( + f"Batch size {bsz} is not divisible by ppo_mini_batch_size " + f"{ppo_mini_batch_size}. Trajectory packing must be applied " + "before trajectory loss weighting." + ) + + weights = torch.zeros(bsz, dtype=torch.float32) + n_chunks = bsz // ppo_mini_batch_size + + for k in range(n_chunks): + start = k * ppo_mini_batch_size + end = start + ppo_mini_batch_size + + chunk_is_pad = is_pad[start:end] + chunk_uids = trajectory_uids[start:end] + + valid_mask = ~chunk_is_pad.astype(bool) + valid_uids = chunk_uids[valid_mask] + + if len(valid_uids) == 0: + # All-pad chunk (can occur if world-size padding exceeds one full + # mini-batch). Leave weights as 0 for the whole chunk. + continue + + uid_counts = Counter(valid_uids) + n_traj = len(uid_counts) + + for local_idx in range(ppo_mini_batch_size): + global_idx = start + local_idx + if chunk_is_pad[local_idx]: + continue + uid = chunk_uids[local_idx] + traj_len = uid_counts[uid] + weights[global_idx] = ppo_mini_batch_size / (n_traj * traj_len) + + return weights + + +def apply_traj_loss_weights( + batch: DataProto, + ppo_mini_batch_size: int, +) -> DataProto: + """Return a copy of *batch* with ``advantages`` pre-scaled for + trajectory-level loss aggregation. + + The returned batch is a shallow copy with only the ``advantages`` tensor + replaced; all other fields (including the original ``response_mask``) are + shared references and are not modified. + + Args: + batch: Full packed batch. Must have + ``batch["advantages"]`` (shape ``(bsz, response_length)``) and + ``non_tensor_batch["trajectory_uids"]`` / ``["is_pad"]``. + ppo_mini_batch_size: Actor's ``ppo_mini_batch_size`` config value. + + Returns: + DataProto with scaled ``advantages``. + + Raises: + KeyError: If required batch fields are absent. + """ + trajectory_uids = batch.non_tensor_batch["trajectory_uids"] + is_pad = batch.non_tensor_batch.get( + "is_pad", np.zeros(len(batch), dtype=bool) + ).astype(bool) + + weights = compute_traj_step_weights( + trajectory_uids=trajectory_uids, + is_pad=is_pad, + ppo_mini_batch_size=ppo_mini_batch_size, + ) + + # Move weights to the same device as advantages before broadcasting. + advantages = batch.batch["advantages"] + weights = weights.to(advantages.device) + + # advantages shape: (bsz, response_length); broadcast weight per row. + scaled_advantages = advantages * weights.unsqueeze(1) + + # Shallow-copy the DataProto so the caller's original batch is untouched. + new_batch = DataProto( + batch=batch.batch.clone(recurse=False), + non_tensor_batch=batch.non_tensor_batch, + meta_info=batch.meta_info, + ) + new_batch.batch["advantages"] = scaled_advantages + + return new_batch + + +# --------------------------------------------------------------------------- +# Step-pool-level loss weighting (used by step-based mini-batch mode) +# --------------------------------------------------------------------------- + + +def compute_step_pool_weights( + is_placeholder: np.ndarray, + ppo_mini_batch_size: int, +) -> torch.Tensor: + """Compute per-row weights for a single step pool. + + In step-based mode every row in a step pool corresponds to the same step + index across different trajectories. Placeholder rows (for trajectories + shorter than the current step) must contribute **zero** to the loss. + Valid rows are weighted uniformly: ``w_i = ppo_mini_batch_size / N_valid`` + within each mini-batch chunk, so the total weight matches what verl's + ``seq-mean-token-sum`` expects. + + Args: + is_placeholder: Boolean array, shape ``(pool_size,)``. + ``True`` for placeholder / padding rows. + ppo_mini_batch_size: The mini-batch size used by the downstream worker + split. Pool size must be divisible by this. + + Returns: + Float tensor of shape ``(pool_size,)`` with per-row weights. + """ + pool_size = len(is_placeholder) + if pool_size == 0: + return torch.zeros(0, dtype=torch.float32) + + if pool_size % ppo_mini_batch_size != 0: + raise ValueError( + f"Step pool size {pool_size} is not divisible by " + f"ppo_mini_batch_size {ppo_mini_batch_size}." + ) + + weights = torch.zeros(pool_size, dtype=torch.float32) + n_chunks = pool_size // ppo_mini_batch_size + + for k in range(n_chunks): + start = k * ppo_mini_batch_size + end = start + ppo_mini_batch_size + + chunk_ph = is_placeholder[start:end].astype(bool) + n_valid = int((~chunk_ph).sum()) + if n_valid == 0: + continue + + w = ppo_mini_batch_size / n_valid + for local_idx in range(ppo_mini_batch_size): + if not chunk_ph[local_idx]: + weights[start + local_idx] = w + + return weights + + +def apply_step_pool_loss_weights( + pool: DataProto, + ppo_mini_batch_size: int, + mode: str = "uniform-valid", +) -> DataProto: + """Pre-scale advantages in a step pool so that placeholders produce zero + gradient and valid rows are uniformly weighted. + + Args: + pool: A single step-pool DataProto (output of + :func:`group_batch_by_step_index`). Must have + ``batch["advantages"]`` and + ``non_tensor_batch["is_placeholder"]``. + ppo_mini_batch_size: Actor's ``ppo_mini_batch_size`` config value. + mode: Weighting mode. Currently only ``"uniform-valid"`` is supported. + + Returns: + DataProto with scaled ``advantages``. Placeholder rows will have + ``advantages == 0`` regardless of their original value. + """ + if not mode or mode == "disabled": + return pool + + if mode not in _STEP_POOL_SUPPORTED_MODES: + raise ValueError( + f"Unsupported step pool loss mode: '{mode}'. " + f"Supported: {sorted(_STEP_POOL_SUPPORTED_MODES)}" + ) + + is_placeholder = pool.non_tensor_batch.get( + "is_placeholder", np.zeros(len(pool), dtype=bool) + ).astype(bool) + + weights = compute_step_pool_weights( + is_placeholder=is_placeholder, + ppo_mini_batch_size=ppo_mini_batch_size, + ) + + advantages = pool.batch["advantages"] + weights = weights.to(advantages.device) + + scaled_advantages = advantages * weights.unsqueeze(1) + + new_pool = DataProto( + batch=pool.batch.clone(recurse=False), + non_tensor_batch=pool.non_tensor_batch, + meta_info=pool.meta_info, + ) + new_pool.batch["advantages"] = scaled_advantages + + return new_pool diff --git a/agent_r1/utils/trajectory_batching.py b/agent_r1/utils/trajectory_batching.py new file mode 100644 index 0000000..de9e93f --- /dev/null +++ b/agent_r1/utils/trajectory_batching.py @@ -0,0 +1,303 @@ +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Trajectory-aware mini-batch packing utilities. + +These helpers reorder step-level DataProto rows so each PPO mini-batch +contains whole trajectories. The core invariant is: + + Every step that belongs to trajectory T appears in the same mini-batch. + +This keeps multi-step agent rollouts intact during forward and backward passes +and avoids splitting a single trajectory across separate PPO updates. +""" + +from __future__ import annotations + +from collections import OrderedDict +from dataclasses import dataclass + +import numpy as np +import torch + +from verl import DataProto + +# Tensor fields that must be zeroed for padding rows so that they never +# contribute to masked-loss computations in actor/critic workers. +# response_mask is the primary loss gate; attention_mask controls the +# transformer's attention - zeroing it prevents the model from attending +# to padding tokens and producing non-zero log-probs that could leak into loss. +_MASK_FIELDS_TO_ZERO = ("response_mask", "attention_mask") + + +@dataclass(frozen=True) +class _TrajectorySummary: + uid: object + indices: tuple[int, ...] + step_lens: tuple[int, ...] + num_steps: int + sum_step_lens: int + max_step_len: int + min_step_len: int + arrival_order: int + + +@dataclass +class _MiniBatchState: + row_indices: list[int] + used_steps: int + sum_step_lens: int + max_step_len: int + min_step_len: int + creation_order: int + + +def _compute_step_lengths(attention_mask: torch.Tensor, is_pad_arr: np.ndarray) -> np.ndarray: + step_lens = attention_mask.sum(dim=1).detach().cpu().numpy().astype(np.int32) + step_lens[is_pad_arr] = 0 + return step_lens + + +def _compute_pad_cost(used_steps: int, sum_step_lens: int, max_step_len: int) -> int: + if used_steps == 0: + return 0 + return used_steps * max_step_len - sum_step_lens + + +def _build_trajectory_summaries( + traj_uids: np.ndarray, + step_idxs: np.ndarray, + step_lens: np.ndarray, + valid_global_indices: np.ndarray, + ppo_mini_batch_size: int, +) -> list[_TrajectorySummary]: + traj_to_sorted_indices: OrderedDict[object, list[int]] = OrderedDict() + for global_idx in valid_global_indices: + uid = traj_uids[global_idx] + if uid not in traj_to_sorted_indices: + traj_to_sorted_indices[uid] = [] + traj_to_sorted_indices[uid].append(int(global_idx)) + + summaries = [] + for arrival_order, (uid, indices) in enumerate(traj_to_sorted_indices.items()): + ordered_indices = tuple(sorted(indices, key=lambda idx: int(step_idxs[idx]))) + ordered_lens = tuple(int(step_lens[idx]) for idx in ordered_indices) + num_steps = len(ordered_indices) + if num_steps > ppo_mini_batch_size: + raise ValueError( + f"Trajectory '{uid}' has {num_steps} steps, which exceeds " + f"ppo_mini_batch_size={ppo_mini_batch_size}. " + f"Please increase ppo_mini_batch_size to at least {num_steps}." + ) + summaries.append( + _TrajectorySummary( + uid=uid, + indices=ordered_indices, + step_lens=ordered_lens, + num_steps=num_steps, + sum_step_lens=sum(ordered_lens), + max_step_len=max(ordered_lens), + min_step_len=min(ordered_lens), + arrival_order=arrival_order, + ) + ) + + return sorted( + summaries, + key=lambda item: ( + -item.max_step_len, + -item.sum_step_lens, + -item.num_steps, + item.arrival_order, + ), + ) + + +def _candidate_key( + minibatch: _MiniBatchState, + trajectory: _TrajectorySummary, + ppo_mini_batch_size: int, +) -> tuple[int, int, int, int]: + used_after = minibatch.used_steps + trajectory.num_steps + max_after = max(minibatch.max_step_len, trajectory.max_step_len) + min_after = min(minibatch.min_step_len, trajectory.min_step_len) + sum_after = minibatch.sum_step_lens + trajectory.sum_step_lens + + delta_pad = _compute_pad_cost(used_after, sum_after, max_after) - _compute_pad_cost( + minibatch.used_steps, + minibatch.sum_step_lens, + minibatch.max_step_len, + ) + range_after = max_after - min_after + remaining_after = ppo_mini_batch_size - used_after + + return (delta_pad, range_after, remaining_after, minibatch.creation_order) + + +def _place_trajectory( + minibatches: list[_MiniBatchState], + trajectory: _TrajectorySummary, + ppo_mini_batch_size: int, +) -> None: + best_index = None + best_key = None + + for idx, minibatch in enumerate(minibatches): + if minibatch.used_steps + trajectory.num_steps > ppo_mini_batch_size: + continue + current_key = _candidate_key(minibatch, trajectory, ppo_mini_batch_size) + if best_key is None or current_key < best_key: + best_index = idx + best_key = current_key + + if best_index is None: + minibatches.append( + _MiniBatchState( + row_indices=list(trajectory.indices), + used_steps=trajectory.num_steps, + sum_step_lens=trajectory.sum_step_lens, + max_step_len=trajectory.max_step_len, + min_step_len=trajectory.min_step_len, + creation_order=len(minibatches), + ) + ) + return + + target = minibatches[best_index] + target.row_indices.extend(trajectory.indices) + target.used_steps += trajectory.num_steps + target.sum_step_lens += trajectory.sum_step_lens + target.max_step_len = max(target.max_step_len, trajectory.max_step_len) + target.min_step_len = min(target.min_step_len, trajectory.min_step_len) + + +def pack_trajectories_into_minibatches( + batch: DataProto, + ppo_mini_batch_size: int, +) -> DataProto: + """Reorder and pad a step-level batch so that split(ppo_mini_batch_size) + yields mini-batches where every trajectory is fully contained. + + The function is the single point of truth for trajectory-level packing. + It operates on the flattened step-row layout produced by AgentFlowWorker, + where each row is one agent step and trajectories are identified by + non_tensor_batch["trajectory_uids"]. + + Algorithm + --------- + 1. Separate valid rows (is_pad=False) from existing world-size padding rows. + 2. Compute per-step effective lengths from attention_mask and build + trajectory summaries sorted by packing difficulty. + 3. Dynamically place each full trajectory into the candidate mini-batch + that minimizes incremental padding cost; break ties by resulting + step-length range, then remaining capacity, then creation order. + 4. Pad each mini-batch to exactly ppo_mini_batch_size rows by appending + copies of an existing pad row (response_mask all-zeros). + 5. Rebuild via DataProto.select_idxs(reorder_indices) which propagates all + tensor and non_tensor fields; overwrite is_pad with the new mask. + + Output guarantees + ----------------- + - All original valid rows are present exactly once, with all fields intact + (trajectory_uids, step_indices, response_mask, advantages, old_log_probs, ...). + - Padding rows have response_mask == 0, so they contribute zero to any + masked loss aggregation (seq-mean-token-sum, token-mean, seq-mean-token-mean). + - Output batch size == num_mini_batches * ppo_mini_batch_size, which is + trivially divisible by ppo_mini_batch_size for DataProto.split(). + + Args: + batch: Step-level DataProto with non_tensor_batch fields + "trajectory_uids", "step_indices", and optionally "is_pad". + ppo_mini_batch_size: Maximum number of rows per mini-batch. Must be + >= the longest single trajectory's step count. + + Returns: + A new DataProto with rows reordered and padded for trajectory-aligned + splitting. meta_info is forwarded unchanged. + + Raises: + ValueError: If "trajectory_uids" or "step_indices" are absent. + ValueError: If ``ppo_mini_batch_size`` exceeds the number of valid rows + in the current batch. + ValueError: If any single trajectory has more steps than ppo_mini_batch_size. + """ + if "trajectory_uids" not in batch.non_tensor_batch: + raise ValueError("batch.non_tensor_batch must contain 'trajectory_uids'") + if "step_indices" not in batch.non_tensor_batch: + raise ValueError("batch.non_tensor_batch must contain 'step_indices'") + if batch.batch is None or "attention_mask" not in batch.batch.keys(): + raise ValueError("batch.batch must contain 'attention_mask'") + + is_pad_arr = batch.non_tensor_batch.get("is_pad", np.zeros(len(batch), dtype=bool)).astype(bool) + traj_uids = batch.non_tensor_batch["trajectory_uids"] + step_idxs = batch.non_tensor_batch["step_indices"] + attention_mask = batch.batch["attention_mask"] + + valid_global_indices = np.where(~is_pad_arr)[0] + pad_global_indices = np.where(is_pad_arr)[0] + + if ppo_mini_batch_size > len(valid_global_indices): + raise ValueError( + f"ppo_mini_batch_size={ppo_mini_batch_size} exceeds the current " + f"effective train batch size ({len(valid_global_indices)} valid rows)." + ) + + # Padding template: reuse an existing pad row (response_mask guaranteed all-zero + # by _pad_dataproto_to_world_size). Fallback: last valid row (rare, edge case). + if len(pad_global_indices) > 0: + pad_template_idx = int(pad_global_indices[0]) + elif len(valid_global_indices) > 0: + pad_template_idx = int(valid_global_indices[-1]) + else: + raise ValueError("batch contains no rows to pack.") + + step_lens = _compute_step_lengths(attention_mask, is_pad_arr) + summaries = _build_trajectory_summaries( + traj_uids=traj_uids, + step_idxs=step_idxs, + step_lens=step_lens, + valid_global_indices=valid_global_indices, + ppo_mini_batch_size=ppo_mini_batch_size, + ) + + minibatches: list[_MiniBatchState] = [] + for trajectory in summaries: + _place_trajectory(minibatches, trajectory, ppo_mini_batch_size) + + # --- Build final index array: real rows + padding up to ppo_mini_batch_size --- + final_indices: list[int] = [] + final_is_pad: list[bool] = [] + + for minibatch in minibatches: + n_real = len(minibatch.row_indices) + n_pad = ppo_mini_batch_size - n_real + final_indices.extend(minibatch.row_indices) + final_is_pad.extend([False] * n_real) + final_indices.extend([pad_template_idx] * n_pad) + final_is_pad.extend([True] * n_pad) + + # --- Rebuild DataProto with the new row order --- + new_batch = batch.select_idxs(np.array(final_indices, dtype=np.int64)) + new_batch.non_tensor_batch["is_pad"] = np.array(final_is_pad, dtype=bool) + + # Zero out mask fields for all padding rows. pad_dataproto_to_divisor creates + # pad rows by copying real rows verbatim, so their masks are not guaranteed + # to be zero unless we fix them here. + if new_batch.batch is not None: + pad_row_mask = torch.from_numpy(np.array(final_is_pad, dtype=bool)) + for field in _MASK_FIELDS_TO_ZERO: + if field in new_batch.batch.keys(): + new_batch.batch[field][pad_row_mask] = 0 + + return new_batch From c35c521a8796871b46be8f99e361693e55e4d15b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:02:21 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- agent_r1/ray_agent_trainer.py | 25 ++++++++++++------------- agent_r1/utils/step_batching.py | 24 +++++++++--------------- agent_r1/utils/traj_loss.py | 18 ++++-------------- 3 files changed, 25 insertions(+), 42 deletions(-) diff --git a/agent_r1/ray_agent_trainer.py b/agent_r1/ray_agent_trainer.py index c751954..8efe0bd 100644 --- a/agent_r1/ray_agent_trainer.py +++ b/agent_r1/ray_agent_trainer.py @@ -1073,10 +1073,7 @@ def _resolve_minibatch_mode(self) -> str: mode = str(self.config.trainer.get("minibatch_mode", "trajectory")) if mode not in ("default", "trajectory", "step"): - raise ValueError( - f"Invalid minibatch_mode '{mode}'. " - f"Must be one of: 'default', 'trajectory', 'step'." - ) + raise ValueError(f"Invalid minibatch_mode '{mode}'. Must be one of: 'default', 'trajectory', 'step'.") if mode in ("trajectory", "step") and self.use_legacy_worker_impl == "disable": print( @@ -1144,7 +1141,9 @@ def _update_actor_step_based(self, batch: DataProto) -> DataProto: mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size self._validate_effective_mini_batch_size(batch, mini_batch_size, component="actor step mode") step_pools = group_batch_by_step_index( - batch, max_steps=max_steps, ppo_mini_batch_size=mini_batch_size, + batch, + max_steps=max_steps, + ppo_mini_batch_size=mini_batch_size, ) aggregated_metrics: dict[str, list] = {} @@ -1157,7 +1156,9 @@ def _update_actor_step_based(self, batch: DataProto) -> DataProto: # Apply step-pool loss weights so that placeholder rows produce # zero gradient and valid rows are uniformly weighted. pool = apply_step_pool_loss_weights( - pool, mini_batch_size, mode="uniform-valid", + pool, + mini_batch_size, + mode="uniform-valid", ) # Delegate to the verl base-class _update_actor which handles @@ -1169,9 +1170,7 @@ def _update_actor_step_based(self, batch: DataProto) -> DataProto: step_metrics = step_output.meta_info.get("metrics", {}) append_to_dict(aggregated_metrics, step_metrics) - return DataProto.from_single_dict( - data={}, meta_info={"metrics": aggregated_metrics} - ) + return DataProto.from_single_dict(data={}, meta_info={"metrics": aggregated_metrics}) # ------------------------------------------------------------------ # Critic update: three-way dispatch @@ -1201,7 +1200,9 @@ def _update_critic_step_based(self, batch: DataProto) -> DataProto: mini_batch_size = self.config.critic.ppo_mini_batch_size self._validate_effective_mini_batch_size(batch, mini_batch_size, component="critic step mode") step_pools = group_batch_by_step_index( - batch, max_steps=max_steps, ppo_mini_batch_size=mini_batch_size, + batch, + max_steps=max_steps, + ppo_mini_batch_size=mini_batch_size, ) aggregated_metrics: dict[str, list] = {} @@ -1215,9 +1216,7 @@ def _update_critic_step_based(self, batch: DataProto) -> DataProto: step_metrics = step_output.meta_info.get("metrics", {}) append_to_dict(aggregated_metrics, step_metrics) - return DataProto.from_single_dict( - data={}, meta_info={"metrics": aggregated_metrics} - ) + return DataProto.from_single_dict(data={}, meta_info={"metrics": aggregated_metrics}) def _pad_dataproto_to_world_size(self, batch): world_sizes = [] diff --git a/agent_r1/utils/step_batching.py b/agent_r1/utils/step_batching.py index aaed2c4..73f8292 100644 --- a/agent_r1/utils/step_batching.py +++ b/agent_r1/utils/step_batching.py @@ -74,9 +74,7 @@ def group_batch_by_step_index( step_indices = batch.non_tensor_batch["step_indices"].astype(np.int32) traj_uids = batch.non_tensor_batch["trajectory_uids"] - is_pad = batch.non_tensor_batch.get( - "is_pad", np.zeros(len(batch), dtype=bool) - ).astype(bool) + is_pad = batch.non_tensor_batch.get("is_pad", np.zeros(len(batch), dtype=bool)).astype(bool) if max_steps is None: valid_steps = step_indices[~is_pad] @@ -116,10 +114,7 @@ def group_batch_by_step_index( # Trajectories that need a placeholder at step t: those with # max_step < t (they don't have a real row for this step index). - need_placeholder_uids = [ - uid for uid in all_traj_uids - if uid not in present_uids and traj_max_step[uid] < t - ] + need_placeholder_uids = [uid for uid in all_traj_uids if uid not in present_uids and traj_max_step[uid] < t] n_real = len(real_indices) n_placeholder = len(need_placeholder_uids) @@ -129,9 +124,7 @@ def group_batch_by_step_index( pool = batch.select_idxs(np.array(pool_indices, dtype=np.int64)) # Stamp metadata on the pool. - pool.non_tensor_batch["is_placeholder"] = np.array( - is_placeholder_flags, dtype=bool - ) + pool.non_tensor_batch["is_placeholder"] = np.array(is_placeholder_flags, dtype=bool) # Overwrite trajectory_uids on placeholder rows so they still link # back to their originating trajectory. for k, uid in enumerate(need_placeholder_uids): @@ -180,10 +173,12 @@ def _pad_step_pool(pool: DataProto, ppo_mini_batch_size: int) -> DataProto: "is_placeholder", np.zeros(original_len, dtype=bool), ) - new_flags = np.concatenate([ - old_flags[:original_len], - np.ones(n_pad, dtype=bool), - ]) + new_flags = np.concatenate( + [ + old_flags[:original_len], + np.ones(n_pad, dtype=bool), + ] + ) pool.non_tensor_batch["is_placeholder"] = new_flags # Zero out mask / advantage fields on the newly padded rows. @@ -198,4 +193,3 @@ def _pad_step_pool(pool: DataProto, ppo_mini_batch_size: int) -> DataProto: pool.batch[field][pad_mask] = 0 return pool - diff --git a/agent_r1/utils/traj_loss.py b/agent_r1/utils/traj_loss.py index 9214611..e5237e8 100644 --- a/agent_r1/utils/traj_loss.py +++ b/agent_r1/utils/traj_loss.py @@ -140,9 +140,7 @@ def apply_traj_loss_weights( KeyError: If required batch fields are absent. """ trajectory_uids = batch.non_tensor_batch["trajectory_uids"] - is_pad = batch.non_tensor_batch.get( - "is_pad", np.zeros(len(batch), dtype=bool) - ).astype(bool) + is_pad = batch.non_tensor_batch.get("is_pad", np.zeros(len(batch), dtype=bool)).astype(bool) weights = compute_traj_step_weights( trajectory_uids=trajectory_uids, @@ -200,10 +198,7 @@ def compute_step_pool_weights( return torch.zeros(0, dtype=torch.float32) if pool_size % ppo_mini_batch_size != 0: - raise ValueError( - f"Step pool size {pool_size} is not divisible by " - f"ppo_mini_batch_size {ppo_mini_batch_size}." - ) + raise ValueError(f"Step pool size {pool_size} is not divisible by ppo_mini_batch_size {ppo_mini_batch_size}.") weights = torch.zeros(pool_size, dtype=torch.float32) n_chunks = pool_size // ppo_mini_batch_size @@ -249,14 +244,9 @@ def apply_step_pool_loss_weights( return pool if mode not in _STEP_POOL_SUPPORTED_MODES: - raise ValueError( - f"Unsupported step pool loss mode: '{mode}'. " - f"Supported: {sorted(_STEP_POOL_SUPPORTED_MODES)}" - ) + raise ValueError(f"Unsupported step pool loss mode: '{mode}'. Supported: {sorted(_STEP_POOL_SUPPORTED_MODES)}") - is_placeholder = pool.non_tensor_batch.get( - "is_placeholder", np.zeros(len(pool), dtype=bool) - ).astype(bool) + is_placeholder = pool.non_tensor_batch.get("is_placeholder", np.zeros(len(pool), dtype=bool)).astype(bool) weights = compute_step_pool_weights( is_placeholder=is_placeholder,