Skip to content

Migrate VPD to JAX: train + analyze in one framework; retire torch to oracle#560

Open
ocg-goodfire wants to merge 454 commits into
mainfrom
feature/jax
Open

Migrate VPD to JAX: train + analyze in one framework; retire torch to oracle#560
ocg-goodfire wants to merge 454 commits into
mainfrom
feature/jax

Conversation

@ocg-goodfire

@ocg-goodfire ocg-goodfire commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Top-line

Migrates VPD fully to JAX — train and analyze in one framework — and retires the torch stack to a git-tagged oracle. Squash-merges to main once, at the end. Net: ~430 commits, +24k / −50k LOC (mostly deletion).

Key decisions

  1. One framework: JAX. Training and all analysis run in JAX. Torch is the battle-tested oracle, preserved at tags torch-oracle / torch-oracle-npool; JAX conforms to it (SPEC.md is the normative contract, numeric seams default to matching torch, goldens prove it).
  2. Consumers read JAX runs natively (orbax + DecomposedModel via open_jax_run) — harvest, clustering, autointerp, slow/offline eval, app. The JAX→torch export bridge is dead.
  3. App is a read-only viewer. Attribution graphs, circuit-opt / editing, PGD intervention are dropped (recoverable from git). App backend imports zero torch.
  4. Generic model interface. DecomposedModel — ordered sites + pure fns clean_output / site_inputs / masked_output / weight_deltas / masked_site_outputs — generic over input/output/recon-loss with [B,T,d] as the fixed waist.
  5. Recon unified as plan × mask-source strategy (make_plan + chunking helpers); loss is KL on final logits. Hidden-acts recon is a separate eval diagnostic over the masked_site_outputs seam (amends SPEC S31), not a recon-grid variant.
  6. Strict bar: no # type: ignore / Any / cast without sign-off; make check-jax gated in pre-commit; fail-fast, types-first.

Status — functionally complete; green

JAX trainer; all consumer ports; read-only app; dropped-feature deletion; DecomposedLM → DecomposedModel rename; hidden-acts eval port; llama8b loader; type-debt → 0 + pre-commit gate; a code-review + fix pass (#854/#855/#856) and a first dead-trainer deletion (#857). Suites green: ~415 lab + ~166 jax at 1 and 4 devices; make type / check-jax clean; torch↔JAX per-term equivalence + stacked-parity trajectory goldens pass bit-unmodified; validated end-to-end on SimpleMLP pile run p-761bc061.

Remaining before main-merge (each gated — see commit history / memory)

  • Retire torch offline-eval (offline_eval.py / pd-offline-eval / jsp-export; rewire _submit_offline_evaljsp-slow-eval). Gated on parity-validating JAX slow-eval vs torch on a real llama8b run — no current-format llama8b run is loadable, so this needs a fresh run.
  • Bulk param_decomp/ core deletion (metrics/ tree, train_step, …). Gated on the above (the live offline-eval path still imports them). Bridge/capstone surface stays.
  • Harvest accumulator → numpy (last torch in harvest). Gated on the async_eval (in-loop autointerp) decision.
  • Capstone: torch→jax run adapter so old torch runs load — the deliberate final step.
  • Scope call: TMS / ResidMLP / vendored / pretrain still torch (separate domains; likely not this PR).
  • Surfaced: PGD scope-vocabulary convergence (deferred — stored-run compat).

Reading guide

param_decomp_jax/jax_single_pool/SPEC.mdlm.pyrecon.pytrain.py. TRANSITION.md = the settled plan; LOSS_PARITY_DESIGN.md = recon unification. VPD paper rides separately as #562.

🤖 Generated with Claude Code

ocg-goodfire and others added 30 commits June 9, 2026 11:50
…-pool

jax stage12/13 ran the per-site layerwise stoch (6 forwards) while the torch 2-pool
recon_plan uses subset n_samples=1 (1 joint forward) — confounding the framework
comparison. Switch jax to the joint estimator so the 2x2's stoch load is apples-to-apples.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ime HF dep)

prestage_tokenized.py tokenizes a portion of a HF text dataset (fineweb sample/350BT)
into local int32 seq-2048 parquet shards so training reads a local Arrow dataset instead
of streaming/tokenizing from HF at run time. Removes the N-rank HF dataset thunderherd
that stalls 80-GPU startup (the build_two_world straggler) + per-rank tokenization cost.

Scavenge-safe: atomic shard writes (.tmp + rename), idempotent skip-existing resume, and
SLURM-array fan-out (task t -> file indices t::num_tasks, shards by global index, disjoint).
~366 files ~= 256B tokens ~= 512GB (int32, ~2x parquet compression). Loader: dataset_name=
parquet, data_files=<dir>/*.parquet, column_name=input_ids, is_tokenized=true, streaming=false.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
PD_COMPILE_SINGLE_POOL=1 compiles the core ComponentModel's masked forward, to test
whether the single-pool path can get the pooled paths' compile win (it has zero compile
wiring today -> torch 1-pool runs eager at 610ms). Env-gated probe; wire into
RuntimeConfig if it proves out.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
probe_reshard.py  - cross-sub-mesh jax.reshard latency (sub-mesh viability)
probe_zero.py     - full-mesh ZeRO sharding: resident-memory freed + gather-on-use forward
probe_hetero.py   - axis_index_groups even vs UNEVEN subgroups (single-mesh heterogeneous DDP)
cw_probe.sbatch   - single-process/8-GPU probe launcher (JAX_VENV override)
cw_jax.sbatch     - honor JAX_VENV (cuda12 vs cuda13)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…replicated)

Cartesian (group, dp) mesh: group 0..Nc-1 = chunk groups running each their
own static site-partition's stochastic masked forwards concurrently; group Nc =
main pool (importance-minimality + PPGD adversary recon + PGD source update).
CI computed only on the main pool, broadcast to all groups via psum trick over
the (Nc+1)-way group axis. V/U replicated (compute-parallel only).

8 GPU single-node verified for n_chunks=1 (group=2 x dp=4) and n_chunks=3
(group=4 x dp=2); finite, decreasing losses. 16 GPU multinode blocked by a
pre-existing CUDA_ILLEGAL_ADDRESS that hits the stage13 reference identically.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Consolidate the surviving learnings from the abandoned 2/3-pool work into a
clean single-process-group LM trainer that scales via FSDP2 (memory) +
torch.compile (speed) instead of hand-written pool transport. Pools and the
core DDP Trainer are left behaviourally untouched.

- param_decomp/train_step.py: extract the model-agnostic step pieces
  (run_loss_step, run_eval_pass, scheduled_lrs, EvalLoop, sigterm, metric-ctx)
  out of optimize.py; core Trainer.run now composes them (no behaviour change).
- param_decomp_lab/fsdp/: new subsystem
  - component_adapter.py: present the vendored LMComponentModel through the core
    ComponentModel surface metrics expect (batch is idx; cache_type mapping).
  - config.py: FsdpRuntimeConfig (compile_model/compile_ci_fn/checkpoint_blocks/
    shard_frozen_target).
  - checkpoint.py + consolidate.py: on-loop sharded DCP save of trainable-only
    state; off-loop async consolidation -> model_<step>.pth / training_<step>.pth.
  - trainer.py: FsdpLMTrainer (FSDP2 fully_shard + compile + residual-start +
    activation checkpointing); resume-in-place from DCP shards.
  - sdpa_strict.py / grad_clip.py: cribbed from pd-nano-jax.
- experiments/lm/fsdp_run.py + fsdp_async_eval.py: composition root, pd-lm-fsdp
  entry point, SLURM requeue + resume-in-place, async consolidate+slow-eval.

torch 2.11 (FSDP2 + DCP). make check green; fsdp unit tests pass. Runtime
distributed smokes still pending (de-risk phase).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Single-pool Parameter Decomposition training step in JAX, the research
counterpart to the torch FSDP single-pool path. Reuses nano_pd_jax primitives
(leaky-hard-sigmoid custom_vjp) and mirrors the torch PPGD semantics from
param_decomp/metrics/persistent_pgd_state.py.

- forward.py: site-local masked decomposed forward (layerwise recon), with the
  weight-delta channel gating the residual per position.
- losses.py: the four VPD losses (faith, imp-min, stochastic recon, ppgd recon)
  over a stacked-site Decomposition; mask = ci + (1-ci)*source.
- scopes.py: PPGD source scopes (single/broadcast/repeat/per-batch-per-position).
- pgd.py: functional persistent adversary -- sources + Adam moments carried in
  state; n_warmup lax.scan ascent + a final post-update ascent.
- step.py: the whole step as one jax.jit fn -- fused multi-argnums grad over
  (V/U, CI), two functional Adam optimizers, frozen-target grad zeroed.
- experiments/toy_stacked_sites.py: CPU smoke (correct adversary signature).
- tests: 11 pure-fn tests incl. bit-exact PGD-scan-vs-python-loop. All green.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… invariance

The SPMD-collapse distribution layer (the FSDP analog): batch sharded P('dp'),
params + PGD sources replicated, jax.jit inserts every collective.

- sharding.py: dp_mesh / replicate / shard_leading / shard_batch + SLURM bring-up.
  shard_batch uses make_array_from_process_local_data so it's correct for BOTH
  topologies (single-process-many-devices and multi-process-1-device-each).
- experiments/distributed_stacked_sites.py: the single-pool step sharded over a
  mesh; correctness = fixed-global-batch trajectory must be GPU-count-invariant.
- tests/test_sharding.py: shard_batch preserves the full global array + the
  non-divisible-batch assertion. Pass at 1 and 4 simulated devices.

RESULT: bit-identical loss trajectories at 1 vs 4 simulated CPU devices -- the
full VPD+PGD step is GPU-count-invariant under GSPMD with zero manual
collectives, persistent adversary included. NOTES records the diagnosis.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- experiments/transformer_qkv.py: decompose the square q/k/v sites of a real
  nano_pd_jax.TinyTransformer (real pre-weight acts) -- proves the step is
  model-agnostic. faith 0.043->0.005, stoch 0.30->0.064 on real attn projections.
- checkpoint.py + tests/test_checkpoint.py: flat-pytree save/resume of TrainState;
  resume continues the trajectory bit-identically (the persistent adversary
  survives). No torch-style state_dict plumbing -- the whole adversary state is
  in the pytree.
- README.md + CLAUDE.md: file map, run instructions, design + invariants.
- NOTES: residual-start analysis (it's implicit in the layerwise-recon factoring
  -- no masked re-forward through the frozen prefix), checkpoint/real-model TODOs
  resolved.

14 tests green at 1 and 4 simulated devices; basedpyright + ruff clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Replace the `cast(ComponentModel, cast(object, x))` double-cast idiom with a
structural Protocol the concrete ComponentModel, the FSDP adapter, and the
vendored LMComponentModel all satisfy.

- param_decomp/component_model.py: add ComponentModelProtocol — the minimal
  surface metrics + train_step consume (__call__ with 3-way cache_type overloads
  so the 4-way concrete model and 3-way adapter both satisfy it via param
  contravariance; forward_with_output_acts; calc_causal_importances;
  calc_weight_deltas; module_to_c; target_module_paths; components).
- Retype MetricContext.model, Metric.model/Metric.bind, instantiate_metrics, the
  train_step helpers, tie_component_weights and run_faithfulness_warmup to the
  Protocol; drop the double-casts in fsdp/trainer.py and three_pool/*.
- Eval metrics that genuinely need the concrete ComponentModel (attn-patterns,
  autointerp) narrow via `assert isinstance(model, ComponentModel)` in an
  overridden bind — not a cast.
- configs.py: make RuntimeConfig.dp's description strategy-agnostic (world size,
  replicated under DDP / sharded under FSDP); NOT renamed (preserves checkpoint
  experiment_config.yaml compat). FsdpRuntimeConfig documents the FSDP meaning.

make check green; 186 tests pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
GPT2 block-6 MLP (c_fc/c_proj) on vendored GPT2 for the fast correctness smoke;
Llama-8B L18 MLP (gate/up/down_proj) adapted from the proven 2-pool config to the
FSDP schema (losses struct -> loss_metrics list, topology -> FsdpRuntimeConfig).
All 4 losses; checkpoint_blocks off (PPGD); compile + shard_frozen_target on.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add mem/peak_gb_per_rank to the FSDP train-log (torch.cuda.max_memory_allocated)
so the shard_frozen_target memory A/B is directly observable. Add two short Llama-8B
L18 memory-A/B configs differing only in shard_frozen_target (true vs false).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…gathers

residual_at() runs the embedding outside the root model.forward, so the root
fully_shard pre-hook never fires; under shard_frozen_target the embedding weight is a
sharded DTensor and meets plain-tensor token ids -> 'aten.embedding got mixed
torch.Tensor and DTensor'. Give each embedding its own fully_shard group so it
all-gathers on every call (incl. residual_at). GPT-2 co-shards the tied wte/lm_head;
Llama embed_tokens shards standalone.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…arded

The prior embed-group fix was wrong twice over (validated on a 2-GPU repro): GPT2's
tied wte/lm_head can't co-shard as one fully_shard group ('requires a single root
module'), and a standalone-sharded embedding trips FSDP2's single-root lazy-init when
residual_at enters it before the parent forward. Root-sharding the embedding (the
original code) instead raises the DTensor-mixed error because residual_at bypasses the
root hook.

Fix: shard only the transformer blocks (which hold all trainable V/U + in-block frozen
target weights) and the ci_fn; drop the fully_shard over lm.model / lm so the frozen
embedding / final-norm / lm_head stay replicated. residual_at then sees a plain-tensor
embedding weight. The ~1GB replicated embedding is negligible vs the 32 sharded blocks.
Repro: embed_replicated mode passes; blocks_only/embed_group/no_root all fail.

Also fix the GPT2 smoke target pattern: the vendored GPT2 MLP second linear is
down_proj, not c_proj.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…tion map)

The apollo openwebtext dataset is tokenized at seq 1024; max_seq_len 512 triggered
data.py's per-example truncation map over all 8.8M examples (~37min). Use native 1024
(map skipped) and bump CI attn max_len to match.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The components' V/U are sharded by their transformer block's fully_shard, so accessed
outside a forward (calc_weight_deltas, called by faithfulness warmup + every train step)
they are DTensors; the frozen target_weight is a plain replicated buffer (GPT2) or a
sharded param (Llama under shard_frozen_target). The vendored calc_weight_deltas subtracts
them directly -> 'aten.sub got mixed torch.Tensor and DTensor'. Override it in the FSDP
adapter to gather each operand via DTensor.full_tensor() before subtracting. The full
per-site weight is materialised anyway (faithfulness needs the whole delta), so no
asymptotic memory cost. Validated on a 2-GPU GPT2 repro: calc_weight_deltas + faithfulness
loss + backward all succeed.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
residual-start runs the embedding on the raw batch via use_cached_residual, opened
before run_loss_step does its own move_batch_to_device; the loader returns CPU tensors,
so the embedding index_select mixed cpu ids with the cuda weight. Move the batch to
device in the trainer loop before use_cached_residual (run_loss_step's move is then a
cheap no-op). Mirrors the 3-pool path, which moves batch_local before use_cached_residual.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…exist

PPGD's before_backward runs torch.autograd.grad(retain_graph=True) through the compiled
region; AOTAutograd's donated-buffer optimization asserts retain_graph=False on every
backward of a compiled graph, raising 'compiled with non-empty donated buffers which
requires create_graph=False and retain_graph=False'. Disable
torch._functorch.config.donated_buffer when a PPGD loss is configured and compile is on.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…tter mismatch

Diagnostic: the compiled GPT2 backward hit 'size of tensor a (512) must match b (2048)'
in FSDP foreach_reduce. This config disables compile to isolate whether the reduce-scatter
mismatch is a compile+FSDP2 grad interaction or a pure sharding issue.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The FSDP2 main-step blocker was a gradient-placement mismatch in calc_weight_deltas.
The prior `full_tensor()` (= redistribute(Replicate).to_local()) detached the DTensor
placement, so the faithfulness path emitted Shard(1)/Replicate grads for the components'
V/U while the recon forward emits Shard(0); accumulating both into one param made FSDP2's
gradient reduce-scatter fail (C vs C/world dim mismatch).

Fix: redistribute V and U to Replicate (KEEPING them DTensors — no .to_local()) before the
einsum, and subtract a Replicate target. redistribute is differentiable and its backward
returns the grad in the source Shard(0) placement, so the faithfulness grads now match the
recon path's. Verified on a 2-GPU repro: V.grad/U.grad both Shard(0), bit-matching the
canonical recon placement (vs Shard(1)/Replicate for the old path).

Also exclude **/_scratch/** from pyright + ruff (throwaway multi-GPU repro dir).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ocg-goodfire and others added 4 commits June 16, 2026 19:44
… wandb short-name table) (#849)

test_validate_wrapper_rejects_unexpected_key: the validator's assertion
message was reworded ("keys must include ...") but the test's pytest.raises
regex still matched the old "keys must be". Stale test; update the regex.
The validation itself is correct.

test_config_short_name_table_matches_metric_registry: METRIC_SHORT_NAMES
carried a stale "PersistentPGDReconSubsetLoss" entry whose metric class was
removed in #824 (PPGD+subset removal). Real drift in the torch-free table;
drop the dead entry so it matches the registry-derived names.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…h) (#850)

Clustering already reads JAX runs natively via run_worker_jax.py. Shed torch
from the whole subsystem:

- Delete torch-worker path superseded by run_worker_jax + run_merge:
  run_harvest.py, run_clustering.py, run_pipeline.py, calc_distances.py,
  plotting/ (only used by those), clustering_run_config.py, and the orphaned
  configs/ for the deleted ensemble pipeline.
- Convert the membership accumulator + merge stack to numpy: MembershipBuilder,
  flatten_lm_activations, GroupMerge/BatchedGroupMerge, compute_costs, merge,
  sample_membership coactivation, merge_pair_samplers (numpy + stdlib random),
  merge_history, matching_dist, the dense test-oracle path, and the type aliases.
- Drop the dense `preview` mechanism (only the deleted run_clustering consumed it).
- run_worker_jax feeds numpy directly and imports zero torch.
- Fail-fast fix surfaced by numpy strictness: MergeHistory stored group/pair
  indices as int16, which overflows above 32767 components (torch silently
  truncated). Now int32.
- Remove dead pd-clustering / pd-cluster-harvest console scripts (keep
  pd-cluster-merge); harvest is `python -m ...run_worker_jax`.

Acceptance: grep for torch in param_decomp_lab/clustering is empty. Clustering
tests (58) green, make type clean, pre-commit passes. Validated e2e on JAX run
p-761bc061: harvest -> membership snapshot -> merge -> history.zip.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
… builder) (#851)

Replace the NotImplementedError TargetConfig branch in load_run.build_target
with the llama8b build mirroring run.py::main: llama_decomposed_lm over
llama_site_specs, replicate_target(load_target_from_hf(...)) for the frozen
suffix, device_put(load_prefix_from_hf(...), P()) for the prefix, prefix_residual
as the residual fn, and vocab_size from llama31_8b_config(). Returns the same
(lm, target, prefix, prefix_residual_fn, vocab_size) tuple as the SimpleMLP
branch, so open_jax_run / LoadedJaxRun.forward stay target-agnostic.

Promote the builder to public build_target and route slow_eval through it,
deleting its private SimpleMLP-only _build_simple_mlp. One loader now covers
both targets; the SimpleMLP path's restore/forward behavior is unchanged.

Unblocks jsp-slow-eval (and the app) on llama8b runs.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…amends SPEC S31) (#853)

Port CIHiddenActsReconLoss + StochasticHiddenActsReconLoss to JAX as standalone
offline slow-eval metrics — explicitly NOT recon-grid terms. The parameterized
recon loss stays KL-on-final-logits only (SPEC §2.3-2.5); these are eval diagnostics.

Deliberate, Oli-approved amendment of SPEC S31, which had marked the hidden-acts seam
"keep-on-bridge / refused". They now have a native JAX eval path.

- Add a fifth DecomposedModel fn `masked_site_outputs(...) -> dict[site, (B,T,d_out)]`,
  factored out of `masked_output` via a shared `_run_masked_suffix` + a per-site `collect`
  dict in `_masked_site_out` (per-site output is an intermediate of the masked forward —
  no logic duplicated). Implemented for both targets (llama8b, llama_simple_mlp).
- Clean (target) per-site output = frozen `x @ W`, obtained from the same seam by routing
  FALSE everywhere (`_site_out`'s frozen branch) — no separate frozen-W accessor.
- New `hidden_acts_eval.py`: per-site MSE(masked, clean) reduction="sum", host-accumulated
  as (Sum sum_mse, Sum n_elements) (token-weighted, exact under micro-batching), divided
  once at compute. Log keys mirror torch: "<ClassName>/<site>" + combined "<ClassName>".
  Masked+clean in COMPUTE_DT (bf16), MSE reduction fp32. CIHiddenActsReconLoss =
  deterministic lower_leaky CI, no delta, one forward; StochasticHiddenActsReconLoss =
  n_mask_samples stochastic CI draws + deltas. Stochastic draws are not seed-aligned to
  torch (exact bitwise parity impossible there, expected).
- Wire both into jsp-slow-eval via SlowEvalOutput(figures, hidden_acts); scalars logged
  under slow_eval/loss/<key> + written to hidden_acts_recon_step<N>.json.
- Per-site unit test on SimpleMLP; SPEC S31 amended; CLAUDE.md/README updated.

Validated: full JAX suite green at 1 and 4 sim devices (165 each); check-jax clean;
equivalence/stacked-parity/simple_mlp_equivalence goldens untouched. Ran on p-761bc061
(24 sites): finite, non-negative per-site, combined within [min,max].

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
@ocg-goodfire ocg-goodfire changed the title JAX single-pool VPD trainer + shared config package Migrate VPD to JAX: train + analyze in one framework; retire torch to oracle Jun 16, 2026
ocg-goodfire and others added 25 commits June 16, 2026 22:38
…ew pass) (#854)

Latent bug fix:
- clustering/sample_membership.py: coactivation matmul cast int32 -> int64
  (X.T @ X silently overflows above ~2.1B samples; same class as the int16
  MergeHistory bug). Downcast to float32 right after is unchanged.

Type hygiene:
- app/backend/state.py: DatasetSearchState now holds typed
  DatasetSearchResult / DatasetSearchMetadata (moved to schemas.py) instead of
  list[dict[str,Any]] / dict[str,Any]; router builds typed objects directly.
- clustering/merge_history.py: meta: dict[str,Any]|None -> typed
  MergeHistoryMeta(origin_path); removes the defensive isinstance(.,Path) scan.

Dead code / honesty:
- app/backend/dependencies.py: drop no-op try/except around StateManager.get().
- harvest/scripts/run_worker_jax.py: one host copy in _to_torch (np.array);
  single-source --batch_size / --activation_threshold defaults from the configs;
  activation_threshold read from method_config (no threaded duplicate).
- harvest/reservoir.py: delete stale TODO on a load-bearing pad mask.
- clustering/merge_history.py: delete dead TODOs + commented-out assert block.
- clustering/merge.py: warnings.warn for expected terminus -> logger.info.

Perf nit:
- app/backend/inference.py: gather next-token probs in JAX before host transfer
  (avoid materializing the full (T, vocab) output_probs to read ~T floats).

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Behavior-preserving cleanups from a review pass over the JAX trainer core.
Equivalence + stacked-parity goldens untouched and green.

- load_run.py build_target: the two return Anys -> AnyFrozenTarget / AnyPrefix
  (new shared target_aliases.py, imported by both build_target and run.py::main,
  which dispatch on the identical match cfg.target). prefix_residual_fn keeps its
  Callable[[Any, Any], Array] (legit generic input edge).
- load_run.py: inner forward's ci_fn Any -> CIFn; LoadedJaxRun._forward Any ->
  the honest Callable.
- hidden_acts_eval.py / slow_eval.py: sampling: str -> SamplingType Literal,
  threaded through compute_hidden_acts_metrics; dropped the now-redundant runtime
  assert sampling in (...). recon.py StochasticSources.sampling + build_recon_terms
  also use SamplingType; build_recon_terms loss_metrics narrowed tuple|list -> tuple.
- slow_eval.py: max(n_positions, 1) -> assert n_positions > 0.
- recon.py: instance_key -> assert_unique_instance_key (reflects the uniqueness assert).
- run.py: assert int(state.step) == ckpt_step after restore (codifies S22);
  _ensure_global made PEP-695 generic [T](tree, mesh) -> T, dropping the call-site
  isinstance assert.
- config.py: _resolve_target uses the existing AnyTargetConfig alias (was a quoted
  forward-ref).

Skipped (behavior-affecting / non-trivial, noted in review): losses.py
annealed_pnorm span assert (would break the documented start_frac==end_frac==1.0
no-op config); build_experiment_config wandb default drop (several test call sites
rely on the defaults).

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…alCiConfig (review pass) (#856)

Schema-package code-review fixes (all in param_decomp_config/, plus consumer updates):

Fail-fast:
- EvalConfig: model_validator enforcing the documented slow_every % every == 0
- PGDConfig: step_size -> PositiveFloat, n_steps -> NonNegativeInt (zero = the
  no-ascent fresh-sample baseline, which the JAX adversary test exercises)

Narrow types -> Positive*:
- eval_metrics: n_heads, AutointerpLabels k/max_examples/context_tokens_per_side/
  seq_len -> PositiveInt; CIHistograms n_batches_accum -> PositiveInt | None
- autointerp: strategy max_examples / label_max_words -> PositiveInt

Encode invariants in types:
- GlobalCiConfig: replaced the two-independently-optional-fields-gated-by-a-validator
  shape with a fn_type-discriminated union of GlobalSharedMlpCiConfig (hidden_dims) vs
  GlobalSharedTransformerCiFnConfig (transformer cfg), nested under the mode
  discriminator. Impossible states unrepresentable; validator deleted. A
  BeforeValidator drops the inactive hidden_dims/simple_transformer_ci_cfg None keys
  that old single-class configs wrote, so stored runs still parse. Readers updated:
  ci_fns.py, component_model_io.py, jax_single_pool/config.py, test fixtures.
- IdentityCIErrorConfig.identity_ci/dense_ci: list[dict[str, str|int]] -> typed
  IdentityCITargetSpec / DenseCITargetSpec; make_target_ci_solution + caller updated
- PermutedCIPlotsConfig / UVPlotsConfig: shared _PermutationPlotsBaseConfig

Honesty: torch-Trainer/.pth docstrings -> orbax ckpts/<step>/ + jsp-train
(pd.py Cadence/keep_last_n_checkpoints/PDConfig, schedule.py, experiment.py).

Surfaced (not forced): the two PGD scope vocabularies (MaskScope string-literal vs
PersistentPGDSourceScope object-union) stay deliberately separate — converging them
would change stored YAML shape; added a cross-reference comment explaining why.

No new type: ignore / Any / cast; param_decomp_config stays torch-free.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…omp/ (#857)

Removes four top-level helpers that the retired torch trainer used and that
nothing live imports after the JAX-trainer pivot:

- grad_clip.py        (cross_pool_clip_grad_norm; n-pool lineage, dead per PARITY_MATRIX)
- sdpa_strict.py      (verify_flash_attention_available)
- _trace.py           (trace / dump_memory_stats)
- phase_timer.py      (PhaseTimer / phase / format_phase_table)

Each verified: zero importers across param_decomp/, param_decomp_lab/,
param_decomp_jax/, configs, and tests (symbol- and module-path searches; no
dynamic imports; bare __init__). The metrics/ tree, train_step.py,
training_state.py, distributed.py, torch_helpers.py, log.py and the bridge
substrate stay — all live via dispatch / offline-eval / eval_metrics.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…pe shape contract (#860)

Generalize the trainer core from a fixed (B,T) activation waist to a generic
[*leading, d] waist (masks/CI [*leading, C]), where leading = (batch,) + named
position axes. CI is independent over every leading axis (broadcast_ci dropped),
so there is no per-axis CI semantics — only axis NAMES.

- DecomposedModel gains leading_axes; CIFn gains expects_axes; init_train_state
  asserts they're equal (early fail) so the CI fn stays per-domain (RoPE over
  `sequence`) without the generic loop adapting. ("sequence",) for LM.
- Keystone train.py: leading = residual.shape[:-1], threaded to routing samplers,
  fresh-PGD source init, stochastic delta-mask shapes (leading_shape: tuple[int,...]).
- losses.py reductions: math.prod(shape[:-1]) / axis=tuple(range(ndim-1)).
- recon.py RoutingSampler + bodies: tuple[int,int] -> tuple[int,...].
- adversary.py init_persistent_sources takes leading_shape; init_fresh_pgd_sources
  spells c/bc/bsc over the model's leading. llama8b_sharding builds the per-scope
  leading_shape (sc->(1,T), bsc->(B,T)).
- eval/slow_eval/hidden_acts_eval: same shape[:-1]/prod generalizations.
- lm.py type aliases -> *leading.
- jaxtyping+beartype: @jaxtyped(typechecker=beartype) on the core step,
  masked_forward (residual/masks/routes bind *leading), and the loss fns — the
  waist contract enforced at trace time. New dep beartype==0.22.2.
- AXIS_SEMANTICS_DESIGN.md documents the contract + designed (unimplemented)
  extensions (scope as leading-axis subset; tying as shared vu leaf).

LM byte-identical: equivalence + stacked-parity goldens UNMODIFIED and passing;
only test call-sites adapt to the renamed source-init signature (arrays identical).
Full suite green at 1 and 4 devices; check-jax clean. TMS not ported; scope stays
LM's c/sc/bsc.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…g,d] core (#861)

Third decomposition target and the first positionless one (leading_axes=()), the
proof the generic waist fits a non-LM model. The core needed ZERO change — only an
added target (tms.py), CI-fn arch (ci_fn_mlp.py: layerwise per-site MLP,
expects_axes=()), and config dispatch.

- param_decomp_config/tms.py: torch-free TMSExperimentConfig/TMSTargetConfig/TMSDataConfig.
- jax tms.py: Anthropic toy (tied weights) with an UNTIED 2-site decomposition,
  recon_loss_fn=tms_mse (post-ReLU MSE, not KL), no prefix, pretrained in-process.
- ci_fn_mlp.py: LayerwiseMLPCIFn (vector-input per-site MLP, fits the ci_fn(site_inputs)
  waist; recovers identity in the non-superposed regime — validated in test_tms.py).
- config dispatch: _is_tms_schema routes to build_tms_experiment_config; TMS*Config join
  AnyTargetConfig/AnyDataConfig; run.py builds+pretrains target and calls train_tms.
- harvest/clustering run_worker_jax: assert isinstance(data, DataConfig) before reading
  .dir/.seq_len — AnyDataConfig now includes TMSDataConfig; these workers are LM-only
  (#12 localize-and-assert; caught by the torch-side make type).

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…erpLabels (#859)

* shed(harvest): numpy accumulator + delete dead torch worker & AutointerpLabels

Make the harvest subsystem fully torch-free.

Convert the accumulator and its shared state to NumPy:
- Harvester / ActivationExamplesReservoir / sampling / storage / analysis /
  HarvestBatch are NumPy. Counts, component co-occurrence, and input token
  cooccurrence/marginals use int64 (overflow-safe over a long stream);
  probability-mass accumulators and activation sums use float64.
- Input token cooccurrence accumulates via np.add.at over np.nonzero(firing)
  entries (parallel 1-D integer indices) instead of torch scatter_add_.
- Reservoir sampling RNG is an explicit np.random.Generator threaded from the
  Harvester (was a torch.Generator | None default).
- Tensor artefacts move from torch.save .pt to np.savez .npz
  (component_correlations.npz, token_stats.npz, worker_states/worker_*.npz).
  Harvest output is regenerable intermediate, so no legacy shim.

run_worker_jax.py is now the only worker and drops torch: np.asarray off the
JAX forward straight into the NumPy Harvester. It gains --rank/--world_size
sharding (via ShardServer process_index/process_count) + worker_states save +
get_command, so the SLURM launcher (run_slurm.py) drives it as a rank-sharded
array + dependent merge, exactly as the old torch worker did.

Delete dead torch code:
- harvest/scripts/run_worker.py + harvest/harvest_fn/* (torch worker path,
  superseded by the JAX worker). pipeline.harvest() went with it; merge_harvest
  stays (now NumPy).
- eval_metrics/autointerp_labels.py + its config, registry entry, wandb
  short-name, and test. Verified dead: in zero eval configs, its async_eval
  driver does not exist, reachable only by its own definition + registry.

Acceptance: grep for torch imports under param_decomp_lab/harvest is empty.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* chore(harvest): drop deleted harvest_fn from lab packages list

harvest_fn was deleted with the dead torch worker; the setuptools packages list
still referenced it.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…atch (#862)

JAX counterparts of torch CIMaskedAttnPatternsReconLoss / StochasticAttnPatternsReconLoss,
run IN-LOOP over the eval pass's residuals (a few forwards, not minutes-slow), removed
from OFFLINE_EVAL_METRIC_TYPES.

- attn_patterns_eval.py: KL(target_pattern ‖ masked_pattern) per attention layer over
  every (B,H,T_query) distribution; matches torch F.kl_div(masked.clamp(1e-12).log(),
  target, reduction="sum"), divided once by Σ n_distributions. Combined + per-q_proj-site
  log keys mirror torch.
- Q/K via the existing masked_site_outputs seam (clean = all-false routes → frozen x@W,
  the hidden_acts_eval trick); DecomposedModel gains NO new method.
- attn_pattern_for(target): the RoPE/GQA/head-dim recipe dispatched on the concrete frozen
  target, reusing the target's OWN vendored RoPE + FrozenAttn config — NOT a method on
  DecomposedModel (the universal interface stays LM-agnostic). Non-attention target raises.
- config: AttnPatternsEvalConfig (ci_masked/stochastic flags); _assert_separate_qk_attn_paths
  refuses combined c_attn (no JAX target produces a merged-QKV site).
- run.py: build attn steps when configured; accumulate token-weighted KL into eval_record.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…orch models (#863)

These were paper-comparison-only tooling (compare PD against CLTs/transcoders) with
no live consumer in the JAX stack. The torch harvest_fn dispatch they fed was already
removed when harvest went JAX-native (#859); only the adapter + config + vendored-model
surface remained.

Deleted:
- adapters/{clt,transcoder}.py and adapters/_vendor/ (CrossLayerTranscoder,
  BatchTopKTranscoder)
- CLTHarvestConfig / TranscoderHarvestConfig + the DecompositionMethodHarvestConfig
  union (collapsed to ParamDecompHarvestConfig — the sole remaining method)

Kept (PD run-loaders, NOT comparison adapters): adapters/pd.py (PDAdapter),
adapters/jax_pd.py (JaxPDAdapter), adapters/base.py.

Collapsed autointerp DecompositionMethod Literal to "pd" and dropped the clt/transcoder
prompt descriptions (only those adapters ever set the non-pd values). Removed the
adapters._vendor package entry from pyproject.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
The fourth DecomposedModel and second non-LM bundle over the generic
positionless (leading_axes=()) core: the SPD/APD residual-stream toy.
Fixed W_E/W_U embeddings, n_layers MLP blocks, decomposition over the
per-layer mlp_in/mlp_out matrices (untied V/U). W_E is the prefix
(residual = x @ W_E); recon is MSE on the model output (not KL).
Pretrained from scratch in-process on the read-off act_fn(x)+x objective.

Reuses LayerwiseMLPCIFn (no new CI arch) and the identity-CI ground-truth
metric. Core needed ZERO change (train/losses/recon/lm) — only ADDED a
target + config/run dispatch, exactly like TMS (PR #861); ResidMLPTarget
joins AnyFrozenTarget. _is_tms_schema marker moved n_features -> n_hidden
so it stays disjoint from ResidMLP's d_embed marker.

Validation: jax basedpyright clean; test_resid_mlp (14, incl slow e2e
identity recovery) + test_tms + test_config + equivalence goldens green;
torch make type clean; make test 407 passed.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
The last LM eval metric (attn-patterns) moved to JAX in-loop (#862) and the
CLT/transcoder adapters were deleted (#863), so the torch offline-eval bridge
is now dead. Slow/plot eval is JAX-native via `jsp-slow-eval`; parity with
`pd-offline-eval` was validated on llama8b `p-4301638c` (PASS) before this delete.

Deleted (solely torch-offline-eval):
- `param_decomp_lab/experiments/lm/offline_eval.py` (pd-offline-eval entry)
- `param_decomp/train_step.py` (its only live consumer was offline_eval)
- `jax_single_pool/export.py` (jsp-export; harvest/clustering read orbax via
  `load_run.open_jax_run`, never the safetensors export)
- `jax_single_pool/slurm/offline_eval_once.sbatch` (+ the now-empty slurm/ dir)
- `jax_single_pool/tests/test_export*.py` + `tools/export_fixtures/*`
- `torch_helpers.loop_dataloader` (offline_eval-only iterator)

Also: drop `pd-offline-eval` + `jsp-export` console scripts, the
`offline_eval_submission_argv`/`_submit_offline_eval` push-trigger in run.py
(+ its test + `subprocess` import), the stale `verify_export_torch.py`
type-check exclude, and dead torch-bridge phrasing in docstrings/CLAUDE/README.

Kept (live non-eval consumers): the `eval_metrics/` torch Metric impls +
`EVAL_METRIC_CLASSES` (their `wandb_config_dict`/`metric_short_names` plumbing
is reused by the JAX trainer via run_sink), `OFFLINE_EVAL_METRIC_TYPES` (a live
config-validation gate), `bf16_autocast`, `target_aliases.py`, and
`convert_llama_simple_mlp_checkpoint.py` (a pretrain converter, not export).

Validation: make type (0), make test (407 passed/4 skipped), JAX basedpyright
(0), equivalence goldens (12 passed, trajectory bit-identical), make test-jax
(174 passed/1 skipped).

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
#866)

Generalize the localize-and-assert pattern (ci_fn↔target expects_axes in
init_train_state; attn_pattern_for's non-attention raise) to the LM-only
in-loop eval metric constructors, which consumed tokens/vocab/attention
without declaring their leading-axes compatibility.

- eval.make_eval_step (CEandKLLosses / CI_L0 / fresh-PGD probe): reads tokens
  + vocab logits and a (1,1,C+1) source over a (batch, sequence) waist — assert
  lm.leading_axes == ("sequence",) at construction.
- attn_patterns_eval: shared _assert_attention_sequence_axes guard on both
  make_ci_attn_patterns_step / make_stochastic_attn_patterns_step (causal
  (B,H,T,T) maps over a sequence axis), complementing attn_pattern_for's
  per-target dispatch.

Domain-generic metrics (hidden-acts MSE, slow-eval CI plots) correctly stay
unguarded. Data sources already fail loud: train() / train_tms / train_resid_mlp
assert their DataConfig subtype, ShardServer asserts shard width vs seq_len, the
harvest/clustering workers assert isinstance(data, DataConfig).

Additive only; generic core (train/losses/recon/lm) unchanged; LM trajectory
bit-identical (full JAX suite + goldens green). New guard tests assert each
constructor raises against a positionless (leading_axes=()) TMS target.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Superseded by the JAX targets (param_decomp_jax/jax_single_pool/tms.py,
resid_mlp.py). The torch experiment dirs had no external importers or console
scripts — the torch training drivers were already retired.

- Delete param_decomp_lab/experiments/{tms,resid_mlp}/ (models/data/run/train/
  feature_importances + their YAML configs)
- Delete their tests: tests/test_data.py (SparseFeatureDataset),
  tests/test_feature_importances.py (compute_feature_importances)
- Drop the two packages entries from param_decomp_lab/pyproject.toml
- Correct CLAUDE.md / README references: experiments bridge is now LM-only;
  TMS/ResidMLP live as JAX targets

Kept: param_decomp_config/{tms,resid_mlp}.py (torch-free schemas the JAX
trainer reads); param_decomp_lab/toy_models/target_ci.py (live importers in
eval_metrics + tests); experiments/utils.py (live importers in lm/adapters/app).

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
)

Remove `param_decomp_lab/app/` (FastAPI backend + Svelte frontend) to shed
torch/web surface ahead of the JAX-primary merge. This is a deliberate
remove-now-re-add-later decision; a JAX-native viewer is slated to replace it.
The PR body carries the full re-add log.

Kept-consumer fallout handled: harvest and autointerp imported the app's
tokenizer-display helpers, so `AppTokenizer`, `escape_for_display`, and
`delimit_tokens` are relocated to `param_decomp_lab/tokenizer_display.py`
(out of the app, no behavior change). Their test moves to
`tests/test_tokenizer_display.py`.

Also removed: `make app`/`install-app`/`check-app`/`install-all` targets, the
`fastapi`/`uvicorn` lab deps (app-only; orjson kept — shared), the app package
+ package-data entries in `param_decomp_lab/pyproject.toml`, the `app/frontend`
ruff/pyright excludes, and the CI `build-frontend` job. CLAUDE.md pointers and
stale app mentions updated; a temporary-removal note added to the root CLAUDE.md.

Out of scope and intentionally untouched: `investigate/scripts/run_agent.py`
still subprocess-launches the app backend, so `pd-investigate` is broken at
runtime until the app returns (no import/type/test break). See PR re-add log.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…yout (#870)

The frozen C49k clone (jax-l18-C49k-200k, Llama-3.1-8B layer-18 MLP, C=49152)
saved a TrainState whose orbax pytree keys differ from HEAD's in three places.
tools/migrate_c49k_checkpoint.py restores the OLD tree single-device on CPU,
remaps it onto the current layout, and save_states it so a fine-tune can
restore_latest from it. Remap:

  components.{Vg,Ug,Vu,Uu,Vd,Ud} (1,*,*)  ->  components.vu[<site>][0|1] (*,*)
    g->gate_proj, u->up_proj, d->down_proj; V->[0], U->[1]; squeeze leading 1.
  components_opt_state mu/nu: same squeeze; optax count scalars copy through.
  ci_fn / ci_fn_opt_state: identical leaf names -> straight copy.
  sources.<site>  ->  sources.<state_key>.<site>  (state_key PersistentPGDReconLoss)
  sources_adam_state.{m,v}/step_count -> sources_opt_state.<state_key>.{...}
  step: copied (restores as 175000).

The reference TrainState is built abstract (jax.eval_shape) so the full fp32
reference never coexists with the 47 GB restored tree. Verified independently:
clean restore (no structure/shape/dtype mismatch), step==175000, component
shapes + finiteness, sources in [0,1], CI-fn forward finite on a dummy batch.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…m probe (#871)

* feat(jax): l18-26 9-layer chunkwise config + abstract-AOT mem probe

New from-scratch decomposition of Llama-3.1-8B MLP layers 18-26 (27 targets =
9 layers x gate/up/down, C=49152), chunkwise recon (3 chunks of 9 sites,
coeff 1.5), remat on. Derived from the C49k single-layer-18 run.

AOT-probed per-device HBM at 64 B200 (remat on, no buffer donation — matching
the trainer's non-donating @jax.jit step):
  B=128 (2/device): 210.2 GiB  -> over the 180 GiB cap
  B=64  (1/device): 179.0 GiB  -> at the cliff, ~0 headroom
B=128 does not fit, so the config carries B=64 with LRs 4th-root-scaled to
ci_fn 4.2e-5 / components 1.26e-4.

mem_probe.py rewritten to lower+compile on ABSTRACT sharded avals (shapes via
jit(init).lower().out_info, GSPMD shardings re-attached per leaf) instead of
eagerly materializing the state — the 27-site C=49152 state is ~360 GB and
OOMs the eager init's jit__identity_fn (128 GiB replicated V/U) at 64 GPU.
Adds --sites_per_chunk / --recon_coeff / --last_layer args and the 64-GPU
probe sbatch.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* feat(jax): finalize l18-26 9-layer R&D config (seq 512, B=128, 40k) + seq-512 probe

seq 2048→512 is what makes 27 sites fit on 64 GPU: AOT probe (seq 512, remat on) =
144.8 GiB (B=64) .. 150.3 GiB (B=256)/device, all under the 180 GiB B200 cap (seq 2048
was 179 GiB at B=64, over the cliff). Final config: layers 18-26 (27 sites, C=49152),
seq 512 / fineweb_llama_tok_512, B=128, 40k steps; comp 1.5e-4 / ci_fn 5e-5 (B=128
precedent); chunkwise recon sites_per_chunk=9 coeff 1.5; imp-min eps 1e-12→1e-6
(fractional-pnorm grad stability), coeff/beta/pnorm unchanged; PGD/faith unchanged;
ci max_len→512. mem_probe gains --seq; seq-512 sweep sbatch added.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…#872)

Training is JAX now and torch-run loading is being deferred (re-add later as the
#10 torch->jax adapter). This shears the torch-run-loading surface and everything
that goes dead with it. No legacy/dual-format support — clean deletion.

Dropped (torch-run-loading surface):
- adapters/pd.py (torch PDAdapter); adapter_from_id now jax_pd-only (asserts JAX run)
- component_model_io.py + experiments/lm/vendored/ (VendoredLlama, LMComponentModel)
- SavedLMRun / build_lm_loader / make_run_batch from experiments/lm/run.py
  (kept build_target — JaxPDAdapter derives topology from it)
- init_pd_run from experiments/utils.py (dead; kept EXPERIMENT_CONFIG_FILENAME)

Cascade-pruned (no kept consumer; grep-confirmed):
- param_decomp core: component_model, components, ci_fns, ci_nn_blocks, ci_sigmoids,
  masks, batch_and_loss_fns, torch_helpers, run_sink, training_state, metrics/*
- param_decomp_lab: eval_metrics/* (torch Metric impls), batch_and_loss_fns, run_sink
- dead tests for all of the above

Kept in param_decomp/ (consumer substrate the torch consumers still import):
log.py, distributed.py, decomposition_targets.py.

Kept untouched: the eval-metric CONFIG classes in param_decomp_config (the JAX
trainer matches on them); the whole param_decomp_jax/ trainer (zero edits —
equivalence goldens bit-identical).

Validation: make type 0 errors; 224 passed (lab+core, not slow); import-smoke of
every kept consumer entrypoint OK; check-jax 0 errors + equivalence goldens 12/12.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…e self-contained run yaml (#873)

Kill the vestigial wrapper/`torch_config:`/`configs/torch/` indirection. A JAX run
config is now ONE self-contained file: the canonical `param_decomp_config` experiment
schema plus the run-instance fields it now also carries.

- `param_decomp_config`: `runtime.remat_recon_forwards` (RuntimeConfig); top-level
  `run_name`/`run_id`/`out_dir` (ExperimentConfig, run_id/out_dir None pre-launch);
  `wandb.group`/`wandb.tags` (WandbConfig).
- `jax_single_pool.config`: `load_wrapper`/`_build_from_schema`/`_wandb_group_tags`
  → `load_config`/`build_from_schema`; build fns read run-instance fields off the
  schema (no per-call args); run-dir pinning collapses to a single `config.yaml`.
- `pd-jax-lm`: validates the single config (structural TMS/ResidMLP/LM dispatch),
  stamps run_id/out_dir/wandb-group/tags into the workspace copy.
- `is_jax_run`: detects orbax `ckpts/` beside `config.yaml` (the `torch_config:` key
  marker is gone); JaxPDAdapter/slow_eval read `config.yaml`.
- Migrated all 9 configs to single files; deleted `configs/torch/`, the `_from_torch`
  wrappers, and `param_decomp_config/jax_wrapper.py`. Added missing `weights_dtype:
  bfloat16` to pile_ppgd_bsc (pre-existing latent refusal; bf16-only target).

Pure config-plumbing — no training-semantics change. Equivalence goldens stay green.
No legacy/dual-format shim; old saved runs' two-file pinned configs need manual
migration if a consumer ever reads one.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…(SPEC S33) (#874)

A fresh run can now initialize its trained decomposition (V/U + ci_fn) from a
PARENT checkpoint and fine-tune under a DIFFERENT config. When the run's own
ckpts/ is empty and cfg.resume_provenance is set, the LM train() loads the
parent's ckpts/<parent_step> onto the fresh reference TrainState, keeps only the
components + ci_fn, and starts a clean schedule from step 0 (fresh optimizer /
sources, no faith warmup). A subsequent SLURM requeue (own ckpts/ non-empty)
resumes from its own dir and ignores provenance.

- checkpoint.py::init_from_parent: restore parent onto reference, keep V/U +
  ci_fn, reset step to 0; assert the parent step exists.
- run.py::assert_finetune_structural_compat: read the parent's pinned config.yaml
  and assert matching sites (names + C) + ci-fn arch before the restore — a
  changed C/sites/ci-fn arch is not a fine-tune (fail-fast, readable diff).
- config.py: ExperimentConfig carries resume_provenance; the three builders
  propagate it from the canonical schema.
- run.py: TMS/ResidMLP refuse resume_provenance (LM-only).
- SPEC S33 + experiment.py ResumeProvenance docstring (dropped dead init_pd_run
  ref) + jax CLAUDE.md training-pipeline fine-tune stanza.

Non-resume training semantics unchanged (equivalence goldens bit-identical).

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Cache compiled XLA executables to a shared-FS dir reused across runs,
requeues, and future launches at the same config+topology. The ~24-min
chunkwise-step compile is keyed by HLO + backend + topology + jax/xla
version, so a matching re-compile loads from disk in seconds.

Set in run.py::main after init_distributed (the write gate reads the
distributed state) and before the first compile, so it covers direct
jsp-train too — not just pd-jax-lm. Cache dir is a SIBLING of runs/
($PARAM_DECOMP_OUT_DIR/xla_compilation_cache), shared across all runs and
all 8N ranks. Threshold 60s so only the big compiles cache.

Multi-host safe on jax 0.10.1: jax gates the cache WRITE on process_id == 0
(compiler.py), so all ranks read but only rank 0 writes — no shared-FS
race. Verified a cross-process cache HIT on CPU.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Tracks: orphaned eval metrics (UVPlots/PermutedCIPlots/general-IdentityCIError
recomputed nowhere since the torch offline path retired), deferred #10 (torch→jax
run adapter) + app re-add, pretrain as a TRANSITIONAL torch island (rewrite in JAX
when next needed → fully torch-free), and the imp-min token-count reparam.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
… torch-shed batch

Points human review at the goldens-blind, silent-corruption-risk surface (ckpt migration
remap, the build_target/topology torch-free swap, fine-tune partial-state load, the
is_jax_run discriminator) and marks the trainer-semantics PRs as skim-able (equivalence
goldens prove the training math is bit-identical).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The repo is now torch-free except `nano_param_decomp/run.py` — the standalone
single-file VPD reference impl for paper readers (self-contained, no lab/pretrain
deps, excluded from `make type`). Training has been JAX (`jsp-train` via `pd-jax-lm`)
since the torch trainer retirement; this sheds the remaining torch consumer/infra
layer. Zero `import torch` in `param_decomp` / `param_decomp_lab` / `param_decomp_jax`.

Deleted entirely:
- `param_decomp_lab/experiments/lm/pretrain/` (torch model defs, train loop,
  `pd-pretrain` CLI, run_info, configs) — pretraining will be reimplemented in JAX
  when next needed; the trainer loads target weights from the on-disk cache via its
  own torch-free loaders, never through pretrain code.
- `param_decomp/{distributed,decomposition_targets}.py`,
  `param_decomp_lab/{distributed,seed}.py`,
  `param_decomp_lab/infra/{ddp_launch,wandb_tensor_info}.py`,
  `param_decomp_lab/toy_models/`, `param_decomp_lab/topology/topology.py`,
  `param_decomp_lab/experiments/lm/run.py` (the torch `build_target` bridge) —
  all dead after de-torching their sole consumers.
- `jax_single_pool/tools/convert_llama_simple_mlp_checkpoint.py` (one-off torch
  checkpoint->safetensors converter; existing caches already converted).
- `nano_param_decomp/{pile_4L,simplestories_2L}.py` — model-wiring entry points that
  imported the deleted torch pretrain archs (now broken); `run.py` (the method) stays.
- Four tests covering deleted modules.

De-torched (relocated metadata path):
- `JaxPDAdapter` now derives target topology (n_blocks / vocab / per-site (name, C))
  from the new torch-free `jax_single_pool.load_run.run_metadata` (config + pretrain
  cache, no orbax restore) + the torch-free `path_schema_for_model_type`. No torch
  model construction.
- `experiments/lm/data.py` keeps only `tokenize_and_concatenate` (numpy, for the
  offline prestage tool); the torch DataLoader machinery is gone.
- `infra/run_files.py` drops the dead `save_file`/torch.save path.

Tooling: dropped torch + the pytorch-cu128 index from the root pyproject/lock; the
two-stack conflict is gone; the main venv bridges CPU jax + beartype + editable
`param_decomp_jax` (`make install-dev`). Docs updated (root/jax/experiments CLAUDE.md,
MIGRATION_HOLES, READMEs).

Validation: zero-torch grep empty; trainer still builds+loads the pile LlamaSimpleMLP
target from cache (CPU); `make type` 0 errors; lab+core 209 passed; `make check-jax`
0 errors; JAX suite 193 passed/2 skipped; equivalence goldens 12 passed (bit-identical
— training semantics unchanged).

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Adds tools/verify_c49k_migration.py, a streaming leaf-by-leaf verifier that
proves the migrated 175k checkpoint (p-bd3cd4d4) reproduces the frozen-clone
source (jax-l18-C49k-200k) bit-for-bit under migrate_c49k_checkpoint.py's
remap, closing the #1 correctness gap in MIGRATION_REVIEW.md (structure was
verified, value was not).

The comparison table is the inverse of the migration's remap, built from that
tool's own constants (KIND_TO_SITE_SUFFIX, SOURCE_STATE_KEY) so the mapping
under test and the mapping applied are the same. Each leaf pair is restored
single-device on CPU one at a time (all other leaves PLACEHOLDER, never read),
so peak RAM is ~2x one V/U leaf (~1.6 GB) and the 47 GB trees never coexist.
Asserts 1:1 coverage of both trees, then np.array_equal per leaf (squeezing
the legacy leading singleton on the 6 component V/U leaves).

Run verdict: all 144 leaves bit-identical -> migration VALUE-EXACT.
The g/u/d->gate/up/down mapping and V->[0]/U->[1] ordering are confirmed by
the asymmetric down_proj shapes (d_in=14336), which would mismatch under any
swap or mismap.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant