diff --git a/examples/torch_onnx/README.md b/examples/torch_onnx/README.md index d540770116..cfd4dc380c 100644 --- a/examples/torch_onnx/README.md +++ b/examples/torch_onnx/README.md @@ -307,6 +307,7 @@ python torch_quant_to_onnx.py \ | [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [resnet50](https://huggingface.co/timm/resnet50.a1_in1k) | ✅ | ✅ | ✅ | ✅ | | ✅ | ## Resources diff --git a/examples/torch_onnx/torch_quant_to_onnx.py b/examples/torch_onnx/torch_quant_to_onnx.py index 7f74e617e8..ceae3b6518 100644 --- a/examples/torch_onnx/torch_quant_to_onnx.py +++ b/examples/torch_onnx/torch_quant_to_onnx.py @@ -17,6 +17,7 @@ import copy import json import re +import subprocess import sys import warnings from pathlib import Path @@ -35,13 +36,17 @@ import modelopt.torch.quantization as mtq """ -This script is used to quantize a timm model using dynamic quantization like MXFP8 or NVFP4, -or using auto quantization for optimal per-layer quantization. +Quantize a timm vision model and export to ONNX for TensorRT deployment. + +Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes. The script will: -1. Given the model name, create a timm torch model. -2. Quantize the torch model in MXFP8, NVFP4, INT4_AWQ, or AUTO mode. -3. Export the quantized torch model to ONNX format. +1. Load a pretrained timm model (e.g., ViT, Swin, ResNet). +2. Quantize the model using the specified mode. For models with Conv2d layers, + Conv2d quantization is automatically overridden for TensorRT compatibility + (FP8 for MXFP8/NVFP4, INT8 for INT4_AWQ). +3. Export the quantized model to ONNX with FP16 weights. +4. Optionally evaluate accuracy on ImageNet-1k before and after quantization. """ @@ -81,6 +86,11 @@ }, ] +# Auto-quantize format configs that use block quantization and need Conv2d overrides for TRT. +# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. +_NEEDS_FP8_CONV_OVERRIDE: set[str] = {"NVFP4_AWQ_LITE_CFG", "NVFP4_DEFAULT_CFG", "MXFP8_DEFAULT_CFG"} +_NEEDS_INT8_CONV_OVERRIDE: set[str] = {"INT4_AWQ_CFG"} + def get_quant_config(quantize_mode): """Get quantization config, overriding Conv2d for TRT compatibility. @@ -109,7 +119,8 @@ def filter_func(name): """Filter function to exclude certain layers from quantization.""" pattern = re.compile( r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|" - r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*" + r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|" + r"maxpool|global_pool).*" ) return pattern.match(name) is not None @@ -147,6 +158,40 @@ def load_calibration_data(model_name, data_size, batch_size, device, with_labels ) +def _calibrate_uncalibrated_quantizers(model, data_loader): + """Calibrate FP8 quantizers that weren't calibrated by mtq.quantize(). + + When MXFP8/NVFP4 modes override Conv2d to FP8, the FP8 quantizers may not + be calibrated because the MXFP8/NVFP4 quantization pipeline skips standard + calibration. This function explicitly calibrates those uncalibrated quantizers. + """ + uncalibrated = [] + for _, module in model.named_modules(): + for attr_name in ("input_quantizer", "weight_quantizer"): + if not hasattr(module, attr_name): + continue + quantizer = getattr(module, attr_name) + if ( + quantizer.is_enabled + and not quantizer.block_sizes + and not hasattr(quantizer, "_amax") + ): + quantizer.enable_calib() + uncalibrated.append(quantizer) + + if not uncalibrated: + return + + model.eval() + with torch.no_grad(): + for batch in data_loader: + model(batch) + + for quantizer in uncalibrated: + quantizer.disable_calib() + quantizer.load_calib_amax() + + def quantize_model(model, config, data_loader=None): """Quantize the model using the given config and calibration data.""" if data_loader is not None: @@ -159,6 +204,10 @@ def forward_loop(model): else: quantized_model = mtq.quantize(model, config) + # Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize() + if data_loader is not None: + _calibrate_uncalibrated_quantizers(quantized_model, data_loader) + mtq.disable_quantizer(quantized_model, filter_func) return quantized_model @@ -209,11 +258,19 @@ def auto_quantize_model( _disable_inplace_relu(model) constraints = {"effective_bits": effective_bits} - # Convert string format names to actual config objects + # Convert string format names to config objects, incorporating Conv2d TRT overrides. + # TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors. + # By including the overrides in the format configs, the auto_quantize search + # correctly accounts for Conv2d being FP8/INT8 in the effective_bits budget. format_configs = [] for fmt in quantization_formats: if isinstance(fmt, str): - format_configs.append(getattr(mtq, fmt)) + config = copy.deepcopy(getattr(mtq, fmt)) + if fmt in _NEEDS_FP8_CONV_OVERRIDE: + config["quant_cfg"].extend(_FP8_CONV_OVERRIDE) + elif fmt in _NEEDS_INT8_CONV_OVERRIDE: + config["quant_cfg"].extend(_INT8_CONV_OVERRIDE) + format_configs.append(config) else: format_configs.append(fmt) @@ -320,6 +377,11 @@ def main(): default=128, help="Number of scoring steps for auto quantization. Default is 128.", ) + parser.add_argument( + "--trt_build", + action="store_true", + help="Build a TensorRT engine from the exported ONNX model using trtexec.", + ) parser.add_argument( "--no_pretrained", action="store_true", @@ -378,18 +440,18 @@ def main(): args.num_score_steps, ) else: - # Standard quantization - only load calibration data if needed + # Standard quantization - load calibration data + # Note: MXFP8 is dynamic and does not need calibration itself, but when + # Conv2d layers are overridden to FP8 (for TRT compatibility), those FP8 + # quantizers require calibration data. config = get_quant_config(args.quantize_mode) - if args.quantize_mode == "mxfp8": - data_loader = None - else: - data_loader = load_calibration_data( - args.timm_model_name, - args.calibration_data_size, - args.batch_size, - device, - with_labels=False, - ) + data_loader = load_calibration_data( + args.timm_model_name, + args.calibration_data_size, + args.batch_size, + device, + with_labels=False, + ) quantized_model = quantize_model(model, config, data_loader) @@ -421,6 +483,26 @@ def main(): print(f"Quantized ONNX model is saved to {args.onnx_save_path}") + if args.trt_build: + build_trt_engine(args.onnx_save_path) + + +def build_trt_engine(onnx_path): + """Build a TensorRT engine from the exported ONNX model using trtexec.""" + cmd = [ + "trtexec", + f"--onnx={onnx_path}", + "--stronglyTyped", + "--builderOptimizationLevel=4", + ] + print(f"\nBuilding TensorRT engine: {' '.join(cmd)}") + result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) + if result.returncode != 0: + raise RuntimeError( + f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}" + ) + print("TensorRT engine build succeeded.") + if __name__ == "__main__": main() diff --git a/modelopt/onnx/export/fp8_exporter.py b/modelopt/onnx/export/fp8_exporter.py index ffcbd89423..c86150b205 100644 --- a/modelopt/onnx/export/fp8_exporter.py +++ b/modelopt/onnx/export/fp8_exporter.py @@ -101,13 +101,89 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: return gs.export_onnx(graph) + @staticmethod + def _quantize_conv_weights_to_fp8(graph: gs.Graph) -> int: + """Add FP8 weight DequantizeLinear for Conv layers with unquantized weights. + + Conv weight quantizers are disabled during TorchScript ONNX export because the + TRT_FP8DequantizeLinear custom op produces outputs with unknown shapes, causing + the _convolution symbolic to fail. This method restores FP8 weight quantization + by inserting DQ nodes in the ONNX graph, mirroring the compress_weights logic. + + For each Conv node with an unquantized constant weight: + 1. Compute per-tensor scale = max(abs(weight)) / 448.0 + 2. Quantize weights to FP8E4M3FN + 3. Insert a DequantizeLinear(fp8_weights, scale) before the Conv weight input + + Args: + graph: The onnx-graphsurgeon graph to modify in-place. + + Returns: + Number of Conv weight DQ nodes inserted. + """ + FP8_MAX = 448.0 + count = 0 + + for node in list(graph.nodes): + if node.op != "Conv": + continue + if len(node.inputs) < 2: + continue + + weight_input = node.inputs[1] + if not isinstance(weight_input, gs.Constant): + continue + + # Skip if weight already has a DQ producer + if any(out.op == "DequantizeLinear" for out in weight_input.outputs): + continue + + torch_weights = torch.from_numpy(weight_input.values.copy()) + amax = torch_weights.abs().max().float() + if amax == 0: + continue + scale_val = (amax / FP8_MAX).item() + + # Quantize weights to FP8 (WAR: numpy doesn't support fp8) + fp8_data = ( + (torch_weights / scale_val).to(torch.float8_e4m3fn).view(torch.uint8).numpy() + ) + fp8_tensor = onnx.TensorProto() + fp8_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN + fp8_tensor.dims.extend(fp8_data.shape) + fp8_tensor.raw_data = fp8_data.tobytes() + fp8_constant = gs.Constant( + node.name + "/weight_quantizer/fp8_weights", LazyValues(fp8_tensor) + ) + + # Scale in FP16 — DQ output type matches scale dtype, must match activation type + import numpy as np + + scale_constant = gs.Constant( + node.name + "/weight_quantizer/scale", + np.array(scale_val, dtype=np.float16), + ) + + dq_output = gs.Variable(node.name + "/weight_quantizer/dq_output") + dq_node = gs.Node( + op="DequantizeLinear", + name=node.name + "/weight_quantizer/DequantizeLinear", + inputs=[fp8_constant, scale_constant], + outputs=[dq_output], + ) + graph.nodes.append(dq_node) + node.inputs[1] = dq_output + count += 1 + + return count + @staticmethod def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Post-processes the ONNX model for FP8 quantization. - Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear: - - TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1 - - TRT_FP8DequantizeLinear -> DequantizeLinear + Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear and + adds FP8 weight DQ for Conv layers whose weight quantizers were disabled during + TorchScript export. Args: onnx_model: The ONNX model containing TRT_FP8 quantization nodes. @@ -144,5 +220,10 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear" ) + # Add FP8 weight DQ for Conv layers that had weight quantizers disabled during export + count = FP8QuantExporter._quantize_conv_weights_to_fp8(graph) + if count > 0: + logger.info(f"Inserted FP8 weight DequantizeLinear for {count} Conv nodes") + graph.cleanup().toposort() return gs.export_onnx(graph) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 54efd0a111..a37fd59845 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -1504,6 +1504,87 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: return onnx_model +def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Remove Cast(FP32->FP16) nodes after DequantizeLinear by setting DQ output to FP16. + + When convert_float_to_float16 blocks DequantizeLinear, it inserts Cast nodes to bridge + the FP32 DQ output to the FP16 graph. This function removes those Cast nodes by: + 1. Converting the DQ scale initializer from FP32 to FP16 + 2. Updating the DQ output type to FP16 in value_info + 3. Bypassing and removing the Cast node + + Args: + onnx_model: The ONNX model with DQ -> Cast(FP32->FP16) patterns. + + Returns: + The ONNX model with Cast nodes removed and DQ outputs set to FP16. + """ + DQ_OPS = {"DequantizeLinear", "TRT_FP8DequantizeLinear"} + + # Build a map of tensor name -> producer node + producer_map: dict[str, onnx.NodeProto] = {} + for node in onnx_model.graph.node: + for out in node.output: + producer_map[out] = node + + # Build initializer lookup + initializer_map: dict[str, onnx.TensorProto] = { + init.name: init for init in onnx_model.graph.initializer + } + + nodes_to_remove = [] + for node in onnx_model.graph.node: + if node.op_type != "Cast": + continue + + # Check: Cast target is FP16 + cast_to = None + for attr in node.attribute: + if attr.name == "to": + cast_to = attr.i + if cast_to != onnx.TensorProto.FLOAT16: + continue + + # Check: producer is a DQ node + producer = producer_map.get(node.input[0]) + if producer is None or producer.op_type not in DQ_OPS: + continue + + # Convert the DQ scale initializer from FP32 to FP16 + # DQ inputs: [input, scale, (zero_point)] + if len(producer.input) >= 2: + scale_name = producer.input[1] + if scale_name in initializer_map: + scale_init = initializer_map[scale_name] + if scale_init.data_type == onnx.TensorProto.FLOAT: + import numpy as np + + scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32) + if not scale_data.size: + scale_data = np.array(scale_init.float_data, dtype=np.float32) + scale_fp16 = scale_data.astype(np.float16) + scale_init.data_type = onnx.TensorProto.FLOAT16 + scale_init.raw_data = scale_fp16.tobytes() + del scale_init.float_data[:] + + # Bypass the Cast node + _bypass_cast_node(onnx_model, node) + nodes_to_remove.append(node) + + # Update the DQ output type in value_info + dq_output_name = producer.output[0] + for vi in onnx_model.graph.value_info: + if vi.name == dq_output_name: + vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16 + break + + logger.debug(f"Folded {len(nodes_to_remove)} DQ -> Cast(FP32->FP16) patterns") + for node in nodes_to_remove: + onnx_model.graph.node.remove(node) + + return onnx_model + + def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto: """Remove `training_mode` attribute and extra training outputs from nodes of a given op type. diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 8cb741dbc7..fdaf4e08d1 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -16,6 +16,7 @@ """Utility functions related to Onnx.""" import base64 +import contextlib import inspect import json import logging @@ -46,6 +47,7 @@ from modelopt.onnx.utils import ( change_casts_to_fp16, check_model_uses_external_data, + fold_dq_fp32_to_fp16_casts, get_input_names, get_input_shapes, get_node_names, @@ -402,6 +404,29 @@ def is_fp8_quantized(model: nn.Module) -> bool: return False +@contextlib.contextmanager +def _disable_fp8_conv_weight_quantizers(model: nn.Module): + """Temporarily disable FP8 weight quantizers on Conv layers during ONNX export. + + The TorchScript ONNX exporter requires static kernel shapes for Conv operations, + but the TRT_FP8DequantizeLinear custom op produces outputs with unknown shapes in + the TorchScript IR, causing the _convolution symbolic to fail. Disabling Conv weight + quantizers during export allows the Conv to export with static-shape FP16/FP32 weights. + FP8 weight quantization is restored as a post-processing step in FP8QuantExporter. + """ + disabled = [] + for _, module in model.named_modules(): + if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + if hasattr(module, "weight_quantizer") and module.weight_quantizer.is_enabled: + module.weight_quantizer.disable() + disabled.append(module) + try: + yield + finally: + for module in disabled: + module.weight_quantizer.enable() + + def quantize_weights(model: nn.Module, onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Real quantizes the weights in the onnx model. @@ -522,7 +547,11 @@ def get_onnx_bytes_and_metadata( input_none_names = list(set(tree_spec_input.names) - set(input_names)) use_torch_autocast = not ( - is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32" + is_fp4_quantized(model) + or is_mxfp8_quantized(model) + or is_fp8_quantized(model) + or is_int8_quantized(model) + or weights_dtype == "fp32" ) autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext() @@ -556,7 +585,13 @@ def get_onnx_bytes_and_metadata( if is_fp4_quantized(model) or is_mxfp8_quantized(model) else nullcontext() ) - with torch.inference_mode(), autocast, quantizer_context: + # Disable FP8 Conv weight quantizers: TorchScript custom ops produce outputs with + # unknown shapes, causing _convolution symbolic to fail. Conv weights are quantized + # to FP8 in post-processing by FP8QuantExporter instead. + conv_wq_context = ( + _disable_fp8_conv_weight_quantizers(model) if is_fp8_quantized(model) else nullcontext() + ) + with torch.inference_mode(), autocast, quantizer_context, conv_wq_context: additional_kwargs = {} if not dynamo_export: additional_kwargs["dynamic_axes"] = dynamic_axes @@ -598,7 +633,12 @@ def get_onnx_bytes_and_metadata( onnx_opt_graph = qdq_to_dq(onnx_opt_graph) if weights_dtype in ["fp16", "bf16"]: - if is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model): + if ( + is_int4_quantized(model) + or is_mxfp8_quantized(model) + or is_fp8_quantized(model) + or is_int8_quantized(model) + ): assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" onnx_opt_graph = convert_float_to_float16( onnx_opt_graph, @@ -610,6 +650,8 @@ def get_onnx_bytes_and_metadata( # Change FP32 cast nodes feeding into Concat/Add to FP16 op_list = ["Concat", "Add", "Sqrt", "LayerNormalization", "Clip", "Mul", "Exp"] onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, op_list) + # Remove Cast(FP32->FP16) nodes after DQ by setting DQ output to FP16 directly + onnx_opt_graph = fold_dq_fp32_to_fp16_casts(onnx_opt_graph) else: onnx_opt_graph = convert_to_f16( onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index fe4fb8d70a..c91cef760b 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -656,9 +656,20 @@ def export_fp4( @contextlib.contextmanager def configure_linear_module_onnx_quantizers(model): - """Sets the onnx export attributes for the given model.""" + """Sets the onnx export attributes for the given model. + + For modules with block quantization (NVFP4/MXFP8): + - Weight quantizers use "static" export (TRT_FP4QDQ for NVFP4, DQ-only for MXFP8) + - Input/activation quantizers use "dynamic" export (TRT_FP4DynamicQuantize, etc.) + + This must be set for ALL modules with block quantization, not just nn.Linear, + because models like ResNet have non-Linear modules (e.g., MaxPool2d) with NVFP4/MXFP8 + input quantizers that would otherwise default to the static path and produce + TRT_FP4QDQ nodes on activations (which the NVFP4 exporter cannot handle). + """ for _, module in model.named_modules(): - if isinstance(module, torch.nn.Linear): + if hasattr(module, "input_quantizer") and module.input_quantizer.block_sizes: module.input_quantizer._onnx_quantizer_type = "dynamic" + if hasattr(module, "weight_quantizer") and module.weight_quantizer.block_sizes: module.weight_quantizer._onnx_quantizer_type = "static" yield diff --git a/tests/_test_utils/torch/vision_models.py b/tests/_test_utils/torch/vision_models.py index 5fed1d20c1..942167d763 100644 --- a/tests/_test_utils/torch/vision_models.py +++ b/tests/_test_utils/torch/vision_models.py @@ -117,6 +117,7 @@ def get_model_and_input(on_gpu: bool = False): # "dm_nfnet_f0", "efficientnet_b0", "swin_tiny_patch4_window7_224", + "resnet50", ], _create_timm_fn, ), diff --git a/tests/examples/torch_onnx/test_torch_quant_to_onnx.py b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py index 7c2692c1d9..6d6e0d9de5 100644 --- a/tests/examples/torch_onnx/test_torch_quant_to_onnx.py +++ b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py @@ -14,9 +14,6 @@ # limitations under the License. -import os -import subprocess - import pytest from _test_utils.examples.run_command import extend_cmd_parts, run_example_command @@ -28,34 +25,9 @@ "vit_tiny": ("vit_tiny_patch16_224", '{"depth": 1}'), "swin_tiny": ("swin_tiny_patch4_window7_224", '{"depths": [1, 1, 1, 1]}'), "swinv2_tiny": ("swinv2_tiny_window8_256", '{"depths": [1, 1, 1, 1]}'), + "resnet50": ("resnet50", None), } -# Builder optimization level: 4 for low-bit modes, 3 otherwise -_LOW_BIT_MODES = {"fp8", "int8", "nvfp4"} - - -def _verify_trt_engine_build(onnx_save_path, quantize_mode): - """Verify the exported ONNX model can be compiled into a TensorRT engine.""" - example_dir = os.path.join( - os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx" - ) - onnx_path = os.path.join(example_dir, onnx_save_path) - assert os.path.exists(onnx_path), f"ONNX file not found: {onnx_path}" - - opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3" - cmd = [ - "trtexec", - f"--onnx={onnx_path}", - "--stronglyTyped", - f"--builderOptimizationLevel={opt_level}", - ] - - result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) - assert result.returncode == 0, ( - f"TensorRT engine build failed for {onnx_save_path} " - f"(mode={quantize_mode}):\n{result.stdout}\n{result.stderr}" - ) - @pytest.mark.parametrize("quantize_mode", _QUANT_MODES) @pytest.mark.parametrize("model_key", list(_MODELS)) @@ -63,7 +35,6 @@ def test_torch_onnx(model_key, quantize_mode): timm_model_name, model_kwargs = _MODELS[model_key] onnx_save_path = f"{model_key}.{quantize_mode}.onnx" - # Step 1: Quantize and export to ONNX cmd_parts = extend_cmd_parts( ["python", "torch_quant_to_onnx.py"], timm_model_name=timm_model_name, @@ -73,8 +44,5 @@ def test_torch_onnx(model_key, quantize_mode): calibration_data_size="1", num_score_steps="1", ) - cmd_parts.append("--no_pretrained") + cmd_parts.extend(["--no_pretrained", "--trt_build"]) run_example_command(cmd_parts, "torch_onnx") - - # Step 2: Verify the exported ONNX model builds a TensorRT engine - _verify_trt_engine_build(onnx_save_path, quantize_mode)