vLLM fakequant export update for AWQ checkpoint#1242
vLLM fakequant export update for AWQ checkpoint#1242kinjalpatel27 wants to merge 19 commits intomainfrom
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughUpdates enable AWQ-related reload/resmoothing and TP-sharding support across vLLM serve examples and HF export plugins, including checkpoint loading changes, quantizer-state prefix remapping and merging, pre-quant-scale sharding, NemotronHMOE support, and export-time AWQ resmoothing and requantization. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
examples/vllm_serve/fakequant_worker.py (2)
64-71: Adequate security justification forweights_only=False.The inline comment correctly explains why
weights_only=Falseis required (ModelOpt state contains metadata, nested dicts, dtypes that PyTorch'sweights_only=Truerejects) and appropriately warns users to only load trusted paths. This satisfies the coding guideline requirement for documented justification.However, consider adding a note that this file format is internally-generated by
export_hf_vllm_fq_checkpointto strengthen the justification.💡 Optional: Strengthen the security comment
# map_location="cpu": load tensors on CPU so device ids in the file need not match this worker. # weights_only=False: ``vllm_fq_modelopt_state.pth`` is a full ModelOpt pickle (metadata, # nested dicts, dtypes, etc.); PyTorch's ``weights_only=True`` rejects that and only # allows tensor-only checkpoints. Loading arbitrary pickles can execute stored code—use - # paths you trust (your own exports or verified checkpoints). + # paths you trust. This file format is internally-generated by export_hf_vllm_fq_checkpoint; + # only load checkpoints you created or from trusted sources.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/fakequant_worker.py` around lines 64 - 71, The comment justifying weights_only=False is fine but should explicitly state that the ModelOpt pickle format is produced by our exporter; update the inline comment near the torch.load call that reads modelopt_state = torch.load(quant_config["modelopt_state_path"], weights_only=False, map_location="cpu") to mention that the checkpoint file is the internal format created by export_hf_vllm_fq_checkpoint, so only trusted exporter outputs should be loaded; keep the existing security warning about arbitrary pickle execution and retain the explanation why weights_only=False is required.
136-136: Consider documenting why_dummy_runis needed here.The
_dummy_run(1)call was added beforeload_state_dict_from_path, but there's no comment explaining why this is necessary. If this initializes lazy modules or allocates buffers needed for state dict loading, a brief comment would help maintainability.📝 Optional: Add explanatory comment
if quantizer_file_path: + # Ensure model buffers/lazy modules are initialized before loading state dict self.model_runner._dummy_run(1) current_state_dict = load_state_dict_from_path(self, quantizer_file_path, model)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/fakequant_worker.py` at line 136, Add a brief inline comment above the self.model_runner._dummy_run(1) call explaining why the dummy run is required (e.g., it triggers lazy module initialization or allocates buffers so load_state_dict_from_path can succeed), referencing the involved symbols (_dummy_run, load_state_dict_from_path, model_runner) so future maintainers understand the dependency between the dummy run and subsequent state dict loading.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 56-70: In _collect_expert_pre_quant_scales replace the private
attribute check getattr(iq, "_disabled", False) with the public API by testing
iq.is_enabled (negated) so the condition becomes: if iq is None or not
iq.is_enabled or iq.pre_quant_scale is None: return None; update the check in
that function to use the public is_enabled property of the input_quantizer for
consistency with other code (e.g., other uses of is_enabled in the file).
---
Nitpick comments:
In `@examples/vllm_serve/fakequant_worker.py`:
- Around line 64-71: The comment justifying weights_only=False is fine but
should explicitly state that the ModelOpt pickle format is produced by our
exporter; update the inline comment near the torch.load call that reads
modelopt_state = torch.load(quant_config["modelopt_state_path"],
weights_only=False, map_location="cpu") to mention that the checkpoint file is
the internal format created by export_hf_vllm_fq_checkpoint, so only trusted
exporter outputs should be loaded; keep the existing security warning about
arbitrary pickle execution and retain the explanation why weights_only=False is
required.
- Line 136: Add a brief inline comment above the self.model_runner._dummy_run(1)
call explaining why the dummy run is required (e.g., it triggers lazy module
initialization or allocates buffers so load_state_dict_from_path can succeed),
referencing the involved symbols (_dummy_run, load_state_dict_from_path,
model_runner) so future maintainers understand the dependency between the dummy
run and subsequent state dict loading.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 74cd1761-c237-4334-a1d7-e995964be2e1
📒 Files selected for processing (5)
examples/vllm_serve/README.mdexamples/vllm_serve/fakequant_worker.pyexamples/vllm_serve/vllm_reload_utils.pymodelopt/torch/export/layer_utils.pymodelopt/torch/export/plugins/vllm_fakequant_hf.py
💤 Files with no reviewable changes (1)
- examples/vllm_serve/README.md
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/export/plugins/vllm_fakequant_hf.py (1)
282-344:⚠️ Potential issue | 🟠 MajorRestore quantizer state in a
finallyblock.Lines 291-295 mutate the in-memory model, but restoration only happens on Lines 342-344. If
get_quantizer_state_dict,safe_save, orsave_pretrainedraises, the caller is left with disabled weight quantizers and patched_rotatevalues. This should be guarded withtry/finally.💡 Proposed fix
wqs_to_restore: list[tuple[TensorQuantizer, Any]] = [] - for _, module in model.named_modules(): - if isinstance(module, QuantModule): - for attr_name, quantizer in module.named_children(): - if ( - attr_name.endswith("weight_quantizer") - and isinstance(quantizer, TensorQuantizer) - and quantizer.is_enabled - ): - quantizer.disable() - orig_rotate = quantizer._rotate - if quantizer.rotate_is_enabled: - quantizer._rotate = disable_rotate(quantizer) - wqs_to_restore.append((quantizer, orig_rotate)) - - quantizer_state_dict = get_quantizer_state_dict(model) - for key in list(quantizer_state_dict): - if is_weight_quantizer_state_key(key): - quantizer_state_dict.pop(key) - elif key in input_quantizers_folded_pqs: - qstate_val = quantizer_state_dict[key] - if isinstance(qstate_val, dict) and "_pre_quant_scale" in qstate_val: - quantizer_state_dict[key]["_pre_quant_scale"] = torch.ones_like( - qstate_val["_pre_quant_scale"] - ) - - for iq_key, (avg_pqs, max_input_amax) in expert_pqs_overrides.items(): - if iq_key in quantizer_state_dict: - qstate_val = quantizer_state_dict[iq_key] - if isinstance(qstate_val, dict): - if "_pre_quant_scale" in qstate_val: - qstate_val["_pre_quant_scale"] = avg_pqs - if max_input_amax is not None and "_amax" in qstate_val: - qstate_val["_amax"] = max_input_amax - - modelopt_state = mto.modelopt_state(model) - qstate = quantizer_state(model) - for key in list(qstate): - if is_weight_quantizer_state_key(key) and qstate[key].get("_disabled"): - qstate.pop(key) - - for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []): - if mode_str == "quantize" and "metadata" in m_state: - m_state["metadata"]["quantizer_state"] = qstate - break - - modelopt_state["modelopt_state_weights"] = quantizer_state_dict - safe_save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") - - model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) - - for wq, orig_rotate in wqs_to_restore: - wq.enable() - wq._rotate = orig_rotate + try: + for _, module in model.named_modules(): + if isinstance(module, QuantModule): + for attr_name, quantizer in module.named_children(): + if ( + attr_name.endswith("weight_quantizer") + and isinstance(quantizer, TensorQuantizer) + and quantizer.is_enabled + ): + quantizer.disable() + orig_rotate = quantizer._rotate + if quantizer.rotate_is_enabled: + quantizer._rotate = disable_rotate(quantizer) + wqs_to_restore.append((quantizer, orig_rotate)) + + quantizer_state_dict = get_quantizer_state_dict(model) + # ... existing quantizer_state_dict/modelopt_state/save logic ... + safe_save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") + model.save_pretrained(export_dir, state_dict=clean_sd, save_modelopt_state=False) + finally: + for wq, orig_rotate in reversed(wqs_to_restore): + wq.enable() + wq._rotate = orig_rotate🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py` around lines 282 - 344, Collecting disabled weight quantizers into wqs_to_restore then performing operations (get_quantizer_state_dict, safe_save, model.save_pretrained) can leave the model mutated if an exception occurs; wrap the mutation scope in a try/finally: after the loop that builds wqs_to_restore and disables quantizers (refer to QuantModule, TensorQuantizer, disable_rotate and wqs_to_restore), run the subsequent state-building, patching, safe_save, and model.save_pretrained calls inside a try block and in the finally iterate wqs_to_restore to restore each quantizer by calling wq.enable() and resetting wq._rotate to orig_rotate so restoration always occurs even if get_quantizer_state_dict, safe_save, or model.save_pretrained raises.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 173-178: The current merge reduces per-expert amax tensors to a
scalar via torch.stack(...).max(), losing per-channel/block layouts and failing
for differing-but-mergeable shapes; replace that reduction by calling the
existing helper merge_amax_tensors_for_vllm_group() to compute max_in_amax.
Specifically, where iq0 = experts[0].input_quantizer and you build amaxes =
[e.input_quantizer.amax for e in experts], pass that amaxes list into
merge_amax_tensors_for_vllm_group() and assign its result to max_in_amax instead
of using torch.stack(...).max(), preserving layout and handling shape
differences.
---
Outside diff comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 282-344: Collecting disabled weight quantizers into wqs_to_restore
then performing operations (get_quantizer_state_dict, safe_save,
model.save_pretrained) can leave the model mutated if an exception occurs; wrap
the mutation scope in a try/finally: after the loop that builds wqs_to_restore
and disables quantizers (refer to QuantModule, TensorQuantizer, disable_rotate
and wqs_to_restore), run the subsequent state-building, patching, safe_save, and
model.save_pretrained calls inside a try block and in the finally iterate
wqs_to_restore to restore each quantizer by calling wq.enable() and resetting
wq._rotate to orig_rotate so restoration always occurs even if
get_quantizer_state_dict, safe_save, or model.save_pretrained raises.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 45fc4919-94c3-4dd3-81f8-965e3196b707
📒 Files selected for processing (3)
examples/vllm_serve/fakequant_worker.pyexamples/vllm_serve/vllm_reload_utils.pymodelopt/torch/export/plugins/vllm_fakequant_hf.py
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/vllm_serve/fakequant_worker.py
- examples/vllm_serve/vllm_reload_utils.py
0acc835 to
6806f16
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (2)
examples/vllm_serve/vllm_reload_utils.py (1)
200-229: Consider narrowing the exception handler.The
except Exceptionat line 227 catches all exceptions, including unexpected ones that may indicate bugs rather than benign mapping failures. Consider catching more specific exceptions (e.g.,KeyError,ValueError,AttributeError) or at least logging at a higher level (warning) for unexpected exception types.♻️ Suggested refinement
try: result = map_fun({probe_key: probe_weight}) if result: new_key = next(iter(result)) new_first = new_key.split(".")[0] if new_first != first_component: prefix_remap[first_component] = new_first - except Exception as e: + except (KeyError, ValueError, AttributeError, StopIteration) as e: logging.getLogger(__name__).debug("prefix-remap probe failed for %r: %s", probe_key, e)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/vllm_reload_utils.py` around lines 200 - 229, The except Exception in _infer_prefix_remap is too broad; narrow it to handle only expected mapping failures (e.g., catch KeyError, ValueError, AttributeError) around the map_fun({probe_key: probe_weight}) call and handle other exceptions by re-raising or logging them at a higher severity; specifically, replace the blanket except Exception as e with targeted except blocks for KeyError/ValueError/AttributeError that emit a debug log mentioning probe_key and continue, and add a fallback except Exception that logs a warning or re-raises so unexpected errors aren’t silently swallowed.examples/vllm_serve/fakequant_worker.py (1)
133-133: Add a comment explaining the purpose of the dummy run.The
_dummy_run(1)call was added but its purpose isn't documented. Is this needed to trigger lazy parameter initialization before loading the quantizer state? A brief comment would help future maintainers understand why this is necessary.📝 Suggested documentation
+ # Trigger lazy initialization so model parameters are materialized before loading state. self.model_runner._dummy_run(1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/fakequant_worker.py` at line 133, The call to self.model_runner._dummy_run(1) lacks documentation; add an inline comment next to the call explaining that _dummy_run(1) forces a single inference step to trigger lazy parameter/module initialization (so all tensors/shapes are created) before loading the quantizer state, preventing mismatched or missing parameters when the quantizer is applied; reference the _dummy_run method on model_runner and mention it intentionally runs exactly one step for this initialization side-effect.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@examples/vllm_serve/fakequant_worker.py`:
- Line 133: The call to self.model_runner._dummy_run(1) lacks documentation; add
an inline comment next to the call explaining that _dummy_run(1) forces a single
inference step to trigger lazy parameter/module initialization (so all
tensors/shapes are created) before loading the quantizer state, preventing
mismatched or missing parameters when the quantizer is applied; reference the
_dummy_run method on model_runner and mention it intentionally runs exactly one
step for this initialization side-effect.
In `@examples/vllm_serve/vllm_reload_utils.py`:
- Around line 200-229: The except Exception in _infer_prefix_remap is too broad;
narrow it to handle only expected mapping failures (e.g., catch KeyError,
ValueError, AttributeError) around the map_fun({probe_key: probe_weight}) call
and handle other exceptions by re-raising or logging them at a higher severity;
specifically, replace the blanket except Exception as e with targeted except
blocks for KeyError/ValueError/AttributeError that emit a debug log mentioning
probe_key and continue, and add a fallback except Exception that logs a warning
or re-raises so unexpected errors aren’t silently swallowed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: d9d3fc38-66bf-4e1a-9251-2e678af69a2f
📒 Files selected for processing (5)
examples/vllm_serve/README.mdexamples/vllm_serve/fakequant_worker.pyexamples/vllm_serve/vllm_reload_utils.pymodelopt/torch/export/layer_utils.pymodelopt/torch/export/plugins/vllm_fakequant_hf.py
💤 Files with no reviewable changes (1)
- examples/vllm_serve/README.md
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/export/layer_utils.py
- modelopt/torch/export/plugins/vllm_fakequant_hf.py
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1242 +/- ##
==========================================
- Coverage 75.58% 72.49% -3.10%
==========================================
Files 459 459
Lines 48613 48817 +204
==========================================
- Hits 36745 35388 -1357
- Misses 11868 13429 +1561
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
🧹 Nitpick comments (1)
modelopt/torch/export/plugins/vllm_fakequant_hf.py (1)
165-175: Consider extracting the device/dtype conversion outside the loop for clarity.The reassignment of
avg_pre_quant_scaleon each iteration (lines 165-167) works correctly when all experts share the same device/dtype (expected for MoE), but the pattern can be confusing. A minor clarification would make the intent clearer:🔧 Optional refactor for clarity
+ # Move to common device/dtype (all experts expected to be on same device) + ref_pqs = experts[0].input_quantizer._pre_quant_scale + avg_pre_quant_scale = avg_pre_quant_scale.to(device=ref_pqs.device, dtype=ref_pqs.dtype) + for ex in experts: nm = id_to_name.get(id(ex)) if nm is None or f"{nm}.weight" not in state_dict: continue old_pre_quant_scale = ex.input_quantizer._pre_quant_scale - avg_pre_quant_scale = avg_pre_quant_scale.to( - device=old_pre_quant_scale.device, dtype=old_pre_quant_scale.dtype - ) if torch.equal(old_pre_quant_scale, avg_pre_quant_scale):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py` around lines 165 - 175, The loop currently reassigns avg_pre_quant_scale via .to(...) inline which is confusing; extract that conversion into a clearly named variable (e.g., avg_pre_quant_scale_converted) before using it so you compare old_pre_quant_scale to avg_pre_quant_scale_converted and reuse that converted tensor when computing updated_weight (which references weight, old_pre_quant_scale and avg_pre_quant_scale). Ensure you replace all occurrences of the inline .to(...) on avg_pre_quant_scale with the new converted variable to avoid repeated conversions while keeping the same device/dtype semantics.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 165-175: The loop currently reassigns avg_pre_quant_scale via
.to(...) inline which is confusing; extract that conversion into a clearly named
variable (e.g., avg_pre_quant_scale_converted) before using it so you compare
old_pre_quant_scale to avg_pre_quant_scale_converted and reuse that converted
tensor when computing updated_weight (which references weight,
old_pre_quant_scale and avg_pre_quant_scale). Ensure you replace all occurrences
of the inline .to(...) on avg_pre_quant_scale with the new converted variable to
avoid repeated conversions while keeping the same device/dtype semantics.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 673f5605-3950-42d4-b14e-576cda6a0306
📒 Files selected for processing (1)
modelopt/torch/export/plugins/vllm_fakequant_hf.py
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/export/plugins/vllm_fakequant_hf.py`:
- Around line 207-210: The probe forward helper _dummy_forward should avoid
building autograd graphs; wrap the model invocation inside a
torch.inference_mode() context so the call model(torch.ones([1, 2],
dtype=torch.long, device=dev)) executes without tracking gradients or allocating
unnecessary graph tensors; update the body of _dummy_forward to use with
torch.inference_mode(): and keep the existing contextlib.suppress(Exception)
behavior and hook side-effects intact.
- Around line 158-159: avg_pqs is being averaged while still in fp16/bf16 (from
pqs_list), so comparing to torch.finfo(torch.float32).tiny underflows when cast
and allows zeros; compute the averaged pre_quant_scale in fp32: cast pqs_list
tensors (or the stacked result) to torch.float32 before calling mean and clamp
using torch.finfo(torch.float32).tiny, then use that fp32-clamped avg_pqs to
compute pre_quant_scale (and only cast back to the original dtype if required
later, e.g., where pre_quant_scale is used downstream such as in the weight
division code that currently references pre_quant_scale).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: ebe9f85e-39c6-4ca8-85d0-db896c435a03
📒 Files selected for processing (2)
modelopt/torch/export/plugins/vllm_fakequant_hf.pymodelopt/torch/export/unified_export_hf.py
cjluo-nv
left a comment
There was a problem hiding this comment.
This PR makes significant improvements to enable AWQ checkpoint export and reload in the vLLM fakequant serving path. The changes are substantial (~473 lines across 6 files) and touch critical export/reload paths. While the overall design is sound and addresses a real gap (AWQ export was listed as unsupported), there are several correctness concerns, missing tests for new utility functions, and a potentially risky silent-failure pattern that should be addressed before merging.
What's good:
- Proper refactoring of TP sharding into
_narrow_tensor_to_tp_local_shard(eliminates duplication) merge_amax_tensors_for_groupcorrectly handles same-shape, different-shape (GQA), and fallback cases- Non-mutative resmoothing via
_resmooth_experts_for_exportavoids modifying the live model - Use of
safe_load/safe_savefor security best practices - Good use of
is_weight_quantizer_state_keyregex for consistent matching
Key concerns:
- No unit tests for new utility functions (
merge_amax_tensors_for_group,is_weight_quantizer_state_key,_resmooth_experts_for_export,shard_pre_quant_scale_for_tp,_infer_prefix_remap). These are critical correctness-sensitive functions. - A silent
contextlib.suppress(Exception)in_resmooth_experts_for_exportthat could mask real failures. - The
_collect_expert_pre_quant_scalesfunction name mentions "experts" but is used for all shared-input groups including dense GQA projections, which could be confusing. - The existing test
test_hf_vllm_exportonly covers FP8, but AWQ is the main feature being added — no AWQ test case was added.
cjluo-nv
left a comment
There was a problem hiding this comment.
This is a substantial and well-designed PR that enables AWQ checkpoint export/reload in vLLM fakequant. Many previous review comments have been addressed (renamed functions, is_enabled instead of _disabled, merge_amax_tensors_for_group usage, dummy forward logging, device alignment for deepcopy, invariant assertions, tests added for infer_quantizer_prefix_remap and merge_amax_tensors_for_group).
However, several critical issues from previous reviews remain unresolved:
-
Missing
try/finallyforwqs_to_restore(lines 430-493 ofvllm_fakequant_hf.py) — weight quantizers are disabled but restoration only happens at the end. Ifsafe_saveorsave_pretrainedthrows, the model is left in a mutated state. This was flagged in the second review iteration and is still unresolved. -
avg_pqscomputed in half precision (line 271) —torch.stack(pqs_list).mean(0)preserves the original fp16/bf16 dtype, then clamping withtorch.finfo(torch.float32).tinyis meaningless because that value underflows to zero in fp16. Division by zero may produce inf weights. This was flagged as Major in the 4th review. -
Missing
torch.inference_mode()in_dummy_forward(line 307) — Builds unnecessary autograd graphs during the probe forward. Flagged in 4th review. -
Bare
assertin_check_all_weight_quantizers_disabled(line 127) — can be optimized away with-O. Per codebase guidelines, useRuntimeError. -
Bug in
requant_weights_for_export— The single-quantizer calibration path runssequence_quantizers[0](weight)without.float()but the application pass usesweight.float(). While the intent seems correct (calibrate then apply), the inconsistency could lead to precision-dependent amax values. -
Duplicate
_infer_prefix_remap— Two implementations exist:infer_quantizer_prefix_remap(invllm_fakequant_hf.py, with consistency checks) and_infer_prefix_remap(invllm_reload_utils.py, without consistency checks). The reload-side function should delegate to the public API to avoid divergence.
The test coverage has improved with tests for merge_amax_tensors_for_group and infer_quantizer_prefix_remap, but key functions like _resmooth_experts_for_export, requant_weights_for_export, is_weight_quantizer_state_key, and shard_pre_quant_scale_for_tp still lack tests.
cjluo-nv
left a comment
There was a problem hiding this comment.
All critical issues from previous reviews have been addressed:
try/finallyforwqs_to_restore— ✅ Fixed. Weight quantizer disable/restore is now properly wrapped in try/finally.- fp32 averaging of
avg_pqs— ✅ Fixed. Usestorch.stack([p.float() for p in pqs_list]).mean(0)to avoid fp16/bf16 underflow. torch.inference_mode()in_dummy_forward— ✅ Fixed.- Bare
assert→RuntimeError— ✅ Fixed across all changed files. _disabled→is_enabledpublic API — ✅ Fixed.merge_amax_tensors_for_groupusage — ✅ Used for all amax merging, preserving per-channel structure.- Duplicate
_infer_prefix_remap— ✅ Eliminated. Reload side now importsinfer_quantizer_prefix_remapfrom the export module. - Weight quantizer invariant assertion — ✅ Added proper validation.
deepcopydevice alignment — ✅requant_weights_for_exportdoes.to(device=weight.device).safe_save/safe_load— ✅ Used instead of rawtorch.save/torch.load.
Tests have been meaningfully expanded: GPU test now covers INT4_AWQ_CFG in addition to FP8, and unit tests cover infer_quantizer_prefix_remap (7 cases) and merge_amax_tensors_for_group (3 cases). The code is well-structured with clear documentation and reasonable design decisions throughout.
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
ea4035e to
585cc06
Compare
meenchen
left a comment
There was a problem hiding this comment.
Change LGTM overall, request to add a unit test for MoE
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
mxinO
left a comment
There was a problem hiding this comment.
LGTM. Thanks.
Further question.
We are solving the the pqs fusing at the export time which is sub-optimal, this is actually a problem of calibration time, since all our inference frameworks use fused q,k,v and experts, any thought to solve it at calibration time? cc @realAsma
| ## Known Problems | ||
|
|
||
| 1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align). | ||
| 2. AWQ reload is not supported yet |
| for quantizer_copy in quantizers: | ||
| weight_quantized = quantizer_copy(weight_quantized) | ||
| for quantizer_copy in quantizers: | ||
| finish_stats_collection(quantizer_copy) |
There was a problem hiding this comment.
This logic looks can be simplified, distinguish the len(quantizers) == 1 is redundant, the loop in the else just works.
What does this PR do?
Type of change: Bug
Enables end-to-end AWQ checkpoint export and reload in the vLLM fake-quant serving path (
MODELOPT_STATE_PATH). Previously, theinput_quantizerwas using incorrectpre_quant_scaleespecially with grouped quantizers likeqkv_proj, using simply the firstinput_quantizer.pre_quant_scale. This MR adds_resmooth_experts_for_exportthat non-mutatively averagespre_quant_scaleacross MoE experts and unifies input_amax, required because vLLM uses a single input quantizer per expert group. Addsmerge_amax_tensors_for_group(element-wise max for same-shape,catfor GQA, scalar-max fallback) replacing the scalar-collapsingtorch.stack().max()that dropped per-channel_amaxstructure.Usage
Testing
Step 1 — Export the quantized checkpoint:
This produces
<EXPORT_DIR>/vllm_fq_modelopt_state.pthwith the averaged per-expertpre_quant_scale and unified _amax now included.
Step 2 — Serve via vLLM fakequant worker:
Tested for quantization configurations:
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Summary by CodeRabbit
New Features
Bug Fixes