Migrate VPD to JAX: train + analyze in one framework; retire torch to oracle#560
Open
ocg-goodfire wants to merge 454 commits into
Open
Migrate VPD to JAX: train + analyze in one framework; retire torch to oracle#560ocg-goodfire wants to merge 454 commits into
ocg-goodfire wants to merge 454 commits into
Conversation
…-pool jax stage12/13 ran the per-site layerwise stoch (6 forwards) while the torch 2-pool recon_plan uses subset n_samples=1 (1 joint forward) — confounding the framework comparison. Switch jax to the joint estimator so the 2x2's stoch load is apples-to-apples. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ime HF dep) prestage_tokenized.py tokenizes a portion of a HF text dataset (fineweb sample/350BT) into local int32 seq-2048 parquet shards so training reads a local Arrow dataset instead of streaming/tokenizing from HF at run time. Removes the N-rank HF dataset thunderherd that stalls 80-GPU startup (the build_two_world straggler) + per-rank tokenization cost. Scavenge-safe: atomic shard writes (.tmp + rename), idempotent skip-existing resume, and SLURM-array fan-out (task t -> file indices t::num_tasks, shards by global index, disjoint). ~366 files ~= 256B tokens ~= 512GB (int32, ~2x parquet compression). Loader: dataset_name= parquet, data_files=<dir>/*.parquet, column_name=input_ids, is_tokenized=true, streaming=false. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
PD_COMPILE_SINGLE_POOL=1 compiles the core ComponentModel's masked forward, to test whether the single-pool path can get the pooled paths' compile win (it has zero compile wiring today -> torch 1-pool runs eager at 610ms). Env-gated probe; wire into RuntimeConfig if it proves out. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
probe_reshard.py - cross-sub-mesh jax.reshard latency (sub-mesh viability) probe_zero.py - full-mesh ZeRO sharding: resident-memory freed + gather-on-use forward probe_hetero.py - axis_index_groups even vs UNEVEN subgroups (single-mesh heterogeneous DDP) cw_probe.sbatch - single-process/8-GPU probe launcher (JAX_VENV override) cw_jax.sbatch - honor JAX_VENV (cuda12 vs cuda13) Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…compute-bound knee)
…replicated) Cartesian (group, dp) mesh: group 0..Nc-1 = chunk groups running each their own static site-partition's stochastic masked forwards concurrently; group Nc = main pool (importance-minimality + PPGD adversary recon + PGD source update). CI computed only on the main pool, broadcast to all groups via psum trick over the (Nc+1)-way group axis. V/U replicated (compute-parallel only). 8 GPU single-node verified for n_chunks=1 (group=2 x dp=4) and n_chunks=3 (group=4 x dp=2); finite, decreasing losses. 16 GPU multinode blocked by a pre-existing CUDA_ILLEGAL_ADDRESS that hits the stage13 reference identically. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Consolidate the surviving learnings from the abandoned 2/3-pool work into a
clean single-process-group LM trainer that scales via FSDP2 (memory) +
torch.compile (speed) instead of hand-written pool transport. Pools and the
core DDP Trainer are left behaviourally untouched.
- param_decomp/train_step.py: extract the model-agnostic step pieces
(run_loss_step, run_eval_pass, scheduled_lrs, EvalLoop, sigterm, metric-ctx)
out of optimize.py; core Trainer.run now composes them (no behaviour change).
- param_decomp_lab/fsdp/: new subsystem
- component_adapter.py: present the vendored LMComponentModel through the core
ComponentModel surface metrics expect (batch is idx; cache_type mapping).
- config.py: FsdpRuntimeConfig (compile_model/compile_ci_fn/checkpoint_blocks/
shard_frozen_target).
- checkpoint.py + consolidate.py: on-loop sharded DCP save of trainable-only
state; off-loop async consolidation -> model_<step>.pth / training_<step>.pth.
- trainer.py: FsdpLMTrainer (FSDP2 fully_shard + compile + residual-start +
activation checkpointing); resume-in-place from DCP shards.
- sdpa_strict.py / grad_clip.py: cribbed from pd-nano-jax.
- experiments/lm/fsdp_run.py + fsdp_async_eval.py: composition root, pd-lm-fsdp
entry point, SLURM requeue + resume-in-place, async consolidate+slow-eval.
torch 2.11 (FSDP2 + DCP). make check green; fsdp unit tests pass. Runtime
distributed smokes still pending (de-risk phase).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Single-pool Parameter Decomposition training step in JAX, the research counterpart to the torch FSDP single-pool path. Reuses nano_pd_jax primitives (leaky-hard-sigmoid custom_vjp) and mirrors the torch PPGD semantics from param_decomp/metrics/persistent_pgd_state.py. - forward.py: site-local masked decomposed forward (layerwise recon), with the weight-delta channel gating the residual per position. - losses.py: the four VPD losses (faith, imp-min, stochastic recon, ppgd recon) over a stacked-site Decomposition; mask = ci + (1-ci)*source. - scopes.py: PPGD source scopes (single/broadcast/repeat/per-batch-per-position). - pgd.py: functional persistent adversary -- sources + Adam moments carried in state; n_warmup lax.scan ascent + a final post-update ascent. - step.py: the whole step as one jax.jit fn -- fused multi-argnums grad over (V/U, CI), two functional Adam optimizers, frozen-target grad zeroed. - experiments/toy_stacked_sites.py: CPU smoke (correct adversary signature). - tests: 11 pure-fn tests incl. bit-exact PGD-scan-vs-python-loop. All green. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… invariance
The SPMD-collapse distribution layer (the FSDP analog): batch sharded P('dp'),
params + PGD sources replicated, jax.jit inserts every collective.
- sharding.py: dp_mesh / replicate / shard_leading / shard_batch + SLURM bring-up.
shard_batch uses make_array_from_process_local_data so it's correct for BOTH
topologies (single-process-many-devices and multi-process-1-device-each).
- experiments/distributed_stacked_sites.py: the single-pool step sharded over a
mesh; correctness = fixed-global-batch trajectory must be GPU-count-invariant.
- tests/test_sharding.py: shard_batch preserves the full global array + the
non-divisible-batch assertion. Pass at 1 and 4 simulated devices.
RESULT: bit-identical loss trajectories at 1 vs 4 simulated CPU devices -- the
full VPD+PGD step is GPU-count-invariant under GSPMD with zero manual
collectives, persistent adversary included. NOTES records the diagnosis.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- experiments/transformer_qkv.py: decompose the square q/k/v sites of a real nano_pd_jax.TinyTransformer (real pre-weight acts) -- proves the step is model-agnostic. faith 0.043->0.005, stoch 0.30->0.064 on real attn projections. - checkpoint.py + tests/test_checkpoint.py: flat-pytree save/resume of TrainState; resume continues the trajectory bit-identically (the persistent adversary survives). No torch-style state_dict plumbing -- the whole adversary state is in the pytree. - README.md + CLAUDE.md: file map, run instructions, design + invariants. - NOTES: residual-start analysis (it's implicit in the layerwise-recon factoring -- no masked re-forward through the frozen prefix), checkpoint/real-model TODOs resolved. 14 tests green at 1 and 4 simulated devices; basedpyright + ruff clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Replace the `cast(ComponentModel, cast(object, x))` double-cast idiom with a structural Protocol the concrete ComponentModel, the FSDP adapter, and the vendored LMComponentModel all satisfy. - param_decomp/component_model.py: add ComponentModelProtocol — the minimal surface metrics + train_step consume (__call__ with 3-way cache_type overloads so the 4-way concrete model and 3-way adapter both satisfy it via param contravariance; forward_with_output_acts; calc_causal_importances; calc_weight_deltas; module_to_c; target_module_paths; components). - Retype MetricContext.model, Metric.model/Metric.bind, instantiate_metrics, the train_step helpers, tie_component_weights and run_faithfulness_warmup to the Protocol; drop the double-casts in fsdp/trainer.py and three_pool/*. - Eval metrics that genuinely need the concrete ComponentModel (attn-patterns, autointerp) narrow via `assert isinstance(model, ComponentModel)` in an overridden bind — not a cast. - configs.py: make RuntimeConfig.dp's description strategy-agnostic (world size, replicated under DDP / sharded under FSDP); NOT renamed (preserves checkpoint experiment_config.yaml compat). FsdpRuntimeConfig documents the FSDP meaning. make check green; 186 tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
GPT2 block-6 MLP (c_fc/c_proj) on vendored GPT2 for the fast correctness smoke; Llama-8B L18 MLP (gate/up/down_proj) adapted from the proven 2-pool config to the FSDP schema (losses struct -> loss_metrics list, topology -> FsdpRuntimeConfig). All 4 losses; checkpoint_blocks off (PPGD); compile + shard_frozen_target on. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add mem/peak_gb_per_rank to the FSDP train-log (torch.cuda.max_memory_allocated) so the shard_frozen_target memory A/B is directly observable. Add two short Llama-8B L18 memory-A/B configs differing only in shard_frozen_target (true vs false). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…gathers residual_at() runs the embedding outside the root model.forward, so the root fully_shard pre-hook never fires; under shard_frozen_target the embedding weight is a sharded DTensor and meets plain-tensor token ids -> 'aten.embedding got mixed torch.Tensor and DTensor'. Give each embedding its own fully_shard group so it all-gathers on every call (incl. residual_at). GPT-2 co-shards the tied wte/lm_head; Llama embed_tokens shards standalone. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…arded
The prior embed-group fix was wrong twice over (validated on a 2-GPU repro): GPT2's
tied wte/lm_head can't co-shard as one fully_shard group ('requires a single root
module'), and a standalone-sharded embedding trips FSDP2's single-root lazy-init when
residual_at enters it before the parent forward. Root-sharding the embedding (the
original code) instead raises the DTensor-mixed error because residual_at bypasses the
root hook.
Fix: shard only the transformer blocks (which hold all trainable V/U + in-block frozen
target weights) and the ci_fn; drop the fully_shard over lm.model / lm so the frozen
embedding / final-norm / lm_head stay replicated. residual_at then sees a plain-tensor
embedding weight. The ~1GB replicated embedding is negligible vs the 32 sharded blocks.
Repro: embed_replicated mode passes; blocks_only/embed_group/no_root all fail.
Also fix the GPT2 smoke target pattern: the vendored GPT2 MLP second linear is
down_proj, not c_proj.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…tion map) The apollo openwebtext dataset is tokenized at seq 1024; max_seq_len 512 triggered data.py's per-example truncation map over all 8.8M examples (~37min). Use native 1024 (map skipped) and bump CI attn max_len to match. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The components' V/U are sharded by their transformer block's fully_shard, so accessed outside a forward (calc_weight_deltas, called by faithfulness warmup + every train step) they are DTensors; the frozen target_weight is a plain replicated buffer (GPT2) or a sharded param (Llama under shard_frozen_target). The vendored calc_weight_deltas subtracts them directly -> 'aten.sub got mixed torch.Tensor and DTensor'. Override it in the FSDP adapter to gather each operand via DTensor.full_tensor() before subtracting. The full per-site weight is materialised anyway (faithfulness needs the whole delta), so no asymptotic memory cost. Validated on a 2-GPU GPT2 repro: calc_weight_deltas + faithfulness loss + backward all succeed. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
residual-start runs the embedding on the raw batch via use_cached_residual, opened before run_loss_step does its own move_batch_to_device; the loader returns CPU tensors, so the embedding index_select mixed cpu ids with the cuda weight. Move the batch to device in the trainer loop before use_cached_residual (run_loss_step's move is then a cheap no-op). Mirrors the 3-pool path, which moves batch_local before use_cached_residual. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…exist PPGD's before_backward runs torch.autograd.grad(retain_graph=True) through the compiled region; AOTAutograd's donated-buffer optimization asserts retain_graph=False on every backward of a compiled graph, raising 'compiled with non-empty donated buffers which requires create_graph=False and retain_graph=False'. Disable torch._functorch.config.donated_buffer when a PPGD loss is configured and compile is on. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…tter mismatch Diagnostic: the compiled GPT2 backward hit 'size of tensor a (512) must match b (2048)' in FSDP foreach_reduce. This config disables compile to isolate whether the reduce-scatter mismatch is a compile+FSDP2 grad interaction or a pure sharding issue. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The FSDP2 main-step blocker was a gradient-placement mismatch in calc_weight_deltas. The prior `full_tensor()` (= redistribute(Replicate).to_local()) detached the DTensor placement, so the faithfulness path emitted Shard(1)/Replicate grads for the components' V/U while the recon forward emits Shard(0); accumulating both into one param made FSDP2's gradient reduce-scatter fail (C vs C/world dim mismatch). Fix: redistribute V and U to Replicate (KEEPING them DTensors — no .to_local()) before the einsum, and subtract a Replicate target. redistribute is differentiable and its backward returns the grad in the source Shard(0) placement, so the faithfulness grads now match the recon path's. Verified on a 2-GPU repro: V.grad/U.grad both Shard(0), bit-matching the canonical recon placement (vs Shard(1)/Replicate for the old path). Also exclude **/_scratch/** from pyright + ruff (throwaway multi-GPU repro dir). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… wandb short-name table) (#849) test_validate_wrapper_rejects_unexpected_key: the validator's assertion message was reworded ("keys must include ...") but the test's pytest.raises regex still matched the old "keys must be". Stale test; update the regex. The validation itself is correct. test_config_short_name_table_matches_metric_registry: METRIC_SHORT_NAMES carried a stale "PersistentPGDReconSubsetLoss" entry whose metric class was removed in #824 (PPGD+subset removal). Real drift in the torch-free table; drop the dead entry so it matches the registry-derived names. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…h) (#850) Clustering already reads JAX runs natively via run_worker_jax.py. Shed torch from the whole subsystem: - Delete torch-worker path superseded by run_worker_jax + run_merge: run_harvest.py, run_clustering.py, run_pipeline.py, calc_distances.py, plotting/ (only used by those), clustering_run_config.py, and the orphaned configs/ for the deleted ensemble pipeline. - Convert the membership accumulator + merge stack to numpy: MembershipBuilder, flatten_lm_activations, GroupMerge/BatchedGroupMerge, compute_costs, merge, sample_membership coactivation, merge_pair_samplers (numpy + stdlib random), merge_history, matching_dist, the dense test-oracle path, and the type aliases. - Drop the dense `preview` mechanism (only the deleted run_clustering consumed it). - run_worker_jax feeds numpy directly and imports zero torch. - Fail-fast fix surfaced by numpy strictness: MergeHistory stored group/pair indices as int16, which overflows above 32767 components (torch silently truncated). Now int32. - Remove dead pd-clustering / pd-cluster-harvest console scripts (keep pd-cluster-merge); harvest is `python -m ...run_worker_jax`. Acceptance: grep for torch in param_decomp_lab/clustering is empty. Clustering tests (58) green, make type clean, pre-commit passes. Validated e2e on JAX run p-761bc061: harvest -> membership snapshot -> merge -> history.zip. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
… builder) (#851) Replace the NotImplementedError TargetConfig branch in load_run.build_target with the llama8b build mirroring run.py::main: llama_decomposed_lm over llama_site_specs, replicate_target(load_target_from_hf(...)) for the frozen suffix, device_put(load_prefix_from_hf(...), P()) for the prefix, prefix_residual as the residual fn, and vocab_size from llama31_8b_config(). Returns the same (lm, target, prefix, prefix_residual_fn, vocab_size) tuple as the SimpleMLP branch, so open_jax_run / LoadedJaxRun.forward stay target-agnostic. Promote the builder to public build_target and route slow_eval through it, deleting its private SimpleMLP-only _build_simple_mlp. One loader now covers both targets; the SimpleMLP path's restore/forward behavior is unchanged. Unblocks jsp-slow-eval (and the app) on llama8b runs. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…amends SPEC S31) (#853) Port CIHiddenActsReconLoss + StochasticHiddenActsReconLoss to JAX as standalone offline slow-eval metrics — explicitly NOT recon-grid terms. The parameterized recon loss stays KL-on-final-logits only (SPEC §2.3-2.5); these are eval diagnostics. Deliberate, Oli-approved amendment of SPEC S31, which had marked the hidden-acts seam "keep-on-bridge / refused". They now have a native JAX eval path. - Add a fifth DecomposedModel fn `masked_site_outputs(...) -> dict[site, (B,T,d_out)]`, factored out of `masked_output` via a shared `_run_masked_suffix` + a per-site `collect` dict in `_masked_site_out` (per-site output is an intermediate of the masked forward — no logic duplicated). Implemented for both targets (llama8b, llama_simple_mlp). - Clean (target) per-site output = frozen `x @ W`, obtained from the same seam by routing FALSE everywhere (`_site_out`'s frozen branch) — no separate frozen-W accessor. - New `hidden_acts_eval.py`: per-site MSE(masked, clean) reduction="sum", host-accumulated as (Sum sum_mse, Sum n_elements) (token-weighted, exact under micro-batching), divided once at compute. Log keys mirror torch: "<ClassName>/<site>" + combined "<ClassName>". Masked+clean in COMPUTE_DT (bf16), MSE reduction fp32. CIHiddenActsReconLoss = deterministic lower_leaky CI, no delta, one forward; StochasticHiddenActsReconLoss = n_mask_samples stochastic CI draws + deltas. Stochastic draws are not seed-aligned to torch (exact bitwise parity impossible there, expected). - Wire both into jsp-slow-eval via SlowEvalOutput(figures, hidden_acts); scalars logged under slow_eval/loss/<key> + written to hidden_acts_recon_step<N>.json. - Per-site unit test on SimpleMLP; SPEC S31 amended; CLAUDE.md/README updated. Validated: full JAX suite green at 1 and 4 sim devices (165 each); check-jax clean; equivalence/stacked-parity/simple_mlp_equivalence goldens untouched. Ran on p-761bc061 (24 sites): finite, non-negative per-site, combined within [min,max]. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…ew pass) (#854) Latent bug fix: - clustering/sample_membership.py: coactivation matmul cast int32 -> int64 (X.T @ X silently overflows above ~2.1B samples; same class as the int16 MergeHistory bug). Downcast to float32 right after is unchanged. Type hygiene: - app/backend/state.py: DatasetSearchState now holds typed DatasetSearchResult / DatasetSearchMetadata (moved to schemas.py) instead of list[dict[str,Any]] / dict[str,Any]; router builds typed objects directly. - clustering/merge_history.py: meta: dict[str,Any]|None -> typed MergeHistoryMeta(origin_path); removes the defensive isinstance(.,Path) scan. Dead code / honesty: - app/backend/dependencies.py: drop no-op try/except around StateManager.get(). - harvest/scripts/run_worker_jax.py: one host copy in _to_torch (np.array); single-source --batch_size / --activation_threshold defaults from the configs; activation_threshold read from method_config (no threaded duplicate). - harvest/reservoir.py: delete stale TODO on a load-bearing pad mask. - clustering/merge_history.py: delete dead TODOs + commented-out assert block. - clustering/merge.py: warnings.warn for expected terminus -> logger.info. Perf nit: - app/backend/inference.py: gather next-token probs in JAX before host transfer (avoid materializing the full (T, vocab) output_probs to read ~T floats). Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Behavior-preserving cleanups from a review pass over the JAX trainer core. Equivalence + stacked-parity goldens untouched and green. - load_run.py build_target: the two return Anys -> AnyFrozenTarget / AnyPrefix (new shared target_aliases.py, imported by both build_target and run.py::main, which dispatch on the identical match cfg.target). prefix_residual_fn keeps its Callable[[Any, Any], Array] (legit generic input edge). - load_run.py: inner forward's ci_fn Any -> CIFn; LoadedJaxRun._forward Any -> the honest Callable. - hidden_acts_eval.py / slow_eval.py: sampling: str -> SamplingType Literal, threaded through compute_hidden_acts_metrics; dropped the now-redundant runtime assert sampling in (...). recon.py StochasticSources.sampling + build_recon_terms also use SamplingType; build_recon_terms loss_metrics narrowed tuple|list -> tuple. - slow_eval.py: max(n_positions, 1) -> assert n_positions > 0. - recon.py: instance_key -> assert_unique_instance_key (reflects the uniqueness assert). - run.py: assert int(state.step) == ckpt_step after restore (codifies S22); _ensure_global made PEP-695 generic [T](tree, mesh) -> T, dropping the call-site isinstance assert. - config.py: _resolve_target uses the existing AnyTargetConfig alias (was a quoted forward-ref). Skipped (behavior-affecting / non-trivial, noted in review): losses.py annealed_pnorm span assert (would break the documented start_frac==end_frac==1.0 no-op config); build_experiment_config wandb default drop (several test call sites rely on the defaults). Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…alCiConfig (review pass) (#856) Schema-package code-review fixes (all in param_decomp_config/, plus consumer updates): Fail-fast: - EvalConfig: model_validator enforcing the documented slow_every % every == 0 - PGDConfig: step_size -> PositiveFloat, n_steps -> NonNegativeInt (zero = the no-ascent fresh-sample baseline, which the JAX adversary test exercises) Narrow types -> Positive*: - eval_metrics: n_heads, AutointerpLabels k/max_examples/context_tokens_per_side/ seq_len -> PositiveInt; CIHistograms n_batches_accum -> PositiveInt | None - autointerp: strategy max_examples / label_max_words -> PositiveInt Encode invariants in types: - GlobalCiConfig: replaced the two-independently-optional-fields-gated-by-a-validator shape with a fn_type-discriminated union of GlobalSharedMlpCiConfig (hidden_dims) vs GlobalSharedTransformerCiFnConfig (transformer cfg), nested under the mode discriminator. Impossible states unrepresentable; validator deleted. A BeforeValidator drops the inactive hidden_dims/simple_transformer_ci_cfg None keys that old single-class configs wrote, so stored runs still parse. Readers updated: ci_fns.py, component_model_io.py, jax_single_pool/config.py, test fixtures. - IdentityCIErrorConfig.identity_ci/dense_ci: list[dict[str, str|int]] -> typed IdentityCITargetSpec / DenseCITargetSpec; make_target_ci_solution + caller updated - PermutedCIPlotsConfig / UVPlotsConfig: shared _PermutationPlotsBaseConfig Honesty: torch-Trainer/.pth docstrings -> orbax ckpts/<step>/ + jsp-train (pd.py Cadence/keep_last_n_checkpoints/PDConfig, schedule.py, experiment.py). Surfaced (not forced): the two PGD scope vocabularies (MaskScope string-literal vs PersistentPGDSourceScope object-union) stay deliberately separate — converging them would change stored YAML shape; added a cross-reference comment explaining why. No new type: ignore / Any / cast; param_decomp_config stays torch-free. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…omp/ (#857) Removes four top-level helpers that the retired torch trainer used and that nothing live imports after the JAX-trainer pivot: - grad_clip.py (cross_pool_clip_grad_norm; n-pool lineage, dead per PARITY_MATRIX) - sdpa_strict.py (verify_flash_attention_available) - _trace.py (trace / dump_memory_stats) - phase_timer.py (PhaseTimer / phase / format_phase_table) Each verified: zero importers across param_decomp/, param_decomp_lab/, param_decomp_jax/, configs, and tests (symbol- and module-path searches; no dynamic imports; bare __init__). The metrics/ tree, train_step.py, training_state.py, distributed.py, torch_helpers.py, log.py and the bridge substrate stay — all live via dispatch / offline-eval / eval_metrics. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…pe shape contract (#860) Generalize the trainer core from a fixed (B,T) activation waist to a generic [*leading, d] waist (masks/CI [*leading, C]), where leading = (batch,) + named position axes. CI is independent over every leading axis (broadcast_ci dropped), so there is no per-axis CI semantics — only axis NAMES. - DecomposedModel gains leading_axes; CIFn gains expects_axes; init_train_state asserts they're equal (early fail) so the CI fn stays per-domain (RoPE over `sequence`) without the generic loop adapting. ("sequence",) for LM. - Keystone train.py: leading = residual.shape[:-1], threaded to routing samplers, fresh-PGD source init, stochastic delta-mask shapes (leading_shape: tuple[int,...]). - losses.py reductions: math.prod(shape[:-1]) / axis=tuple(range(ndim-1)). - recon.py RoutingSampler + bodies: tuple[int,int] -> tuple[int,...]. - adversary.py init_persistent_sources takes leading_shape; init_fresh_pgd_sources spells c/bc/bsc over the model's leading. llama8b_sharding builds the per-scope leading_shape (sc->(1,T), bsc->(B,T)). - eval/slow_eval/hidden_acts_eval: same shape[:-1]/prod generalizations. - lm.py type aliases -> *leading. - jaxtyping+beartype: @jaxtyped(typechecker=beartype) on the core step, masked_forward (residual/masks/routes bind *leading), and the loss fns — the waist contract enforced at trace time. New dep beartype==0.22.2. - AXIS_SEMANTICS_DESIGN.md documents the contract + designed (unimplemented) extensions (scope as leading-axis subset; tying as shared vu leaf). LM byte-identical: equivalence + stacked-parity goldens UNMODIFIED and passing; only test call-sites adapt to the renamed source-init signature (arrays identical). Full suite green at 1 and 4 devices; check-jax clean. TMS not ported; scope stays LM's c/sc/bsc. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…g,d] core (#861) Third decomposition target and the first positionless one (leading_axes=()), the proof the generic waist fits a non-LM model. The core needed ZERO change — only an added target (tms.py), CI-fn arch (ci_fn_mlp.py: layerwise per-site MLP, expects_axes=()), and config dispatch. - param_decomp_config/tms.py: torch-free TMSExperimentConfig/TMSTargetConfig/TMSDataConfig. - jax tms.py: Anthropic toy (tied weights) with an UNTIED 2-site decomposition, recon_loss_fn=tms_mse (post-ReLU MSE, not KL), no prefix, pretrained in-process. - ci_fn_mlp.py: LayerwiseMLPCIFn (vector-input per-site MLP, fits the ci_fn(site_inputs) waist; recovers identity in the non-superposed regime — validated in test_tms.py). - config dispatch: _is_tms_schema routes to build_tms_experiment_config; TMS*Config join AnyTargetConfig/AnyDataConfig; run.py builds+pretrains target and calls train_tms. - harvest/clustering run_worker_jax: assert isinstance(data, DataConfig) before reading .dir/.seq_len — AnyDataConfig now includes TMSDataConfig; these workers are LM-only (#12 localize-and-assert; caught by the torch-side make type). Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…erpLabels (#859) * shed(harvest): numpy accumulator + delete dead torch worker & AutointerpLabels Make the harvest subsystem fully torch-free. Convert the accumulator and its shared state to NumPy: - Harvester / ActivationExamplesReservoir / sampling / storage / analysis / HarvestBatch are NumPy. Counts, component co-occurrence, and input token cooccurrence/marginals use int64 (overflow-safe over a long stream); probability-mass accumulators and activation sums use float64. - Input token cooccurrence accumulates via np.add.at over np.nonzero(firing) entries (parallel 1-D integer indices) instead of torch scatter_add_. - Reservoir sampling RNG is an explicit np.random.Generator threaded from the Harvester (was a torch.Generator | None default). - Tensor artefacts move from torch.save .pt to np.savez .npz (component_correlations.npz, token_stats.npz, worker_states/worker_*.npz). Harvest output is regenerable intermediate, so no legacy shim. run_worker_jax.py is now the only worker and drops torch: np.asarray off the JAX forward straight into the NumPy Harvester. It gains --rank/--world_size sharding (via ShardServer process_index/process_count) + worker_states save + get_command, so the SLURM launcher (run_slurm.py) drives it as a rank-sharded array + dependent merge, exactly as the old torch worker did. Delete dead torch code: - harvest/scripts/run_worker.py + harvest/harvest_fn/* (torch worker path, superseded by the JAX worker). pipeline.harvest() went with it; merge_harvest stays (now NumPy). - eval_metrics/autointerp_labels.py + its config, registry entry, wandb short-name, and test. Verified dead: in zero eval configs, its async_eval driver does not exist, reachable only by its own definition + registry. Acceptance: grep for torch imports under param_decomp_lab/harvest is empty. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * chore(harvest): drop deleted harvest_fn from lab packages list harvest_fn was deleted with the dead torch worker; the setuptools packages list still referenced it. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…atch (#862) JAX counterparts of torch CIMaskedAttnPatternsReconLoss / StochasticAttnPatternsReconLoss, run IN-LOOP over the eval pass's residuals (a few forwards, not minutes-slow), removed from OFFLINE_EVAL_METRIC_TYPES. - attn_patterns_eval.py: KL(target_pattern ‖ masked_pattern) per attention layer over every (B,H,T_query) distribution; matches torch F.kl_div(masked.clamp(1e-12).log(), target, reduction="sum"), divided once by Σ n_distributions. Combined + per-q_proj-site log keys mirror torch. - Q/K via the existing masked_site_outputs seam (clean = all-false routes → frozen x@W, the hidden_acts_eval trick); DecomposedModel gains NO new method. - attn_pattern_for(target): the RoPE/GQA/head-dim recipe dispatched on the concrete frozen target, reusing the target's OWN vendored RoPE + FrozenAttn config — NOT a method on DecomposedModel (the universal interface stays LM-agnostic). Non-attention target raises. - config: AttnPatternsEvalConfig (ci_masked/stochastic flags); _assert_separate_qk_attn_paths refuses combined c_attn (no JAX target produces a merged-QKV site). - run.py: build attn steps when configured; accumulate token-weighted KL into eval_record. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…orch models (#863) These were paper-comparison-only tooling (compare PD against CLTs/transcoders) with no live consumer in the JAX stack. The torch harvest_fn dispatch they fed was already removed when harvest went JAX-native (#859); only the adapter + config + vendored-model surface remained. Deleted: - adapters/{clt,transcoder}.py and adapters/_vendor/ (CrossLayerTranscoder, BatchTopKTranscoder) - CLTHarvestConfig / TranscoderHarvestConfig + the DecompositionMethodHarvestConfig union (collapsed to ParamDecompHarvestConfig — the sole remaining method) Kept (PD run-loaders, NOT comparison adapters): adapters/pd.py (PDAdapter), adapters/jax_pd.py (JaxPDAdapter), adapters/base.py. Collapsed autointerp DecompositionMethod Literal to "pd" and dropped the clt/transcoder prompt descriptions (only those adapters ever set the non-pd values). Removed the adapters._vendor package entry from pyproject. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
The fourth DecomposedModel and second non-LM bundle over the generic positionless (leading_axes=()) core: the SPD/APD residual-stream toy. Fixed W_E/W_U embeddings, n_layers MLP blocks, decomposition over the per-layer mlp_in/mlp_out matrices (untied V/U). W_E is the prefix (residual = x @ W_E); recon is MSE on the model output (not KL). Pretrained from scratch in-process on the read-off act_fn(x)+x objective. Reuses LayerwiseMLPCIFn (no new CI arch) and the identity-CI ground-truth metric. Core needed ZERO change (train/losses/recon/lm) — only ADDED a target + config/run dispatch, exactly like TMS (PR #861); ResidMLPTarget joins AnyFrozenTarget. _is_tms_schema marker moved n_features -> n_hidden so it stays disjoint from ResidMLP's d_embed marker. Validation: jax basedpyright clean; test_resid_mlp (14, incl slow e2e identity recovery) + test_tms + test_config + equivalence goldens green; torch make type clean; make test 407 passed. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
The last LM eval metric (attn-patterns) moved to JAX in-loop (#862) and the CLT/transcoder adapters were deleted (#863), so the torch offline-eval bridge is now dead. Slow/plot eval is JAX-native via `jsp-slow-eval`; parity with `pd-offline-eval` was validated on llama8b `p-4301638c` (PASS) before this delete. Deleted (solely torch-offline-eval): - `param_decomp_lab/experiments/lm/offline_eval.py` (pd-offline-eval entry) - `param_decomp/train_step.py` (its only live consumer was offline_eval) - `jax_single_pool/export.py` (jsp-export; harvest/clustering read orbax via `load_run.open_jax_run`, never the safetensors export) - `jax_single_pool/slurm/offline_eval_once.sbatch` (+ the now-empty slurm/ dir) - `jax_single_pool/tests/test_export*.py` + `tools/export_fixtures/*` - `torch_helpers.loop_dataloader` (offline_eval-only iterator) Also: drop `pd-offline-eval` + `jsp-export` console scripts, the `offline_eval_submission_argv`/`_submit_offline_eval` push-trigger in run.py (+ its test + `subprocess` import), the stale `verify_export_torch.py` type-check exclude, and dead torch-bridge phrasing in docstrings/CLAUDE/README. Kept (live non-eval consumers): the `eval_metrics/` torch Metric impls + `EVAL_METRIC_CLASSES` (their `wandb_config_dict`/`metric_short_names` plumbing is reused by the JAX trainer via run_sink), `OFFLINE_EVAL_METRIC_TYPES` (a live config-validation gate), `bf16_autocast`, `target_aliases.py`, and `convert_llama_simple_mlp_checkpoint.py` (a pretrain converter, not export). Validation: make type (0), make test (407 passed/4 skipped), JAX basedpyright (0), equivalence goldens (12 passed, trajectory bit-identical), make test-jax (174 passed/1 skipped). Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
#866) Generalize the localize-and-assert pattern (ci_fn↔target expects_axes in init_train_state; attn_pattern_for's non-attention raise) to the LM-only in-loop eval metric constructors, which consumed tokens/vocab/attention without declaring their leading-axes compatibility. - eval.make_eval_step (CEandKLLosses / CI_L0 / fresh-PGD probe): reads tokens + vocab logits and a (1,1,C+1) source over a (batch, sequence) waist — assert lm.leading_axes == ("sequence",) at construction. - attn_patterns_eval: shared _assert_attention_sequence_axes guard on both make_ci_attn_patterns_step / make_stochastic_attn_patterns_step (causal (B,H,T,T) maps over a sequence axis), complementing attn_pattern_for's per-target dispatch. Domain-generic metrics (hidden-acts MSE, slow-eval CI plots) correctly stay unguarded. Data sources already fail loud: train() / train_tms / train_resid_mlp assert their DataConfig subtype, ShardServer asserts shard width vs seq_len, the harvest/clustering workers assert isinstance(data, DataConfig). Additive only; generic core (train/losses/recon/lm) unchanged; LM trajectory bit-identical (full JAX suite + goldens green). New guard tests assert each constructor raises against a positionless (leading_axes=()) TMS target. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Superseded by the JAX targets (param_decomp_jax/jax_single_pool/tms.py,
resid_mlp.py). The torch experiment dirs had no external importers or console
scripts — the torch training drivers were already retired.
- Delete param_decomp_lab/experiments/{tms,resid_mlp}/ (models/data/run/train/
feature_importances + their YAML configs)
- Delete their tests: tests/test_data.py (SparseFeatureDataset),
tests/test_feature_importances.py (compute_feature_importances)
- Drop the two packages entries from param_decomp_lab/pyproject.toml
- Correct CLAUDE.md / README references: experiments bridge is now LM-only;
TMS/ResidMLP live as JAX targets
Kept: param_decomp_config/{tms,resid_mlp}.py (torch-free schemas the JAX
trainer reads); param_decomp_lab/toy_models/target_ci.py (live importers in
eval_metrics + tests); experiments/utils.py (live importers in lm/adapters/app).
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
) Remove `param_decomp_lab/app/` (FastAPI backend + Svelte frontend) to shed torch/web surface ahead of the JAX-primary merge. This is a deliberate remove-now-re-add-later decision; a JAX-native viewer is slated to replace it. The PR body carries the full re-add log. Kept-consumer fallout handled: harvest and autointerp imported the app's tokenizer-display helpers, so `AppTokenizer`, `escape_for_display`, and `delimit_tokens` are relocated to `param_decomp_lab/tokenizer_display.py` (out of the app, no behavior change). Their test moves to `tests/test_tokenizer_display.py`. Also removed: `make app`/`install-app`/`check-app`/`install-all` targets, the `fastapi`/`uvicorn` lab deps (app-only; orjson kept — shared), the app package + package-data entries in `param_decomp_lab/pyproject.toml`, the `app/frontend` ruff/pyright excludes, and the CI `build-frontend` job. CLAUDE.md pointers and stale app mentions updated; a temporary-removal note added to the root CLAUDE.md. Out of scope and intentionally untouched: `investigate/scripts/run_agent.py` still subprocess-launches the app backend, so `pd-investigate` is broken at runtime until the app returns (no import/type/test break). See PR re-add log. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…yout (#870) The frozen C49k clone (jax-l18-C49k-200k, Llama-3.1-8B layer-18 MLP, C=49152) saved a TrainState whose orbax pytree keys differ from HEAD's in three places. tools/migrate_c49k_checkpoint.py restores the OLD tree single-device on CPU, remaps it onto the current layout, and save_states it so a fine-tune can restore_latest from it. Remap: components.{Vg,Ug,Vu,Uu,Vd,Ud} (1,*,*) -> components.vu[<site>][0|1] (*,*) g->gate_proj, u->up_proj, d->down_proj; V->[0], U->[1]; squeeze leading 1. components_opt_state mu/nu: same squeeze; optax count scalars copy through. ci_fn / ci_fn_opt_state: identical leaf names -> straight copy. sources.<site> -> sources.<state_key>.<site> (state_key PersistentPGDReconLoss) sources_adam_state.{m,v}/step_count -> sources_opt_state.<state_key>.{...} step: copied (restores as 175000). The reference TrainState is built abstract (jax.eval_shape) so the full fp32 reference never coexists with the 47 GB restored tree. Verified independently: clean restore (no structure/shape/dtype mismatch), step==175000, component shapes + finiteness, sources in [0,1], CI-fn forward finite on a dummy batch. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…m probe (#871) * feat(jax): l18-26 9-layer chunkwise config + abstract-AOT mem probe New from-scratch decomposition of Llama-3.1-8B MLP layers 18-26 (27 targets = 9 layers x gate/up/down, C=49152), chunkwise recon (3 chunks of 9 sites, coeff 1.5), remat on. Derived from the C49k single-layer-18 run. AOT-probed per-device HBM at 64 B200 (remat on, no buffer donation — matching the trainer's non-donating @jax.jit step): B=128 (2/device): 210.2 GiB -> over the 180 GiB cap B=64 (1/device): 179.0 GiB -> at the cliff, ~0 headroom B=128 does not fit, so the config carries B=64 with LRs 4th-root-scaled to ci_fn 4.2e-5 / components 1.26e-4. mem_probe.py rewritten to lower+compile on ABSTRACT sharded avals (shapes via jit(init).lower().out_info, GSPMD shardings re-attached per leaf) instead of eagerly materializing the state — the 27-site C=49152 state is ~360 GB and OOMs the eager init's jit__identity_fn (128 GiB replicated V/U) at 64 GPU. Adds --sites_per_chunk / --recon_coeff / --last_layer args and the 64-GPU probe sbatch. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * feat(jax): finalize l18-26 9-layer R&D config (seq 512, B=128, 40k) + seq-512 probe seq 2048→512 is what makes 27 sites fit on 64 GPU: AOT probe (seq 512, remat on) = 144.8 GiB (B=64) .. 150.3 GiB (B=256)/device, all under the 180 GiB B200 cap (seq 2048 was 179 GiB at B=64, over the cliff). Final config: layers 18-26 (27 sites, C=49152), seq 512 / fineweb_llama_tok_512, B=128, 40k steps; comp 1.5e-4 / ci_fn 5e-5 (B=128 precedent); chunkwise recon sites_per_chunk=9 coeff 1.5; imp-min eps 1e-12→1e-6 (fractional-pnorm grad stability), coeff/beta/pnorm unchanged; PGD/faith unchanged; ci max_len→512. mem_probe gains --seq; seq-512 sweep sbatch added. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…#872) Training is JAX now and torch-run loading is being deferred (re-add later as the #10 torch->jax adapter). This shears the torch-run-loading surface and everything that goes dead with it. No legacy/dual-format support — clean deletion. Dropped (torch-run-loading surface): - adapters/pd.py (torch PDAdapter); adapter_from_id now jax_pd-only (asserts JAX run) - component_model_io.py + experiments/lm/vendored/ (VendoredLlama, LMComponentModel) - SavedLMRun / build_lm_loader / make_run_batch from experiments/lm/run.py (kept build_target — JaxPDAdapter derives topology from it) - init_pd_run from experiments/utils.py (dead; kept EXPERIMENT_CONFIG_FILENAME) Cascade-pruned (no kept consumer; grep-confirmed): - param_decomp core: component_model, components, ci_fns, ci_nn_blocks, ci_sigmoids, masks, batch_and_loss_fns, torch_helpers, run_sink, training_state, metrics/* - param_decomp_lab: eval_metrics/* (torch Metric impls), batch_and_loss_fns, run_sink - dead tests for all of the above Kept in param_decomp/ (consumer substrate the torch consumers still import): log.py, distributed.py, decomposition_targets.py. Kept untouched: the eval-metric CONFIG classes in param_decomp_config (the JAX trainer matches on them); the whole param_decomp_jax/ trainer (zero edits — equivalence goldens bit-identical). Validation: make type 0 errors; 224 passed (lab+core, not slow); import-smoke of every kept consumer entrypoint OK; check-jax 0 errors + equivalence goldens 12/12. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…e self-contained run yaml (#873) Kill the vestigial wrapper/`torch_config:`/`configs/torch/` indirection. A JAX run config is now ONE self-contained file: the canonical `param_decomp_config` experiment schema plus the run-instance fields it now also carries. - `param_decomp_config`: `runtime.remat_recon_forwards` (RuntimeConfig); top-level `run_name`/`run_id`/`out_dir` (ExperimentConfig, run_id/out_dir None pre-launch); `wandb.group`/`wandb.tags` (WandbConfig). - `jax_single_pool.config`: `load_wrapper`/`_build_from_schema`/`_wandb_group_tags` → `load_config`/`build_from_schema`; build fns read run-instance fields off the schema (no per-call args); run-dir pinning collapses to a single `config.yaml`. - `pd-jax-lm`: validates the single config (structural TMS/ResidMLP/LM dispatch), stamps run_id/out_dir/wandb-group/tags into the workspace copy. - `is_jax_run`: detects orbax `ckpts/` beside `config.yaml` (the `torch_config:` key marker is gone); JaxPDAdapter/slow_eval read `config.yaml`. - Migrated all 9 configs to single files; deleted `configs/torch/`, the `_from_torch` wrappers, and `param_decomp_config/jax_wrapper.py`. Added missing `weights_dtype: bfloat16` to pile_ppgd_bsc (pre-existing latent refusal; bf16-only target). Pure config-plumbing — no training-semantics change. Equivalence goldens stay green. No legacy/dual-format shim; old saved runs' two-file pinned configs need manual migration if a consumer ever reads one. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
…(SPEC S33) (#874) A fresh run can now initialize its trained decomposition (V/U + ci_fn) from a PARENT checkpoint and fine-tune under a DIFFERENT config. When the run's own ckpts/ is empty and cfg.resume_provenance is set, the LM train() loads the parent's ckpts/<parent_step> onto the fresh reference TrainState, keeps only the components + ci_fn, and starts a clean schedule from step 0 (fresh optimizer / sources, no faith warmup). A subsequent SLURM requeue (own ckpts/ non-empty) resumes from its own dir and ignores provenance. - checkpoint.py::init_from_parent: restore parent onto reference, keep V/U + ci_fn, reset step to 0; assert the parent step exists. - run.py::assert_finetune_structural_compat: read the parent's pinned config.yaml and assert matching sites (names + C) + ci-fn arch before the restore — a changed C/sites/ci-fn arch is not a fine-tune (fail-fast, readable diff). - config.py: ExperimentConfig carries resume_provenance; the three builders propagate it from the canonical schema. - run.py: TMS/ResidMLP refuse resume_provenance (LM-only). - SPEC S33 + experiment.py ResumeProvenance docstring (dropped dead init_pd_run ref) + jax CLAUDE.md training-pipeline fine-tune stanza. Non-resume training semantics unchanged (equivalence goldens bit-identical). Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Cache compiled XLA executables to a shared-FS dir reused across runs, requeues, and future launches at the same config+topology. The ~24-min chunkwise-step compile is keyed by HLO + backend + topology + jax/xla version, so a matching re-compile loads from disk in seconds. Set in run.py::main after init_distributed (the write gate reads the distributed state) and before the first compile, so it covers direct jsp-train too — not just pd-jax-lm. Cache dir is a SIBLING of runs/ ($PARAM_DECOMP_OUT_DIR/xla_compilation_cache), shared across all runs and all 8N ranks. Threshold 60s so only the big compiles cache. Multi-host safe on jax 0.10.1: jax gates the cache WRITE on process_id == 0 (compiler.py), so all ranks read but only rank 0 writes — no shared-FS race. Verified a cross-process cache HIT on CPU. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Tracks: orphaned eval metrics (UVPlots/PermutedCIPlots/general-IdentityCIError recomputed nowhere since the torch offline path retired), deferred #10 (torch→jax run adapter) + app re-add, pretrain as a TRANSITIONAL torch island (rewrite in JAX when next needed → fully torch-free), and the imp-min token-count reparam. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
… torch-shed batch Points human review at the goldens-blind, silent-corruption-risk surface (ckpt migration remap, the build_target/topology torch-free swap, fine-tune partial-state load, the is_jax_run discriminator) and marks the trainer-semantics PRs as skim-able (equivalence goldens prove the training math is bit-identical). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The repo is now torch-free except `nano_param_decomp/run.py` — the standalone
single-file VPD reference impl for paper readers (self-contained, no lab/pretrain
deps, excluded from `make type`). Training has been JAX (`jsp-train` via `pd-jax-lm`)
since the torch trainer retirement; this sheds the remaining torch consumer/infra
layer. Zero `import torch` in `param_decomp` / `param_decomp_lab` / `param_decomp_jax`.
Deleted entirely:
- `param_decomp_lab/experiments/lm/pretrain/` (torch model defs, train loop,
`pd-pretrain` CLI, run_info, configs) — pretraining will be reimplemented in JAX
when next needed; the trainer loads target weights from the on-disk cache via its
own torch-free loaders, never through pretrain code.
- `param_decomp/{distributed,decomposition_targets}.py`,
`param_decomp_lab/{distributed,seed}.py`,
`param_decomp_lab/infra/{ddp_launch,wandb_tensor_info}.py`,
`param_decomp_lab/toy_models/`, `param_decomp_lab/topology/topology.py`,
`param_decomp_lab/experiments/lm/run.py` (the torch `build_target` bridge) —
all dead after de-torching their sole consumers.
- `jax_single_pool/tools/convert_llama_simple_mlp_checkpoint.py` (one-off torch
checkpoint->safetensors converter; existing caches already converted).
- `nano_param_decomp/{pile_4L,simplestories_2L}.py` — model-wiring entry points that
imported the deleted torch pretrain archs (now broken); `run.py` (the method) stays.
- Four tests covering deleted modules.
De-torched (relocated metadata path):
- `JaxPDAdapter` now derives target topology (n_blocks / vocab / per-site (name, C))
from the new torch-free `jax_single_pool.load_run.run_metadata` (config + pretrain
cache, no orbax restore) + the torch-free `path_schema_for_model_type`. No torch
model construction.
- `experiments/lm/data.py` keeps only `tokenize_and_concatenate` (numpy, for the
offline prestage tool); the torch DataLoader machinery is gone.
- `infra/run_files.py` drops the dead `save_file`/torch.save path.
Tooling: dropped torch + the pytorch-cu128 index from the root pyproject/lock; the
two-stack conflict is gone; the main venv bridges CPU jax + beartype + editable
`param_decomp_jax` (`make install-dev`). Docs updated (root/jax/experiments CLAUDE.md,
MIGRATION_HOLES, READMEs).
Validation: zero-torch grep empty; trainer still builds+loads the pile LlamaSimpleMLP
target from cache (CPU); `make type` 0 errors; lab+core 209 passed; `make check-jax`
0 errors; JAX suite 193 passed/2 skipped; equivalence goldens 12 passed (bit-identical
— training semantics unchanged).
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Adds tools/verify_c49k_migration.py, a streaming leaf-by-leaf verifier that proves the migrated 175k checkpoint (p-bd3cd4d4) reproduces the frozen-clone source (jax-l18-C49k-200k) bit-for-bit under migrate_c49k_checkpoint.py's remap, closing the #1 correctness gap in MIGRATION_REVIEW.md (structure was verified, value was not). The comparison table is the inverse of the migration's remap, built from that tool's own constants (KIND_TO_SITE_SUFFIX, SOURCE_STATE_KEY) so the mapping under test and the mapping applied are the same. Each leaf pair is restored single-device on CPU one at a time (all other leaves PLACEHOLDER, never read), so peak RAM is ~2x one V/U leaf (~1.6 GB) and the 47 GB trees never coexist. Asserts 1:1 coverage of both trees, then np.array_equal per leaf (squeezing the legacy leading singleton on the 6 component V/U leaves). Run verdict: all 144 leaves bit-identical -> migration VALUE-EXACT. The g/u/d->gate/up/down mapping and V->[0]/U->[1] ordering are confirmed by the asymmetric down_proj shapes (d_in=14336), which would mismatch under any swap or mismap. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Top-line
Migrates VPD fully to JAX — train and analyze in one framework — and retires the torch stack to a git-tagged oracle. Squash-merges to
mainonce, at the end. Net: ~430 commits, +24k / −50k LOC (mostly deletion).Key decisions
torch-oracle/torch-oracle-npool; JAX conforms to it (SPEC.mdis the normative contract, numeric seams default to matching torch, goldens prove it).DecomposedModelviaopen_jax_run) — harvest, clustering, autointerp, slow/offline eval, app. The JAX→torch export bridge is dead.DecomposedModel— orderedsites+ pure fnsclean_output/site_inputs/masked_output/weight_deltas/masked_site_outputs— generic over input/output/recon-loss with[B,T,d]as the fixed waist.plan × mask-source strategy(make_plan+ chunking helpers); loss is KL on final logits. Hidden-acts recon is a separate eval diagnostic over themasked_site_outputsseam (amends SPEC S31), not a recon-grid variant.# type: ignore/Any/castwithout sign-off;make check-jaxgated in pre-commit; fail-fast, types-first.Status — functionally complete; green
JAX trainer; all consumer ports; read-only app; dropped-feature deletion;
DecomposedLM → DecomposedModelrename; hidden-acts eval port; llama8b loader; type-debt → 0 + pre-commit gate; a code-review + fix pass (#854/#855/#856) and a first dead-trainer deletion (#857). Suites green: ~415 lab + ~166 jax at 1 and 4 devices;make type/check-jaxclean; torch↔JAX per-term equivalence + stacked-parity trajectory goldens pass bit-unmodified; validated end-to-end on SimpleMLP pile runp-761bc061.Remaining before main-merge (each gated — see commit history / memory)
offline_eval.py/pd-offline-eval/jsp-export; rewire_submit_offline_eval→jsp-slow-eval). Gated on parity-validating JAX slow-eval vs torch on a real llama8b run — no current-format llama8b run is loadable, so this needs a fresh run.param_decomp/core deletion (metrics/tree,train_step, …). Gated on the above (the live offline-eval path still imports them). Bridge/capstone surface stays.async_eval(in-loop autointerp) decision.Reading guide
param_decomp_jax/jax_single_pool/SPEC.md→lm.py→recon.py→train.py.TRANSITION.md= the settled plan;LOSS_PARITY_DESIGN.md= recon unification. VPD paper rides separately as #562.🤖 Generated with Claude Code