Skip to content

deepakvijaykee/rl-experiments

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

41 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

rl-experiments

Five connected pieces of code for studying what RL update rules actually do to a policy. The base is rl_sandbox/, a small PyTorch sandbox of bandit and sequence tasks with a registry of update rules to compare on them. Next to it sit four companions. rlm_grpo/ trains small causal LMs (0.5-0.6B) with GRPO on a tree-of-rollouts protocol. opd_sandbox/ is an appendix that probes on-policy distillation mechanics on similar toy tasks. vpo_sandbox/ isolates Vector Policy Optimization on a set-valued, vector-reward toy where the question is whether training should preserve a pool of trade-offs for test-time search. pedagogy_sandbox/ isolates Pedagogical RL as a privileged rollout-acquisition protocol with separate teacher and student policies.

The question driving all five pieces is older than any particular objective family. When the reward signal is sparse, noisy, delayed, privileged, or vector-valued, which rollouts deserve gradient credit, how should they be found, and at what granularity should a policy spend its update budget? Benchmarks rank methods by who finished first. They cannot tell you why one method got there and another stalled, or whether a winner at small scale will still win at a larger one. I wanted to watch the training step itself, where gradient mass concentrates as the policy moves, when the importance ratio drifts past the band the clipped surrogate tracks, at what step entropy collapses past recovery, when privileged information produces learnable rather than merely correct rollouts, and when a policy has collapsed a useful candidate pool into one scalar compromise. These signals are the first casualty of scale. Distributed rollouts and large batches average them out, and you do not see a failure until several days of compute have been wasted on it. A single GPU with toy models and short horizons brings the signals back into view. The rest of this setup is built to keep them readable.

How I think about the space of choices

Post-training RL papers usually organize themselves by objective family: PPO, GRPO, DPO, RLVR, rejection sampling, distillation. After spending enough time reading them, that framing started to feel too coarse for the decisions I actually had to make. Two methods inside the same family can disagree about the choice that determines the run, and two methods from very different lineages can agree on it. The framing that has been more useful to me sits a level below the objective.

Given a fixed budget of online rollouts, learner updates, and KL distance from the reference policy, which samples should receive gradient weight, at what granularity, and for what downstream use of the candidate pool?

The question splits into seven axes. The order matters. The first two ask how to use the reward signal you already collected. The third asks when that signal is still relevant to the current policy. The fourth asks when the reward function itself is trustworthy. The fifth asks how the optimizer responds to the choices upstream of it. The sixth asks how privileged information should help find rollouts before the student can find them on-policy. The seventh steps back and asks whether training should produce one answer at all, or a searchable set of alternatives. Each axis is also a failure mode I can reproduce cheaply, which is the whole reason I can stare at the gradient when something goes wrong.

Influence allocation is the first. Every rollout arrives with an advantage, a surprisal, and a sampled action. Turning those three signals into a per-sample weight produces structurally different gradients, not rescalings of the same direction. Uniform weighting anchors the axis as the no-effect baseline. Advantage-only weighting is the standard policy-gradient choice. The advantage * surprisal weighting that the delight family (DG, Kondo) uses is a third bet, and on noisy tasks the surprisal factor acts as a soft confidence weight that ordinary policy gradient does not have. GRPO sits on the same axis despite the different vocabulary, because its group-normalized advantage is just another choice of weighting. The one method that escapes the axis entirely is TPO. TPO does not reweight a sample-driven update at all. It changes what the update target is, treating the rollouts in a group as candidates for a soft target distribution rather than as one sampled action and seven noise reducers. The right comparison for TPO, in this framing, is not "does its weighting beat GRPO's?" but "does its candidate-target construction extract more from the same grouped rollouts than GRPO's normalization step does?".

