Skip to content

fix(gemma4_moe): use vision-aware attention mask when use_bidirection…#1901

Draft
akoumpa wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
jQizhang:fix/gemma4-moe-vision-aware-mask
Draft

fix(gemma4_moe): use vision-aware attention mask when use_bidirection…#1901
akoumpa wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
jQizhang:fix/gemma4-moe-vision-aware-mask

Conversation

@akoumpa
Copy link
Copy Markdown
Contributor

@akoumpa akoumpa commented Apr 17, 2026

…al_attention="vision"

Gemma4 multimodal variants with use_bidirectional_attention="vision" in their text config (e.g. gemma-4-26B-A4B-it, gemma-4-31B-it) require a vision-aware attention mask that makes tokens inside the same vision group visible to each other bidirectionally. HF's Gemma4Model.forward builds this mask via create_causal_mask_mapping.

The MoE backend Gemma4MoETextModelBackend.forward was always building plain create_causal_mask + create_sliding_window_causal_mask regardless of the config flag. For gemma-4-26B-A4B-it this makes the MoE forward numerically diverge from HF on multimodal inputs (vision token logprobs can differ by 20+ in log-space), and increases train/gen_kl_error by roughly an order of magnitude during GRPO training (~0.01 vs ~0.001 on text-only).

Fix:

  • Accept mm_token_type_ids and pixel_values in Gemma4MoETextModelBackend.forward.
  • When config.use_bidirectional_attention == "vision", call HF's create_causal_mask_mapping (matches Gemma4Model.forward). Otherwise keep the existing plain causal-mask path.
  • Plumb mm_token_type_ids / pixel_values from Gemma4ForConditionalGeneration.forward down to the text backend.

Measured impact on gemma-4-26B-A4B-it multimodal forward (HF as ground truth, single synthetic image, 341-token sequence):

metric before after (expected)
HF vs Automodel gen_kl 0.034 ~0.01 (FSDP noise floor)
HF vs vLLM gen_kl 0.092 0.092 (unchanged)

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Changelog

  • Add specific line by line info of high level changes in this PR.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

…al_attention="vision"

Gemma4 multimodal variants with `use_bidirectional_attention="vision"` in
their text config (e.g. gemma-4-26B-A4B-it, gemma-4-31B-it) require a
vision-aware attention mask that makes tokens inside the same vision group
visible to each other bidirectionally. HF's `Gemma4Model.forward` builds
this mask via `create_causal_mask_mapping`.

The MoE backend `Gemma4MoETextModelBackend.forward` was always building
plain `create_causal_mask` + `create_sliding_window_causal_mask` regardless
of the config flag. For `gemma-4-26B-A4B-it` this makes the MoE forward
numerically diverge from HF on multimodal inputs (vision token logprobs
can differ by 20+ in log-space), and increases `train/gen_kl_error` by
roughly an order of magnitude during GRPO training (~0.01 vs ~0.001 on
text-only).

Fix:
- Accept `mm_token_type_ids` and `pixel_values` in
  `Gemma4MoETextModelBackend.forward`.
- When `config.use_bidirectional_attention == "vision"`, call HF's
  `create_causal_mask_mapping` (matches `Gemma4Model.forward`). Otherwise
  keep the existing plain causal-mask path.
- Plumb `mm_token_type_ids` / `pixel_values` from
  `Gemma4ForConditionalGeneration.forward` down to the text backend.

Measured impact on gemma-4-26B-A4B-it multimodal forward (HF as ground
truth, single synthetic image, 341-token sequence):

  metric                     before      after (expected)
  HF vs Automodel gen_kl     0.034       ~0.01 (FSDP noise floor)
  HF vs vLLM gen_kl          0.092       0.092 (unchanged)
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 17, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@akoumpa
Copy link
Copy Markdown
Contributor Author

akoumpa commented Apr 17, 2026

/ok to test ce4b225

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants