Handle zero-amax per-channel activation scaling for MoE export#1265
Handle zero-amax per-channel activation scaling for MoE export#1265AEON-7 wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThe Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/quantization/qtensor/nvfp4_tensor.py`:
- Around line 199-211: Restrict the repair to exact zeros: change zero_mask to
use activation_scaling_factor == 0, then compute positive =
activation_scaling_factor[~zero_mask] and further filter positive =
positive[positive > 0] (so negatives are not considered recoverable); if
positive.numel() > 0 replace zeros with positive.min(), else if there are only
zeros (no negatives present) fall back to torch.full_like(..., 1e-8) to keep the
tensor valid, but if negatives exist leave activation_scaling_factor untouched
so the existing assert can catch the error. Ensure these updates are applied
around the activation_scaling_factor / zero_mask logic in nvfp4_tensor.py.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: a54c8117-c13f-42c0-86a5-b85150490b56
📒 Files selected for processing (1)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py
8b3a4eb to
7cb5851
Compare
NVFP4QTensor.get_activation_scaling_factor asserts `torch.all(activation_scaling_factor > 0)` but on MoE models some per-channel activation amax entries are exactly zero: routing sparsity means certain input slots on rarely-routed experts never receive any tokens during calibration, so their observed amax stays at initialization (zero). The derived scaling factor (`amax / (maxbound * 448)`) is then zero too, and the assertion trips during `export_hf_checkpoint()`. In practice this fires immediately after the (separate) fused-linear fusion step completes, on the first expert whose calibration-time coverage left even a single channel dark. With 128 experts and ~6% activation rate per expert per token, this is routine rather than exceptional. This change: - Detects exact-zero entries in the computed scaling factor tensor via `== 0` (not `<= 0`), so that negative entries — which would indicate a genuine upstream bug, not sparsity — remain untouched and continue to trip the existing positivity assertion rather than being silently masked. - Replaces the zero entries with the minimum strictly-positive value in the same tensor (elementwise `torch.where`), preserving the per-channel shape and the positivity invariant downstream code relies on. - Falls back to a small positive floor (1e-8) only when no positive entries exist (every channel in the tensor is zero). Why this is numerically safe: a zero amax channel means no activation was ever observed there during calibration. Any value flowing through that channel at inference time is therefore statistically near-zero relative to the observed distribution. Scaling that near-zero value by the "quietest live channel's" scaling factor quantizes it to near-zero and dequantizes it back to near-zero — the same end result as with a genuinely zero scale, minus the NaN/division hazards. Validated end-to-end on SuperGemma4 26B (128-expert Gemma 4 MoE) with `NVFP4_AWQ_FULL_CFG`: export completes, the serialized checkpoint loads into transformers via `mto.restore`, and sampled generation is semantically equivalent to the BF16 baseline on fact-recall, creative, and technical prompts. Signed-off-by: AEON-7 <m2vgz48wpp@privaterelay.appleid.com>
7cb5851 to
c6edb16
Compare
|
Good catch — you're right, Fixed in Diff: - zero_mask = activation_scaling_factor <= 0
+ zero_mask = activation_scaling_factor == 0
if zero_mask.any():
- positive = activation_scaling_factor[~zero_mask]
- if positive.numel() > 0:
- activation_scaling_factor = torch.where(
- zero_mask, positive.min(), activation_scaling_factor
- )
- else:
- activation_scaling_factor = torch.full_like(
- activation_scaling_factor, 1e-8
- )
+ positive = activation_scaling_factor[activation_scaling_factor > 0]
+ replacement = (
+ positive.min()
+ if positive.numel() > 0
+ else torch.tensor(
+ 1e-8,
+ device=activation_scaling_factor.device,
+ dtype=activation_scaling_factor.dtype,
+ )
+ )
+ activation_scaling_factor = torch.where(
+ zero_mask, replacement, activation_scaling_factor
+ )Thanks for the review. |
What
NVFP4QTensor.get_activation_scaling_factorasserts:On MoE models, some per-channel activation amax entries are exactly zero because routing sparsity leaves certain input slots on rarely-routed experts un-activated during calibration. The derived scaling factor (
amax / (maxbound * 448)) is then zero and the assertion trips.How to reproduce
Any MoE model with per-expert-decomposed linears quantized using
NVFP4_AWQ_FULL_CFG. On SuperGemma4 26B (128 experts, ~6% activation rate per expert per token), this fires on the first expert whose calibration-time coverage left even a single channel dark. It is the routine case, not the edge case.The fix
Detect zero entries in the computed
activation_scaling_factortensor and replace them with the minimum positive value in the same tensor viatorch.where. Fall back to a small positive floor (1e-8) for the pathological case where every channel in the tensor is zero (block entirely un-activated).Why this is numerically safe
A zero amax channel means no activation was ever observed there during calibration. Any value flowing through that channel at inference is therefore statistically near-zero relative to the observed distribution. Scaling that near-zero value by the "quietest live channel's" scaling factor quantizes it to near-zero and dequantizes back to near-zero — the same end result as a genuinely zero scale, minus the NaN/division hazards.
The assertion after the fix remains strict (
torch.all(scale > 0)), so downstream code that relies on the positivity invariant is unaffected.Validation
End-to-end on SuperGemma4 26B (Gemma 4 MoE, 128 experts, per-expert-decomposed plugin) with
NVFP4_AWQ_FULL_CFG:AssertionError: activation scaling factor tensor([...]) not positive.on a per-channel tensor whose printed head hides zeros in the...ellipsis.The resulting quantized model ships at AEON-7/supergemma4-26b-abliterated-multimodal-nvfp4.
Companion PR
Depends on / pairs with #1264 (non-scalar input amax in
preprocess_linear_fusion). Both are orthogonal bugs on the same NVFP4 + per-expert-MoE export path; this PR fixes the bug that fires after #1264's fix unblocks the fusion step.Summary by CodeRabbit