Credit granularity is the second. Even when the weighting scheme is fixed, an update still has to attribute the reward somewhere, either to the whole trajectory, to a branch within a tree of rollouts, or to individual tokens. Sparse-reward tasks need a credit story even when the feedback is a single bit, because flattening one bit across an entire trajectory drowns the meaningful tokens in noise from the unmeaningful ones. Dense-reward tasks have the opposite problem. Collapsing positions that carry per-step signal back into a single sequence-level score throws away most of what the reward function was telling you. TPOToken and DGToken are the per-token bets in this sandbox, and masked_reversal is the task built to expose their trade-offs by giving the comparison both scored and unscored positions to read.

Support and coverage is the third, and it is really the question of when logged data is still informative for the current policy. Replay reduces gradient variance by recycling samples across optimizer steps, but every recycled sample comes from a policy that has since moved, and the importance ratios meant to correct for that mismatch are themselves noisy. A rollout collected k steps ago is approximately on-policy only while the policy has not moved much in k steps. The useful capacity of any replay buffer is therefore whatever fits inside that freshness window. Outside the window, the gradient points at a policy that has already moved past where it was collected. ReplayDG and FreshDG are the diagnostics on this axis, and what the runs say is that the controlling variable is the effective age distribution of the buffer, not its capacity.

Reward uncertainty is the fourth, and it becomes load-bearing whenever the reward is a proxy, such as a verifier, a judge, or a process reward model. Any proxy makes mistakes. Some fraction of its high signals are false positives, and some fraction of its low signals are missed credit. The methods in the sandbox each pick a different uncertainty proxy. UncertaintyDG and RewardVarianceDG are gate-based. FilteredDG is a hard threshold. ASPO and R2VPO regularize on ratio variance. They fail differently from each other, but they share a structural property the noise sweep makes visible. Every method that downweights the update when the reward looks uncertain is choosing an operating point on the same true-versus-false-positive trade-off. There is no free filtering on this axis, only different points on the curve.

Optimization geometry is the fifth, meaning clipping, ratio variance, normalization, KL terms, entropy penalties, and the way these interact. The wrong combination can collapse entropy inside a hundred steps while the loss curve still looks healthy, which is exactly the kind of failure a distributed pipeline tends to discover only after several days of compute have been spent on it. The entropy-collapse sweep is the cheapest setting in which to watch this happen, and it points clearly at where in the GRPO recipe the collapse is coming from.

Privileged rollout acquisition is the sixth. On-policy RL and on-policy distillation can hold privileged information such as an answer or verifier trace while still sampling as if blind. Pedagogical RL moves that decision upstream: a privileged self-teacher samples trajectories conditioned on (x, c), but its reward is the product of task success and a student-scored learnability term. That changes the rollout contract. A batch now has two views of the same trajectory, one for the teacher and one for the student, and the student update is a gated imitation step rather than a scalar reward toggle. That is why Pedagogical RL lives in pedagogy_sandbox/ rather than the scalar rl_sandbox/ registry.

Search coverage under vector rewards is the seventh, and VPO sits here. It starts from a different deployment assumption: if test-time search will choose among many candidates, then post-training should not always collapse probability mass onto the single best scalarized answer. When reward decomposes into components, the useful object is a candidate set that covers several reward trade-offs. That changes the rollout contract itself. A VPO rollout contains multiple candidates, each candidate has a reward vector, and the group advantage comes from a set-level score computed over sampled scalarizations. VPO lives in vpo_sandbox/ rather than the scalar rl_sandbox/ registry because that change of contract is not a loss-toggle.

Each method or companion sandbox makes a bet on one or two of these axes. The repo is built so I can compare those bets on tasks small enough to inspect the gradient directly. That shifts the methodological discussion away from "which method scored highest" and toward "which axis of disagreement does this comparison actually exercise?".

What is in here

  • rl_sandbox/ is the toy sandbox itself, with bandit and sequence tasks and a registry of update rules covering PG, GRPO, DG, TPO, and the smaller families that each probe one axis at a time. Its README lays out the method-family menu, the task list, and the experiment matrix I work from when deciding what to run next.
  • opd_sandbox/ is the on-policy distillation appendix. It isolates estimator variance, support truncation, warmup schedules, and teacher-entropy effects on a smaller toy task where each design knob can be moved one at a time. Its README walks the five appendix experiments and what each one isolates.
  • rlm_grpo/ is the one piece that scales beyond a toy: a standalone flow for training small causal LMs (0.5-0.6B) with GRPO under a recursive tree-of-rollouts sampling protocol. The design problem worth solving there is how the root reward should propagate through a tree whose shape the model itself decides at rollout time. Its README covers the credit-propagation rule and its rationale.
  • vpo_sandbox/ is the vector-reward companion. It keeps VPO separate because set-valued rollouts and reward vectors are a different batch contract, not a scalar loss toggle. Its README covers the scoped VPO estimator, the scalar baselines, and the best@k/diversity metrics.
  • pedagogy_sandbox/ is the privileged rollout-acquisition companion. It keeps Pedagogical RL separate because the method has two policies, two prompt views, teacher-side learnability scoring by a frozen student, and student-side surprisal-gated imitation. Its README covers the scoped implementation and baselines.
  • rl_sandbox/analysis/ is the evidence: reproduction commands, result tables, and the figures the README embeds, all reproducible end to end on a single GPU.

Run this first

pip install -e .

python -m rl_sandbox.train --task token_reversal --method TPO \
  --batch_size 96 --group_size 8 --inner_epochs 4 \
  --num_steps 300 --eval_every 20 --num_seeds 3

python -m vpo_sandbox.train --method VPOGRPO \
  --batch_size 128 --group_size 8 --num_candidates 3 \
  --num_steps 300 --eval_every 20 --num_seeds 3

python -m pedagogy_sandbox.train --method PedagogicalRL \
  --batch_size 96 --group_size 8 \
  --num_steps 300 --eval_every 20 --num_seeds 3

The base install covers everything the inferences below rely on. Add [lm-bandit] for the HuggingFace LM bandit task, or [rlm] for the rlm_grpo/ flow.

Regenerate the figures with:

python rl_sandbox/analysis/plot_evidence.py

The reward-noise, replay, partial-credit, dense-correction, and entropy sweeps live in rl_sandbox/analysis/sweep_manifest.md. They run end to end on a single GPU.

Inferences from these runs

Five inferences I take from the scalar sandbox methods after running them through compact three-seed sweeps. The VPO sandbox is a newer vector-reward slice and is documented as a contract first, not folded into these scalar result claims yet. The runs below are short and the batches are small, so I read the absolute numbers as regime checks rather than benchmark claims. The orderings between methods and the qualitative shape of each failure mode have been stable across the seed-to-seed variance and across small changes in horizon, which is the part I trust enough to draw mechanism from. The headlines below state each inference and the mechanism I think drives it. The sections after the headlines walk through the evidence and the prediction the toy makes about how each result would shift at scale.

The inferences map cleanly onto the axes from the previous section. TPO versus GRPO sits on influence allocation. Replay sits on support and coverage. Token-level credit sits on credit granularity. Reward-noise heuristics sit on reward uncertainty. GRPO entropy collapse sits on optimization geometry. Treating them as five readings of the same underlying map, rather than five independent findings, is how the rest of this section is organized.

  1. TPO extracts more from each rollout group than GRPO does. On clean token-reversal, TPO finishes ahead of GRPO by more than the seed-to-seed variance, and the gap is not faster entropy collapse; the entropy diagnostics rule that out. At fixed rollout cost, TPO produces K weighted gradient directions per group, one per candidate, while GRPO produces one direction (the sampled action's, scaled by the standardized advantage). The information-per-step difference shows up as the test-error gap.
  2. Replay's useful band is set by effective sample age, not buffer capacity. Capacity is the variance knob in the replay trade-off. Age is the bias knob. A replay comparison that reports buffer capacity without the age distribution it produces is tuning along one axis of a two-axis trade-off, and the conclusions transfer poorly to setups where the second knob takes a different value.
  3. Token-level credit redistributes existing reward signal but cannot create it. Anywhere the reward function is silent, the advantage is zero, so a per-token gradient contribution vanishes no matter how careful the upstream credit-assignment scheme is. The trap is reaching for finer credit assignment when the bottleneck is reward sparsity. The right tool in that regime is reward shaping, not routing.
  4. Reward-noise heuristics are operating points on a true-versus-false-positive trade-off, not free improvements. Every method that downweights the update when the reward looks uncertain is filtering some true signal as collateral and accepting some false signal as truth. The framing for comparing them is calibration at a specific noise rate, not "which method is most robust" in the abstract.
  5. GRPO's entropy collapse lives in the standardization, upstream of the clip. Within-group standardization scales (R_i - mu) / sigma as 1/sigma when the rollouts in a group mostly agree, which is the common case on an easy task. The PPO clip sits downstream of that amplification, so it bounds the magnitude of each step without constraining the direction the step is being pushed in. The variants that have actually slowed entropy decay in the literature (DrGRPO, DAPO) attack the normalization rather than tightening the clip.

TPO extracts more from each rollout group than GRPO does

The size of the gap on clean token-reversal made me look closer. Final test error sits at 0.24 +/- 0.06 for TPO and 0.35 +/- 0.01 for GRPO, comfortably more than the seed-to-seed variance. Before reading anything into the gap, I checked the most charitable alternative: maybe TPO is winning by collapsing exploration faster and converging on a sharper policy. The entropy diagnostics rule that out. TPO keeps final entropy near DG's and never crosses the 0.1 threshold within 300 steps. GRPO crosses it around step 53. Whatever TPO is doing, it is not buying its lead with reduced exploration.

The lead comes from using more of each rollout group. With group_size=8, GRPO reads the group as a noise-reduction device. It standardizes within the group, computes a clean advantage for the one action that was actually sampled, and applies that advantage to the score function of that action. The other seven rollouts contributed to the standardization but went unused for direction. TPO treats the same eight rollouts differently. It reads them as candidate next-actions, builds a soft target distribution over them weighted by relative rewards, and pulls the policy toward that soft target. So GRPO produces one weighted gradient direction per group, and TPO produces K of them, one per candidate. At fixed rollout cost the K-direction construction uses more of what the rollouts measured, and the information-per-step difference shows up as faster decrease in test error.

Whether the gap survives at scale is the open question, and the toy suggests it could go either way. With a stronger base policy whose action distribution is already concentrated near the right answer, the sampled action carries most of the signal TPO is currently extracting from K candidates, and the gap should close. With richer reward signals (per-token rewards, multi-criterion judges, per-segment scoring), the per-candidate weights themselves carry more structure, and the gap should widen. Either outcome would be informative, because the toy identifies the mechanism whose transferability is being tested rather than just an effect to chase. The cleanest experiment to settle it is a same-task TPO-versus-GRPO comparison with a 1B-class base under fixed rollout cost, and that is where I would put the next budget.

Replay's useful band is set by sample age, not buffer capacity

The replay sweep nearly fooled me. The table reads simply: capacity 5 at delay 4 is fine (test error 0.49-0.50, no entropy collapse), capacity 32 at the same delay is a disaster (test error 1.00 for vanilla ReplayDG, full entropy collapse). "Small buffers fine, big buffers not" fits the numbers, and it is what I almost wrote down. The trouble with that takeaway is that it is not the mechanism, and the next replay-augmented method to come along will be tuned differently enough that the wrong takeaway transfers badly.

Replay is really a variance-versus-bias trade with two separate knobs. On the variance side, recycling samples reduces the variance of the gradient, and the reduction scales with the buffer's effective sample count. On the bias side, every recycled sample comes from a policy that has since moved, and the bias each sample contributes scales with how stale it has become. The two knobs are not independent. Age is itself a function of capacity, freshness decay, and the speed at which the policy is moving. Reporting only capacity, the way most replay comparisons do, is reporting a single-axis tuning curve for a two-axis trade-off.

The capacity-5-at-delay-4 regime turns out to be the boring case for a freshness-aware method. By construction every sample in the buffer is exactly 4 steps old, with no spread to compare across. FreshDG is only marginally more stable than delayed DG in those rows because its age weights have nothing meaningful to compare across. The capacity-32 rows are the stress test that decouples capacity from age and lets the age axis swing. Mean sample age climbs to roughly 16-18 steps, well past the nominal delay, and the gradient ends up averaging over policies the current one no longer resembles. Without freshness decay, the averaged direction is wrong often enough to collapse entropy and drive error to chance. With --replay_age_decay 0.5, the effective sample age halves to around 7.4, the bias drops back into the range importance ratios can absorb, and final error recovers to roughly delayed-DG.

The implication for any future replay comparison is that buffer capacity by itself is the wrong variable to report. The variable that decides whether replay is buying or costing is the effective age distribution the buffer produces in combination with the freshness-decay schedule. That distribution is what one should report alongside the final-error number. Comparisons that report only capacity are doing the right experiment on the wrong axis.

Token-level credit redistributes signal but cannot create it

The masked-reversal experiment was instrumented to ask whether token-level credit improves learning at scored positions without damaging unscored ones. The answer came back half encouraging and half cautionary, and only the cautionary half changes my future choices. On the encouraging side, TPOToken drives the scored suffix to zero error within 300 steps. That is about the cleanest possible per-position routing result anyone could hope for: the credit lands where the reward function placed signal, and the method does not pretend to learn at the positions the reward function declined to score. On the cautionary side, unscored positions sit near chance throughout, and that residual error is not a fixable failure mode. It is a faithful reflection of what the reward function asked the optimizer to do.

The token-level policy gradient itself shows why. Any per-token update has the form grad_theta log pi(a_t | s_t) * A_t. At any position where the reward function places no signal, A_t is zero, and zero advantage multiplies the score function back to zero no matter how clever the upstream credit-assignment scheme was. No token-level method can manufacture gradient at positions the reward function declined to score. Token-level credit redistributes whatever reward signal already exists across positions, concentrating it where the reward function placed it. It does not source new signal. The distinction between routing reward (credit assignment) and sourcing reward (reward shaping) is worth keeping in mind when picking which tool to reach for.

Here is the trap an outside reader is most likely to walk into. Look at the masked-reversal scored-suffix accuracy in isolation and it is tempting to conclude that token-level credit "works", then reach for it on a different task where the reward signal is sparse rather than partially located. That conclusion picks the wrong tool. A sparse-reward task needs reward shaping, not finer credit assignment, because the bottleneck there is whether the reward function carries usable signal at all, not how the signal gets allocated across positions. The unscored column of the masked-reversal table exposes the distinction cleanly, and I would put it first in any future comparison of token-level methods.

Reward-noise heuristics are operating points, not free improvements

After the reward-noise sweep I stopped using the phrase "noise-robust" and started saying something more honest. Every method in the sweep downweights or drops updates when some uncertainty proxy fires. They are all making the same structural choice: an operating point on the trade-off between filtering out true reward signal as collateral and accepting false reward signal as truth. At a fixed noise rate (0.2 in this sweep), no proxy that operates without ground truth can avoid paying on one side of that trade.

FilteredDG was the clearest demonstration of how much the proxy can confound the comparison. With its default threshold of 0.5 and ungrouped runs, its uncertainty signal is computed at the batch level. The threshold sees the same value across the whole batch and ends up acting as all-or-nothing thresholding rather than per-sample filtering. At threshold 0.5 it accepts essentially every batch and reproduces the no-filter baseline. Drop the threshold to 0.2 and it rejects essentially every batch and learns nothing. Test error climbs from 0.38 to 0.65 with much wider variance, which is what one would expect from an optimizer that has almost no gradient signal. The proxy's granularity, not the threshold tuning, made the method behave as a single-bit gate. The failure was the proxy, not filtering.

ASPO and R2VPO are sturdier under false-positive rare-token noise, with final test error within a hair of DG's, but they pay for it with very low final entropy (0.15 and 0.18 against DG's 0.96). The conservative gates UncertaintyDG and RewardVarianceDG preserve entropy near DG's level and pay for it in slower learning instead. Reading the entropy column alongside the test-error column reveals the trade-off the test-error column alone hides. Every "robustness gain" in this sweep is being paid for in either exploration or speed. Any comparison that reports the gain without the cost has reported half the operating point.

So "noise robustness" turns out to be the wrong unit on this task. The right unit is calibration. At a given noise rate, how much true signal are you willing to sacrifice to keep out how much false signal? Different methods sit at different points on that curve, and a fair comparison plots them on the curve rather than picking a winner in the abstract.

GRPO's entropy collapse lives in the standardization, upstream of the clip

Of all the diagnostics in the sandbox, the one that has shifted my thinking the most is the entropy sweep on GRPO. I started with the framing the PPO heritage encourages: GRPO collapses entropy because the clipped surrogate lets the policy ratio swing too far per step, so the fix is to tighten the clip. The data does not support that framing. DGEntropyGuard, which acts at the sampled-action probability level downstream of the same standardization, barely moves the entropy needle. The variants in the literature that have actually slowed GRPO's entropy decay (DrGRPO, DAPO) get their traction from attacking the standardization rather than tightening the clip.

The mechanism lives in the advantage formula. GRPO standardizes rewards within each rollout group, computing A_i = (R_i - mu) / sigma. Standardization is a reasonable noise-reduction step when sigma is on the same order as the true reward spread. On an easy task it is not. Most rollouts in a group mostly agree on the right answer, sigma shrinks toward zero, and the standardized advantage scales as 1/sigma regardless of how small the raw reward gap R_i - mu actually was. A reward gap of 0.01 between two near-identical rollouts produces a standardized advantage of order 1 when sigma is also around 0.01. The optimizer cannot tell from A_i alone that the underlying signal was tiny. The amplification is structural rather than pathological. Standardization is doing exactly what it was designed to do; the literature on the GRPO entropy problem has, I think, partly missed how it interacts with PPO-style clipping.

