From 56a1d211b4b280ce07e5e802506cb44de5c3edbf Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Tue, 14 Apr 2026 21:26:39 +0000 Subject: [PATCH 1/6] Add ResNet50 support for torch_onnx quantization workflow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add end-to-end support for ResNet50 (Conv2d-heavy model) in the torch_onnx quantization → ONNX export → TRT engine pipeline. Key fixes for Conv2d-heavy models: - Disable FP8 Conv2d weight quantizers during ONNX export to avoid TorchScript exporter's "kernel of unknown shape" error (FP8 DequantizeLinear produces dynamic-shape outputs incompatible with Conv2d's static kernel requirement) - Disable autocast for FP8/INT8 quantized models during export (prevents dynamic-shape kernels from autocast-induced FP16 casting) - Fix configure_linear_module_onnx_quantizers to handle all modules with block quantization (not just nn.Linear), fixing NVFP4/MXFP8 export for models with quantized non-Linear modules like MaxPool2d - Add calibration step for FP8 override quantizers that aren't calibrated by mtq.quantize() in MXFP8/NVFP4 modes - Override Conv2d block quantizers to FP8 in auto mode for TRT compat - Add maxpool and global_pool to filter_func (TRT DynamicQuantize requires 2D/3D input, but pooling layers operate on 4D tensors) - Always load calibration data (MXFP8 Conv2d FP8 overrides need it) Signed-off-by: ajrasane Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/torch_onnx/README.md | 1 + examples/torch_onnx/torch_quant_to_onnx.py | 95 ++++++++++++++++--- modelopt/torch/_deploy/utils/torch_onnx.py | 46 ++++++++- modelopt/torch/quantization/export_onnx.py | 15 ++- tests/_test_utils/torch/vision_models.py | 1 + .../torch_onnx/test_torch_quant_to_onnx.py | 1 + 6 files changed, 142 insertions(+), 17 deletions(-) diff --git a/examples/torch_onnx/README.md b/examples/torch_onnx/README.md index d540770116..cfd4dc380c 100644 --- a/examples/torch_onnx/README.md +++ b/examples/torch_onnx/README.md @@ -307,6 +307,7 @@ python torch_quant_to_onnx.py \ | [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [resnet50](https://huggingface.co/timm/resnet50.a1_in1k) | ✅ | ✅ | ✅ | ✅ | | ✅ | ## Resources diff --git a/examples/torch_onnx/torch_quant_to_onnx.py b/examples/torch_onnx/torch_quant_to_onnx.py index 7f74e617e8..220e3f0684 100644 --- a/examples/torch_onnx/torch_quant_to_onnx.py +++ b/examples/torch_onnx/torch_quant_to_onnx.py @@ -109,7 +109,8 @@ def filter_func(name): """Filter function to exclude certain layers from quantization.""" pattern = re.compile( r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|" - r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*" + r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|" + r"downsample|maxpool|global_pool).*" ) return pattern.match(name) is not None @@ -147,6 +148,36 @@ def load_calibration_data(model_name, data_size, batch_size, device, with_labels ) +def _calibrate_uncalibrated_quantizers(model, data_loader): + """Calibrate FP8 quantizers that weren't calibrated by mtq.quantize(). + + When MXFP8/NVFP4 modes override Conv2d to FP8, the FP8 quantizers may not + be calibrated because the MXFP8/NVFP4 quantization pipeline skips standard + calibration. This function explicitly calibrates those uncalibrated quantizers. + """ + uncalibrated = [] + for _, module in model.named_modules(): + for attr_name in ("input_quantizer", "weight_quantizer"): + if not hasattr(module, attr_name): + continue + quantizer = getattr(module, attr_name) + if quantizer.is_enabled and not quantizer.block_sizes and not hasattr(quantizer, "_amax"): + quantizer.enable_calib() + uncalibrated.append(quantizer) + + if not uncalibrated: + return + + model.eval() + with torch.no_grad(): + for batch in data_loader: + model(batch) + + for quantizer in uncalibrated: + quantizer.disable_calib() + quantizer.load_calib_amax() + + def quantize_model(model, config, data_loader=None): """Quantize the model using the given config and calibration data.""" if data_loader is not None: @@ -159,6 +190,10 @@ def forward_loop(model): else: quantized_model = mtq.quantize(model, config) + # Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize() + if data_loader is not None: + _calibrate_uncalibrated_quantizers(quantized_model, data_loader) + mtq.disable_quantizer(quantized_model, filter_func) return quantized_model @@ -185,6 +220,38 @@ def _disable_inplace_relu(model): module.inplace = False +def _override_conv2d_to_fp8(model, data_loader): + """Override Conv2d layers with NVFP4/MXFP8 block quantization to FP8. + + TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. + This overrides Conv2d block quantizers to FP8 per-tensor and calibrates them. + """ + overridden = [] + for _, module in model.named_modules(): + if not isinstance(module, torch.nn.Conv2d): + continue + for attr_name in ("input_quantizer", "weight_quantizer"): + if not hasattr(module, attr_name): + continue + quantizer = getattr(module, attr_name) + if quantizer.is_enabled and quantizer.block_sizes: + # Override to FP8 per-tensor + quantizer.block_sizes = None + quantizer._num_bits = (4, 3) + quantizer._axis = None + quantizer.enable_calib() + overridden.append(quantizer) + + if overridden: + model.eval() + with torch.no_grad(): + for batch in data_loader: + model(batch["image"]) + for quantizer in overridden: + quantizer.disable_calib() + quantizer.load_calib_amax() + + def auto_quantize_model( model, data_loader, @@ -233,6 +300,10 @@ def auto_quantize_model( verbose=True, ) + # Override Conv2d layers that got NVFP4/MXFP8 to FP8 for TRT compatibility. + # TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. + _override_conv2d_to_fp8(quantized_model, data_loader) + # Disable quantization for specified layers mtq.disable_quantizer(quantized_model, filter_func) @@ -378,18 +449,18 @@ def main(): args.num_score_steps, ) else: - # Standard quantization - only load calibration data if needed + # Standard quantization - load calibration data + # Note: MXFP8 is dynamic and does not need calibration itself, but when + # Conv2d layers are overridden to FP8 (for TRT compatibility), those FP8 + # quantizers require calibration data. config = get_quant_config(args.quantize_mode) - if args.quantize_mode == "mxfp8": - data_loader = None - else: - data_loader = load_calibration_data( - args.timm_model_name, - args.calibration_data_size, - args.batch_size, - device, - with_labels=False, - ) + data_loader = load_calibration_data( + args.timm_model_name, + args.calibration_data_size, + args.batch_size, + device, + with_labels=False, + ) quantized_model = quantize_model(model, config, data_loader) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 8cb741dbc7..d9eab72e33 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -16,6 +16,7 @@ """Utility functions related to Onnx.""" import base64 +import contextlib import inspect import json import logging @@ -402,6 +403,29 @@ def is_fp8_quantized(model: nn.Module) -> bool: return False +@contextlib.contextmanager +def _disable_fp8_conv_weight_quantizers(model: nn.Module): + """Temporarily disable FP8 weight quantizers on Conv layers during ONNX export. + + The TorchScript ONNX exporter requires static kernel shapes for Conv operations, + but FP8 weight quantization (TRT_FP8QuantizeLinear -> TRT_FP8DequantizeLinear) + produces dynamic-shape outputs that break this requirement. Disabling Conv weight + quantizers during export allows the Conv to export with static-shape FP16/FP32 + weights. Conv activations still have FP8 QDQ nodes (input quantizers remain enabled). + """ + disabled = [] + for _, module in model.named_modules(): + if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + if hasattr(module, "weight_quantizer") and module.weight_quantizer.is_enabled: + module.weight_quantizer.disable() + disabled.append(module) + try: + yield + finally: + for module in disabled: + module.weight_quantizer.enable() + + def quantize_weights(model: nn.Module, onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Real quantizes the weights in the onnx model. @@ -522,7 +546,11 @@ def get_onnx_bytes_and_metadata( input_none_names = list(set(tree_spec_input.names) - set(input_names)) use_torch_autocast = not ( - is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32" + is_fp4_quantized(model) + or is_mxfp8_quantized(model) + or is_fp8_quantized(model) + or is_int8_quantized(model) + or weights_dtype == "fp32" ) autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext() @@ -556,7 +584,14 @@ def get_onnx_bytes_and_metadata( if is_fp4_quantized(model) or is_mxfp8_quantized(model) else nullcontext() ) - with torch.inference_mode(), autocast, quantizer_context: + # Disable FP8 Conv weight quantizers: TorchScript ONNX exporter requires static + # kernel shapes, but FP8 DequantizeLinear produces dynamic shapes. + conv_wq_context = ( + _disable_fp8_conv_weight_quantizers(model) + if is_fp8_quantized(model) + else nullcontext() + ) + with torch.inference_mode(), autocast, quantizer_context, conv_wq_context: additional_kwargs = {} if not dynamo_export: additional_kwargs["dynamic_axes"] = dynamic_axes @@ -598,7 +633,12 @@ def get_onnx_bytes_and_metadata( onnx_opt_graph = qdq_to_dq(onnx_opt_graph) if weights_dtype in ["fp16", "bf16"]: - if is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model): + if ( + is_int4_quantized(model) + or is_mxfp8_quantized(model) + or is_fp8_quantized(model) + or is_int8_quantized(model) + ): assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" onnx_opt_graph = convert_float_to_float16( onnx_opt_graph, diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index fe4fb8d70a..c91cef760b 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -656,9 +656,20 @@ def export_fp4( @contextlib.contextmanager def configure_linear_module_onnx_quantizers(model): - """Sets the onnx export attributes for the given model.""" + """Sets the onnx export attributes for the given model. + + For modules with block quantization (NVFP4/MXFP8): + - Weight quantizers use "static" export (TRT_FP4QDQ for NVFP4, DQ-only for MXFP8) + - Input/activation quantizers use "dynamic" export (TRT_FP4DynamicQuantize, etc.) + + This must be set for ALL modules with block quantization, not just nn.Linear, + because models like ResNet have non-Linear modules (e.g., MaxPool2d) with NVFP4/MXFP8 + input quantizers that would otherwise default to the static path and produce + TRT_FP4QDQ nodes on activations (which the NVFP4 exporter cannot handle). + """ for _, module in model.named_modules(): - if isinstance(module, torch.nn.Linear): + if hasattr(module, "input_quantizer") and module.input_quantizer.block_sizes: module.input_quantizer._onnx_quantizer_type = "dynamic" + if hasattr(module, "weight_quantizer") and module.weight_quantizer.block_sizes: module.weight_quantizer._onnx_quantizer_type = "static" yield diff --git a/tests/_test_utils/torch/vision_models.py b/tests/_test_utils/torch/vision_models.py index 5fed1d20c1..942167d763 100644 --- a/tests/_test_utils/torch/vision_models.py +++ b/tests/_test_utils/torch/vision_models.py @@ -117,6 +117,7 @@ def get_model_and_input(on_gpu: bool = False): # "dm_nfnet_f0", "efficientnet_b0", "swin_tiny_patch4_window7_224", + "resnet50", ], _create_timm_fn, ), diff --git a/tests/examples/torch_onnx/test_torch_quant_to_onnx.py b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py index 7c2692c1d9..51a6e462d6 100644 --- a/tests/examples/torch_onnx/test_torch_quant_to_onnx.py +++ b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py @@ -28,6 +28,7 @@ "vit_tiny": ("vit_tiny_patch16_224", '{"depth": 1}'), "swin_tiny": ("swin_tiny_patch4_window7_224", '{"depths": [1, 1, 1, 1]}'), "swinv2_tiny": ("swinv2_tiny_window8_256", '{"depths": [1, 1, 1, 1]}'), + "resnet50": ("resnet50", None), } # Builder optimization level: 4 for low-bit modes, 3 otherwise From 9810e1d7c144132b9798565119c52f3ea7c5a608 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Tue, 14 Apr 2026 21:43:45 +0000 Subject: [PATCH 2/6] Update torch_quant_to_onnx script docstring Reflect all supported quantization modes and Conv2d override behavior. Signed-off-by: ajrasane Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/torch_onnx/torch_quant_to_onnx.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/torch_onnx/torch_quant_to_onnx.py b/examples/torch_onnx/torch_quant_to_onnx.py index 220e3f0684..c39c878c97 100644 --- a/examples/torch_onnx/torch_quant_to_onnx.py +++ b/examples/torch_onnx/torch_quant_to_onnx.py @@ -35,13 +35,17 @@ import modelopt.torch.quantization as mtq """ -This script is used to quantize a timm model using dynamic quantization like MXFP8 or NVFP4, -or using auto quantization for optimal per-layer quantization. +Quantize a timm vision model and export to ONNX for TensorRT deployment. + +Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes. The script will: -1. Given the model name, create a timm torch model. -2. Quantize the torch model in MXFP8, NVFP4, INT4_AWQ, or AUTO mode. -3. Export the quantized torch model to ONNX format. +1. Load a pretrained timm model (e.g., ViT, Swin, ResNet). +2. Quantize the model using the specified mode. For models with Conv2d layers, + Conv2d quantization is automatically overridden for TensorRT compatibility + (FP8 for MXFP8/NVFP4, INT8 for INT4_AWQ). +3. Export the quantized model to ONNX with FP16 weights. +4. Optionally evaluate accuracy on ImageNet-1k before and after quantization. """ From 54cefb250847edff4d074931fc2ada6bb1581a3d Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Tue, 14 Apr 2026 21:56:09 +0000 Subject: [PATCH 3/6] Add --trt_build flag to torch_quant_to_onnx and simplify tests Move TRT engine build logic into the script as a --trt_build flag, removing the duplicate trtexec invocation from the test file. Signed-off-by: ajrasane Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/torch_onnx/torch_quant_to_onnx.py | 26 ++++++++++++++ .../torch_onnx/test_torch_quant_to_onnx.py | 35 +------------------ 2 files changed, 27 insertions(+), 34 deletions(-) diff --git a/examples/torch_onnx/torch_quant_to_onnx.py b/examples/torch_onnx/torch_quant_to_onnx.py index c39c878c97..37caebb36a 100644 --- a/examples/torch_onnx/torch_quant_to_onnx.py +++ b/examples/torch_onnx/torch_quant_to_onnx.py @@ -17,6 +17,7 @@ import copy import json import re +import subprocess import sys import warnings from pathlib import Path @@ -395,6 +396,11 @@ def main(): default=128, help="Number of scoring steps for auto quantization. Default is 128.", ) + parser.add_argument( + "--trt_build", + action="store_true", + help="Build a TensorRT engine from the exported ONNX model using trtexec.", + ) parser.add_argument( "--no_pretrained", action="store_true", @@ -496,6 +502,26 @@ def main(): print(f"Quantized ONNX model is saved to {args.onnx_save_path}") + if args.trt_build: + build_trt_engine(args.onnx_save_path) + + +def build_trt_engine(onnx_path): + """Build a TensorRT engine from the exported ONNX model using trtexec.""" + cmd = [ + "trtexec", + f"--onnx={onnx_path}", + "--stronglyTyped", + "--builderOptimizationLevel=4", + ] + print(f"\nBuilding TensorRT engine: {' '.join(cmd)}") + result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) + if result.returncode != 0: + raise RuntimeError( + f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}" + ) + print("TensorRT engine build succeeded.") + if __name__ == "__main__": main() diff --git a/tests/examples/torch_onnx/test_torch_quant_to_onnx.py b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py index 51a6e462d6..6d6e0d9de5 100644 --- a/tests/examples/torch_onnx/test_torch_quant_to_onnx.py +++ b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py @@ -14,9 +14,6 @@ # limitations under the License. -import os -import subprocess - import pytest from _test_utils.examples.run_command import extend_cmd_parts, run_example_command @@ -31,32 +28,6 @@ "resnet50": ("resnet50", None), } -# Builder optimization level: 4 for low-bit modes, 3 otherwise -_LOW_BIT_MODES = {"fp8", "int8", "nvfp4"} - - -def _verify_trt_engine_build(onnx_save_path, quantize_mode): - """Verify the exported ONNX model can be compiled into a TensorRT engine.""" - example_dir = os.path.join( - os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx" - ) - onnx_path = os.path.join(example_dir, onnx_save_path) - assert os.path.exists(onnx_path), f"ONNX file not found: {onnx_path}" - - opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3" - cmd = [ - "trtexec", - f"--onnx={onnx_path}", - "--stronglyTyped", - f"--builderOptimizationLevel={opt_level}", - ] - - result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) - assert result.returncode == 0, ( - f"TensorRT engine build failed for {onnx_save_path} " - f"(mode={quantize_mode}):\n{result.stdout}\n{result.stderr}" - ) - @pytest.mark.parametrize("quantize_mode", _QUANT_MODES) @pytest.mark.parametrize("model_key", list(_MODELS)) @@ -64,7 +35,6 @@ def test_torch_onnx(model_key, quantize_mode): timm_model_name, model_kwargs = _MODELS[model_key] onnx_save_path = f"{model_key}.{quantize_mode}.onnx" - # Step 1: Quantize and export to ONNX cmd_parts = extend_cmd_parts( ["python", "torch_quant_to_onnx.py"], timm_model_name=timm_model_name, @@ -74,8 +44,5 @@ def test_torch_onnx(model_key, quantize_mode): calibration_data_size="1", num_score_steps="1", ) - cmd_parts.append("--no_pretrained") + cmd_parts.extend(["--no_pretrained", "--trt_build"]) run_example_command(cmd_parts, "torch_onnx") - - # Step 2: Verify the exported ONNX model builds a TensorRT engine - _verify_trt_engine_build(onnx_save_path, quantize_mode) From ba020d5485b909da1a0e2b5135141b70f8a936de Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Tue, 14 Apr 2026 22:04:01 +0000 Subject: [PATCH 4/6] Fix ruff formatting issues Signed-off-by: ajrasane Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/torch_onnx/torch_quant_to_onnx.py | 6 +++++- modelopt/torch/_deploy/utils/torch_onnx.py | 4 +--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/torch_onnx/torch_quant_to_onnx.py b/examples/torch_onnx/torch_quant_to_onnx.py index 37caebb36a..7411bdc579 100644 --- a/examples/torch_onnx/torch_quant_to_onnx.py +++ b/examples/torch_onnx/torch_quant_to_onnx.py @@ -166,7 +166,11 @@ def _calibrate_uncalibrated_quantizers(model, data_loader): if not hasattr(module, attr_name): continue quantizer = getattr(module, attr_name) - if quantizer.is_enabled and not quantizer.block_sizes and not hasattr(quantizer, "_amax"): + if ( + quantizer.is_enabled + and not quantizer.block_sizes + and not hasattr(quantizer, "_amax") + ): quantizer.enable_calib() uncalibrated.append(quantizer) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index d9eab72e33..472bb919a9 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -587,9 +587,7 @@ def get_onnx_bytes_and_metadata( # Disable FP8 Conv weight quantizers: TorchScript ONNX exporter requires static # kernel shapes, but FP8 DequantizeLinear produces dynamic shapes. conv_wq_context = ( - _disable_fp8_conv_weight_quantizers(model) - if is_fp8_quantized(model) - else nullcontext() + _disable_fp8_conv_weight_quantizers(model) if is_fp8_quantized(model) else nullcontext() ) with torch.inference_mode(), autocast, quantizer_context, conv_wq_context: additional_kwargs = {} From b8b1d5cdf2cc1a21e749f8aa8e799db64bbfd887 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:18:37 +0000 Subject: [PATCH 5/6] Fix auto_quantize Conv2d budget by incorporating TRT overrides into format configs Previously, Conv2d layers were overridden from block quantization to FP8 after mtq.auto_quantize() returned, causing the effective_bits budget and search_state to be stale. Move the Conv2d TRT overrides into the format configs passed to auto_quantize so the search correctly accounts for Conv2d being FP8/INT8 in the budget. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/torch_onnx/torch_quant_to_onnx.py | 53 ++++++---------------- 1 file changed, 15 insertions(+), 38 deletions(-) diff --git a/examples/torch_onnx/torch_quant_to_onnx.py b/examples/torch_onnx/torch_quant_to_onnx.py index 7411bdc579..c72acbc804 100644 --- a/examples/torch_onnx/torch_quant_to_onnx.py +++ b/examples/torch_onnx/torch_quant_to_onnx.py @@ -86,6 +86,11 @@ }, ] +# Auto-quantize format configs that use block quantization and need Conv2d overrides for TRT. +# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. +_NEEDS_FP8_CONV_OVERRIDE: set[str] = {"NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", "MXFP8_DEFAULT_CFG"} +_NEEDS_INT8_CONV_OVERRIDE: set[str] = {"INT4_AWQ_CFG"} + def get_quant_config(quantize_mode): """Get quantization config, overriding Conv2d for TRT compatibility. @@ -229,38 +234,6 @@ def _disable_inplace_relu(model): module.inplace = False -def _override_conv2d_to_fp8(model, data_loader): - """Override Conv2d layers with NVFP4/MXFP8 block quantization to FP8. - - TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. - This overrides Conv2d block quantizers to FP8 per-tensor and calibrates them. - """ - overridden = [] - for _, module in model.named_modules(): - if not isinstance(module, torch.nn.Conv2d): - continue - for attr_name in ("input_quantizer", "weight_quantizer"): - if not hasattr(module, attr_name): - continue - quantizer = getattr(module, attr_name) - if quantizer.is_enabled and quantizer.block_sizes: - # Override to FP8 per-tensor - quantizer.block_sizes = None - quantizer._num_bits = (4, 3) - quantizer._axis = None - quantizer.enable_calib() - overridden.append(quantizer) - - if overridden: - model.eval() - with torch.no_grad(): - for batch in data_loader: - model(batch["image"]) - for quantizer in overridden: - quantizer.disable_calib() - quantizer.load_calib_amax() - - def auto_quantize_model( model, data_loader, @@ -285,11 +258,19 @@ def auto_quantize_model( _disable_inplace_relu(model) constraints = {"effective_bits": effective_bits} - # Convert string format names to actual config objects + # Convert string format names to config objects, incorporating Conv2d TRT overrides. + # TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. + # By including the overrides in the format configs, the auto_quantize search + # correctly accounts for Conv2d being FP8/INT8 in the effective_bits budget. format_configs = [] for fmt in quantization_formats: if isinstance(fmt, str): - format_configs.append(getattr(mtq, fmt)) + config = copy.deepcopy(getattr(mtq, fmt)) + if fmt in _NEEDS_FP8_CONV_OVERRIDE: + config["quant_cfg"].extend(_FP8_CONV_OVERRIDE) + elif fmt in _NEEDS_INT8_CONV_OVERRIDE: + config["quant_cfg"].extend(_INT8_CONV_OVERRIDE) + format_configs.append(config) else: format_configs.append(fmt) @@ -309,10 +290,6 @@ def auto_quantize_model( verbose=True, ) - # Override Conv2d layers that got NVFP4/MXFP8 to FP8 for TRT compatibility. - # TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. - _override_conv2d_to_fp8(quantized_model, data_loader) - # Disable quantization for specified layers mtq.disable_quantizer(quantized_model, filter_func) From 6419b34523d2217d6f7eadc2c8798cfe7e8f9771 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Thu, 16 Apr 2026 16:01:40 +0000 Subject: [PATCH 6/6] Fix FP8 Conv weight quantization in ONNX export pipeline TorchScript ONNX export breaks when Conv weight quantizers are enabled because TRT_FP8DequantizeLinear produces unknown shapes. This restores FP8 weight quantization as a post-processing step in FP8QuantExporter and adds a utility to fold redundant DQ->Cast(FP32->FP16) patterns inserted by float16 conversion. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/torch_onnx/torch_quant_to_onnx.py | 2 +- modelopt/onnx/export/fp8_exporter.py | 87 +++++++++++++++++++++- modelopt/onnx/utils.py | 81 ++++++++++++++++++++ modelopt/torch/_deploy/utils/torch_onnx.py | 16 ++-- 4 files changed, 176 insertions(+), 10 deletions(-) diff --git a/examples/torch_onnx/torch_quant_to_onnx.py b/examples/torch_onnx/torch_quant_to_onnx.py index c72acbc804..ceae3b6518 100644 --- a/examples/torch_onnx/torch_quant_to_onnx.py +++ b/examples/torch_onnx/torch_quant_to_onnx.py @@ -120,7 +120,7 @@ def filter_func(name): pattern = re.compile( r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|" r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|" - r"downsample|maxpool|global_pool).*" + r"maxpool|global_pool).*" ) return pattern.match(name) is not None diff --git a/modelopt/onnx/export/fp8_exporter.py b/modelopt/onnx/export/fp8_exporter.py index ffcbd89423..c86150b205 100644 --- a/modelopt/onnx/export/fp8_exporter.py +++ b/modelopt/onnx/export/fp8_exporter.py @@ -101,13 +101,89 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: return gs.export_onnx(graph) + @staticmethod + def _quantize_conv_weights_to_fp8(graph: gs.Graph) -> int: + """Add FP8 weight DequantizeLinear for Conv layers with unquantized weights. + + Conv weight quantizers are disabled during TorchScript ONNX export because the + TRT_FP8DequantizeLinear custom op produces outputs with unknown shapes, causing + the _convolution symbolic to fail. This method restores FP8 weight quantization + by inserting DQ nodes in the ONNX graph, mirroring the compress_weights logic. + + For each Conv node with an unquantized constant weight: + 1. Compute per-tensor scale = max(abs(weight)) / 448.0 + 2. Quantize weights to FP8E4M3FN + 3. Insert a DequantizeLinear(fp8_weights, scale) before the Conv weight input + + Args: + graph: The onnx-graphsurgeon graph to modify in-place. + + Returns: + Number of Conv weight DQ nodes inserted. + """ + FP8_MAX = 448.0 + count = 0 + + for node in list(graph.nodes): + if node.op != "Conv": + continue + if len(node.inputs) < 2: + continue + + weight_input = node.inputs[1] + if not isinstance(weight_input, gs.Constant): + continue + + # Skip if weight already has a DQ producer + if any(out.op == "DequantizeLinear" for out in weight_input.outputs): + continue + + torch_weights = torch.from_numpy(weight_input.values.copy()) + amax = torch_weights.abs().max().float() + if amax == 0: + continue + scale_val = (amax / FP8_MAX).item() + + # Quantize weights to FP8 (WAR: numpy doesn't support fp8) + fp8_data = ( + (torch_weights / scale_val).to(torch.float8_e4m3fn).view(torch.uint8).numpy() + ) + fp8_tensor = onnx.TensorProto() + fp8_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN + fp8_tensor.dims.extend(fp8_data.shape) + fp8_tensor.raw_data = fp8_data.tobytes() + fp8_constant = gs.Constant( + node.name + "/weight_quantizer/fp8_weights", LazyValues(fp8_tensor) + ) + + # Scale in FP16 — DQ output type matches scale dtype, must match activation type + import numpy as np + + scale_constant = gs.Constant( + node.name + "/weight_quantizer/scale", + np.array(scale_val, dtype=np.float16), + ) + + dq_output = gs.Variable(node.name + "/weight_quantizer/dq_output") + dq_node = gs.Node( + op="DequantizeLinear", + name=node.name + "/weight_quantizer/DequantizeLinear", + inputs=[fp8_constant, scale_constant], + outputs=[dq_output], + ) + graph.nodes.append(dq_node) + node.inputs[1] = dq_output + count += 1 + + return count + @staticmethod def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Post-processes the ONNX model for FP8 quantization. - Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear: - - TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1 - - TRT_FP8DequantizeLinear -> DequantizeLinear + Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear and + adds FP8 weight DQ for Conv layers whose weight quantizers were disabled during + TorchScript export. Args: onnx_model: The ONNX model containing TRT_FP8 quantization nodes. @@ -144,5 +220,10 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear" ) + # Add FP8 weight DQ for Conv layers that had weight quantizers disabled during export + count = FP8QuantExporter._quantize_conv_weights_to_fp8(graph) + if count > 0: + logger.info(f"Inserted FP8 weight DequantizeLinear for {count} Conv nodes") + graph.cleanup().toposort() return gs.export_onnx(graph) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 54efd0a111..a37fd59845 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -1504,6 +1504,87 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: return onnx_model +def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Remove Cast(FP32->FP16) nodes after DequantizeLinear by setting DQ output to FP16. + + When convert_float_to_float16 blocks DequantizeLinear, it inserts Cast nodes to bridge + the FP32 DQ output to the FP16 graph. This function removes those Cast nodes by: + 1. Converting the DQ scale initializer from FP32 to FP16 + 2. Updating the DQ output type to FP16 in value_info + 3. Bypassing and removing the Cast node + + Args: + onnx_model: The ONNX model with DQ -> Cast(FP32->FP16) patterns. + + Returns: + The ONNX model with Cast nodes removed and DQ outputs set to FP16. + """ + DQ_OPS = {"DequantizeLinear", "TRT_FP8DequantizeLinear"} + + # Build a map of tensor name -> producer node + producer_map: dict[str, onnx.NodeProto] = {} + for node in onnx_model.graph.node: + for out in node.output: + producer_map[out] = node + + # Build initializer lookup + initializer_map: dict[str, onnx.TensorProto] = { + init.name: init for init in onnx_model.graph.initializer + } + + nodes_to_remove = [] + for node in onnx_model.graph.node: + if node.op_type != "Cast": + continue + + # Check: Cast target is FP16 + cast_to = None + for attr in node.attribute: + if attr.name == "to": + cast_to = attr.i + if cast_to != onnx.TensorProto.FLOAT16: + continue + + # Check: producer is a DQ node + producer = producer_map.get(node.input[0]) + if producer is None or producer.op_type not in DQ_OPS: + continue + + # Convert the DQ scale initializer from FP32 to FP16 + # DQ inputs: [input, scale, (zero_point)] + if len(producer.input) >= 2: + scale_name = producer.input[1] + if scale_name in initializer_map: + scale_init = initializer_map[scale_name] + if scale_init.data_type == onnx.TensorProto.FLOAT: + import numpy as np + + scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32) + if not scale_data.size: + scale_data = np.array(scale_init.float_data, dtype=np.float32) + scale_fp16 = scale_data.astype(np.float16) + scale_init.data_type = onnx.TensorProto.FLOAT16 + scale_init.raw_data = scale_fp16.tobytes() + del scale_init.float_data[:] + + # Bypass the Cast node + _bypass_cast_node(onnx_model, node) + nodes_to_remove.append(node) + + # Update the DQ output type in value_info + dq_output_name = producer.output[0] + for vi in onnx_model.graph.value_info: + if vi.name == dq_output_name: + vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16 + break + + logger.debug(f"Folded {len(nodes_to_remove)} DQ -> Cast(FP32->FP16) patterns") + for node in nodes_to_remove: + onnx_model.graph.node.remove(node) + + return onnx_model + + def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto: """Remove `training_mode` attribute and extra training outputs from nodes of a given op type. diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 472bb919a9..fdaf4e08d1 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -47,6 +47,7 @@ from modelopt.onnx.utils import ( change_casts_to_fp16, check_model_uses_external_data, + fold_dq_fp32_to_fp16_casts, get_input_names, get_input_shapes, get_node_names, @@ -408,10 +409,10 @@ def _disable_fp8_conv_weight_quantizers(model: nn.Module): """Temporarily disable FP8 weight quantizers on Conv layers during ONNX export. The TorchScript ONNX exporter requires static kernel shapes for Conv operations, - but FP8 weight quantization (TRT_FP8QuantizeLinear -> TRT_FP8DequantizeLinear) - produces dynamic-shape outputs that break this requirement. Disabling Conv weight - quantizers during export allows the Conv to export with static-shape FP16/FP32 - weights. Conv activations still have FP8 QDQ nodes (input quantizers remain enabled). + but the TRT_FP8DequantizeLinear custom op produces outputs with unknown shapes in + the TorchScript IR, causing the _convolution symbolic to fail. Disabling Conv weight + quantizers during export allows the Conv to export with static-shape FP16/FP32 weights. + FP8 weight quantization is restored as a post-processing step in FP8QuantExporter. """ disabled = [] for _, module in model.named_modules(): @@ -584,8 +585,9 @@ def get_onnx_bytes_and_metadata( if is_fp4_quantized(model) or is_mxfp8_quantized(model) else nullcontext() ) - # Disable FP8 Conv weight quantizers: TorchScript ONNX exporter requires static - # kernel shapes, but FP8 DequantizeLinear produces dynamic shapes. + # Disable FP8 Conv weight quantizers: TorchScript custom ops produce outputs with + # unknown shapes, causing _convolution symbolic to fail. Conv weights are quantized + # to FP8 in post-processing by FP8QuantExporter instead. conv_wq_context = ( _disable_fp8_conv_weight_quantizers(model) if is_fp8_quantized(model) else nullcontext() ) @@ -648,6 +650,8 @@ def get_onnx_bytes_and_metadata( # Change FP32 cast nodes feeding into Concat/Add to FP16 op_list = ["Concat", "Add", "Sqrt", "LayerNormalization", "Clip", "Mul", "Exp"] onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, op_list) + # Remove Cast(FP32->FP16) nodes after DQ by setting DQ output to FP16 directly + onnx_opt_graph = fold_dq_fp32_to_fp16_casts(onnx_opt_graph) else: onnx_opt_graph = convert_to_f16( onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False