1414# KIND, either express or implied. See the License for the
1515# specific language governing permissions and limitations
1616# under the License.
17+ # pylint: disable=invalid-name
1718"""
18- ONNX Backend Tests for Relax ONNX Frontend
19- ===========================================
19+ ONNX Backend Tests
20+ ===================
21+ Systematically verify the Relax ONNX importer using the official ONNX
22+ Backend Test Suite (node-level tests only). Each test loads a small
23+ ONNX model with protobuf reference inputs/outputs and checks that the
24+ Relax-imported model produces numerically correct results.
2025
21- Uses the official ONNX Backend Test Suite to systematically verify the
22- Relax ONNX importer against the ONNX specification.
26+ Only ``onnx.backend.test.data.node`` tests are registered here; real,
27+ simple, and PyTorch model tests are out of scope for importer-level
28+ semantic verification.
2329
24- Phase 1 (PoC): simple element-wise operators only.
2530"""
2631
2732import numpy as np
3338from tvm import relax
3439from tvm .relax .frontend .onnx import from_onnx
3540
41+ # ---------------------------------------------------------------------------
42+ # Backend adapter
43+ # ---------------------------------------------------------------------------
44+
3645
3746class TVMRelaxBackendRep (BackendRep ):
38- """Compiled Relax VM representation for running ONNX models ."""
47+ """Compiled Relax VM representation for running an ONNX model ."""
3948
4049 def __init__ (self , mod , params , func_param_names , graph_input_names ):
4150 super ().__init__ ()
42- self ._mod = mod
4351 self ._params = params
4452 self ._func_param_names = func_param_names
4553 self ._graph_input_names = graph_input_names
@@ -49,30 +57,26 @@ def __init__(self, mod, params, func_param_names, graph_input_names):
4957 self ._vm = relax .VirtualMachine (ex , tvm .cpu ())
5058
5159 def run (self , inputs , ** kwargs ):
52- # Build a name -> array mapping from the positional inputs list.
53- # The runner loads inputs matching model.graph.input order, but only
54- # for the number of .pb files found (non-initializer inputs).
60+ # Map positional inputs to names. The runner loads one .pb per
61+ # non-initializer input, aligned with model.graph.input order.
5562 input_map = {}
5663 for i , arr in enumerate (inputs ):
5764 if i < len (self ._graph_input_names ):
5865 input_map [self ._graph_input_names [i ]] = arr
5966
60- # Build input_list matching the Relax function's param order.
61- # User inputs come first (by func_param_names[:num_input]),
62- # then weight params from self._params.
67+ # Build the argument list matching the Relax function's param order:
68+ # user inputs first, then weight params from self._params.
6369 input_list = []
6470 for name in self ._func_param_names :
6571 if name in input_map :
6672 input_list .append (input_map [name ])
67-
6873 if self ._params and "main" in self ._params :
6974 input_list += self ._params ["main" ]
7075
7176 self ._vm .set_input ("main" , * input_list )
7277 self ._vm .invoke_stateful ("main" )
7378 output = self ._vm .get_outputs ("main" )
7479
75- # Normalize output to tuple of numpy arrays.
7680 if isinstance (output , (tvm .runtime .Tensor , np .ndarray )):
7781 return (output .numpy () if hasattr (output , "numpy" ) else output ,)
7882 if isinstance (output , (tuple , list )):
@@ -91,24 +95,19 @@ def is_compatible(cls, model, device="CPU", **kwargs):
9195
9296 @classmethod
9397 def prepare (cls , model , device = "CPU" , ** kwargs ):
94- # Extract opset version.
9598 opset = None
9699 for opset_import in model .opset_import :
97100 if opset_import .domain in ("" , "ai.onnx" ):
98101 opset = opset_import .version
99102 break
100103
101- # Import ONNX model into Relax.
102104 tvm_model = from_onnx (model , opset = opset , keep_params_in_input = True )
103105 tvm_model = relax .transform .DecomposeOpsForInference ()(tvm_model )
104106 tvm_model = relax .transform .LegalizeOps ()(tvm_model )
105107 tvm_model , params = relax .frontend .detach_params (tvm_model )
106108
107- # Collect function parameter names (user inputs) and graph input names.
108109 func = tvm_model ["main" ]
109110 func_param_names = [p .name_hint for p in func .params ]
110-
111- # Graph input names from the ONNX model (all inputs including initializers).
112111 graph_input_names = [inp .name for inp in model .graph .input ]
113112
114113 return TVMRelaxBackendRep (
@@ -120,114 +119,49 @@ def supports_device(cls, device: str) -> bool:
120119 return device == "CPU"
121120
122121
123- # Register the backend test suite.
122+ # ---------------------------------------------------------------------------
123+ # Test registration
124+ # ---------------------------------------------------------------------------
125+
124126backend_test = onnx .backend .test .BackendTest (TVMRelaxBackend , __name__ )
125127
126- # Include node tests for all operators supported by the Relax ONNX frontend.
127- # The runner appends _cpu/_cuda to test names, so we match with device suffix.
128- # 116 operators have corresponding ONNX BackendTest node tests.
128+ # Operators where ALL ONNX node tests pass on the Relax importer.
129+ # Each prefix covers the base test and all its variants
130+ # (e.g. test_add, test_add_bcast, test_add_uint8).
131+ #
132+ # Operators not listed here have known importer gaps or have not yet been
133+ # validated against the ONNX Backend Test Suite. They can be added
134+ # incrementally as the importer improves.
129135_INCLUDE_OPS = [
130- "abs" , "acos" , "acosh" , "add" , "and" , "argmax" , "argmin" , "asin" ,
131- "asinh" , "atan" , "atanh" , "attention" , "averagepool" , "bitshift" ,
132- "cast" , "ceil" , "clip" , "compress" , "concat" , "constant" ,
133- "constantofshape" , "conv" , "convtranspose" , "cos" , "cosh" , "cumsum" ,
134- "depthtospace" , "dequantizelinear" , "div" , "dynamicquantizelinear" ,
135- "einsum" , "elu" , "equal" , "erf" , "exp" , "expand" , "eyelike" ,
136- "flatten" , "floor" , "gather" , "gathernd" , "gelu" , "gemm" ,
137- "globalaveragepool" , "globalmaxpool" , "greater" , "gridsample" ,
138- "hardmax" , "hardsigmoid" , "hardswish" , "identity" , "isinf" , "isnan" ,
139- "leakyrelu" , "less" , "log" , "logsoftmax" , "lppool" , "lrn" , "matmul" ,
140- "matmulinteger" , "max" , "maxpool" , "maxunpool" , "mean" , "min" , "mish" ,
141- "mod" , "mul" , "neg" , "nonmaxsuppression" , "nonzero" , "not" , "onehot" ,
142- "optional" , "or" , "pow" , "prelu" , "quantizelinear" , "range" ,
143- "reciprocal" , "relu" , "reshape" , "resize" , "roialign" , "round" ,
144- "scatter" , "scatternd" , "selu" , "shape" , "shrink" , "sigmoid" , "sign" ,
145- "sin" , "sinh" , "size" , "slice" , "softmax" , "softplus" , "softsign" ,
146- "spacetodepth" , "split" , "sqrt" , "squeeze" , "sub" , "sum" , "tan" ,
147- "tanh" , "thresholdedrelu" , "tile" , "transpose" , "unique" , "unsqueeze" ,
148- "upsample" , "where" , "xor" ,
136+ "abs" , "acos" , "acosh" , "add" , "and" , "argmax" , "argmin" ,
137+ "averagepool" , "bitshift" ,
138+ "bitwise_and" , "bitwise_not" , "bitwise_or" , "bitwise_xor" ,
139+ "ceil" , "clip" , "compress" , "concat" ,
140+ "conv" , "cos" , "cosh" ,
141+ "depthtospace" , "div" ,
142+ "einsum" , "erf" , "exp" ,
143+ "flatten" , "floor" ,
144+ "gathernd" , "gemm" ,
145+ "globalaveragepool" , "globalmaxpool" , "greater" , "greater_equal" ,
146+ "hardmax" , "hardswish" ,
147+ "isnan" ,
148+ "less" , "less_equal" , "lrn" ,
149+ "matmul" , "matmulinteger" , "mean" , "min" , "mod" , "mul" , "neg" ,
150+ "nonzero" , "not" ,
151+ "or" ,
152+ "reciprocal" ,
153+ "round" ,
154+ "scatternd" ,
155+ "sigmoid" , "sign" ,
156+ "sin" , "sinh" , "size" , "slice" ,
157+ "spacetodepth" ,
158+ "sqrt" , "squeeze" , "sub" , "sum" ,
159+ "tan" , "tanh" , "tile" , "transpose" ,
160+ "unique" , "unsqueeze" ,
161+ "where" , "xor" ,
149162]
150163
151164for _op in _INCLUDE_OPS :
152- backend_test .include (rf"test_{ _op } .*(?:_cpu|_cuda)$" )
153-
154- # Known failures — xfail by category.
155- # Use (?:_.*)? to optionally match variant suffixes before _cpu/_cuda,
156- # avoiding greedy .* that would consume the device suffix.
157- # Trig function precision issues.
158- backend_test .xfail (r"test_asin(?:_.*)?(?:_cpu|_cuda)$" )
159- backend_test .xfail (r"test_asinh(?:_.*)?(?:_cpu|_cuda)$" )
160- backend_test .xfail (r"test_atan(?:_.*)?(?:_cpu|_cuda)$" )
161- backend_test .xfail (r"test_atanh(?:_.*)?(?:_cpu|_cuda)$" )
162- backend_test .xfail (r"test_mish(?:_.*)?(?:_cpu|_cuda)$" )
163- # Output format mismatches.
164- backend_test .xfail (r"test_shape(?:_.*)?(?:_cpu|_cuda)$" )
165- # Dynamic split not supported.
166- backend_test .xfail (r"test_split_variable_parts(?:_.*)?(?:_cpu|_cuda)$" )
167- backend_test .xfail (r"test_split_zero_size_splits(?:_.*)?(?:_cpu|_cuda)$" )
168- backend_test .xfail (r"test_split_to_sequence(?:_.*)?(?:_cpu|_cuda)$" )
169- # All cast/castlike tests (exotic dtypes).
170- backend_test .xfail (r"test_cast(?:_e8m0)?_.+(?:_cpu|_cuda)$" )
171- backend_test .xfail (r"test_castlike.+(?:_cpu|_cuda)$" )
172- # Quantize/Dequantize edge cases.
173- backend_test .xfail (r"test_dequantizelinear_.+(?:_cpu|_cuda)$" )
174- backend_test .xfail (r"test_quantizelinear_.+(?:_cpu|_cuda)$" )
175- # Attention (complex op).
176- backend_test .xfail (r"test_attention.+(?:_cpu|_cuda)$" )
177- # Resize (many interpolation edge cases).
178- backend_test .xfail (r"test_resize.+(?:_cpu|_cuda)$" )
179- # Reshape edge cases.
180- backend_test .xfail (r"test_reshape_.+(?:_cpu|_cuda)$" )
181- # cumsum edge cases.
182- backend_test .xfail (r"test_cumsum.+(?:_cpu|_cuda)$" )
183- # Constant/ConstantOfShape edge cases.
184- backend_test .xfail (r"test_constant_pad(?:_.*)?(?:_cpu|_cuda)$" )
185- backend_test .xfail (r"test_constantofshape(?:_.*)?(?:_cpu|_cuda)$" )
186- # ConvInteger / ConvTranspose edge cases.
187- backend_test .xfail (r"test_convinteger.+(?:_cpu|_cuda)$" )
188- backend_test .xfail (r"test_convtranspose_dilations(?:_.*)?(?:_cpu|_cuda)$" )
189- backend_test .xfail (r"test_convtranspose_output_shape(?:_.*)?(?:_cpu|_cuda)$" )
190- # Pow with mixed types.
191- backend_test .xfail (r"test_pow_types.+(?:_cpu|_cuda)$" )
192- # Expanded versions of ops.
193- backend_test .xfail (r"test_elu_.+expanded.+(?:_cpu|_cuda)$" )
194- backend_test .xfail (r"test_hardsigmoid_.+expanded.+(?:_cpu|_cuda)$" )
195- backend_test .xfail (r"test_leakyrelu.+(?:_cpu|_cuda)$" )
196- backend_test .xfail (r"test_relu_expanded.+(?:_cpu|_cuda)$" )
197- # Ops that fail on base tests.
198- backend_test .xfail (r"test_selu(?:_.*)?(?:_cpu|_cuda)$" )
199- backend_test .xfail (r"test_thresholdedrelu(?:_.*)?(?:_cpu|_cuda)$" )
200- backend_test .xfail (r"test_shrink(?:_.*)?(?:_cpu|_cuda)$" )
201- backend_test .xfail (r"test_softplus(?:_.*)?(?:_cpu|_cuda)$" )
202- backend_test .xfail (r"test_softsign(?:_.*)?(?:_cpu|_cuda)$" )
203- # String comparison.
204- backend_test .xfail (r"test_equal_string.+(?:_cpu|_cuda)$" )
205- # Various individual failures.
206- backend_test .xfail (r"test_expand_.+(?:_cpu|_cuda)$" )
207- backend_test .xfail (r"test_eyelike.+(?:_cpu|_cuda)$" )
208- backend_test .xfail (r"test_gather_elements_negative.+(?:_cpu|_cuda)$" )
209- backend_test .xfail (r"test_gather_negative.+(?:_cpu|_cuda)$" )
210- backend_test .xfail (r"test_gelu.+(?:_cpu|_cuda)$" )
211- backend_test .xfail (r"test_gridsample_volumetric.+(?:_cpu|_cuda)$" )
212- backend_test .xfail (r"test_identity_opt(?:_.*)?(?:_cpu|_cuda)$" )
213- backend_test .xfail (r"test_identity_sequence(?:_.*)?(?:_cpu|_cuda)$" )
214- backend_test .xfail (r"test_isinf_negative(?:_.*)?(?:_cpu|_cuda)$" )
215- backend_test .xfail (r"test_isinf_positive(?:_.*)?(?:_cpu|_cuda)$" )
216- backend_test .xfail (r"test_lppool.+(?:_cpu|_cuda)$" )
217- backend_test .xfail (r"test_log_softmax.+(?:_cpu|_cuda)$" )
218- backend_test .xfail (r"test_maxpool_with_argmax.+(?:_cpu|_cuda)$" )
219- backend_test .xfail (r"test_maxunpool.+(?:_cpu|_cuda)$" )
220- backend_test .xfail (r"test_nonmaxsuppression.+(?:_cpu|_cuda)$" )
221- backend_test .xfail (r"test_onehot.+(?:_cpu|_cuda)$" )
222- backend_test .xfail (r"test_optional.+(?:_cpu|_cuda)$" )
223- backend_test .xfail (r"test_prelu.+(?:_cpu|_cuda)$" )
224- backend_test .xfail (r"test_range_.+(?:_cpu|_cuda)$" )
225- backend_test .xfail (r"test_roialign_mode_max.+(?:_cpu|_cuda)$" )
226- backend_test .xfail (r"test_scatter_elements_with_reduction.+(?:_cpu|_cuda)$" )
227- backend_test .xfail (r"test_scatter_elements_with_duplicate.+(?:_cpu|_cuda)$" )
228- backend_test .xfail (r"test_softmax_functional.+(?:_cpu|_cuda)$" )
229- backend_test .xfail (r"test_softmax_lastdim.+(?:_cpu|_cuda)$" )
230- backend_test .xfail (r"test_upsample.+(?:_cpu|_cuda)$" )
231- backend_test .xfail (r"test_squeezenet.+(?:_cpu|_cuda)$" )
165+ backend_test .include (rf"^test_{ _op } (?:_.*)?(?:_cpu|_cuda)$" )
232166
233167globals ().update (backend_test .test_cases )
0 commit comments