Skip to content

Migrate Decoder to NNX#2831

Open
hsuan-lun-chiang wants to merge 1 commit intoAI-Hypercomputer:mainfrom
CIeNET-International:feat/Migrate-Decoder-to-NNX
Open

Migrate Decoder to NNX#2831
hsuan-lun-chiang wants to merge 1 commit intoAI-Hypercomputer:mainfrom
CIeNET-International:feat/Migrate-Decoder-to-NNX

Conversation

@hsuan-lun-chiang
Copy link
Collaborator

@hsuan-lun-chiang hsuan-lun-chiang commented Dec 15, 2025

Description

Migrate the Transformer decoder layer into NNX.

Note: The following models are currently not supported:

  • DeepSeek
  • Gemma3
  • Llama4

Support for these models will be added in a follow-up PR.

Strategy:
A pure_nnx_decoder flag is added to control whether NNX or Linen decoder shall be used.
Initial migration doesn't include the pipeline NNX support.

Tests

Conducted these tests. Details in the GDoc file

  1. Test with different model and compare with Linen training
  2. Golden logits comparison
  3. Inference
  4. Checkpoint comparison (Including TreeStructure Comparison)
  5. Sharding comparison

TODOs:

  • NNX version of unit tests (future PRs)

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Decoder-to-NNX branch 4 times, most recently from 175fc07 to 48b0d7e Compare December 22, 2025 09:40
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Decoder-to-NNX branch 9 times, most recently from e6a172f to ba9b74d Compare January 7, 2026 12:01
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! Unfortunately this part involves all models using scan/unscan, and some patterns are different. Please be careful and include comprehensive test suit before the merge.

@hsuan-lun-chiang
Copy link
Collaborator Author

Thanks for the change! Unfortunately this part involves all models using scan/unscan, and some patterns are different. Please be careful and include comprehensive test suit before the merge.

Sure, will do, and thanks for the heads-up!

@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Decoder-to-NNX branch 4 times, most recently from 505b519 to adf5a13 Compare January 15, 2026 08:41
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Decoder-to-NNX branch 5 times, most recently from 116fb81 to c653a9c Compare January 20, 2026 06:41
@codecov
Copy link

codecov bot commented Jan 20, 2026

Codecov Report

❌ Patch coverage is 40.27149% with 264 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/nnx_decoders.py 40.28% 212 Missing and 43 partials ⚠️
src/maxtext/models/models.py 25.00% 9 Missing ⚠️

📢 Thoughts on this report? Let us know!

@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Decoder-to-NNX branch from 768ff18 to 5b481be Compare January 29, 2026 03:16
@hsuan-lun-chiang hsuan-lun-chiang force-pushed the feat/Migrate-Decoder-to-NNX branch 9 times, most recently from e4138b3 to c21ab5d Compare February 3, 2026 04:01
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hsuan-lun-chiang thanks for the testing in the PR description. For the memory differences mentioned in the doc, could you collect before/after profiles so we can confirm the HLOs are the same?

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these models being skipped for now? I think it's fine just want to understand if there is additional context

Note: The following models are currently not supported:
DeepSeek
Gemma3
Llama4

@hsuan-lun-chiang
Copy link
Collaborator Author

hsuan-lun-chiang commented Feb 5, 2026

@hsuan-lun-chiang thanks for the testing in the PR description. For the memory differences mentioned in the doc, could you collect before/after profiles so we can confirm the HLOs are the same?

Hi @bvandermoon ,
Yes, we already have the profiles (for Gpt3, where the memory difference was most significant)
Before (Linen)
After (NNX)
Thank you for checking the work!

@hsuan-lun-chiang
Copy link
Collaborator Author

Why are these models being skipped for now? I think it's fine just want to understand if there is additional context

Note: The following models are currently not supported:
DeepSeek
Gemma3
Llama4

No major blockers; these models require a different handling logic compared to other models (as @RissyRan commented earlier). We are addressing them in a separate workstream to ensure they are integrated correctly.

@bvandermoon
Copy link
Collaborator

@hsuan-lun-chiang thanks for the testing in the PR description. For the memory differences mentioned in the doc, could you collect before/after profiles so we can confirm the HLOs are the same?

Hi @bvandermoon , Yes, we already have the profiles (for Gpt3, where the memory difference was most significant) Before (Linen) After (NNX) Thank you for checking the work!

Discussed offline, thanks @hsuan-lun-chiang. Looks like the profiles have different HLOs. Could you also collect profiles for the other ones in the first section of your doc? We can confirm if it's just GPT3 or something happening everywhere

@hsuan-lun-chiang
Copy link
Collaborator Author

hsuan-lun-chiang commented Feb 6, 2026

Discussed offline, thanks @hsuan-lun-chiang. Looks like the profiles have different HLOs. Could you also collect profiles for the other ones in the first section of your doc? We can confirm if it's just GPT3 or something happening everywhere

Sure! I'm collecting those profiles and will leave a comment to let you know where I documented them (in the testing page, section 1). Thank you!

@bvandermoon
Copy link
Collaborator

Discussed offline, thanks @hsuan-lun-chiang. Looks like the profiles have different HLOs. Could you also collect profiles for the other ones in the first section of your doc? We can confirm if it's just GPT3 or something happening everywhere

Sure! I'm collecting those profiles and will leave a comment to let you know where I documented them (in the testing page, section 1). Thank you!

Thanks @hsuan-lun-chiang. For Llama2-7b and Llama3.1-70B, the before/after memory usage are almost identical. But for some reason the HLOs are not identical. Can you confirm the before/after runs were run in the same env with the same JAX/libtu version, etc.?

@hsuan-lun-chiang
Copy link
Collaborator Author

hsuan-lun-chiang commented Feb 9, 2026

Thanks @hsuan-lun-chiang. For Llama2-7b and Llama3.1-70B, the before/after memory usage are almost identical. But for some reason the HLOs are not identical. Can you confirm the before/after runs were run in the same env with the same JAX/libtu version, etc.?

Thank you @bvandermoon for checking this. Yes, I ran the before/after tests by only switching branches like:

git checkout main
python3 -m MaxText.train ...
git checkout Migrate-Decoder-to-NNX
python3 -m MaxText.train ... nnx_enable=True pure_nnx_decoder=True

So both the main branch code and our code were tested in the same environment.
I will double check the implementation of both models to see if I can find anything.

    Adding nnx_decoders.py in parallel with decoders.py

    1. Dup and modifiy decoders.py on new file nnx_decoders.py
    2. add new config pure_nnx_decoder to control if model will use NNXDecoder, default false for now
    3. modify relative code to accomodate the change
    4. add/modify unit test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants