diff --git a/examples/torch_onnx/README.md b/examples/torch_onnx/README.md index d540770116..2dfe33b864 100644 --- a/examples/torch_onnx/README.md +++ b/examples/torch_onnx/README.md @@ -307,6 +307,7 @@ python torch_quant_to_onnx.py \ | [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [efficientvit_l2](https://huggingface.co/timm/efficientvit_l2.r224_in1k) | ✅ | ✅ | ✅ | ✅ | | | ## Resources diff --git a/examples/torch_onnx/torch_quant_to_onnx.py b/examples/torch_onnx/torch_quant_to_onnx.py index 7f74e617e8..ede3058735 100644 --- a/examples/torch_onnx/torch_quant_to_onnx.py +++ b/examples/torch_onnx/torch_quant_to_onnx.py @@ -17,6 +17,7 @@ import copy import json import re +import subprocess import sys import warnings from pathlib import Path @@ -96,6 +97,10 @@ def get_quant_config(quantize_mode): f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode." ) config["quant_cfg"].extend(_FP8_CONV_OVERRIDE) + # The FP8 Conv2d overrides use static quantization which requires + # calibration (amax). Ensure the calibration algorithm is set. + if config.get("algorithm") is None: + config["algorithm"] = "max" elif quantize_mode == "int4_awq": warnings.warn( "TensorRT only supports FP8/INT8 for Conv layers. " @@ -109,7 +114,8 @@ def filter_func(name): """Filter function to exclude certain layers from quantization.""" pattern = re.compile( r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|" - r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*" + r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|" + r"downsample|global_pool).*" ) return pattern.match(name) is not None @@ -147,6 +153,21 @@ def load_calibration_data(model_name, data_size, batch_size, device, with_labels ) +def _disable_conv2d_dynamic_quantizers(model): + """Disable dynamic block quantizers (NVFP4/MXFP8) on Conv2d modules. + + TRT's FP4/MXFP8 DynamicQuantize only supports 2D/3D input tensors, but Conv2d + layers have 4D inputs. Disable these quantizers to avoid TRT build failures. + """ + for name, module in model.named_modules(): + if not isinstance(module, torch.nn.Conv2d): + continue + for qname in ("input_quantizer", "weight_quantizer"): + quantizer = getattr(module, qname, None) + if quantizer is not None and getattr(quantizer, "block_sizes", None): + quantizer.disable() + + def quantize_model(model, config, data_loader=None): """Quantize the model using the given config and calibration data.""" if data_loader is not None: @@ -160,6 +181,7 @@ def forward_loop(model): quantized_model = mtq.quantize(model, config) mtq.disable_quantizer(quantized_model, filter_func) + _disable_conv2d_dynamic_quantizers(quantized_model) return quantized_model @@ -235,6 +257,7 @@ def auto_quantize_model( # Disable quantization for specified layers mtq.disable_quantizer(quantized_model, filter_func) + _disable_conv2d_dynamic_quantizers(quantized_model) return quantized_model, search_state @@ -320,6 +343,17 @@ def main(): default=128, help="Number of scoring steps for auto quantization. Default is 128.", ) + parser.add_argument( + "--trt_build", + action="store_true", + help="Build a TRT engine from the exported ONNX model to verify compatibility.", + ) + parser.add_argument( + "--trt_builder_opt_level", + type=int, + default=4, + help="TRT builder optimization level (default: 4).", + ) parser.add_argument( "--no_pretrained", action="store_true", @@ -378,18 +412,19 @@ def main(): args.num_score_steps, ) else: - # Standard quantization - only load calibration data if needed + # Standard quantization config = get_quant_config(args.quantize_mode) - if args.quantize_mode == "mxfp8": - data_loader = None - else: - data_loader = load_calibration_data( - args.timm_model_name, - args.calibration_data_size, - args.batch_size, - device, - with_labels=False, - ) + # Always load calibration data. Even though MXFP8 uses dynamic quantization + # and doesn't strictly require calibration, the Conv2d FP8 overrides (applied + # by get_quant_config for MXFP8/NVFP4) use static FP8 quantization which + # needs calibration data to compute amax values. + data_loader = load_calibration_data( + args.timm_model_name, + args.calibration_data_size, + args.batch_size, + device, + with_labels=False, + ) quantized_model = quantize_model(model, config, data_loader) @@ -421,6 +456,25 @@ def main(): print(f"Quantized ONNX model is saved to {args.onnx_save_path}") + if args.trt_build: + print("\n=== Building TRT Engine ===") + cmd = [ + "trtexec", + f"--onnx={args.onnx_save_path}", + "--stronglyTyped", + f"--builderOptimizationLevel={args.trt_builder_opt_level}", + ] + print(f"Running: {' '.join(cmd)}") + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print("TRT engine build FAILED:") + for line in result.stderr.splitlines(): + if "Error" in line or "FAIL" in line or "error" in line: + print(f" {line.strip()}") + sys.exit(1) + else: + print("TRT engine build succeeded.") + if __name__ == "__main__": main() diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 54efd0a111..1f4ab2ec94 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -1504,6 +1504,130 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: return onnx_model +def fix_fp16_fp32_mismatches(model: onnx.ModelProto) -> onnx.ModelProto: + """Insert Cast nodes to resolve FP32/FP16 type mismatches after blocked-op FP16 conversion. + + After convert_float_to_float16 with an op_block_list, FP32 data from blocked ops + (e.g., QDQ paths) can flow into nodes whose other inputs are FP16. TensorRT + --stronglyTyped rejects such mismatches. This function propagates "real" types + through the graph and inserts FP32->FP16 Cast nodes where needed. + + Note: value_info types are unreliable after convert_float_to_float16 with blocked ops + (metadata may say FP16 even when actual data is FP32), so this function re-derives + types by following op semantics. + + Args: + model: The ONNX model to fix. + + Returns: + The modified ONNX model with Cast nodes inserted to resolve mismatches. + """ + FLOAT = onnx.TensorProto.FLOAT + FLOAT16 = onnx.TensorProto.FLOAT16 + + # Ops whose data inputs must all have the same type in TRT stronglyTyped mode. + _ELEMENTWISE_OPS = { + "Add", "Sub", "Mul", "Div", "Pow", "Min", "Max", "Equal", "Less", + "Greater", "Where", "Sum", "Mean", "Concat", + } + + # Ops that are FP32-only (QDQ) — never cast their I/O. + _BLOCKED_OPS = {"QuantizeLinear", "DequantizeLinear"} + + # --- Step 1: Propagate real element types through the graph. --- + real_type: dict[str, int] = {} + + # Seed from graph inputs and initializers (these are authoritative). + for inp in model.graph.input: + real_type[inp.name] = inp.type.tensor_type.elem_type + for init in model.graph.initializer: + real_type[init.name] = init.data_type + + # Process nodes in topological order. + for node in model.graph.node: + if node.op_type == "Constant": + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + for out in node.output: + real_type[out] = attr.t.data_type + continue + + if node.op_type == "Cast": + cast_to = get_cast_to_type(node) + for out in node.output: + real_type[out] = cast_to + continue + + if node.op_type in _BLOCKED_OPS: + for out in node.output: + real_type[out] = FLOAT + continue + + # For other ops: output type matches the predominant data-input type. + data_types = [] + for inp_name in node.input: + if inp_name and inp_name in real_type and real_type[inp_name] in (FLOAT, FLOAT16): + data_types.append(real_type[inp_name]) + + if data_types: + out_type = FLOAT if FLOAT in data_types else FLOAT16 + else: + out_type = FLOAT16 + + for out in node.output: + real_type[out] = out_type + + # --- Step 2: Find nodes with mixed real types and insert Casts. --- + nodes_to_insert: list[tuple[int, onnx.NodeProto]] = [] + + for node_idx, node in enumerate(model.graph.node): + if node.op_type not in _ELEMENTWISE_OPS: + continue + + input_real_types = [] + for inp_name in node.input: + if inp_name and inp_name in real_type and real_type[inp_name] in (FLOAT, FLOAT16): + input_real_types.append((inp_name, real_type[inp_name])) + + if not input_real_types: + continue + + has_fp32 = any(t == FLOAT for _, t in input_real_types) + has_fp16 = any(t == FLOAT16 for _, t in input_real_types) + if not (has_fp32 and has_fp16): + continue + + # Insert Cast(FP32 -> FP16) for each FP32 input. + # Reuse existing Cast if the same input was already cast (avoids duplicate names). + for inp_idx, inp_name in enumerate(node.input): + if not inp_name or inp_name not in real_type: + continue + if real_type[inp_name] != FLOAT: + continue + cast_out_name = inp_name + "_cast_to_fp16" + if cast_out_name not in real_type: + cast_node = onnx.helper.make_node( + "Cast", + inputs=[inp_name], + outputs=[cast_out_name], + to=FLOAT16, + ) + real_type[cast_out_name] = FLOAT16 + nodes_to_insert.append((node_idx, cast_node)) + node.input[inp_idx] = cast_out_name + + # Insert cast nodes in reverse order so positions stay valid. + for pos, cast_node in sorted(nodes_to_insert, key=lambda x: x[0], reverse=True): + model.graph.node.insert(pos, cast_node) + + if nodes_to_insert: + logger.info( + f"Inserted {len(nodes_to_insert)} Cast node(s) to fix FP32/FP16 mismatches" + ) + + return model + + def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto: """Remove `training_mode` attribute and extra training outputs from nodes of a given op type. diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 8cb741dbc7..0073cd7143 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -46,6 +46,7 @@ from modelopt.onnx.utils import ( change_casts_to_fp16, check_model_uses_external_data, + fix_fp16_fp32_mismatches, get_input_names, get_input_shapes, get_node_names, @@ -382,22 +383,30 @@ def is_int8_quantized(model: nn.Module) -> bool: return False +def _is_fp8_quantizer(quantizer) -> bool: + """Check if a single quantizer is configured for FP8 (not MXFP8).""" + return ( + quantizer.is_enabled + and quantizer._num_bits == (4, 3) + and not ( + quantizer.block_sizes + and quantizer.block_sizes.get("scale_bits", None) == (8, 0) + ) + ) + + def is_fp8_quantized(model: nn.Module) -> bool: - """Check if the model is quantized in FP8 mode.""" + """Check if the model is quantized in FP8 mode. + + Returns True if any module has an FP8-configured quantizer (weight or input). + This covers mixed-precision scenarios (e.g., auto_quantize) where only the + input_quantizer might be FP8 while the weight_quantizer is disabled or uses + a different format. + """ for _, module in model.named_modules(): - if ( - hasattr(module, "weight_quantizer") - and hasattr(module, "input_quantizer") - and module.weight_quantizer.is_enabled - and module.input_quantizer.is_enabled - and module.weight_quantizer._num_bits == (4, 3) - and module.input_quantizer._num_bits == (4, 3) - # Exclude MXFP8 which also uses (4,3) but has block_sizes with scale_bits - and not ( - module.input_quantizer.block_sizes - and module.input_quantizer.block_sizes.get("scale_bits", None) == (8, 0) - ) - ): + if hasattr(module, "weight_quantizer") and _is_fp8_quantizer(module.weight_quantizer): + return True + if hasattr(module, "input_quantizer") and _is_fp8_quantizer(module.input_quantizer): return True return False @@ -522,7 +531,10 @@ def get_onnx_bytes_and_metadata( input_none_names = list(set(tree_spec_input.names) - set(input_names)) use_torch_autocast = not ( - is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32" + is_fp4_quantized(model) + or is_mxfp8_quantized(model) + or is_fp8_quantized(model) + or weights_dtype == "fp32" ) autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext() @@ -556,6 +568,22 @@ def get_onnx_bytes_and_metadata( if is_fp4_quantized(model) or is_mxfp8_quantized(model) else nullcontext() ) + + # Disable Conv2d FP8 weight quantizer for ONNX export. + # FP8 TRT_FP8QuantizeLinear/DequantizeLinear custom ops produce tensors with + # dynamic shapes, and the ONNX Conv exporter requires static kernel shapes. + # Disabling the weight quantizer keeps Conv2d weights as static constants in + # the ONNX graph. Input quantizer remains enabled so TRT still uses FP8 for + # Conv2d activations. Weights are converted to FP16 by post-export processing. + conv_quantizers_to_reenable: list[tuple[nn.Module, str]] = [] + if is_fp8_quantized(model): + for module in model.modules(): + if not isinstance(module, nn.Conv2d): + continue + quantizer = getattr(module, "weight_quantizer", None) + if quantizer is not None and _is_fp8_quantizer(quantizer): + quantizer.disable() + conv_quantizers_to_reenable.append((module, "weight_quantizer")) with torch.inference_mode(), autocast, quantizer_context: additional_kwargs = {} if not dynamo_export: @@ -571,6 +599,10 @@ def get_onnx_bytes_and_metadata( **additional_kwargs, ) + # Re-enable Conv2d quantizers that were temporarily disabled for FP8 export + for module, qname in conv_quantizers_to_reenable: + getattr(module, qname).enable() + # Check that export worked assert len(os.listdir(onnx_path)) > 0, "Torch to onnx export failed." @@ -617,6 +649,16 @@ def get_onnx_bytes_and_metadata( onnx_opt_graph = remove_redundant_casts(onnx_opt_graph) + # Fix remaining FP32/FP16 mismatches AFTER remove_redundant_casts. + # Only needed for the convert_float_to_float16 path (FP8/MXFP8/NVFP4) where + # blocked QDQ ops produce FP32 that flows into nodes with FP16 inputs. + # Must run after remove_redundant_casts because that function uses unreliable + # value_info metadata and would incorrectly remove the Cast nodes we insert. + if weights_dtype in ["fp16", "bf16"] and ( + is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model) + ): + onnx_opt_graph = fix_fp16_fp32_mismatches(onnx_opt_graph) + # TensorRT expects all scales to be postive onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index fe4fb8d70a..c00afaef4a 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -656,9 +656,20 @@ def export_fp4( @contextlib.contextmanager def configure_linear_module_onnx_quantizers(model): - """Sets the onnx export attributes for the given model.""" + """Sets the onnx export attributes for the given model. + + For Linear modules, sets both input and weight quantizer types. + For other modules with block-quantized input_quantizers (e.g., pooling layers + in models like EfficientViT), sets the input quantizer to "dynamic" to prevent + TRT_FP4QDQ static export for activations. + """ for _, module in model.named_modules(): if isinstance(module, torch.nn.Linear): module.input_quantizer._onnx_quantizer_type = "dynamic" module.weight_quantizer._onnx_quantizer_type = "static" + elif ( + hasattr(module, "input_quantizer") + and getattr(module.input_quantizer, "block_sizes", None) + ): + module.input_quantizer._onnx_quantizer_type = "dynamic" yield diff --git a/tests/_test_utils/torch/vision_models.py b/tests/_test_utils/torch/vision_models.py index 5fed1d20c1..cf29bc3a96 100644 --- a/tests/_test_utils/torch/vision_models.py +++ b/tests/_test_utils/torch/vision_models.py @@ -117,6 +117,7 @@ def get_model_and_input(on_gpu: bool = False): # "dm_nfnet_f0", "efficientnet_b0", "swin_tiny_patch4_window7_224", + "efficientvit_l2", ], _create_timm_fn, ), diff --git a/tests/examples/torch_onnx/test_torch_quant_to_onnx.py b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py index 7c2692c1d9..786486c69d 100644 --- a/tests/examples/torch_onnx/test_torch_quant_to_onnx.py +++ b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py @@ -14,9 +14,6 @@ # limitations under the License. -import os -import subprocess - import pytest from _test_utils.examples.run_command import extend_cmd_parts, run_example_command @@ -24,46 +21,31 @@ # (e.g., DQ -> Reshape -> Slice in small ViT / SwinTransformer ONNX graphs). _QUANT_MODES = ["fp8", "int8", "mxfp8", "nvfp4", "auto"] +# Models where auto mode is excluded due to Conv2d FP8 input/weight type mismatch in TRT. +# Auto mode may assign FP8 to Conv2d input quantizer, producing FP32 output (from blocked QDQ), +# while the Conv2d weight is FP16 — TRT stronglyTyped rejects this mismatch. +_AUTO_EXCLUDED_MODELS = {"efficientvit_l2"} + _MODELS = { "vit_tiny": ("vit_tiny_patch16_224", '{"depth": 1}'), "swin_tiny": ("swin_tiny_patch4_window7_224", '{"depths": [1, 1, 1, 1]}'), "swinv2_tiny": ("swinv2_tiny_window8_256", '{"depths": [1, 1, 1, 1]}'), + "efficientvit_l2": ("efficientvit_l2", '{"depths": [1, 1, 1, 1, 1]}'), } # Builder optimization level: 4 for low-bit modes, 3 otherwise _LOW_BIT_MODES = {"fp8", "int8", "nvfp4"} -def _verify_trt_engine_build(onnx_save_path, quantize_mode): - """Verify the exported ONNX model can be compiled into a TensorRT engine.""" - example_dir = os.path.join( - os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx" - ) - onnx_path = os.path.join(example_dir, onnx_save_path) - assert os.path.exists(onnx_path), f"ONNX file not found: {onnx_path}" - - opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3" - cmd = [ - "trtexec", - f"--onnx={onnx_path}", - "--stronglyTyped", - f"--builderOptimizationLevel={opt_level}", - ] - - result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) - assert result.returncode == 0, ( - f"TensorRT engine build failed for {onnx_save_path} " - f"(mode={quantize_mode}):\n{result.stdout}\n{result.stderr}" - ) - - @pytest.mark.parametrize("quantize_mode", _QUANT_MODES) @pytest.mark.parametrize("model_key", list(_MODELS)) def test_torch_onnx(model_key, quantize_mode): + if quantize_mode == "auto" and model_key in _AUTO_EXCLUDED_MODELS: + pytest.skip(f"auto mode not supported for {model_key} (Conv2d FP8 type mismatch)") timm_model_name, model_kwargs = _MODELS[model_key] onnx_save_path = f"{model_key}.{quantize_mode}.onnx" + opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3" - # Step 1: Quantize and export to ONNX cmd_parts = extend_cmd_parts( ["python", "torch_quant_to_onnx.py"], timm_model_name=timm_model_name, @@ -72,9 +54,8 @@ def test_torch_onnx(model_key, quantize_mode): onnx_save_path=onnx_save_path, calibration_data_size="1", num_score_steps="1", + trt_builder_opt_level=opt_level, ) cmd_parts.append("--no_pretrained") + cmd_parts.append("--trt_build") run_example_command(cmd_parts, "torch_onnx") - - # Step 2: Verify the exported ONNX model builds a TensorRT engine - _verify_trt_engine_build(onnx_save_path, quantize_mode)