Migrate Decoder to NNX#2831
Conversation
175fc07 to
48b0d7e
Compare
e6a172f to
ba9b74d
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the change! Unfortunately this part involves all models using scan/unscan, and some patterns are different. Please be careful and include comprehensive test suit before the merge.
Sure, will do, and thanks for the heads-up! |
505b519 to
adf5a13
Compare
116fb81 to
c653a9c
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
768ff18 to
5b481be
Compare
e4138b3 to
c21ab5d
Compare
bvandermoon
left a comment
There was a problem hiding this comment.
@hsuan-lun-chiang thanks for the testing in the PR description. For the memory differences mentioned in the doc, could you collect before/after profiles so we can confirm the HLOs are the same?
bvandermoon
left a comment
There was a problem hiding this comment.
Why are these models being skipped for now? I think it's fine just want to understand if there is additional context
Note: The following models are currently not supported:
DeepSeek
Gemma3
Llama4
Hi @bvandermoon , |
No major blockers; these models require a different handling logic compared to other models (as @RissyRan commented earlier). We are addressing them in a separate workstream to ensure they are integrated correctly. |
Discussed offline, thanks @hsuan-lun-chiang. Looks like the profiles have different HLOs. Could you also collect profiles for the other ones in the first section of your doc? We can confirm if it's just GPT3 or something happening everywhere |
Sure! I'm collecting those profiles and will leave a comment to let you know where I documented them (in the testing page, section 1). Thank you! |
Thanks @hsuan-lun-chiang. For Llama2-7b and Llama3.1-70B, the before/after memory usage are almost identical. But for some reason the HLOs are not identical. Can you confirm the before/after runs were run in the same env with the same JAX/libtu version, etc.? |
Thank you @bvandermoon for checking this. Yes, I ran the before/after tests by only switching branches like: So both the main branch code and our code were tested in the same environment. |
Adding nnx_decoders.py in parallel with decoders.py
1. Dup and modifiy decoders.py on new file nnx_decoders.py
2. add new config pure_nnx_decoder to control if model will use NNXDecoder, default false for now
3. modify relative code to accomodate the change
4. add/modify unit test
Description
Migrate the Transformer decoder layer into NNX.
Note: The following models are currently not supported:
Support for these models will be added in a follow-up PR.
Strategy:
A
pure_nnx_decoderflag is added to control whether NNX or Linen decoder shall be used.Initial migration doesn't include the pipeline NNX support.
Tests
Conducted these tests. Details in the GDoc file
TODOs:
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.