Add ResNet50 support for torch_onnx quantization workflow#1263
Add ResNet50 support for torch_onnx quantization workflow#1263
Conversation
Add end-to-end support for ResNet50 (Conv2d-heavy model) in the torch_onnx quantization → ONNX export → TRT engine pipeline. Key fixes for Conv2d-heavy models: - Disable FP8 Conv2d weight quantizers during ONNX export to avoid TorchScript exporter's "kernel of unknown shape" error (FP8 DequantizeLinear produces dynamic-shape outputs incompatible with Conv2d's static kernel requirement) - Disable autocast for FP8/INT8 quantized models during export (prevents dynamic-shape kernels from autocast-induced FP16 casting) - Fix configure_linear_module_onnx_quantizers to handle all modules with block quantization (not just nn.Linear), fixing NVFP4/MXFP8 export for models with quantized non-Linear modules like MaxPool2d - Add calibration step for FP8 override quantizers that aren't calibrated by mtq.quantize() in MXFP8/NVFP4 modes - Override Conv2d block quantizers to FP8 in auto mode for TRT compat - Add maxpool and global_pool to filter_func (TRT DynamicQuantize requires 2D/3D input, but pooling layers operate on 4D tensors) - Always load calibration data (MXFP8 Conv2d FP8 overrides need it) Signed-off-by: ajrasane <arasane@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds resnet50 to tests/docs; introduces calibration of uncalibrated quantizers and Conv2d block-quantizer overrides during quantization; temporarily disables FP8 Conv weight quantizers during ONNX export; adds a CLI flag to build a TensorRT engine via trtexec after ONNX export. Changes
Sequence Diagram(s)sequenceDiagram
participant Script as Quantization Script
participant Model as Model
participant DL as DataLoader
participant Quant as Quantizers
participant Export as ONNX Export
Script->>Model: load model
Script->>Model: auto_quantize()/quantize_model(data_loader)
Model->>Quant: enumerate enabled quantizers
Quant-->>Model: identify uncalibrated & Conv2d block quantizers
Model->>DL: run _calibrate_uncalibrated_quantizers(data_loader)
DL-->>Model: calibration statistics (amax)
Model->>Model: apply calib_amax, disable calibration
Model->>Model: apply Conv2d override config if needed
Script->>Export: get_onnx_bytes_and_metadata(model)
Export->>Export: enter autocast/precision/quantizer contexts
Export->>Export: enter _disable_fp8_conv_weight_quantizers() if FP8 present
Export->>Model: torch.onnx.export
Export-->>Script: return ONNX bytes + metadata
Script->>Script: if --trt_build -> build_trt_engine(onnx_path) (trtexec)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 inconclusive)
✅ Passed checks (3 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (3)
examples/torch_onnx/torch_quant_to_onnx.py (2)
164-164: Consider using the publicamaxproperty instead of checking internal attribute.The code uses
hasattr(quantizer, "_amax")to detect uncalibrated quantizers. Based on theTensorQuantizerimplementation, the publicamaxproperty returnsNonewhen_amaxis not set. Usingquantizer.amax is Nonewould be more aligned with the public API.♻️ Suggested fix
- if quantizer.is_enabled and not quantizer.block_sizes and not hasattr(quantizer, "_amax"): + if quantizer.is_enabled and not quantizer.block_sizes and quantizer.amax is None:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/torch_onnx/torch_quant_to_onnx.py` at line 164, Replace the internal attribute check on quantizer with the public property: instead of checking hasattr(quantizer, "_amax") in the conditional that currently reads if quantizer.is_enabled and not quantizer.block_sizes and not hasattr(quantizer, "_amax"): use quantizer.amax is None to detect uncalibrated TensorQuantizer instances; update the condition to if quantizer.is_enabled and not quantizer.block_sizes and quantizer.amax is None so it uses the public TensorQuantizer API.
239-241: Prefer public property setters over internal attribute assignment.The code directly assigns to
_num_bitsand_axis, which are internal attributes. Based on the relevant code snippet fromtensor_quantizer.py, these have public property setters (num_bitsandaxis) that also update the calibrator state. Using the internal attributes may leave the calibrator out of sync.♻️ Suggested fix
# Override to FP8 per-tensor quantizer.block_sizes = None - quantizer._num_bits = (4, 3) - quantizer._axis = None + quantizer.num_bits = (4, 3) + quantizer.axis = None quantizer.enable_calib()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 239 - 241, The code assigns internal attributes quantizer._num_bits and quantizer._axis directly which can leave the calibrator out of sync; replace those direct assignments with the public property setters quantizer.num_bits = (4, 3) and quantizer.axis = None (leave quantizer.block_sizes = None as-is) so the setter logic in tensor_quantizer.py runs and updates the calibrator state accordingly.modelopt/torch/quantization/export_onnx.py (1)
657-674: Consider restoring quantizer state after export (optional).The context manager sets
_onnx_quantizer_typebut doesn't restore the original state in afinallyblock. While this is likely fine since models aren't typically reused after ONNX export, adding state restoration would make the context manager more robust for edge cases.♻️ Optional: Add state restoration
`@contextlib.contextmanager` def configure_linear_module_onnx_quantizers(model): """Sets the onnx export attributes for the given model. ... """ + original_states = [] for _, module in model.named_modules(): if hasattr(module, "input_quantizer") and module.input_quantizer.block_sizes: + original_states.append((module.input_quantizer, "_onnx_quantizer_type", getattr(module.input_quantizer, "_onnx_quantizer_type", None))) module.input_quantizer._onnx_quantizer_type = "dynamic" if hasattr(module, "weight_quantizer") and module.weight_quantizer.block_sizes: + original_states.append((module.weight_quantizer, "_onnx_quantizer_type", getattr(module.weight_quantizer, "_onnx_quantizer_type", None))) module.weight_quantizer._onnx_quantizer_type = "static" - yield + try: + yield + finally: + for obj, attr, orig_val in original_states: + if orig_val is None: + if hasattr(obj, attr): + delattr(obj, attr) + else: + setattr(obj, attr, orig_val)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/export_onnx.py` around lines 657 - 674, The context manager configure_linear_module_onnx_quantizers mutates module.input_quantizer._onnx_quantizer_type and module.weight_quantizer._onnx_quantizer_type but never restores original values; update configure_linear_module_onnx_quantizers to record the original _onnx_quantizer_type for each module found via model.named_modules() (checking hasattr(module, "input_quantizer") / "weight_quantizer" and block_sizes) and then yield, ensuring a finally block iterates the saved entries to reset each input_quantizer._onnx_quantizer_type and weight_quantizer._onnx_quantizer_type back to their original values (including handling missing/None originals).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Around line 161-165: Reformat the if-block to satisfy ruff: keep the same
logic but use consistent spacing and line breaks (wrap the long conditional in
parentheses across lines) and standardize quotes; specifically, keep the initial
guard "if not hasattr(module, attr_name): continue", assign quantizer =
getattr(module, attr_name), then write the conditional as e.g. "if
(quantizer.is_enabled and not quantizer.block_sizes and not hasattr(quantizer,
'_amax')): quantizer.enable_calib()" (or break into multiple indented lines) so
spacing, parentheses, and quotation are ruff-compliant while preserving behavior
of module, attr_name, quantizer, .is_enabled, .block_sizes, ._amax and
enable_calib().
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 587-593: The conditional assignment to conv_wq_context is
misformatted for the project's linter; reformat the expression using a single
parenthesized conditional expression: set conv_wq_context =
(_disable_fp8_conv_weight_quantizers(model) if is_fp8_quantized(model) else
nullcontext()), referencing the existing symbols
_disable_fp8_conv_weight_quantizers, is_fp8_quantized, nullcontext, and model so
the ternary is on one line inside parentheses and passes ruff formatting.
---
Nitpick comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Line 164: Replace the internal attribute check on quantizer with the public
property: instead of checking hasattr(quantizer, "_amax") in the conditional
that currently reads if quantizer.is_enabled and not quantizer.block_sizes and
not hasattr(quantizer, "_amax"): use quantizer.amax is None to detect
uncalibrated TensorQuantizer instances; update the condition to if
quantizer.is_enabled and not quantizer.block_sizes and quantizer.amax is None so
it uses the public TensorQuantizer API.
- Around line 239-241: The code assigns internal attributes quantizer._num_bits
and quantizer._axis directly which can leave the calibrator out of sync; replace
those direct assignments with the public property setters quantizer.num_bits =
(4, 3) and quantizer.axis = None (leave quantizer.block_sizes = None as-is) so
the setter logic in tensor_quantizer.py runs and updates the calibrator state
accordingly.
In `@modelopt/torch/quantization/export_onnx.py`:
- Around line 657-674: The context manager
configure_linear_module_onnx_quantizers mutates
module.input_quantizer._onnx_quantizer_type and
module.weight_quantizer._onnx_quantizer_type but never restores original values;
update configure_linear_module_onnx_quantizers to record the original
_onnx_quantizer_type for each module found via model.named_modules() (checking
hasattr(module, "input_quantizer") / "weight_quantizer" and block_sizes) and
then yield, ensuring a finally block iterates the saved entries to reset each
input_quantizer._onnx_quantizer_type and weight_quantizer._onnx_quantizer_type
back to their original values (including handling missing/None originals).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: cb1433e9-ae94-4122-80a6-bc7f9a108588
📒 Files selected for processing (6)
examples/torch_onnx/README.mdexamples/torch_onnx/torch_quant_to_onnx.pymodelopt/torch/_deploy/utils/torch_onnx.pymodelopt/torch/quantization/export_onnx.pytests/_test_utils/torch/vision_models.pytests/examples/torch_onnx/test_torch_quant_to_onnx.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1263 +/- ##
==========================================
- Coverage 76.90% 67.95% -8.96%
==========================================
Files 350 459 +109
Lines 40524 49658 +9134
==========================================
+ Hits 31166 33743 +2577
- Misses 9358 15915 +6557
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Reflect all supported quantization modes and Conv2d override behavior. Signed-off-by: ajrasane <arasane@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Around line 38-46: The docstring in torch_quant_to_onnx.py incorrectly groups
INT4_AWQ with fully supported quantization modes; update the top-level
description to either remove INT4_AWQ from the supported list or add a clear
caveat that INT4_AWQ is a known limitation and may not work end-to-end (e.g.,
"INT4_AWQ is experimental/limited — see PR objectives for current limitations"),
ensuring references to the script name and the quantization modes (FP8, INT8,
MXFP8, NVFP4, INT4_AWQ, AUTO) are adjusted so users won't assume INT4_AWQ is
fully supported.
- Around line 227-256: The override of Conv2d block quantizers to FP8 in
_override_conv2d_to_fp8 must not be applied after mtq.auto_quantize() because
mtq.auto_quantize() returns a search_state based on the candidate configs;
either incorporate the Conv2d->FP8/disable-block-quantization rule into the
candidate generation passed into mtq.auto_quantize() (so those Conv2d layers are
treated as FP8 during the search) or, if you must keep the override path,
recompute and validate the effective_bits/budget and update/refresh the returned
search_state after running _override_conv2d_to_fp8 so the final model’s budget
reflects FP8 costs for Conv2d (reference symbols: _override_conv2d_to_fp8,
mtq.auto_quantize, search_state).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 34fcb2fe-f485-45bd-96e4-df7cebf815fa
📒 Files selected for processing (1)
examples/torch_onnx/torch_quant_to_onnx.py
| Quantize a timm vision model and export to ONNX for TensorRT deployment. | ||
|
|
||
| Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes. | ||
|
|
||
| The script will: | ||
| 1. Given the model name, create a timm torch model. | ||
| 2. Quantize the torch model in MXFP8, NVFP4, INT4_AWQ, or AUTO mode. | ||
| 3. Export the quantized torch model to ONNX format. | ||
| 1. Load a pretrained timm model (e.g., ViT, Swin, ResNet). | ||
| 2. Quantize the model using the specified mode. For models with Conv2d layers, | ||
| Conv2d quantization is automatically overridden for TensorRT compatibility | ||
| (FP8 for MXFP8/NVFP4, INT8 for INT4_AWQ). |
There was a problem hiding this comment.
Don’t advertise INT4_AWQ as supported end-to-end here.
The PR objectives still call out INT4_AWQ as a known limitation, but this docstring now groups it with the working modes. Please caveat or remove it here so users do not assume this example is expected to succeed in that mode.
✏️ Suggested wording
-Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes.
+Supports FP8, INT8, MXFP8, NVFP4, and AUTO (mixed-precision) quantization modes.
+`INT4_AWQ` remains a known limitation for this example.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 38 - 46, The
docstring in torch_quant_to_onnx.py incorrectly groups INT4_AWQ with fully
supported quantization modes; update the top-level description to either remove
INT4_AWQ from the supported list or add a clear caveat that INT4_AWQ is a known
limitation and may not work end-to-end (e.g., "INT4_AWQ is experimental/limited
— see PR objectives for current limitations"), ensuring references to the script
name and the quantization modes (FP8, INT8, MXFP8, NVFP4, INT4_AWQ, AUTO) are
adjusted so users won't assume INT4_AWQ is fully supported.
Move TRT engine build logic into the script as a --trt_build flag, removing the duplicate trtexec invocation from the test file. Signed-off-by: ajrasane <arasane@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
There was a problem hiding this comment.
♻️ Duplicate comments (2)
examples/torch_onnx/torch_quant_to_onnx.py (2)
39-47:⚠️ Potential issue | 🟡 MinorRemove
INT4_AWQfrom the supported-modes docstring.This still reads like end-to-end support, but the PR summary calls out
INT4_AWQas a known limitation for this example.✏️ Suggested wording
-Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes. +Supports FP8, INT8, MXFP8, NVFP4, and AUTO (mixed-precision) quantization modes. +`INT4_AWQ` remains a known limitation for this example.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 39 - 47, Update the module-level docstring that lists supported quantization modes by removing "INT4_AWQ" from the supported-modes sentence (the top-level docstring/description in torch_quant_to_onnx.py); ensure any other occurrences in the same docstring (e.g., the second paragraph that enumerates FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO) are updated so INT4_AWQ is no longer mentioned, leaving the remaining modes and explanatory text intact.
228-257:⚠️ Potential issue | 🟠 MajorApply the Conv2d TRT restriction inside the auto-quant search.
mtq.auto_quantize()returns asearch_statefor the formats it actually evaluated. Rewriting selected Conv2d layers to FP8 afterwards means the exported model no longer matches that search space or itseffective_bitsbudget, so the returnedsearch_stateis stale. Push this compatibility rule into the candidate configs beforemtq.auto_quantize(), or recompute/validate the budget after the override.Also applies to: 308-310
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 228 - 257, The current post-hoc override in _override_conv2d_to_fp8 mutates Conv2d quantizers after mtq.auto_quantize has already produced a search_state and effective_bits budget, making search_state stale; move the Conv2d -> FP8 compatibility rule into the candidate configuration generation that mtq.auto_quantize consumes (or run a budget/validation pass to recompute search_state/effective_bits after performing the override). Specifically, incorporate the TRT Conv2d restriction into the candidate configs or pre-filter candidates before calling mtq.auto_quantize, or if you keep _override_conv2d_to_fp8, call a function to recompute/validate search_state and effective_bits (the returned object from mtq.auto_quantize) immediately after performing the override so the exported model and reported budget remain consistent.
🧹 Nitpick comments (1)
examples/torch_onnx/torch_quant_to_onnx.py (1)
198-202: Avoid forcing calibration data for quantizers that are disabled later.This now downloads calibration data for every non-
autorun, then does the extra calibration pass beforemtq.disable_quantizer(). For models whose Conv2d quantizers are later filtered out, that makesmxfp8/nvfp4pay the Tiny-ImageNet + calibration cost even though no active override survives to export. Consider applying the filter before_calibrate_uncalibrated_quantizers(), or loading calibration data lazily only when an enabled Conv2d override remains.Also applies to: 462-473
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 198 - 202, The current flow calls _calibrate_uncalibrated_quantizers(quantized_model, data_loader) before mtq.disable_quantizer(...), causing calibration data to be downloaded even for quantizers that will be disabled; change the order or gate calibration so only enabled overrides are calibrated: either call mtq.disable_quantizer(quantized_model, filter_func) before invoking _calibrate_uncalibrated_quantizers, or add a pre-check that filters quantized_model (using filter_func) to determine if any Conv2d override remains enabled and only then load/process data_loader and call _calibrate_uncalibrated_quantizers; reference functions: _calibrate_uncalibrated_quantizers, mtq.disable_quantizer, filter_func, quantized_model, and data_loader.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Around line 39-47: Update the module-level docstring that lists supported
quantization modes by removing "INT4_AWQ" from the supported-modes sentence (the
top-level docstring/description in torch_quant_to_onnx.py); ensure any other
occurrences in the same docstring (e.g., the second paragraph that enumerates
FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO) are updated so INT4_AWQ is no
longer mentioned, leaving the remaining modes and explanatory text intact.
- Around line 228-257: The current post-hoc override in _override_conv2d_to_fp8
mutates Conv2d quantizers after mtq.auto_quantize has already produced a
search_state and effective_bits budget, making search_state stale; move the
Conv2d -> FP8 compatibility rule into the candidate configuration generation
that mtq.auto_quantize consumes (or run a budget/validation pass to recompute
search_state/effective_bits after performing the override). Specifically,
incorporate the TRT Conv2d restriction into the candidate configs or pre-filter
candidates before calling mtq.auto_quantize, or if you keep
_override_conv2d_to_fp8, call a function to recompute/validate search_state and
effective_bits (the returned object from mtq.auto_quantize) immediately after
performing the override so the exported model and reported budget remain
consistent.
---
Nitpick comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Around line 198-202: The current flow calls
_calibrate_uncalibrated_quantizers(quantized_model, data_loader) before
mtq.disable_quantizer(...), causing calibration data to be downloaded even for
quantizers that will be disabled; change the order or gate calibration so only
enabled overrides are calibrated: either call
mtq.disable_quantizer(quantized_model, filter_func) before invoking
_calibrate_uncalibrated_quantizers, or add a pre-check that filters
quantized_model (using filter_func) to determine if any Conv2d override remains
enabled and only then load/process data_loader and call
_calibrate_uncalibrated_quantizers; reference functions:
_calibrate_uncalibrated_quantizers, mtq.disable_quantizer, filter_func,
quantized_model, and data_loader.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 54e5c119-bb09-448b-ba29-0ad682692eb3
📒 Files selected for processing (2)
examples/torch_onnx/torch_quant_to_onnx.pytests/examples/torch_onnx/test_torch_quant_to_onnx.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/examples/torch_onnx/test_torch_quant_to_onnx.py
Signed-off-by: ajrasane <arasane@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Around line 248-250: Replace the direct internal assignments to
quantizer._num_bits and quantizer._axis with the public property setters so the
calibrator stays synchronized: use quantizer.num_bits = (4, 3) and
quantizer.axis = None instead of writing to _num_bits and _axis (keeping
quantizer.block_sizes assignment as-is); update references in the same snippet
to call the num_bits and axis properties on the quantizer object to ensure
proper calibrator/internal state updates.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 703357c3-97d1-492b-a949-8b09ff915c3c
📒 Files selected for processing (2)
examples/torch_onnx/torch_quant_to_onnx.pymodelopt/torch/_deploy/utils/torch_onnx.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/_deploy/utils/torch_onnx.py
…ormat configs Previously, Conv2d layers were overridden from block quantization to FP8 after mtq.auto_quantize() returned, causing the effective_bits budget and search_state to be stale. Move the Conv2d TRT overrides into the format configs passed to auto_quantize so the search correctly accounts for Conv2d being FP8/INT8 in the budget. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/torch_onnx/torch_quant_to_onnx.py (1)
128-143:⚠️ Potential issue | 🟠 MajorUse the constructed model to derive calibration transforms.
Now that every non-auto run goes through
load_calibration_data(), this helper’s internaltimm.create_model(..., pretrained=True, ...)becomes user-visible behavior. That ignores--no_pretrained, ignores--model_kwargs, and can calibrate with a differentdata_configthan the model you actually quantize/export.♻️ Suggested direction
-def load_calibration_data(model_name, data_size, batch_size, device, with_labels=False): +def load_calibration_data(model, data_size, batch_size, device, with_labels=False): """Load and prepare calibration data.""" dataset = load_dataset("zh-plus/tiny-imagenet") - model = timm.create_model(model_name, pretrained=True, num_classes=1000) data_config = timm.data.resolve_model_data_config(model) transforms = timm.data.create_transform(**data_config, is_training=False)- data_loader = load_calibration_data( - args.timm_model_name, + data_loader = load_calibration_data( + model, args.calibration_data_size, args.batch_size, device, with_labels=False, )Also applies to: 448-454
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 128 - 143, load_calibration_data currently constructs its own timm model (timm.create_model) which ignores caller flags (e.g., --no_pretrained, --model_kwargs) and can produce a data_config different from the model being quantized; change load_calibration_data to accept either an optional model instance (e.g., model argument) or explicit model creation params (pretrained flag + model_kwargs) and then derive data_config via timm.data.resolve_model_data_config using that same model instance; if no model instance is passed, create the model using the supplied pretrained and model_kwargs so transforms = timm.data.create_transform(**data_config, is_training=False) is always based on the exact model to be quantized/exported (apply same change to the similar helper around the 448-454 area).
♻️ Duplicate comments (1)
examples/torch_onnx/torch_quant_to_onnx.py (1)
41-47:⚠️ Potential issue | 🟡 MinorCaveat
INT4_AWQin the end-to-end messaging.The PR summary still treats
INT4_AWQas a known limitation for this example, but these strings still read as full TRT-path support. Please either remove it from the supported list here or explicitly scope it to quantize/export-only so users do not assume--quantize_mode int4_awq --trt_buildis expected to work.Also applies to: 307-309
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 41 - 47, Update the messaging around INT4_AWQ so users don't assume full TensorRT support: locate the top script description text that lists supported quantize modes (the string "INT4_AWQ") and either remove "INT4_AWQ" from that supported-list or modify its wording to explicitly state "INT4_AWQ (quantize/export-only; not supported with --trt_build/TensorRT runtime)". Make the same change for the other occurrences referenced (the block around lines 307-309) so any mention of "--quantize_mode int4_awq" is explicitly scoped as export-only and not compatible with "--trt_build".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Around line 490-503: In build_trt_engine, wrap the subprocess.run call in a
try/except that catches FileNotFoundError and subprocess.TimeoutExpired (and
optionally OSError) and re-raise a clear RuntimeError that explains trtexec is
missing or timed out (include the original exception message), then preserve the
existing returncode check and RuntimeError for non-zero exits; this makes
failures like missing TensorRT or long-running builds yield a concise
example-level error instead of raw exceptions.
---
Outside diff comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Around line 128-143: load_calibration_data currently constructs its own timm
model (timm.create_model) which ignores caller flags (e.g., --no_pretrained,
--model_kwargs) and can produce a data_config different from the model being
quantized; change load_calibration_data to accept either an optional model
instance (e.g., model argument) or explicit model creation params (pretrained
flag + model_kwargs) and then derive data_config via
timm.data.resolve_model_data_config using that same model instance; if no model
instance is passed, create the model using the supplied pretrained and
model_kwargs so transforms = timm.data.create_transform(**data_config,
is_training=False) is always based on the exact model to be quantized/exported
(apply same change to the similar helper around the 448-454 area).
---
Duplicate comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Around line 41-47: Update the messaging around INT4_AWQ so users don't assume
full TensorRT support: locate the top script description text that lists
supported quantize modes (the string "INT4_AWQ") and either remove "INT4_AWQ"
from that supported-list or modify its wording to explicitly state "INT4_AWQ
(quantize/export-only; not supported with --trt_build/TensorRT runtime)". Make
the same change for the other occurrences referenced (the block around lines
307-309) so any mention of "--quantize_mode int4_awq" is explicitly scoped as
export-only and not compatible with "--trt_build".
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: c45751cb-4052-4ca5-a6ab-114f767504ac
📒 Files selected for processing (1)
examples/torch_onnx/torch_quant_to_onnx.py
| def build_trt_engine(onnx_path): | ||
| """Build a TensorRT engine from the exported ONNX model using trtexec.""" | ||
| cmd = [ | ||
| "trtexec", | ||
| f"--onnx={onnx_path}", | ||
| "--stronglyTyped", | ||
| "--builderOptimizationLevel=4", | ||
| ] | ||
| print(f"\nBuilding TensorRT engine: {' '.join(cmd)}") | ||
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) | ||
| if result.returncode != 0: | ||
| raise RuntimeError( | ||
| f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}" | ||
| ) |
There was a problem hiding this comment.
Wrap trtexec launch errors in an example-level error message.
--trt_build currently surfaces raw FileNotFoundError / TimeoutExpired exceptions. Catching those here would turn missing TensorRT installs and long-running builds into clear, actionable failures instead of a traceback.
🛠️ Suggested fix
def build_trt_engine(onnx_path):
"""Build a TensorRT engine from the exported ONNX model using trtexec."""
cmd = [
"trtexec",
f"--onnx={onnx_path}",
"--stronglyTyped",
"--builderOptimizationLevel=4",
]
print(f"\nBuilding TensorRT engine: {' '.join(cmd)}")
- result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
+ try:
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
+ except FileNotFoundError as exc:
+ raise RuntimeError("`trtexec` was not found in PATH. Install TensorRT or omit --trt_build.") from exc
+ except subprocess.TimeoutExpired as exc:
+ raise RuntimeError(f"TensorRT engine build timed out after {exc.timeout}s for {onnx_path}.") from exc
if result.returncode != 0:
raise RuntimeError(
f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}"
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 490 - 503, In
build_trt_engine, wrap the subprocess.run call in a try/except that catches
FileNotFoundError and subprocess.TimeoutExpired (and optionally OSError) and
re-raise a clear RuntimeError that explains trtexec is missing or timed out
(include the original exception message), then preserve the existing returncode
check and RuntimeError for non-zero exits; this makes failures like missing
TensorRT or long-running builds yield a concise example-level error instead of
raw exceptions.
|
Does this quantized model perform similarly to the same model quantized via the ONNX path? The reason I'm asking is because previous experiments showed that the Torch path does not add Q/DQ nodes in the Residual branches, causing perf regression when compared to its ONNX counterpart. |
cjluo-nv
left a comment
There was a problem hiding this comment.
The PR cleanly adds ResNet50 support to the torch_onnx quantization workflow with well-considered fixes:
-
configure_linear_module_onnx_quantizersfix (export_onnx.py): Correctly generalizes fromisinstance(module, nn.Linear)to checkingblock_sizeson any module with quantizers. This is more precise (only sets export attributes when block quantization is actually used) and handles non-Linear modules like MaxPool2d that get NVFP4/MXFP8 input quantizers. -
FP8 Conv weight quantizer disable (torch_onnx.py): Well-implemented context manager with proper try/finally cleanup. The issue (TorchScript exporter needs static kernel shapes, FP8 DequantizeLinear produces dynamic shapes) is clearly documented.
-
Autocast disable for FP8/INT8: Correct — these quantized models should not use autocast during export since it could interfere with the QDQ patterns.
-
_calibrate_uncalibrated_quantizers: Clean solution for the gap where MXFP8/NVFP4 pipelines skip standard calibration but FP8 Conv overrides need it. The detection heuristic (enabled, no block_sizes, no _amax) is sound. -
Test simplification: Moving TRT build into the main script via
--trt_buildreduces test boilerplate and ensures the build logic is available for non-test use too. -
PR size: 159 additions, 58 deletions — well-scoped and cohesive.
All 5 quant modes tested for ResNet50 (fp8, int8, mxfp8, nvfp4, auto).
Doesn't the high-precision tensors will be in FP32 instead of FP16? This will cause perf regressions when using |
TorchScript ONNX export breaks when Conv weight quantizers are enabled because TRT_FP8DequantizeLinear produces unknown shapes. This restores FP8 weight quantization as a post-processing step in FP8QuantExporter and adds a utility to fold redundant DQ->Cast(FP32->FP16) patterns inserted by float16 conversion. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
We are using the convert_float_to_float16 API as the autocast API fails for these precisions. I will put out a follow up PR to make sure that we can run all the precisions with the AutoCast API. |
Summary
configure_linear_module_onnx_quantizersto handle all modules with block quantization (not justnn.Linear), fixing NVFP4/MXFP8 export for models with quantized non-Linear modules--trt_buildflag totorch_quant_to_onnx.pyand simplify test infrastructureFiles Changed
modelopt/torch/_deploy/utils/torch_onnx.py— Disable FP8 Conv2d weight quantizers and autocast during ONNX exportmodelopt/torch/quantization/export_onnx.py— Fixconfigure_linear_module_onnx_quantizersfor all module types with block quantizationexamples/torch_onnx/torch_quant_to_onnx.py— Add--trt_buildflag, calibration for FP8 override quantizers, Conv2d→FP8 override for auto mode, filter_func updatesexamples/torch_onnx/README.md— Add ResNet50 to supported models tabletests/examples/torch_onnx/test_torch_quant_to_onnx.py— Add ResNet50 test entry, simplify using--trt_buildtests/_test_utils/torch/vision_models.py— Add ResNet50 to timm model registryQuantization modes passing
Test plan
pytest tests/examples/torch_onnx/test_torch_quant_to_onnx.py -k resnet50(5/5 passed)🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Improvements