Skip to content

vLLM fakequant export update for AWQ checkpoint#1242

Open
kinjalpatel27 wants to merge 19 commits intomainfrom
kinjal/vllm_fix_prequant_scale
Open

vLLM fakequant export update for AWQ checkpoint#1242
kinjalpatel27 wants to merge 19 commits intomainfrom
kinjal/vllm_fix_prequant_scale

Conversation

@kinjalpatel27
Copy link
Copy Markdown
Contributor

@kinjalpatel27 kinjalpatel27 commented Apr 13, 2026

What does this PR do?

Type of change: Bug

Enables end-to-end AWQ checkpoint export and reload in the vLLM fake-quant serving path (MODELOPT_STATE_PATH). Previously, the input_quantizer was using incorrect pre_quant_scale especially with grouped quantizers like qkv_proj, using simply the first input_quantizer.pre_quant_scale. This MR adds _resmooth_experts_for_export that non-mutatively averages pre_quant_scale across MoE experts and unifies input _amax, required because vLLM uses a single input quantizer per expert group. Adds merge_amax_tensors_for_group (element-wise max for same-shape, cat for GQA, scalar-max fallback) replacing the scalar-collapsing torch.stack().max() that dropped per-channel _amax structure.

Usage

# Export AWQ checkpoint from HF model
  from modelopt.torch.export.plugins.vllm_fakequant_hf import export_hf_vllm_fq_checkpoint
  export_hf_vllm_fq_checkpoint(model, export_dir="./awq_vllm_checkpoint")      

Testing

Step 1 — Export the quantized checkpoint:

python examples/llm_ptq/hf_ptq.py \
  --pyt_ckpt_path <MODEL_PATH> \                                                                                                                                                                             
  --recipe <AWQ_RECIPE> \
  --calib_size 512 \                                                                                                                                                                                         
  --export_path <EXPORT_DIR> \                                                                                                                                                                               
  --vllm_fakequant_export

This produces <EXPORT_DIR>/vllm_fq_modelopt_state.pth with the averaged per-expert
pre_quant_scale and unified _amax now included.

Step 2 — Serve via vLLM fakequant worker:

  MODELOPT_STATE_PATH=<EXPORT_DIR>/vllm_fq_modelopt_state.pth \
    python examples/vllm_serve/vllm_serve_fakequant.py \                                                                                                                                                       
      <EXPORT_DIR> --tensor-parallel-size <TP>   

Tested for quantization configurations:

FP8_DEFAULT_CFG
FP8_DEFAULT_CFG (input_q disabled)
INT8_SMOOTHQUANT_CFG
INT8_WEIGHT_ONLY_CFG
NVFP4_DEFAULT_CFG
NVFP4_AWQ_LITE_CFG
INT4_AWQ_CFG
NVFP4_AWQ_CFG
NVFP4_DEFAULT_CFG (input_q disabled)

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: N/A
  • Did you update Changelog?: N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added Nemotron-style MoE export support and group-aware AWQ resmoothing with optional requantization during export.
    • Improved handling for shared-input / expert groups and tensor-parallel sharding of pre-quantization scales.
  • Bug Fixes

    • Removed AWQ reload limitation from known issues; improved checkpoint validation and safer save/load behavior.
    • Better detection and handling of enabled weight-quantizers and clearer warnings for mismatched checkpoint keys.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 13, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 13, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Updates enable AWQ-related reload/resmoothing and TP-sharding support across vLLM serve examples and HF export plugins, including checkpoint loading changes, quantizer-state prefix remapping and merging, pre-quant-scale sharding, NemotronHMOE support, and export-time AWQ resmoothing and requantization.

Changes

Cohort / File(s) Summary
Docs
examples/vllm_serve/README.md
Removed the "AWQ reload is not supported yet" item from Known Problems.
Fakequant worker / checkpoint flow
examples/vllm_serve/fakequant_worker.py
Use safe_load(...) for checkpoint deserialization; added rank-0 validation comparing checkpoint _pre_quant_scale keys to model quantizers; call shard_pre_quant_scale_for_tp(model) after vLLM parallel-linear restoration; perform a dummy run before applying quantizer state; replace suffix check with is_weight_quantizer_state_key and raise RuntimeError when a weight quantizer stays enabled.
vLLM reload & TP sharding utils
examples/vllm_serve/vllm_reload_utils.py
Added logging and helpers is_weight_quantizer_state_key, merge_amax_tensors_for_group; replaced inline _amax merges with the helper; added prefix-remap inference and rewritten key remapping in convert_dict_to_vllm; ensure missing weight-quantizer state is forced _disabled=True; added TP-sharding helpers (_tp_concat_shard_dims, _narrow_tensor_to_tp_local_shard, _pqs_local_expected_shape, _expected_in_features_for_input_quantizer, shard_pre_quant_scale_for_tp); centralized tensor narrowing and consolidated warnings for extra checkpoint keys.
Export layer utils (MoE support)
modelopt/torch/export/layer_utils.py
Expanded docstring and examples; added NemotronHMOE detection in is_moe and get_expert_linear_names, mapping expert linears to ["up_proj", "down_proj"]; reformatted get_experts_list signature layout.
vLLM fakequant HF export plugin
modelopt/torch/export/plugins/vllm_fakequant_hf.py
Exported is_weight_quantizer_state_key and merge_amax_tensors_for_group; added AWQ resmoothing that averages _pre_quant_scale across grouped linears (including MoE experts), optionally merges _amax, applies patched quantizer-state overrides, and performs export-time requantization when weights changed; switched _disabled checks to is_enabled and use safe_save for final modelopt state.
Unified export helper rename
modelopt/torch/export/unified_export_hf.py
Renamed internal _collect_shared_input_modules(...)collect_shared_input_modules(...) and updated call sites (no behaviour changes).

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'vLLM fakequant export update for AWQ checkpoint' directly addresses the main objective: fixing AWQ checkpoint export and reload via a fakequant serving export path update.
Docstring Coverage ✅ Passed Docstring coverage is 84.85% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed PR changes adhere to security practices with safe deserialization, proper trust_remote_code configuration, and no unsafe patterns.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kinjal/vllm_fix_prequant_scale

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
examples/vllm_serve/fakequant_worker.py (2)

64-71: Adequate security justification for weights_only=False.

The inline comment correctly explains why weights_only=False is required (ModelOpt state contains metadata, nested dicts, dtypes that PyTorch's weights_only=True rejects) and appropriately warns users to only load trusted paths. This satisfies the coding guideline requirement for documented justification.

However, consider adding a note that this file format is internally-generated by export_hf_vllm_fq_checkpoint to strengthen the justification.

💡 Optional: Strengthen the security comment
         # map_location="cpu": load tensors on CPU so device ids in the file need not match this worker.
         # weights_only=False: ``vllm_fq_modelopt_state.pth`` is a full ModelOpt pickle (metadata,
         # nested dicts, dtypes, etc.); PyTorch's ``weights_only=True`` rejects that and only
         # allows tensor-only checkpoints. Loading arbitrary pickles can execute stored code—use
-        # paths you trust (your own exports or verified checkpoints).
+        # paths you trust. This file format is internally-generated by export_hf_vllm_fq_checkpoint;
+        # only load checkpoints you created or from trusted sources.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/fakequant_worker.py` around lines 64 - 71, The comment
justifying weights_only=False is fine but should explicitly state that the
ModelOpt pickle format is produced by our exporter; update the inline comment
near the torch.load call that reads modelopt_state =
torch.load(quant_config["modelopt_state_path"], weights_only=False,
map_location="cpu") to mention that the checkpoint file is the internal format
created by export_hf_vllm_fq_checkpoint, so only trusted exporter outputs should
be loaded; keep the existing security warning about arbitrary pickle execution
and retain the explanation why weights_only=False is required.

136-136: Consider documenting why _dummy_run is needed here.

The _dummy_run(1) call was added before load_state_dict_from_path, but there's no comment explaining why this is necessary. If this initializes lazy modules or allocates buffers needed for state dict loading, a brief comment would help maintainability.

📝 Optional: Add explanatory comment
         if quantizer_file_path:
+            # Ensure model buffers/lazy modules are initialized before loading state dict
             self.model_runner._dummy_run(1)
             current_state_dict = load_state_dict_from_path(self, quantizer_file_path, model)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/fakequant_worker.py` at line 136, Add a brief inline
comment above the self.model_runner._dummy_run(1) call explaining why the dummy
run is required (e.g., it triggers lazy module initialization or allocates
buffers so load_state_dict_from_path can succeed), referencing the involved
symbols (_dummy_run, load_state_dict_from_path, model_runner) so future
maintainers understand the dependency between the dummy run and subsequent state
dict loading.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 56-70: In _collect_expert_pre_quant_scales replace the private
attribute check getattr(iq, "_disabled", False) with the public API by testing
iq.is_enabled (negated) so the condition becomes: if iq is None or not
iq.is_enabled or iq.pre_quant_scale is None: return None; update the check in
that function to use the public is_enabled property of the input_quantizer for
consistency with other code (e.g., other uses of is_enabled in the file).

---

Nitpick comments:
In `@examples/vllm_serve/fakequant_worker.py`:
- Around line 64-71: The comment justifying weights_only=False is fine but
should explicitly state that the ModelOpt pickle format is produced by our
exporter; update the inline comment near the torch.load call that reads
modelopt_state = torch.load(quant_config["modelopt_state_path"],
weights_only=False, map_location="cpu") to mention that the checkpoint file is
the internal format created by export_hf_vllm_fq_checkpoint, so only trusted
exporter outputs should be loaded; keep the existing security warning about
arbitrary pickle execution and retain the explanation why weights_only=False is
required.
- Line 136: Add a brief inline comment above the self.model_runner._dummy_run(1)
call explaining why the dummy run is required (e.g., it triggers lazy module
initialization or allocates buffers so load_state_dict_from_path can succeed),
referencing the involved symbols (_dummy_run, load_state_dict_from_path,
model_runner) so future maintainers understand the dependency between the dummy
run and subsequent state dict loading.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 74cd1761-c237-4334-a1d7-e995964be2e1

📥 Commits

Reviewing files that changed from the base of the PR and between 687ceea and b83afd8.

📒 Files selected for processing (5)
  • examples/vllm_serve/README.md
  • examples/vllm_serve/fakequant_worker.py
  • examples/vllm_serve/vllm_reload_utils.py
  • modelopt/torch/export/layer_utils.py
  • modelopt/torch/export/plugins/vllm_fakequant_hf.py
💤 Files with no reviewable changes (1)
  • examples/vllm_serve/README.md

Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/export/plugins/vllm_fakequant_hf.py (1)

282-344: ⚠️ Potential issue | 🟠 Major

Restore quantizer state in a finally block.

Lines 291-295 mutate the in-memory model, but restoration only happens on Lines 342-344. If get_quantizer_state_dict, safe_save, or save_pretrained raises, the caller is left with disabled weight quantizers and patched _rotate values. This should be guarded with try/finally.

💡 Proposed fix
     wqs_to_restore: list[tuple[TensorQuantizer, Any]] = []
-    for _, module in model.named_modules():
-        if isinstance(module, QuantModule):
-            for attr_name, quantizer in module.named_children():
-                if (
-                    attr_name.endswith("weight_quantizer")
-                    and isinstance(quantizer, TensorQuantizer)
-                    and quantizer.is_enabled
-                ):
-                    quantizer.disable()
-                    orig_rotate = quantizer._rotate
-                    if quantizer.rotate_is_enabled:
-                        quantizer._rotate = disable_rotate(quantizer)
-                    wqs_to_restore.append((quantizer, orig_rotate))
-
-    quantizer_state_dict = get_quantizer_state_dict(model)
-    for key in list(quantizer_state_dict):
-        if is_weight_quantizer_state_key(key):
-            quantizer_state_dict.pop(key)
-        elif key in input_quantizers_folded_pqs:
-            qstate_val = quantizer_state_dict[key]
-            if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val:
-                quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like(
-                    qstate_val["_pre_quant_scale"]
-                )
-
-    for iq_key, (avg_pqs, max_input_amax) in expert_pqs_overrides.items():
-        if iq_key in quantizer_state_dict:
-            qstate_val = quantizer_state_dict[iq_key]
-            if isinstance(qstate_val, dict):
-                if "_pre_quant_scale" in qstate_val:
-                    qstate_val["_pre_quant_scale"] = avg_pqs
-                if max_input_amax is not None and "_amax" in qstate_val:
-                    qstate_val["_amax"] = max_input_amax
-
-    modelopt_state = mto.modelopt_state(model)
-    qstate = quantizer_state(model)
-    for key in list(qstate):
-        if is_weight_quantizer_state_key(key) and qstate[key].get("_disabled"):
-            qstate.pop(key)
-
-    for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []):
-        if mode_str == "quantize" and "metadata" in m_state:
-            m_state["metadata"]["quantizer_state"] = qstate
-            break
-
-    modelopt_state["modelopt_state_weights"] = quantizer_state_dict
-    safe_save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth")
-
-    model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
-
-    for wq, orig_rotate in wqs_to_restore:
-        wq.enable()
-        wq._rotate = orig_rotate
+    try:
+        for _, module in model.named_modules():
+            if isinstance(module, QuantModule):
+                for attr_name, quantizer in module.named_children():
+                    if (
+                        attr_name.endswith("weight_quantizer")
+                        and isinstance(quantizer, TensorQuantizer)
+                        and quantizer.is_enabled
+                    ):
+                        quantizer.disable()
+                        orig_rotate = quantizer._rotate
+                        if quantizer.rotate_is_enabled:
+                            quantizer._rotate = disable_rotate(quantizer)
+                        wqs_to_restore.append((quantizer, orig_rotate))
+
+        quantizer_state_dict = get_quantizer_state_dict(model)
+        # ... existing quantizer_state_dict/modelopt_state/save logic ...
+        safe_save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth")
+        model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False)
+    finally:
+        for wq, orig_rotate in reversed(wqs_to_restore):
+            wq.enable()
+            wq._rotate = orig_rotate
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py` around lines 282 - 344,
Collecting disabled weight quantizers into wqs_to_restore then performing
operations (get_quantizer_state_dict, safe_save, model.save_pretrained) can
leave the model mutated if an exception occurs; wrap the mutation scope in a
try/finally: after the loop that builds wqs_to_restore and disables quantizers
(refer to QuantModule, TensorQuantizer, disable_rotate and wqs_to_restore), run
the subsequent state-building, patching, safe_save, and model.save_pretrained
calls inside a try block and in the finally iterate wqs_to_restore to restore
each quantizer by calling wq.enable() and resetting wq._rotate to orig_rotate so
restoration always occurs even if get_quantizer_state_dict, safe_save, or
model.save_pretrained raises.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 173-178: The current merge reduces per-expert amax tensors to a
scalar via torch.stack(...).max(), losing per-channel/block layouts and failing
for differing-but-mergeable shapes; replace that reduction by calling the
existing helper merge_amax_tensors_for_vllm_group() to compute max_in_amax.
Specifically, where iq0 = experts[0].input_quantizer and you build amaxes =
[e.input_quantizer.amax for e in experts], pass that amaxes list into
merge_amax_tensors_for_vllm_group() and assign its result to max_in_amax instead
of using torch.stack(...).max(), preserving layout and handling shape
differences.

---

Outside diff comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 282-344: Collecting disabled weight quantizers into wqs_to_restore
then performing operations (get_quantizer_state_dict, safe_save,
model.save_pretrained) can leave the model mutated if an exception occurs; wrap
the mutation scope in a try/finally: after the loop that builds wqs_to_restore
and disables quantizers (refer to QuantModule, TensorQuantizer, disable_rotate
and wqs_to_restore), run the subsequent state-building, patching, safe_save, and
model.save_pretrained calls inside a try block and in the finally iterate
wqs_to_restore to restore each quantizer by calling wq.enable() and resetting
wq._rotate to orig_rotate so restoration always occurs even if
get_quantizer_state_dict, safe_save, or model.save_pretrained raises.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 45fc4919-94c3-4dd3-81f8-965e3196b707

📥 Commits

Reviewing files that changed from the base of the PR and between b83afd8 and 0acc835.

📒 Files selected for processing (3)
  • examples/vllm_serve/fakequant_worker.py
  • examples/vllm_serve/vllm_reload_utils.py
  • modelopt/torch/export/plugins/vllm_fakequant_hf.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/vllm_serve/fakequant_worker.py
  • examples/vllm_serve/vllm_reload_utils.py

Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py Outdated
Base automatically changed from kinjal/vllm_super_nano_support to main April 14, 2026 23:40
@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/vllm_fix_prequant_scale branch from 0acc835 to 6806f16 Compare April 15, 2026 00:51
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
examples/vllm_serve/vllm_reload_utils.py (1)

200-229: Consider narrowing the exception handler.

The except Exception at line 227 catches all exceptions, including unexpected ones that may indicate bugs rather than benign mapping failures. Consider catching more specific exceptions (e.g., KeyError, ValueError, AttributeError) or at least logging at a higher level (warning) for unexpected exception types.

♻️ Suggested refinement
         try:
             result = map_fun({probe_key: probe_weight})
             if result:
                 new_key = next(iter(result))
                 new_first = new_key.split(".")[0]
                 if new_first != first_component:
                     prefix_remap[first_component] = new_first
-        except Exception as e:
+        except (KeyError, ValueError, AttributeError, StopIteration) as e:
             logging.getLogger(__name__).debug("prefix-remap probe failed for %r: %s", probe_key, e)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/vllm_reload_utils.py` around lines 200 - 229, The except
Exception in _infer_prefix_remap is too broad; narrow it to handle only expected
mapping failures (e.g., catch KeyError, ValueError, AttributeError) around the
map_fun({probe_key: probe_weight}) call and handle other exceptions by
re-raising or logging them at a higher severity; specifically, replace the
blanket except Exception as e with targeted except blocks for
KeyError/ValueError/AttributeError that emit a debug log mentioning probe_key
and continue, and add a fallback except Exception that logs a warning or
re-raises so unexpected errors aren’t silently swallowed.
examples/vllm_serve/fakequant_worker.py (1)

133-133: Add a comment explaining the purpose of the dummy run.

The _dummy_run(1) call was added but its purpose isn't documented. Is this needed to trigger lazy parameter initialization before loading the quantizer state? A brief comment would help future maintainers understand why this is necessary.

📝 Suggested documentation
+            # Trigger lazy initialization so model parameters are materialized before loading state.
             self.model_runner._dummy_run(1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/fakequant_worker.py` at line 133, The call to
self.model_runner._dummy_run(1) lacks documentation; add an inline comment next
to the call explaining that _dummy_run(1) forces a single inference step to
trigger lazy parameter/module initialization (so all tensors/shapes are created)
before loading the quantizer state, preventing mismatched or missing parameters
when the quantizer is applied; reference the _dummy_run method on model_runner
and mention it intentionally runs exactly one step for this initialization
side-effect.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@examples/vllm_serve/fakequant_worker.py`:
- Line 133: The call to self.model_runner._dummy_run(1) lacks documentation; add
an inline comment next to the call explaining that _dummy_run(1) forces a single
inference step to trigger lazy parameter/module initialization (so all
tensors/shapes are created) before loading the quantizer state, preventing
mismatched or missing parameters when the quantizer is applied; reference the
_dummy_run method on model_runner and mention it intentionally runs exactly one
step for this initialization side-effect.

In `@examples/vllm_serve/vllm_reload_utils.py`:
- Around line 200-229: The except Exception in _infer_prefix_remap is too broad;
narrow it to handle only expected mapping failures (e.g., catch KeyError,
ValueError, AttributeError) around the map_fun({probe_key: probe_weight}) call
and handle other exceptions by re-raising or logging them at a higher severity;
specifically, replace the blanket except Exception as e with targeted except
blocks for KeyError/ValueError/AttributeError that emit a debug log mentioning
probe_key and continue, and add a fallback except Exception that logs a warning
or re-raises so unexpected errors aren’t silently swallowed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: d9d3fc38-66bf-4e1a-9251-2e678af69a2f

📥 Commits

Reviewing files that changed from the base of the PR and between 0acc835 and 6806f16.

📒 Files selected for processing (5)
  • examples/vllm_serve/README.md
  • examples/vllm_serve/fakequant_worker.py
  • examples/vllm_serve/vllm_reload_utils.py
  • modelopt/torch/export/layer_utils.py
  • modelopt/torch/export/plugins/vllm_fakequant_hf.py
💤 Files with no reviewable changes (1)
  • examples/vllm_serve/README.md
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/export/layer_utils.py
  • modelopt/torch/export/plugins/vllm_fakequant_hf.py

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 15, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1242/

Built to branch gh-pages at 2026-04-17 00:55 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 15, 2026

Codecov Report

❌ Patch coverage is 64.48598% with 76 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.49%. Comparing base (f238d93) to head (9a6f5d2).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/export/plugins/vllm_fakequant_hf.py 64.87% 72 Missing ⚠️
modelopt/torch/export/layer_utils.py 33.33% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1242      +/-   ##
==========================================
- Coverage   75.58%   72.49%   -3.10%     
==========================================
  Files         459      459              
  Lines       48613    48817     +204     
==========================================
- Hits        36745    35388    -1357     
- Misses      11868    13429    +1561     
Flag Coverage Δ
examples 41.27% <12.14%> (+11.43%) ⬆️
gpu 51.70% <49.06%> (-8.80%) ⬇️
unit 52.14% <26.16%> (-0.07%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kinjalpatel27 kinjalpatel27 changed the title Kinjal/vllm fix prequant scale vLLM fakequant export update for AWQ checkpoint Apr 15, 2026
@kinjalpatel27 kinjalpatel27 marked this pull request as ready for review April 15, 2026 01:37
@kinjalpatel27 kinjalpatel27 requested review from a team as code owners April 15, 2026 01:37
@kinjalpatel27 kinjalpatel27 requested a review from meenchen April 15, 2026 01:37
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/export/plugins/vllm_fakequant_hf.py (1)

165-175: Consider extracting the device/dtype conversion outside the loop for clarity.

The reassignment of avg_pre_quant_scale on each iteration (lines 165-167) works correctly when all experts share the same device/dtype (expected for MoE), but the pattern can be confusing. A minor clarification would make the intent clearer:

🔧 Optional refactor for clarity
+            # Move to common device/dtype (all experts expected to be on same device)
+            ref_pqs = experts[0].input_quantizer._pre_quant_scale
+            avg_pre_quant_scale = avg_pre_quant_scale.to(device=ref_pqs.device, dtype=ref_pqs.dtype)
+
             for ex in experts:
                 nm = id_to_name.get(id(ex))
                 if nm is None or f"{nm}.weight" not in state_dict:
                     continue
                 old_pre_quant_scale = ex.input_quantizer._pre_quant_scale
-                avg_pre_quant_scale = avg_pre_quant_scale.to(
-                    device=old_pre_quant_scale.device, dtype=old_pre_quant_scale.dtype
-                )
                 if torch.equal(old_pre_quant_scale, avg_pre_quant_scale):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py` around lines 165 - 175,
The loop currently reassigns avg_pre_quant_scale via .to(...) inline which is
confusing; extract that conversion into a clearly named variable (e.g.,
avg_pre_quant_scale_converted) before using it so you compare
old_pre_quant_scale to avg_pre_quant_scale_converted and reuse that converted
tensor when computing updated_weight (which references weight,
old_pre_quant_scale and avg_pre_quant_scale). Ensure you replace all occurrences
of the inline .to(...) on avg_pre_quant_scale with the new converted variable to
avoid repeated conversions while keeping the same device/dtype semantics.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 165-175: The loop currently reassigns avg_pre_quant_scale via
.to(...) inline which is confusing; extract that conversion into a clearly named
variable (e.g., avg_pre_quant_scale_converted) before using it so you compare
old_pre_quant_scale to avg_pre_quant_scale_converted and reuse that converted
tensor when computing updated_weight (which references weight,
old_pre_quant_scale and avg_pre_quant_scale). Ensure you replace all occurrences
of the inline .to(...) on avg_pre_quant_scale with the new converted variable to
avoid repeated conversions while keeping the same device/dtype semantics.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 673f5605-3950-42d4-b14e-576cda6a0306

📥 Commits

Reviewing files that changed from the base of the PR and between 6806f16 and f55d623.

📒 Files selected for processing (1)
  • modelopt/torch/export/plugins/vllm_fakequant_hf.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 207-210: The probe forward helper _dummy_forward should avoid
building autograd graphs; wrap the model invocation inside a
torch.inference_mode() context so the call model(torch.ones([1, 2],
dtype=torch.long, device=dev)) executes without tracking gradients or allocating
unnecessary graph tensors; update the body of _dummy_forward to use with
torch.inference_mode(): and keep the existing contextlib.suppress(Exception)
behavior and hook side-effects intact.
- Around line 158-159: avg_pqs is being averaged while still in fp16/bf16 (from
pqs_list), so comparing to torch.finfo(torch.float32).tiny underflows when cast
and allows zeros; compute the averaged pre_quant_scale in fp32: cast pqs_list
tensors (or the stacked result) to torch.float32 before calling mean and clamp
using torch.finfo(torch.float32).tiny, then use that fp32-clamped avg_pqs to
compute pre_quant_scale (and only cast back to the original dtype if required
later, e.g., where pre_quant_scale is used downstream such as in the weight
division code that currently references pre_quant_scale).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: ebe9f85e-39c6-4ca8-85d0-db896c435a03

📥 Commits

Reviewing files that changed from the base of the PR and between f55d623 and 6e17da8.

📒 Files selected for processing (2)
  • modelopt/torch/export/plugins/vllm_fakequant_hf.py
  • modelopt/torch/export/unified_export_hf.py

Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py Outdated
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py Outdated
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR makes significant improvements to enable AWQ checkpoint export and reload in the vLLM fakequant serving path. The changes are substantial (~473 lines across 6 files) and touch critical export/reload paths. While the overall design is sound and addresses a real gap (AWQ export was listed as unsupported), there are several correctness concerns, missing tests for new utility functions, and a potentially risky silent-failure pattern that should be addressed before merging.

What's good:

  • Proper refactoring of TP sharding into _narrow_tensor_to_tp_local_shard (eliminates duplication)
  • merge_amax_tensors_for_group correctly handles same-shape, different-shape (GQA), and fallback cases
  • Non-mutative resmoothing via _resmooth_experts_for_export avoids modifying the live model
  • Use of safe_load/safe_save for security best practices
  • Good use of is_weight_quantizer_state_key regex for consistent matching

Key concerns:

  1. No unit tests for new utility functions (merge_amax_tensors_for_group, is_weight_quantizer_state_key, _resmooth_experts_for_export, shard_pre_quant_scale_for_tp, _infer_prefix_remap). These are critical correctness-sensitive functions.
  2. A silent contextlib.suppress(Exception) in _resmooth_experts_for_export that could mask real failures.
  3. The _collect_expert_pre_quant_scales function name mentions "experts" but is used for all shared-input groups including dense GQA projections, which could be confusing.
  4. The existing test test_hf_vllm_export only covers FP8, but AWQ is the main feature being added — no AWQ test case was added.

Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py Outdated
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py
Comment thread examples/vllm_serve/vllm_reload_utils.py Outdated
Comment thread examples/vllm_serve/fakequant_worker.py
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py Outdated
Comment thread examples/vllm_serve/vllm_reload_utils.py
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a substantial and well-designed PR that enables AWQ checkpoint export/reload in vLLM fakequant. Many previous review comments have been addressed (renamed functions, is_enabled instead of _disabled, merge_amax_tensors_for_group usage, dummy forward logging, device alignment for deepcopy, invariant assertions, tests added for infer_quantizer_prefix_remap and merge_amax_tensors_for_group).

However, several critical issues from previous reviews remain unresolved:

  1. Missing try/finally for wqs_to_restore (lines 430-493 of vllm_fakequant_hf.py) — weight quantizers are disabled but restoration only happens at the end. If safe_save or save_pretrained throws, the model is left in a mutated state. This was flagged in the second review iteration and is still unresolved.

  2. avg_pqs computed in half precision (line 271) — torch.stack(pqs_list).mean(0) preserves the original fp16/bf16 dtype, then clamping with torch.finfo(torch.float32).tiny is meaningless because that value underflows to zero in fp16. Division by zero may produce inf weights. This was flagged as Major in the 4th review.

  3. Missing torch.inference_mode() in _dummy_forward (line 307) — Builds unnecessary autograd graphs during the probe forward. Flagged in 4th review.

  4. Bare assert in _check_all_weight_quantizers_disabled (line 127) — can be optimized away with -O. Per codebase guidelines, use RuntimeError.

  5. Bug in requant_weights_for_export — The single-quantizer calibration path runs sequence_quantizers[0](weight) without .float() but the application pass uses weight.float(). While the intent seems correct (calibrate then apply), the inconsistency could lead to precision-dependent amax values.

  6. Duplicate _infer_prefix_remap — Two implementations exist: infer_quantizer_prefix_remap (in vllm_fakequant_hf.py, with consistency checks) and _infer_prefix_remap (in vllm_reload_utils.py, without consistency checks). The reload-side function should delegate to the public API to avoid divergence.

The test coverage has improved with tests for merge_amax_tensors_for_group and infer_quantizer_prefix_remap, but key functions like _resmooth_experts_for_export, requant_weights_for_export, is_weight_quantizer_state_key, and shard_pre_quant_scale_for_tp still lack tests.

Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py Outdated
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py Outdated
Comment thread examples/vllm_serve/vllm_reload_utils.py Outdated
Comment thread examples/vllm_serve/vllm_reload_utils.py
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py
@kinjalpatel27 kinjalpatel27 requested a review from cjluo-nv April 16, 2026 01:51
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All critical issues from previous reviews have been addressed:

  1. try/finally for wqs_to_restore — ✅ Fixed. Weight quantizer disable/restore is now properly wrapped in try/finally.
  2. fp32 averaging of avg_pqs — ✅ Fixed. Uses torch.stack([p.float() for p in pqs_list]).mean(0) to avoid fp16/bf16 underflow.
  3. torch.inference_mode() in _dummy_forward — ✅ Fixed.
  4. Bare assertRuntimeError — ✅ Fixed across all changed files.
  5. _disabledis_enabled public API — ✅ Fixed.
  6. merge_amax_tensors_for_group usage — ✅ Used for all amax merging, preserving per-channel structure.
  7. Duplicate _infer_prefix_remap — ✅ Eliminated. Reload side now imports infer_quantizer_prefix_remap from the export module.
  8. Weight quantizer invariant assertion — ✅ Added proper validation.
  9. deepcopy device alignment — ✅ requant_weights_for_export does .to(device=weight.device).
  10. safe_save/safe_load — ✅ Used instead of raw torch.save/torch.load.

Tests have been meaningfully expanded: GPU test now covers INT4_AWQ_CFG in addition to FP8, and unit tests cover infer_quantizer_prefix_remap (7 cases) and merge_amax_tensors_for_group (3 cases). The code is well-structured with clear documentation and reasonable design decisions throughout.

Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/vllm_fix_prequant_scale branch from ea4035e to 585cc06 Compare April 16, 2026 17:21
Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change LGTM overall, request to add a unit test for MoE

Comment thread tests/gpu/torch/export/test_vllm_fakequant_hf_export.py Outdated
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Copy link
Copy Markdown
Contributor

@mxinO mxinO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks.
Further question.
We are solving the the pqs fusing at the export time which is sub-optimal, this is actually a problem of calibration time, since all our inference frameworks use fused q,k,v and experts, any thought to solve it at calibration time? cc @realAsma

## Known Problems

1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align).
2. AWQ reload is not supported yet
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change 3, 4-> 2, 3

for quantizer_copy in quantizers:
weight_quantized = quantizer_copy(weight_quantized)
for quantizer_copy in quantizers:
finish_stats_collection(quantizer_copy)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic looks can be simplified, distinguish the len(quantizers) == 1 is redundant, the loop in the else just works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants