Group all same-shape parameters into a single AsyncTask per shape group#28
Open
alint77 wants to merge 1 commit intomicrosoft:mainfrom
Open
Group all same-shape parameters into a single AsyncTask per shape group#28alint77 wants to merge 1 commit intomicrosoft:mainfrom
alint77 wants to merge 1 commit intomicrosoft:mainfrom
Conversation
Contributor
Author
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
When profiling NorMuon with FSDP2 on a 12-layer/768-hidden model across 2 GPUs, I noticed that optimizer.step was dominated by GPU idle time between dozens of small, sequential communication rounds. Each world_size-sized batch of same-shape parameters triggered its own pair of all-to-all calls and Newton-Schulz iteration, resulting in ~36 separate rounds per step with visible gaps on the Chrome trace between each one.
This PR collapses all same-shape parameters into a single "mega-batch" per shape group, bringing that down to ~3 rounds (one per unique weight shape in a standard transformer).
What changed
Mega-batched communication (normuon_update_megabatch_async)
Instead of processing world_size matrices at a time through all-to-all → Newton-Schulz → all-to-all, we now:
Stack all N same-shape local shards into 3D tensors
Do one all-to-all to redistribute the full stack
Run Newton-Schulz on a [N/world_size, rows, cols] batch (already supported by the existing NS implementations since they use dim=(-2,-1) norms and @ broadcasting)
Do one all-to-all back
The non-sharded (DDP-style) path gets the same treatment with all-gather, and single-GPU also benefits from batched NS.
Stacked normalization kernel (normuon_normalization_stacked)
The original normuon_normalization operated on a List[Tensor] using torch.foreach* ops. With 48 attention weight matrices, that's 48-element foreach calls — each one launching separate kernels per tensor. The new version stacks everything into a single [N, rows, cols] tensor and does plain tensor ops, which torch.compile fuses into far fewer kernels.
Refactored _create_normuon_tasks
Extracted _get_shard_info as a helper method (was duplicated inline for every batch). The task creation now groups parameters by (shape, sharding, dtype) and yields one AsyncTask per group rather than per world_size-chunk.
Backward compatibility
The old normuon_update_batch_async and normuon_normalization are still present and used for the batch-sharded 3D tensor edge case
The async yield points are preserved at the same communication boundaries, so the AsyncRuntime overlap behavior is unchanged
All existing tests pass
Expected impact
For a typical transformer with 3 distinct weight shapes:
Communication rounds: O(num_params / world_size) → O(num_shapes), e.g. 36 → 3


Kernel launches: foreach over N-element lists → single stacked tensor ops
Newton-Schulz: N separate 2D calls → one batched 3D call with better occupancy
The actual speedup will depend on model size and world_size — larger models with more layers benefit more since there are more same-shape matrices to batch together.
BEOFRE
AFTER
Also the loss curve matches perfectly.