Skip to content

Add skip_update and magma support#27

Open
alint77 wants to merge 2 commits intomicrosoft:mainfrom
alint77:add/skipUpdate_magma
Open

Add skip_update and magma support#27
alint77 wants to merge 2 commits intomicrosoft:mainfrom
alint77:add/skipUpdate_magma

Conversation

@alint77
Copy link
Contributor

@alint77 alint77 commented Feb 18, 2026

Implements two stochastic update masking techniques from arxiv.org/abs/2602.15322 for both Muon and NorMuon.

SkipUpdate (skip_update_prob): at each step, each parameter matrix is independently kept with probability p or zeroed out with probability 1-p. Surviving updates are rescaled by 1/p to stay unbiased in expectation. Moment buffers always update densely regardless of the skip.

Magma (magma_tau): replaces the fixed 1/p rescaling with an adaptive EMA scale driven by momentum-gradient cosine similarity:

ẽ_t = sigmoid(cossim(μ_t_before, g_t) / τ)
s_t  = 0.9 * s_{t-1} + 0.1 * ẽ_t

The scale is intentionally biased (no 1/s_t correction), the paper found unbiased variants to be unstable. Bernoulli masking is still applied on top.

Both features are opt-in and off by default (None). For NorMuon, the mask is applied after the neuron-normalization step so both moment buffers (momentum and variance_neuron) always update densely, consistent with the paper's intent.

@alint77
Copy link
Contributor Author

alint77 commented Feb 18, 2026

@microsoft-github-policy-service agree

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant