Skip to content

Fix non-scalar input amax in preprocess_linear_fusion for MoE export#1264

Open
AEON-7 wants to merge 1 commit intoNVIDIA:mainfrom
AEON-7:aeon7/fix-nonscalar-input-amax-moe-export
Open

Fix non-scalar input amax in preprocess_linear_fusion for MoE export#1264
AEON-7 wants to merge 1 commit intoNVIDIA:mainfrom
AEON-7:aeon7/fix-nonscalar-input-amax-moe-export

Conversation

@AEON-7
Copy link
Copy Markdown

@AEON-7 AEON-7 commented Apr 15, 2026

What

modelopt/torch/export/quant_utils.py::preprocess_linear_fusion unconditionally asserts:

assert modules[0].input_quantizer.amax.numel() == 1, (
    "Only support scalar input quant amax"
)

This breaks NVFP4 quantization for models whose MoE experts are decomposed into per-expert gate_proj/up_proj/down_proj nn.Linear modules (the standard pattern for HuggingFace-compatible export). NVFP4's activation quantizer is per-channel, so input_quantizer.amax is a vector — not a scalar — and the assertion trips immediately on the first expert during export_hf_checkpoint().

How to reproduce

Any MoE model where the modelopt plugin decomposes fused-expert storage (e.g. Gemma 4's [E, 2I, H] + [E, H, I]) into per-expert nn.Linear modules, quantized with NVFP4_AWQ_FULL_CFG. On SuperGemma4 26B (128 experts, 30 layers) the assertion fires during the requantize_resmooth_fused_llm_layers pass, immediately after 2h 24min of successful calibration — calibration state is in-memory only, so all work is lost.

The fix

Branch on amax.numel() == 1:

  • Scalar path (unchanged): torch.max(torch.stack(...)) — behaviour identical to today for dense FP8/INT8.
  • Non-scalar path (new): torch.stack(...).amax(dim=0) — elementwise max across the stacked per-channel amax tensors.

Why this is numerically safe

The modules being fused here (e.g. gate_proj and up_proj of a single expert) consume the same input tensor by construction — modelopt.torch.export.unified_export_hf._fuse_shared_input_modules groups them precisely because they share an input. Their per-channel input amax tensors are therefore identical (up to float accumulation noise), and elementwise max is a no-op. If they ever differ for numerical reasons, elementwise max is the correct unification rule — exactly analogous to the scalar max the existing code uses.

The scalar path is untouched, so dense models and FP8/INT8 MoE paths are unchanged.

Validation

End-to-end on SuperGemma4 26B (Gemma 4 MoE, 128 experts, per-expert-decomposed plugin) with NVFP4_AWQ_FULL_CFG:

  • Before: AssertionError: Only support scalar input quant amax on first expert during export, after 2h 24min of successful calibration.
  • After: preprocess_linear_fusion completes; export produces a valid NVFP4 checkpoint that loads + generates coherent output.

The resulting quantized model ships at AEON-7/supergemma4-26b-abliterated-multimodal-nvfp4.

Follow-up

A companion fix for NVFP4QTensor.get_activation_scaling_factor (handling zero-amax channels from MoE routing sparsity) is coming in a separate PR — it's an orthogonal bug on the same export path.

Summary by CodeRabbit

  • New Features
    • Quantization now handles both scalar and non-scalar activation ranges during linear-layer fusion. This enables per-channel and per-expert activation scaling to be unified correctly across fused modules, improving quantized model behavior and consistency during optimization and inference.

@AEON-7 AEON-7 requested a review from a team as a code owner April 15, 2026 04:25
@AEON-7 AEON-7 requested a review from Edwardf0t1 April 15, 2026 04:25
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 15, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: aab1e9b3-373a-4415-bafb-b669695114aa

📥 Commits

Reviewing files that changed from the base of the PR and between 3fcc5a7 and 6929ecb.

📒 Files selected for processing (1)
  • modelopt/torch/export/quant_utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/quant_utils.py

📝 Walkthrough

Walkthrough

The preprocess_linear_fusion function in modelopt/torch/export/quant_utils.py now handles both scalar and non-scalar activation quantization maximum (amax) values by replacing a scalar-only assertion with conditional branching that aggregates amax either via scalar max or elementwise max across stacked tensors.

Changes

Cohort / File(s) Summary
Activation amax unification
modelopt/torch/export/quant_utils.py
Replaced scalar-only amax assertion in preprocess_linear_fusion with conditional handling: when amax is scalar, compute unified value with torch.max; when non-scalar, compute elementwise unified amax via torch.stack(...).amax(dim=0) and assign back to each module's input_quantizer.amax.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately and concisely describes the main change: fixing non-scalar input amax handling in preprocess_linear_fusion specifically for MoE export scenarios.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. Changes consist of algorithmic modifications to handle scalar and non-scalar amax tensors using standard PyTorch operations only.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

preprocess_linear_fusion unconditionally asserts
`modules[0].input_quantizer.amax.numel() == 1`, which breaks for NVFP4
quantization when the model has per-expert-decomposed MoE linears
(gate_proj/up_proj pairs per expert). NVFP4's per-channel input quantizer
produces a vector amax, not a scalar, so the assertion trips immediately
on the first expert during `export_hf_checkpoint()`.

Root cause: the function was written assuming fused linears have per-tensor
scalar input amax. That's true for dense FP8/INT8 paths but false for
NVFP4's per-channel activation statistics, which modelopt's own
NVFP4_AWQ_FULL_CFG produces.

This change:
- Keeps the existing scalar-amax path (dense + FP8/INT8 unchanged)
- Adds a non-scalar path using elementwise max (`.amax(dim=0)`) across the
  stacked per-channel amax tensors of the modules being fused

Numerical correctness for the MoE case: the modules being fused here
(e.g. gate_proj and up_proj of one expert) consume the *same* input
tensor by construction, so their per-channel input amax tensors are
identical. Elementwise max is therefore a no-op, and is the correct
unification rule if they ever differ due to floating-point accumulation.

Validated end-to-end on SuperGemma4 26B (128-expert MoE) with
NVFP4_AWQ_FULL_CFG; export now completes and the serialized checkpoint
loads + generates correctly. Before: export failed with
`AssertionError: Only support scalar input quant amax` after 2h 24min of
successful calibration.

Signed-off-by: AEON-7 <m2vgz48wpp@privaterelay.appleid.com>
@AEON-7 AEON-7 force-pushed the aeon7/fix-nonscalar-input-amax-moe-export branch from 3fcc5a7 to 6929ecb Compare April 15, 2026 04:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant