Skip to content

Group all same-shape parameters into a single AsyncTask per shape group#28

Open
alint77 wants to merge 1 commit intomicrosoft:mainfrom
alint77:dev/megabatching
Open

Group all same-shape parameters into a single AsyncTask per shape group#28
alint77 wants to merge 1 commit intomicrosoft:mainfrom
alint77:dev/megabatching

Conversation

@alint77
Copy link
Contributor

@alint77 alint77 commented Feb 23, 2026

When profiling NorMuon with FSDP2 on a 12-layer/768-hidden model across 2 GPUs, I noticed that optimizer.step was dominated by GPU idle time between dozens of small, sequential communication rounds. Each world_size-sized batch of same-shape parameters triggered its own pair of all-to-all calls and Newton-Schulz iteration, resulting in ~36 separate rounds per step with visible gaps on the Chrome trace between each one.

This PR collapses all same-shape parameters into a single "mega-batch" per shape group, bringing that down to ~3 rounds (one per unique weight shape in a standard transformer).

What changed
Mega-batched communication (normuon_update_megabatch_async)
Instead of processing world_size matrices at a time through all-to-all → Newton-Schulz → all-to-all, we now:

Stack all N same-shape local shards into 3D tensors
Do one all-to-all to redistribute the full stack
Run Newton-Schulz on a [N/world_size, rows, cols] batch (already supported by the existing NS implementations since they use dim=(-2,-1) norms and @ broadcasting)
Do one all-to-all back
The non-sharded (DDP-style) path gets the same treatment with all-gather, and single-GPU also benefits from batched NS.

Stacked normalization kernel (normuon_normalization_stacked)
The original normuon_normalization operated on a List[Tensor] using torch.foreach* ops. With 48 attention weight matrices, that's 48-element foreach calls — each one launching separate kernels per tensor. The new version stacks everything into a single [N, rows, cols] tensor and does plain tensor ops, which torch.compile fuses into far fewer kernels.

Refactored _create_normuon_tasks
Extracted _get_shard_info as a helper method (was duplicated inline for every batch). The task creation now groups parameters by (shape, sharding, dtype) and yields one AsyncTask per group rather than per world_size-chunk.

Backward compatibility
The old normuon_update_batch_async and normuon_normalization are still present and used for the batch-sharded 3D tensor edge case
The async yield points are preserved at the same communication boundaries, so the AsyncRuntime overlap behavior is unchanged
All existing tests pass
Expected impact
For a typical transformer with 3 distinct weight shapes:

Communication rounds: O(num_params / world_size) → O(num_shapes), e.g. 36 → 3
Kernel launches: foreach over N-element lists → single stacked tensor ops
Newton-Schulz: N separate 2D calls → one batched 3D call with better occupancy
The actual speedup will depend on model size and world_size — larger models with more layers benefit more since there are more same-shape matrices to batch together.
BEOFRE
Screenshot 2026-02-23 at 21 13 34
AFTER
Screenshot 2026-02-23 at 21 09 59

Also the loss curve matches perfectly.

@alint77
Copy link
Contributor Author

alint77 commented Feb 23, 2026

Metric OLD NEW Speedup
Optimizer.step total GPU time 128ms 18ms 7x
Optimizer.step CPU time per step 78ms 12ms 6x
NCCL all-to-all calls/step 100 6 17x fewer
Total NCCL events/step 172 8 21x fewer
CPU step time 194ms 171ms -12%
GPU step time 164ms 164ms same

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant