-
Notifications
You must be signed in to change notification settings - Fork 358
Add ResNet50 support for torch_onnx quantization workflow #1263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
56a1d21
9810e1d
54cefb2
ba020d5
b8b1d5c
6419b34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import copy | ||
| import json | ||
| import re | ||
| import subprocess | ||
| import sys | ||
| import warnings | ||
| from pathlib import Path | ||
|
|
@@ -35,13 +36,17 @@ | |
| import modelopt.torch.quantization as mtq | ||
|
|
||
| """ | ||
| This script is used to quantize a timm model using dynamic quantization like MXFP8 or NVFP4, | ||
| or using auto quantization for optimal per-layer quantization. | ||
| Quantize a timm vision model and export to ONNX for TensorRT deployment. | ||
|
|
||
| Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes. | ||
|
|
||
| The script will: | ||
| 1. Given the model name, create a timm torch model. | ||
| 2. Quantize the torch model in MXFP8, NVFP4, INT4_AWQ, or AUTO mode. | ||
| 3. Export the quantized torch model to ONNX format. | ||
| 1. Load a pretrained timm model (e.g., ViT, Swin, ResNet). | ||
| 2. Quantize the model using the specified mode. For models with Conv2d layers, | ||
| Conv2d quantization is automatically overridden for TensorRT compatibility | ||
| (FP8 for MXFP8/NVFP4, INT8 for INT4_AWQ). | ||
| 3. Export the quantized model to ONNX with FP16 weights. | ||
| 4. Optionally evaluate accuracy on ImageNet-1k before and after quantization. | ||
| """ | ||
|
|
||
|
|
||
|
|
@@ -81,6 +86,11 @@ | |
| }, | ||
| ] | ||
|
|
||
| # Auto-quantize format configs that use block quantization and need Conv2d overrides for TRT. | ||
| # TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. | ||
| _NEEDS_FP8_CONV_OVERRIDE: set[str] = {"NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", "MXFP8_DEFAULT_CFG"} | ||
| _NEEDS_INT8_CONV_OVERRIDE: set[str] = {"INT4_AWQ_CFG"} | ||
|
|
||
|
|
||
| def get_quant_config(quantize_mode): | ||
| """Get quantization config, overriding Conv2d for TRT compatibility. | ||
|
|
@@ -109,7 +119,8 @@ def filter_func(name): | |
| """Filter function to exclude certain layers from quantization.""" | ||
| pattern = re.compile( | ||
| r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|" | ||
| r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*" | ||
| r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|" | ||
| r"maxpool|global_pool).*" | ||
| ) | ||
| return pattern.match(name) is not None | ||
|
|
||
|
|
@@ -147,6 +158,40 @@ def load_calibration_data(model_name, data_size, batch_size, device, with_labels | |
| ) | ||
|
|
||
|
|
||
| def _calibrate_uncalibrated_quantizers(model, data_loader): | ||
| """Calibrate FP8 quantizers that weren't calibrated by mtq.quantize(). | ||
|
|
||
| When MXFP8/NVFP4 modes override Conv2d to FP8, the FP8 quantizers may not | ||
| be calibrated because the MXFP8/NVFP4 quantization pipeline skips standard | ||
| calibration. This function explicitly calibrates those uncalibrated quantizers. | ||
| """ | ||
| uncalibrated = [] | ||
| for _, module in model.named_modules(): | ||
| for attr_name in ("input_quantizer", "weight_quantizer"): | ||
| if not hasattr(module, attr_name): | ||
| continue | ||
| quantizer = getattr(module, attr_name) | ||
| if ( | ||
| quantizer.is_enabled | ||
| and not quantizer.block_sizes | ||
| and not hasattr(quantizer, "_amax") | ||
| ): | ||
| quantizer.enable_calib() | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| uncalibrated.append(quantizer) | ||
|
|
||
| if not uncalibrated: | ||
| return | ||
|
|
||
| model.eval() | ||
| with torch.no_grad(): | ||
| for batch in data_loader: | ||
| model(batch) | ||
|
|
||
| for quantizer in uncalibrated: | ||
| quantizer.disable_calib() | ||
| quantizer.load_calib_amax() | ||
|
|
||
|
|
||
| def quantize_model(model, config, data_loader=None): | ||
| """Quantize the model using the given config and calibration data.""" | ||
| if data_loader is not None: | ||
|
|
@@ -159,6 +204,10 @@ def forward_loop(model): | |
| else: | ||
| quantized_model = mtq.quantize(model, config) | ||
|
|
||
| # Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize() | ||
| if data_loader is not None: | ||
| _calibrate_uncalibrated_quantizers(quantized_model, data_loader) | ||
|
|
||
| mtq.disable_quantizer(quantized_model, filter_func) | ||
| return quantized_model | ||
|
|
||
|
|
@@ -209,11 +258,19 @@ def auto_quantize_model( | |
| _disable_inplace_relu(model) | ||
| constraints = {"effective_bits": effective_bits} | ||
|
|
||
| # Convert string format names to actual config objects | ||
| # Convert string format names to config objects, incorporating Conv2d TRT overrides. | ||
| # TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. | ||
| # By including the overrides in the format configs, the auto_quantize search | ||
| # correctly accounts for Conv2d being FP8/INT8 in the effective_bits budget. | ||
| format_configs = [] | ||
| for fmt in quantization_formats: | ||
| if isinstance(fmt, str): | ||
| format_configs.append(getattr(mtq, fmt)) | ||
| config = copy.deepcopy(getattr(mtq, fmt)) | ||
| if fmt in _NEEDS_FP8_CONV_OVERRIDE: | ||
| config["quant_cfg"].extend(_FP8_CONV_OVERRIDE) | ||
| elif fmt in _NEEDS_INT8_CONV_OVERRIDE: | ||
| config["quant_cfg"].extend(_INT8_CONV_OVERRIDE) | ||
| format_configs.append(config) | ||
| else: | ||
| format_configs.append(fmt) | ||
|
|
||
|
|
@@ -320,6 +377,11 @@ def main(): | |
| default=128, | ||
| help="Number of scoring steps for auto quantization. Default is 128.", | ||
| ) | ||
| parser.add_argument( | ||
| "--trt_build", | ||
| action="store_true", | ||
| help="Build a TensorRT engine from the exported ONNX model using trtexec.", | ||
| ) | ||
| parser.add_argument( | ||
| "--no_pretrained", | ||
| action="store_true", | ||
|
|
@@ -378,18 +440,18 @@ def main(): | |
| args.num_score_steps, | ||
| ) | ||
| else: | ||
| # Standard quantization - only load calibration data if needed | ||
| # Standard quantization - load calibration data | ||
| # Note: MXFP8 is dynamic and does not need calibration itself, but when | ||
| # Conv2d layers are overridden to FP8 (for TRT compatibility), those FP8 | ||
| # quantizers require calibration data. | ||
| config = get_quant_config(args.quantize_mode) | ||
| if args.quantize_mode == "mxfp8": | ||
| data_loader = None | ||
| else: | ||
| data_loader = load_calibration_data( | ||
| args.timm_model_name, | ||
| args.calibration_data_size, | ||
| args.batch_size, | ||
| device, | ||
| with_labels=False, | ||
| ) | ||
| data_loader = load_calibration_data( | ||
| args.timm_model_name, | ||
| args.calibration_data_size, | ||
| args.batch_size, | ||
| device, | ||
| with_labels=False, | ||
| ) | ||
|
|
||
| quantized_model = quantize_model(model, config, data_loader) | ||
|
|
||
|
|
@@ -421,6 +483,26 @@ def main(): | |
|
|
||
| print(f"Quantized ONNX model is saved to {args.onnx_save_path}") | ||
|
|
||
| if args.trt_build: | ||
| build_trt_engine(args.onnx_save_path) | ||
|
|
||
|
|
||
| def build_trt_engine(onnx_path): | ||
| """Build a TensorRT engine from the exported ONNX model using trtexec.""" | ||
| cmd = [ | ||
| "trtexec", | ||
| f"--onnx={onnx_path}", | ||
| "--stronglyTyped", | ||
| "--builderOptimizationLevel=4", | ||
| ] | ||
| print(f"\nBuilding TensorRT engine: {' '.join(cmd)}") | ||
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) | ||
| if result.returncode != 0: | ||
| raise RuntimeError( | ||
| f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}" | ||
| ) | ||
|
Comment on lines
+490
to
+503
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wrap
🛠️ Suggested fix def build_trt_engine(onnx_path):
"""Build a TensorRT engine from the exported ONNX model using trtexec."""
cmd = [
"trtexec",
f"--onnx={onnx_path}",
"--stronglyTyped",
"--builderOptimizationLevel=4",
]
print(f"\nBuilding TensorRT engine: {' '.join(cmd)}")
- result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
+ try:
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
+ except FileNotFoundError as exc:
+ raise RuntimeError("`trtexec` was not found in PATH. Install TensorRT or omit --trt_build.") from exc
+ except subprocess.TimeoutExpired as exc:
+ raise RuntimeError(f"TensorRT engine build timed out after {exc.timeout}s for {onnx_path}.") from exc
if result.returncode != 0:
raise RuntimeError(
f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}"
)🤖 Prompt for AI Agents |
||
| print("TensorRT engine build succeeded.") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don’t advertise
INT4_AWQas supported end-to-end here.The PR objectives still call out
INT4_AWQas a known limitation, but this docstring now groups it with the working modes. Please caveat or remove it here so users do not assume this example is expected to succeed in that mode.✏️ Suggested wording
🤖 Prompt for AI Agents