Skip to content

Add MAGIC CLI with runtime DTensor double-backward patch#174

Open
luciaquirke wants to merge 17 commits intomagicfrom
magic-dtensor-patch
Open

Add MAGIC CLI with runtime DTensor double-backward patch#174
luciaquirke wants to merge 17 commits intomagicfrom
magic-dtensor-patch

Conversation

@luciaquirke
Copy link
Collaborator

@luciaquirke luciaquirke commented Mar 6, 2026

My Changes Summary

  • Add MAGIC CLI. Brings in DataConfig to start giving us the features the other Bergson CLIs have. Will likely need more work over time to bring to feature parity.
  • The query reduction currently only works if the number of queries divides nicely across the world size. I have a fix for this in a follow-up PR but this current one was getting ungainly.

Things which could go in this PR or a follow-up

  • Save config jsons to disk
  • Use gradcheck to test mixed precision? Or any other ideas for correctness tests?
  • no_dist is only available on later torch versions (2.9ish+)
  • support reading configuration from yaml so it's easier to document configurations
  • per-token scores (PR exists)
  • padding for batches that don't fit nicely into world size (PR exists)

Claude Changes Summary

  • Runtime DTensor patch (bergson/magic_patch.py): Monkey-patches Redistribute.backward and _ToTorchTensor.backward at runtime to make FSDP redistribution twice-differentiable (Add support for twice-differentiable DTensor redistribution pytorch/pytorch#160509). Replaces the old magic_wmdp_setup.sh that modified torch source files on disk. Idempotent — call apply_dtensor_patch() before any DTensor double-backward operations.

Test plan

  • End-to-end MAGIC attribution run on GPU cluster

@luciaquirke luciaquirke changed the title Add runtime DTensor double-backward patch and per-token weights Add runtime DTensor double-backward patch and per-token weights; Add MAGIC CLI Mar 7, 2026


@dataclass
class DoubleBackwardConfig:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the original RunConfig but it uses DataConfig for the query rather than the first item of the dataset, save_dir renamed to run_path, DataConfig for training data

"""Random seed for subset permutation."""


def compute_query_gradients(
Copy link
Collaborator Author

@luciaquirke luciaquirke Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could technically use build here with all the Trackstar bells and whistles turned off but this seems more readable. Technically not DRY. Currently lacks TRL-style tokenization/masking support

@luciaquirke luciaquirke force-pushed the magic-dtensor-patch branch 3 times, most recently from 77497a6 to 1be9172 Compare March 7, 2026 00:27
…ight support

- Add bergson/magic_patch.py: runtime monkey-patch for twice-differentiable
  DTensor redistribution (pytorch/pytorch#160509), replacing the old
  magic_wmdp_setup.sh that modified torch source files on disk
- Add per_token mode to DataStream for [n_examples, max_length] weight tensors
- Support 2D [B, T] per-token weights in weighted_causal_lm_ce
- Fix backward weight_grads accumulation when autograd returns None
@luciaquirke luciaquirke force-pushed the magic-dtensor-patch branch from 1be9172 to 97fe18f Compare March 7, 2026 00:28
@luciaquirke luciaquirke requested a review from norabelrose March 7, 2026 00:30
luciaquirke and others added 3 commits March 7, 2026 12:03
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The weight gradient from autograd.grad should always be a tensor since
data.weights participates in the computation graph via weighted_causal_lm_ce.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@luciaquirke luciaquirke removed the request for review from norabelrose March 7, 2026 01:40
luciaquirke and others added 6 commits March 7, 2026 16:33
Multiple concurrent DCP async_save calls each create their own Gloo
process group. With consecutive saves at steps 20-24 (last_start logic),
up to 5 saves were in-flight simultaneously. Background threads from these
saves may call distributed operations that conflict, causing all ranks to
deadlock in fut.result() until the NCCL watchdog times out.

Limit to one concurrent save at a time: wait for the previous save to
complete before starting the next one. Each save still overlaps with at
least one training step, so async I/O benefit is preserved.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Raises a clear ValueError at init time when the dataset doesn't have
enough examples for the requested number of batches, instead of crashing
with an IndexError mid-training.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
PyTorch's Future.result() waits for done callbacks to complete before
returning. The destroy_process_group callback was invoked from DCP's
background thread after each save, but destroy_process_group may do
a barrier on the Gloo group. Since ranks complete their I/O at different
times, the fast rank would deadlock waiting for the slow rank to also
call destroy_process_group, while the slow rank was still in fut.result().

DCP holds its own reference to the process group, keeping it alive for
the duration of the background I/O. GC will clean it up afterwards.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@luciaquirke luciaquirke changed the title Add runtime DTensor double-backward patch and per-token weights; Add MAGIC CLI Add MAGIC CLI, runtime DTensor double-backward patch, and per-token scores Mar 7, 2026
@luciaquirke luciaquirke requested a review from norabelrose March 7, 2026 07:52
luciaquirke and others added 2 commits March 7, 2026 07:55
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Strip per_token parameter from DataStream and 2D weight path from
weighted_causal_lm_ce to keep the merge scope minimal. The per-token
code is preserved on the magic-per-token branch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@luciaquirke luciaquirke changed the title Add MAGIC CLI, runtime DTensor double-backward patch, and per-token scores Add MAGIC CLI and runtime DTensor double-backward patch Mar 7, 2026
luciaquirke and others added 2 commits March 7, 2026 10:45
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@luciaquirke luciaquirke changed the title Add MAGIC CLI and runtime DTensor double-backward patch Add MAGIC CLI with runtime DTensor double-backward patch Mar 7, 2026
luciaquirke and others added 2 commits March 7, 2026 11:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant