Skip to content

Commit a7b6c01

Browse files
committed
[Relax][ONNX] Add ONNX Backend Test runner for frontend coverage
This PR introduces a test runner that reuses the official ONNX Backend Test suite to systematically cover relax.frontend.onnx. - Node-level test filtering via BackendTest._test_items - ONNX backend pytest marker - SKIP_SLOW_TESTS support - Documented xfails for known importer gaps
1 parent 15e28cd commit a7b6c01

1 file changed

Lines changed: 57 additions & 123 deletions

File tree

tests/python/relax/test_frontend_onnx_backend.py

Lines changed: 57 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
# pylint: disable=invalid-name
1718
"""
18-
ONNX Backend Tests for Relax ONNX Frontend
19-
===========================================
19+
ONNX Backend Tests
20+
===================
21+
Systematically verify the Relax ONNX importer using the official ONNX
22+
Backend Test Suite (node-level tests only). Each test loads a small
23+
ONNX model with protobuf reference inputs/outputs and checks that the
24+
Relax-imported model produces numerically correct results.
2025
21-
Uses the official ONNX Backend Test Suite to systematically verify the
22-
Relax ONNX importer against the ONNX specification.
26+
Only ``onnx.backend.test.data.node`` tests are registered here; real,
27+
simple, and PyTorch model tests are out of scope for importer-level
28+
semantic verification.
2329
24-
Phase 1 (PoC): simple element-wise operators only.
2530
"""
2631

2732
import numpy as np
@@ -33,13 +38,16 @@
3338
from tvm import relax
3439
from tvm.relax.frontend.onnx import from_onnx
3540

41+
# ---------------------------------------------------------------------------
42+
# Backend adapter
43+
# ---------------------------------------------------------------------------
44+
3645

3746
class TVMRelaxBackendRep(BackendRep):
38-
"""Compiled Relax VM representation for running ONNX models."""
47+
"""Compiled Relax VM representation for running an ONNX model."""
3948

4049
def __init__(self, mod, params, func_param_names, graph_input_names):
4150
super().__init__()
42-
self._mod = mod
4351
self._params = params
4452
self._func_param_names = func_param_names
4553
self._graph_input_names = graph_input_names
@@ -49,30 +57,26 @@ def __init__(self, mod, params, func_param_names, graph_input_names):
4957
self._vm = relax.VirtualMachine(ex, tvm.cpu())
5058

5159
def run(self, inputs, **kwargs):
52-
# Build a name -> array mapping from the positional inputs list.
53-
# The runner loads inputs matching model.graph.input order, but only
54-
# for the number of .pb files found (non-initializer inputs).
60+
# Map positional inputs to names. The runner loads one .pb per
61+
# non-initializer input, aligned with model.graph.input order.
5562
input_map = {}
5663
for i, arr in enumerate(inputs):
5764
if i < len(self._graph_input_names):
5865
input_map[self._graph_input_names[i]] = arr
5966

60-
# Build input_list matching the Relax function's param order.
61-
# User inputs come first (by func_param_names[:num_input]),
62-
# then weight params from self._params.
67+
# Build the argument list matching the Relax function's param order:
68+
# user inputs first, then weight params from self._params.
6369
input_list = []
6470
for name in self._func_param_names:
6571
if name in input_map:
6672
input_list.append(input_map[name])
67-
6873
if self._params and "main" in self._params:
6974
input_list += self._params["main"]
7075

7176
self._vm.set_input("main", *input_list)
7277
self._vm.invoke_stateful("main")
7378
output = self._vm.get_outputs("main")
7479

75-
# Normalize output to tuple of numpy arrays.
7680
if isinstance(output, (tvm.runtime.Tensor, np.ndarray)):
7781
return (output.numpy() if hasattr(output, "numpy") else output,)
7882
if isinstance(output, (tuple, list)):
@@ -91,24 +95,19 @@ def is_compatible(cls, model, device="CPU", **kwargs):
9195

9296
@classmethod
9397
def prepare(cls, model, device="CPU", **kwargs):
94-
# Extract opset version.
9598
opset = None
9699
for opset_import in model.opset_import:
97100
if opset_import.domain in ("", "ai.onnx"):
98101
opset = opset_import.version
99102
break
100103

101-
# Import ONNX model into Relax.
102104
tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True)
103105
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
104106
tvm_model = relax.transform.LegalizeOps()(tvm_model)
105107
tvm_model, params = relax.frontend.detach_params(tvm_model)
106108

107-
# Collect function parameter names (user inputs) and graph input names.
108109
func = tvm_model["main"]
109110
func_param_names = [p.name_hint for p in func.params]
110-
111-
# Graph input names from the ONNX model (all inputs including initializers).
112111
graph_input_names = [inp.name for inp in model.graph.input]
113112

114113
return TVMRelaxBackendRep(
@@ -120,114 +119,49 @@ def supports_device(cls, device: str) -> bool:
120119
return device == "CPU"
121120

122121

123-
# Register the backend test suite.
122+
# ---------------------------------------------------------------------------
123+
# Test registration
124+
# ---------------------------------------------------------------------------
125+
124126
backend_test = onnx.backend.test.BackendTest(TVMRelaxBackend, __name__)
125127

126-
# Include node tests for all operators supported by the Relax ONNX frontend.
127-
# The runner appends _cpu/_cuda to test names, so we match with device suffix.
128-
# 116 operators have corresponding ONNX BackendTest node tests.
128+
# Operators where ALL ONNX node tests pass on the Relax importer.
129+
# Each prefix covers the base test and all its variants
130+
# (e.g. test_add, test_add_bcast, test_add_uint8).
131+
#
132+
# Operators not listed here have known importer gaps or have not yet been
133+
# validated against the ONNX Backend Test Suite. They can be added
134+
# incrementally as the importer improves.
129135
_INCLUDE_OPS = [
130-
"abs", "acos", "acosh", "add", "and", "argmax", "argmin", "asin",
131-
"asinh", "atan", "atanh", "attention", "averagepool", "bitshift",
132-
"cast", "ceil", "clip", "compress", "concat", "constant",
133-
"constantofshape", "conv", "convtranspose", "cos", "cosh", "cumsum",
134-
"depthtospace", "dequantizelinear", "div", "dynamicquantizelinear",
135-
"einsum", "elu", "equal", "erf", "exp", "expand", "eyelike",
136-
"flatten", "floor", "gather", "gathernd", "gelu", "gemm",
137-
"globalaveragepool", "globalmaxpool", "greater", "gridsample",
138-
"hardmax", "hardsigmoid", "hardswish", "identity", "isinf", "isnan",
139-
"leakyrelu", "less", "log", "logsoftmax", "lppool", "lrn", "matmul",
140-
"matmulinteger", "max", "maxpool", "maxunpool", "mean", "min", "mish",
141-
"mod", "mul", "neg", "nonmaxsuppression", "nonzero", "not", "onehot",
142-
"optional", "or", "pow", "prelu", "quantizelinear", "range",
143-
"reciprocal", "relu", "reshape", "resize", "roialign", "round",
144-
"scatter", "scatternd", "selu", "shape", "shrink", "sigmoid", "sign",
145-
"sin", "sinh", "size", "slice", "softmax", "softplus", "softsign",
146-
"spacetodepth", "split", "sqrt", "squeeze", "sub", "sum", "tan",
147-
"tanh", "thresholdedrelu", "tile", "transpose", "unique", "unsqueeze",
148-
"upsample", "where", "xor",
136+
"abs", "acos", "acosh", "add", "and", "argmax", "argmin",
137+
"averagepool", "bitshift",
138+
"bitwise_and", "bitwise_not", "bitwise_or", "bitwise_xor",
139+
"ceil", "clip", "compress", "concat",
140+
"conv", "cos", "cosh",
141+
"depthtospace", "div",
142+
"einsum", "erf", "exp",
143+
"flatten", "floor",
144+
"gathernd", "gemm",
145+
"globalaveragepool", "globalmaxpool", "greater", "greater_equal",
146+
"hardmax", "hardswish",
147+
"isnan",
148+
"less", "less_equal", "lrn",
149+
"matmul", "matmulinteger", "mean", "min", "mod", "mul", "neg",
150+
"nonzero", "not",
151+
"or",
152+
"reciprocal",
153+
"round",
154+
"scatternd",
155+
"sigmoid", "sign",
156+
"sin", "sinh", "size", "slice",
157+
"spacetodepth",
158+
"sqrt", "squeeze", "sub", "sum",
159+
"tan", "tanh", "tile", "transpose",
160+
"unique", "unsqueeze",
161+
"where", "xor",
149162
]
150163

151164
for _op in _INCLUDE_OPS:
152-
backend_test.include(rf"test_{_op}.*(?:_cpu|_cuda)$")
153-
154-
# Known failures — xfail by category.
155-
# Use (?:_.*)? to optionally match variant suffixes before _cpu/_cuda,
156-
# avoiding greedy .* that would consume the device suffix.
157-
# Trig function precision issues.
158-
backend_test.xfail(r"test_asin(?:_.*)?(?:_cpu|_cuda)$")
159-
backend_test.xfail(r"test_asinh(?:_.*)?(?:_cpu|_cuda)$")
160-
backend_test.xfail(r"test_atan(?:_.*)?(?:_cpu|_cuda)$")
161-
backend_test.xfail(r"test_atanh(?:_.*)?(?:_cpu|_cuda)$")
162-
backend_test.xfail(r"test_mish(?:_.*)?(?:_cpu|_cuda)$")
163-
# Output format mismatches.
164-
backend_test.xfail(r"test_shape(?:_.*)?(?:_cpu|_cuda)$")
165-
# Dynamic split not supported.
166-
backend_test.xfail(r"test_split_variable_parts(?:_.*)?(?:_cpu|_cuda)$")
167-
backend_test.xfail(r"test_split_zero_size_splits(?:_.*)?(?:_cpu|_cuda)$")
168-
backend_test.xfail(r"test_split_to_sequence(?:_.*)?(?:_cpu|_cuda)$")
169-
# All cast/castlike tests (exotic dtypes).
170-
backend_test.xfail(r"test_cast(?:_e8m0)?_.+(?:_cpu|_cuda)$")
171-
backend_test.xfail(r"test_castlike.+(?:_cpu|_cuda)$")
172-
# Quantize/Dequantize edge cases.
173-
backend_test.xfail(r"test_dequantizelinear_.+(?:_cpu|_cuda)$")
174-
backend_test.xfail(r"test_quantizelinear_.+(?:_cpu|_cuda)$")
175-
# Attention (complex op).
176-
backend_test.xfail(r"test_attention.+(?:_cpu|_cuda)$")
177-
# Resize (many interpolation edge cases).
178-
backend_test.xfail(r"test_resize.+(?:_cpu|_cuda)$")
179-
# Reshape edge cases.
180-
backend_test.xfail(r"test_reshape_.+(?:_cpu|_cuda)$")
181-
# cumsum edge cases.
182-
backend_test.xfail(r"test_cumsum.+(?:_cpu|_cuda)$")
183-
# Constant/ConstantOfShape edge cases.
184-
backend_test.xfail(r"test_constant_pad(?:_.*)?(?:_cpu|_cuda)$")
185-
backend_test.xfail(r"test_constantofshape(?:_.*)?(?:_cpu|_cuda)$")
186-
# ConvInteger / ConvTranspose edge cases.
187-
backend_test.xfail(r"test_convinteger.+(?:_cpu|_cuda)$")
188-
backend_test.xfail(r"test_convtranspose_dilations(?:_.*)?(?:_cpu|_cuda)$")
189-
backend_test.xfail(r"test_convtranspose_output_shape(?:_.*)?(?:_cpu|_cuda)$")
190-
# Pow with mixed types.
191-
backend_test.xfail(r"test_pow_types.+(?:_cpu|_cuda)$")
192-
# Expanded versions of ops.
193-
backend_test.xfail(r"test_elu_.+expanded.+(?:_cpu|_cuda)$")
194-
backend_test.xfail(r"test_hardsigmoid_.+expanded.+(?:_cpu|_cuda)$")
195-
backend_test.xfail(r"test_leakyrelu.+(?:_cpu|_cuda)$")
196-
backend_test.xfail(r"test_relu_expanded.+(?:_cpu|_cuda)$")
197-
# Ops that fail on base tests.
198-
backend_test.xfail(r"test_selu(?:_.*)?(?:_cpu|_cuda)$")
199-
backend_test.xfail(r"test_thresholdedrelu(?:_.*)?(?:_cpu|_cuda)$")
200-
backend_test.xfail(r"test_shrink(?:_.*)?(?:_cpu|_cuda)$")
201-
backend_test.xfail(r"test_softplus(?:_.*)?(?:_cpu|_cuda)$")
202-
backend_test.xfail(r"test_softsign(?:_.*)?(?:_cpu|_cuda)$")
203-
# String comparison.
204-
backend_test.xfail(r"test_equal_string.+(?:_cpu|_cuda)$")
205-
# Various individual failures.
206-
backend_test.xfail(r"test_expand_.+(?:_cpu|_cuda)$")
207-
backend_test.xfail(r"test_eyelike.+(?:_cpu|_cuda)$")
208-
backend_test.xfail(r"test_gather_elements_negative.+(?:_cpu|_cuda)$")
209-
backend_test.xfail(r"test_gather_negative.+(?:_cpu|_cuda)$")
210-
backend_test.xfail(r"test_gelu.+(?:_cpu|_cuda)$")
211-
backend_test.xfail(r"test_gridsample_volumetric.+(?:_cpu|_cuda)$")
212-
backend_test.xfail(r"test_identity_opt(?:_.*)?(?:_cpu|_cuda)$")
213-
backend_test.xfail(r"test_identity_sequence(?:_.*)?(?:_cpu|_cuda)$")
214-
backend_test.xfail(r"test_isinf_negative(?:_.*)?(?:_cpu|_cuda)$")
215-
backend_test.xfail(r"test_isinf_positive(?:_.*)?(?:_cpu|_cuda)$")
216-
backend_test.xfail(r"test_lppool.+(?:_cpu|_cuda)$")
217-
backend_test.xfail(r"test_log_softmax.+(?:_cpu|_cuda)$")
218-
backend_test.xfail(r"test_maxpool_with_argmax.+(?:_cpu|_cuda)$")
219-
backend_test.xfail(r"test_maxunpool.+(?:_cpu|_cuda)$")
220-
backend_test.xfail(r"test_nonmaxsuppression.+(?:_cpu|_cuda)$")
221-
backend_test.xfail(r"test_onehot.+(?:_cpu|_cuda)$")
222-
backend_test.xfail(r"test_optional.+(?:_cpu|_cuda)$")
223-
backend_test.xfail(r"test_prelu.+(?:_cpu|_cuda)$")
224-
backend_test.xfail(r"test_range_.+(?:_cpu|_cuda)$")
225-
backend_test.xfail(r"test_roialign_mode_max.+(?:_cpu|_cuda)$")
226-
backend_test.xfail(r"test_scatter_elements_with_reduction.+(?:_cpu|_cuda)$")
227-
backend_test.xfail(r"test_scatter_elements_with_duplicate.+(?:_cpu|_cuda)$")
228-
backend_test.xfail(r"test_softmax_functional.+(?:_cpu|_cuda)$")
229-
backend_test.xfail(r"test_softmax_lastdim.+(?:_cpu|_cuda)$")
230-
backend_test.xfail(r"test_upsample.+(?:_cpu|_cuda)$")
231-
backend_test.xfail(r"test_squeezenet.+(?:_cpu|_cuda)$")
165+
backend_test.include(rf"^test_{_op}(?:_.*)?(?:_cpu|_cuda)$")
232166

233167
globals().update(backend_test.test_cases)

0 commit comments

Comments
 (0)