Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/torch_onnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ python torch_quant_to_onnx.py \
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [resnet50](https://huggingface.co/timm/resnet50.a1_in1k) | ✅ | ✅ | ✅ | ✅ | | ✅ |

## Resources

Expand Down
120 changes: 101 additions & 19 deletions examples/torch_onnx/torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import json
import re
import subprocess
import sys
import warnings
from pathlib import Path
Expand All @@ -35,13 +36,17 @@
import modelopt.torch.quantization as mtq

"""
This script is used to quantize a timm model using dynamic quantization like MXFP8 or NVFP4,
or using auto quantization for optimal per-layer quantization.
Quantize a timm vision model and export to ONNX for TensorRT deployment.

Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes.

The script will:
1. Given the model name, create a timm torch model.
2. Quantize the torch model in MXFP8, NVFP4, INT4_AWQ, or AUTO mode.
3. Export the quantized torch model to ONNX format.
1. Load a pretrained timm model (e.g., ViT, Swin, ResNet).
2. Quantize the model using the specified mode. For models with Conv2d layers,
Conv2d quantization is automatically overridden for TensorRT compatibility
(FP8 for MXFP8/NVFP4, INT8 for INT4_AWQ).
Comment on lines +39 to +47
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don’t advertise INT4_AWQ as supported end-to-end here.

The PR objectives still call out INT4_AWQ as a known limitation, but this docstring now groups it with the working modes. Please caveat or remove it here so users do not assume this example is expected to succeed in that mode.

✏️ Suggested wording
-Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes.
+Supports FP8, INT8, MXFP8, NVFP4, and AUTO (mixed-precision) quantization modes.
+`INT4_AWQ` remains a known limitation for this example.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 38 - 46, The
docstring in torch_quant_to_onnx.py incorrectly groups INT4_AWQ with fully
supported quantization modes; update the top-level description to either remove
INT4_AWQ from the supported list or add a clear caveat that INT4_AWQ is a known
limitation and may not work end-to-end (e.g., "INT4_AWQ is experimental/limited
— see PR objectives for current limitations"), ensuring references to the script
name and the quantization modes (FP8, INT8, MXFP8, NVFP4, INT4_AWQ, AUTO) are
adjusted so users won't assume INT4_AWQ is fully supported.

3. Export the quantized model to ONNX with FP16 weights.
4. Optionally evaluate accuracy on ImageNet-1k before and after quantization.
"""


Expand Down Expand Up @@ -81,6 +86,11 @@
},
]

# Auto-quantize format configs that use block quantization and need Conv2d overrides for TRT.
# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
_NEEDS_FP8_CONV_OVERRIDE: set[str] = {"NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", "MXFP8_DEFAULT_CFG"}
_NEEDS_INT8_CONV_OVERRIDE: set[str] = {"INT4_AWQ_CFG"}


def get_quant_config(quantize_mode):
"""Get quantization config, overriding Conv2d for TRT compatibility.
Expand Down Expand Up @@ -109,7 +119,8 @@ def filter_func(name):
"""Filter function to exclude certain layers from quantization."""
pattern = re.compile(
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|"
r"maxpool|global_pool).*"
)
return pattern.match(name) is not None

Expand Down Expand Up @@ -147,6 +158,40 @@ def load_calibration_data(model_name, data_size, batch_size, device, with_labels
)


def _calibrate_uncalibrated_quantizers(model, data_loader):
"""Calibrate FP8 quantizers that weren't calibrated by mtq.quantize().

When MXFP8/NVFP4 modes override Conv2d to FP8, the FP8 quantizers may not
be calibrated because the MXFP8/NVFP4 quantization pipeline skips standard
calibration. This function explicitly calibrates those uncalibrated quantizers.
"""
uncalibrated = []
for _, module in model.named_modules():
for attr_name in ("input_quantizer", "weight_quantizer"):
if not hasattr(module, attr_name):
continue
quantizer = getattr(module, attr_name)
if (
quantizer.is_enabled
and not quantizer.block_sizes
and not hasattr(quantizer, "_amax")
):
quantizer.enable_calib()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
uncalibrated.append(quantizer)

if not uncalibrated:
return

model.eval()
with torch.no_grad():
for batch in data_loader:
model(batch)

for quantizer in uncalibrated:
quantizer.disable_calib()
quantizer.load_calib_amax()


def quantize_model(model, config, data_loader=None):
"""Quantize the model using the given config and calibration data."""
if data_loader is not None:
Expand All @@ -159,6 +204,10 @@ def forward_loop(model):
else:
quantized_model = mtq.quantize(model, config)

# Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize()
if data_loader is not None:
_calibrate_uncalibrated_quantizers(quantized_model, data_loader)

mtq.disable_quantizer(quantized_model, filter_func)
return quantized_model

Expand Down Expand Up @@ -209,11 +258,19 @@ def auto_quantize_model(
_disable_inplace_relu(model)
constraints = {"effective_bits": effective_bits}

# Convert string format names to actual config objects
# Convert string format names to config objects, incorporating Conv2d TRT overrides.
# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
# By including the overrides in the format configs, the auto_quantize search
# correctly accounts for Conv2d being FP8/INT8 in the effective_bits budget.
format_configs = []
for fmt in quantization_formats:
if isinstance(fmt, str):
format_configs.append(getattr(mtq, fmt))
config = copy.deepcopy(getattr(mtq, fmt))
if fmt in _NEEDS_FP8_CONV_OVERRIDE:
config["quant_cfg"].extend(_FP8_CONV_OVERRIDE)
elif fmt in _NEEDS_INT8_CONV_OVERRIDE:
config["quant_cfg"].extend(_INT8_CONV_OVERRIDE)
format_configs.append(config)
else:
format_configs.append(fmt)

Expand Down Expand Up @@ -320,6 +377,11 @@ def main():
default=128,
help="Number of scoring steps for auto quantization. Default is 128.",
)
parser.add_argument(
"--trt_build",
action="store_true",
help="Build a TensorRT engine from the exported ONNX model using trtexec.",
)
parser.add_argument(
"--no_pretrained",
action="store_true",
Expand Down Expand Up @@ -378,18 +440,18 @@ def main():
args.num_score_steps,
)
else:
# Standard quantization - only load calibration data if needed
# Standard quantization - load calibration data
# Note: MXFP8 is dynamic and does not need calibration itself, but when
# Conv2d layers are overridden to FP8 (for TRT compatibility), those FP8
# quantizers require calibration data.
config = get_quant_config(args.quantize_mode)
if args.quantize_mode == "mxfp8":
data_loader = None
else:
data_loader = load_calibration_data(
args.timm_model_name,
args.calibration_data_size,
args.batch_size,
device,
with_labels=False,
)
data_loader = load_calibration_data(
args.timm_model_name,
args.calibration_data_size,
args.batch_size,
device,
with_labels=False,
)

quantized_model = quantize_model(model, config, data_loader)

Expand Down Expand Up @@ -421,6 +483,26 @@ def main():

print(f"Quantized ONNX model is saved to {args.onnx_save_path}")

if args.trt_build:
build_trt_engine(args.onnx_save_path)


def build_trt_engine(onnx_path):
"""Build a TensorRT engine from the exported ONNX model using trtexec."""
cmd = [
"trtexec",
f"--onnx={onnx_path}",
"--stronglyTyped",
"--builderOptimizationLevel=4",
]
print(f"\nBuilding TensorRT engine: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
if result.returncode != 0:
raise RuntimeError(
f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}"
)
Comment on lines +490 to +503
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Wrap trtexec launch errors in an example-level error message.

--trt_build currently surfaces raw FileNotFoundError / TimeoutExpired exceptions. Catching those here would turn missing TensorRT installs and long-running builds into clear, actionable failures instead of a traceback.

🛠️ Suggested fix
 def build_trt_engine(onnx_path):
     """Build a TensorRT engine from the exported ONNX model using trtexec."""
     cmd = [
         "trtexec",
         f"--onnx={onnx_path}",
         "--stronglyTyped",
         "--builderOptimizationLevel=4",
     ]
     print(f"\nBuilding TensorRT engine: {' '.join(cmd)}")
-    result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
+    try:
+        result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
+    except FileNotFoundError as exc:
+        raise RuntimeError("`trtexec` was not found in PATH. Install TensorRT or omit --trt_build.") from exc
+    except subprocess.TimeoutExpired as exc:
+        raise RuntimeError(f"TensorRT engine build timed out after {exc.timeout}s for {onnx_path}.") from exc
     if result.returncode != 0:
         raise RuntimeError(
             f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}"
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 490 - 503, In
build_trt_engine, wrap the subprocess.run call in a try/except that catches
FileNotFoundError and subprocess.TimeoutExpired (and optionally OSError) and
re-raise a clear RuntimeError that explains trtexec is missing or timed out
(include the original exception message), then preserve the existing returncode
check and RuntimeError for non-zero exits; this makes failures like missing
TensorRT or long-running builds yield a concise example-level error instead of
raw exceptions.

print("TensorRT engine build succeeded.")


if __name__ == "__main__":
main()
87 changes: 84 additions & 3 deletions modelopt/onnx/export/fp8_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,89 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:

return gs.export_onnx(graph)

@staticmethod
def _quantize_conv_weights_to_fp8(graph: gs.Graph) -> int:
"""Add FP8 weight DequantizeLinear for Conv layers with unquantized weights.

Conv weight quantizers are disabled during TorchScript ONNX export because the
TRT_FP8DequantizeLinear custom op produces outputs with unknown shapes, causing
the _convolution symbolic to fail. This method restores FP8 weight quantization
by inserting DQ nodes in the ONNX graph, mirroring the compress_weights logic.

For each Conv node with an unquantized constant weight:
1. Compute per-tensor scale = max(abs(weight)) / 448.0
2. Quantize weights to FP8E4M3FN
3. Insert a DequantizeLinear(fp8_weights, scale) before the Conv weight input

Args:
graph: The onnx-graphsurgeon graph to modify in-place.

Returns:
Number of Conv weight DQ nodes inserted.
"""
FP8_MAX = 448.0
count = 0

for node in list(graph.nodes):
if node.op != "Conv":
continue
if len(node.inputs) < 2:
continue

weight_input = node.inputs[1]
if not isinstance(weight_input, gs.Constant):
continue

# Skip if weight already has a DQ producer
if any(out.op == "DequantizeLinear" for out in weight_input.outputs):
continue

torch_weights = torch.from_numpy(weight_input.values.copy())
amax = torch_weights.abs().max().float()
if amax == 0:
continue
scale_val = (amax / FP8_MAX).item()

# Quantize weights to FP8 (WAR: numpy doesn't support fp8)
fp8_data = (
(torch_weights / scale_val).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
)
fp8_tensor = onnx.TensorProto()
fp8_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
fp8_tensor.dims.extend(fp8_data.shape)
fp8_tensor.raw_data = fp8_data.tobytes()
fp8_constant = gs.Constant(
node.name + "/weight_quantizer/fp8_weights", LazyValues(fp8_tensor)
)

# Scale in FP16 — DQ output type matches scale dtype, must match activation type
import numpy as np

scale_constant = gs.Constant(
node.name + "/weight_quantizer/scale",
np.array(scale_val, dtype=np.float16),
)

dq_output = gs.Variable(node.name + "/weight_quantizer/dq_output")
dq_node = gs.Node(
op="DequantizeLinear",
name=node.name + "/weight_quantizer/DequantizeLinear",
inputs=[fp8_constant, scale_constant],
outputs=[dq_output],
)
graph.nodes.append(dq_node)
node.inputs[1] = dq_output
count += 1

return count

@staticmethod
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Post-processes the ONNX model for FP8 quantization.

Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear:
- TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1
- TRT_FP8DequantizeLinear -> DequantizeLinear
Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear and
adds FP8 weight DQ for Conv layers whose weight quantizers were disabled during
TorchScript export.

Args:
onnx_model: The ONNX model containing TRT_FP8 quantization nodes.
Expand Down Expand Up @@ -144,5 +220,10 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear"
)

# Add FP8 weight DQ for Conv layers that had weight quantizers disabled during export
count = FP8QuantExporter._quantize_conv_weights_to_fp8(graph)
if count > 0:
logger.info(f"Inserted FP8 weight DequantizeLinear for {count} Conv nodes")

graph.cleanup().toposort()
return gs.export_onnx(graph)
81 changes: 81 additions & 0 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,87 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
return onnx_model


def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Remove Cast(FP32->FP16) nodes after DequantizeLinear by setting DQ output to FP16.

When convert_float_to_float16 blocks DequantizeLinear, it inserts Cast nodes to bridge
the FP32 DQ output to the FP16 graph. This function removes those Cast nodes by:
1. Converting the DQ scale initializer from FP32 to FP16
2. Updating the DQ output type to FP16 in value_info
3. Bypassing and removing the Cast node

Args:
onnx_model: The ONNX model with DQ -> Cast(FP32->FP16) patterns.

Returns:
The ONNX model with Cast nodes removed and DQ outputs set to FP16.
"""
DQ_OPS = {"DequantizeLinear", "TRT_FP8DequantizeLinear"}

# Build a map of tensor name -> producer node
producer_map: dict[str, onnx.NodeProto] = {}
for node in onnx_model.graph.node:
for out in node.output:
producer_map[out] = node

# Build initializer lookup
initializer_map: dict[str, onnx.TensorProto] = {
init.name: init for init in onnx_model.graph.initializer
}

nodes_to_remove = []
for node in onnx_model.graph.node:
if node.op_type != "Cast":
continue

# Check: Cast target is FP16
cast_to = None
for attr in node.attribute:
if attr.name == "to":
cast_to = attr.i
if cast_to != onnx.TensorProto.FLOAT16:
continue

# Check: producer is a DQ node
producer = producer_map.get(node.input[0])
if producer is None or producer.op_type not in DQ_OPS:
continue

# Convert the DQ scale initializer from FP32 to FP16
# DQ inputs: [input, scale, (zero_point)]
if len(producer.input) >= 2:
scale_name = producer.input[1]
if scale_name in initializer_map:
scale_init = initializer_map[scale_name]
if scale_init.data_type == onnx.TensorProto.FLOAT:
import numpy as np

scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32)
if not scale_data.size:
scale_data = np.array(scale_init.float_data, dtype=np.float32)
scale_fp16 = scale_data.astype(np.float16)
scale_init.data_type = onnx.TensorProto.FLOAT16
scale_init.raw_data = scale_fp16.tobytes()
del scale_init.float_data[:]

# Bypass the Cast node
_bypass_cast_node(onnx_model, node)
nodes_to_remove.append(node)

# Update the DQ output type in value_info
dq_output_name = producer.output[0]
for vi in onnx_model.graph.value_info:
if vi.name == dq_output_name:
vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
break

logger.debug(f"Folded {len(nodes_to_remove)} DQ -> Cast(FP32->FP16) patterns")
for node in nodes_to_remove:
onnx_model.graph.node.remove(node)

return onnx_model


def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto:
"""Remove `training_mode` attribute and extra training outputs from nodes of a given op type.

Expand Down
Loading
Loading