The PPO clip bounds the policy ratio at 1 +/- eps per step. It bounds how far one update can move the policy. It does not bound the advantage that drives the update. At the clip boundary, the clipped surrogate's gradient is still proportional to A_i, and A_i is the quantity the standardization has just amplified. Every update therefore saturates the trust region in the direction of whichever rollout happened to win the group, and the saturation continues for as many steps as it takes for entropy to decay below useful levels. The sweep shows entropy crossing 0.1 by step 53 +/- 12, which is a tighter horizon than the loss curve makes it look. Once entropy is gone, the policy stops exploring and the test-error curve flattens, regardless of what the clip is doing.

Where the fix lives matters because it changes which interventions are worth running. The clip sits downstream of the standardization, so tightening it constrains the magnitude of each step but neither the direction the step is being pushed in nor the amplification that put it there. DGEntropyGuard sits downstream as well, at the sampled-action probability level, which is consistent with why it barely helps. A fix has to act upstream of the clip. DrGRPO does that by dropping the reward-std normalization in favor of group-centering alone, reducing the sigma-driven amplification at its source. DAPO does it by decoupling the clipping range so positive and negative updates use different bounds, attacking the asymmetry rather than the standardization itself. Both target the part of the recipe that produces the broken signal in the first place. Both, on the toy mechanism, should show their advantage by slowing the entropy decay measured here even before any test-error gap opens up.

Predictions worth testing at scale

The toy runs make two predictions I would want to put to larger experiments.

The first is whether TPO's information-per-step advantage holds up at scale. The gap exists in the toy because grouped rollouts contain more candidate-target information than GRPO's normalization step extracts. With a stronger base policy whose action distribution already concentrates near the right answer, the one sampled action carries most of the signal TPO is currently squeezing out of K candidates, and the gap should close. With richer reward signals (per-token rewards, multi-criterion judges, per-segment scoring), TPO's per-candidate weights carry more structure, and the gap should widen. The cleanest experiment is a same-task TPO-versus-GRPO comparison with a 1B-class base under fixed rollout cost. The question reduces to whether the K-candidate extraction is still worth its compute at scale, and either outcome would be informative because the toy identifies the mechanism being tested, not just an effect.

The second is about where the entropy-collapse fix for the GRPO family lives. The toy mechanism says the standardization is the load-bearing part of the entropy decay, with the clip downstream of it. The direct test is an ablation. Hold the PPO clip fixed at its standard 1 +/- eps band, drop only the reward-std normalization (which recovers the DrGRPO arm), and measure whether the entropy decay slows. If the toy mechanism transfers, the slowdown should hold up across base-model sizes. As a control, tightening the clip without touching the normalization should not slow the entropy decay meaningfully, because the clip is not where the broken signal is being generated. If both ablations slow entropy comparably at scale, the toy diagnosis was wrong, and the clip is doing more of the work than the small-scale mechanism suggested. That outcome would itself be useful, because it would localize the disagreement between toy and scale instead of leaving it diffuse.

Evidence plots

Mean and standard error across three seeds. The plots show full trajectories rather than only final bars, so stalls and collapses remain visible.

Influence Reward noise
Clean token-reversal learning curves False-positive reward-noise learning curves
Replay Partial credit
Replay freshness trajectories under delay Masked-reversal scored and unscored trajectories
Dense correction Entropy
Reward-chain dense correction trajectories Entropy and accuracy trajectories

The full numbers live in rl_sandbox/analysis/results_matrix.md. The per-method scope notes, including what each scoped version of a published method does and does not include, are in rl_sandbox/analysis/implementation_scope.md.

Scope

The sandboxes run on a single machine, with CUDA when a GPU is available and CPU otherwise. The methods themselves are scoped to their local batch and task contracts. I keep the part of each published method that decides the update, meaning the update rule, the credit-assignment scheme, and the normalization choices, and drop the distributed-system scaffolding that the toy setup does not need. Where the scoping would change a method's meaning rather than only its scale, for instance running advantage normalization in a regime where the original paper assumes a critic, the trainer rejects the config rather than running a variant that would be misleading to compare. VPO follows the same rule: the local implementation is faithful to the set-reward estimator and explicit about not reproducing the paper's large LM domains. Pedagogical RL follows it too: the local implementation keeps privileged teacher GRPO, product-form spike-aware rewards, and surprisal-gated assimilation, without claiming to reproduce the paper's large-model domains. Distributed rollout, learned critics, and production reward pipelines are out of scope by design. They cost the observability the sandbox is built around, and observability is the part that the small scale buys.

Verification

python -m compileall -q rl_sandbox
python -m compileall -q vpo_sandbox
python -m compileall -q pedagogy_sandbox
python -m rl_sandbox.train --task token_reversal --method DG \
  --batch_size 16 --num_steps 2 --eval_every 1 --num_seeds 1 \
  --output /tmp/rl_sandbox_smoke.csv --verbose false
python -m vpo_sandbox.train --num_steps 2 --eval_every 1 --num_seeds 1 \
  --batch_size 8 --group_size 4 --inner_epochs 1 \
  --output /tmp/vpo_sandbox_smoke.csv --verbose false
python -m pedagogy_sandbox.train --num_steps 2 --eval_every 1 --num_seeds 1 \
  --batch_size 8 --group_size 4 \
  --output /tmp/pedagogy_sandbox_smoke.csv --verbose false

About

A PyTorch sandbox for studying which samples and tokens deserve gradient weight in post-training RL.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages