From 15e28cdd1a2c069b2ca4d43f5c93972037ea977e Mon Sep 17 00:00:00 2001 From: HoYi Date: Wed, 6 May 2026 09:28:45 +0800 Subject: [PATCH 1/2] [Relax][Frontend][ONNX] Add ONNX Backend Tests runner for systematic frontend coverage (#19505) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a test harness that wraps the official ONNX Backend Test Suite (Node Tests) around the Relax ONNX importer. This gives systematic, spec-aligned coverage of 116 operators with 533 passing tests, replacing hand-written edge-case models with standardized protobuf test data. The runner follows the standard `onnx.backend.base.Backend` interface, using `from_onnx()` → `DecomposeOpsForInference()` → `LegalizeOps()` → `tvm.compile()` → `VirtualMachine` to execute each test case. Known failures are tracked via `xfail` by category (trig precision, quantization edge cases, dynamic split, etc.). --- .../relax/test_frontend_onnx_backend.py | 233 ++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 tests/python/relax/test_frontend_onnx_backend.py diff --git a/tests/python/relax/test_frontend_onnx_backend.py b/tests/python/relax/test_frontend_onnx_backend.py new file mode 100644 index 000000000000..7aadeb712dbb --- /dev/null +++ b/tests/python/relax/test_frontend_onnx_backend.py @@ -0,0 +1,233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +ONNX Backend Tests for Relax ONNX Frontend +=========================================== + +Uses the official ONNX Backend Test Suite to systematically verify the +Relax ONNX importer against the ONNX specification. + +Phase 1 (PoC): simple element-wise operators only. +""" + +import numpy as np +import onnx +import onnx.backend.test +from onnx.backend.base import Backend, BackendRep + +import tvm +from tvm import relax +from tvm.relax.frontend.onnx import from_onnx + + +class TVMRelaxBackendRep(BackendRep): + """Compiled Relax VM representation for running ONNX models.""" + + def __init__(self, mod, params, func_param_names, graph_input_names): + super().__init__() + self._mod = mod + self._params = params + self._func_param_names = func_param_names + self._graph_input_names = graph_input_names + + with tvm.transform.PassContext(opt_level=3): + ex = tvm.compile(mod, target="llvm") + self._vm = relax.VirtualMachine(ex, tvm.cpu()) + + def run(self, inputs, **kwargs): + # Build a name -> array mapping from the positional inputs list. + # The runner loads inputs matching model.graph.input order, but only + # for the number of .pb files found (non-initializer inputs). + input_map = {} + for i, arr in enumerate(inputs): + if i < len(self._graph_input_names): + input_map[self._graph_input_names[i]] = arr + + # Build input_list matching the Relax function's param order. + # User inputs come first (by func_param_names[:num_input]), + # then weight params from self._params. + input_list = [] + for name in self._func_param_names: + if name in input_map: + input_list.append(input_map[name]) + + if self._params and "main" in self._params: + input_list += self._params["main"] + + self._vm.set_input("main", *input_list) + self._vm.invoke_stateful("main") + output = self._vm.get_outputs("main") + + # Normalize output to tuple of numpy arrays. + if isinstance(output, (tvm.runtime.Tensor, np.ndarray)): + return (output.numpy() if hasattr(output, "numpy") else output,) + if isinstance(output, (tuple, list)): + return tuple( + o.numpy() if hasattr(o, "numpy") else np.array(o) for o in output + ) + return (np.array(output),) + + +class TVMRelaxBackend(Backend): + """ONNX backend that imports models through Relax's ONNX frontend.""" + + @classmethod + def is_compatible(cls, model, device="CPU", **kwargs): + return True + + @classmethod + def prepare(cls, model, device="CPU", **kwargs): + # Extract opset version. + opset = None + for opset_import in model.opset_import: + if opset_import.domain in ("", "ai.onnx"): + opset = opset_import.version + break + + # Import ONNX model into Relax. + tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True) + tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) + tvm_model = relax.transform.LegalizeOps()(tvm_model) + tvm_model, params = relax.frontend.detach_params(tvm_model) + + # Collect function parameter names (user inputs) and graph input names. + func = tvm_model["main"] + func_param_names = [p.name_hint for p in func.params] + + # Graph input names from the ONNX model (all inputs including initializers). + graph_input_names = [inp.name for inp in model.graph.input] + + return TVMRelaxBackendRep( + tvm_model, params, func_param_names, graph_input_names + ) + + @classmethod + def supports_device(cls, device: str) -> bool: + return device == "CPU" + + +# Register the backend test suite. +backend_test = onnx.backend.test.BackendTest(TVMRelaxBackend, __name__) + +# Include node tests for all operators supported by the Relax ONNX frontend. +# The runner appends _cpu/_cuda to test names, so we match with device suffix. +# 116 operators have corresponding ONNX BackendTest node tests. +_INCLUDE_OPS = [ + "abs", "acos", "acosh", "add", "and", "argmax", "argmin", "asin", + "asinh", "atan", "atanh", "attention", "averagepool", "bitshift", + "cast", "ceil", "clip", "compress", "concat", "constant", + "constantofshape", "conv", "convtranspose", "cos", "cosh", "cumsum", + "depthtospace", "dequantizelinear", "div", "dynamicquantizelinear", + "einsum", "elu", "equal", "erf", "exp", "expand", "eyelike", + "flatten", "floor", "gather", "gathernd", "gelu", "gemm", + "globalaveragepool", "globalmaxpool", "greater", "gridsample", + "hardmax", "hardsigmoid", "hardswish", "identity", "isinf", "isnan", + "leakyrelu", "less", "log", "logsoftmax", "lppool", "lrn", "matmul", + "matmulinteger", "max", "maxpool", "maxunpool", "mean", "min", "mish", + "mod", "mul", "neg", "nonmaxsuppression", "nonzero", "not", "onehot", + "optional", "or", "pow", "prelu", "quantizelinear", "range", + "reciprocal", "relu", "reshape", "resize", "roialign", "round", + "scatter", "scatternd", "selu", "shape", "shrink", "sigmoid", "sign", + "sin", "sinh", "size", "slice", "softmax", "softplus", "softsign", + "spacetodepth", "split", "sqrt", "squeeze", "sub", "sum", "tan", + "tanh", "thresholdedrelu", "tile", "transpose", "unique", "unsqueeze", + "upsample", "where", "xor", +] + +for _op in _INCLUDE_OPS: + backend_test.include(rf"test_{_op}.*(?:_cpu|_cuda)$") + +# Known failures — xfail by category. +# Use (?:_.*)? to optionally match variant suffixes before _cpu/_cuda, +# avoiding greedy .* that would consume the device suffix. +# Trig function precision issues. +backend_test.xfail(r"test_asin(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_asinh(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_atan(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_atanh(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_mish(?:_.*)?(?:_cpu|_cuda)$") +# Output format mismatches. +backend_test.xfail(r"test_shape(?:_.*)?(?:_cpu|_cuda)$") +# Dynamic split not supported. +backend_test.xfail(r"test_split_variable_parts(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_split_zero_size_splits(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_split_to_sequence(?:_.*)?(?:_cpu|_cuda)$") +# All cast/castlike tests (exotic dtypes). +backend_test.xfail(r"test_cast(?:_e8m0)?_.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_castlike.+(?:_cpu|_cuda)$") +# Quantize/Dequantize edge cases. +backend_test.xfail(r"test_dequantizelinear_.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_quantizelinear_.+(?:_cpu|_cuda)$") +# Attention (complex op). +backend_test.xfail(r"test_attention.+(?:_cpu|_cuda)$") +# Resize (many interpolation edge cases). +backend_test.xfail(r"test_resize.+(?:_cpu|_cuda)$") +# Reshape edge cases. +backend_test.xfail(r"test_reshape_.+(?:_cpu|_cuda)$") +# cumsum edge cases. +backend_test.xfail(r"test_cumsum.+(?:_cpu|_cuda)$") +# Constant/ConstantOfShape edge cases. +backend_test.xfail(r"test_constant_pad(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_constantofshape(?:_.*)?(?:_cpu|_cuda)$") +# ConvInteger / ConvTranspose edge cases. +backend_test.xfail(r"test_convinteger.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_convtranspose_dilations(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_convtranspose_output_shape(?:_.*)?(?:_cpu|_cuda)$") +# Pow with mixed types. +backend_test.xfail(r"test_pow_types.+(?:_cpu|_cuda)$") +# Expanded versions of ops. +backend_test.xfail(r"test_elu_.+expanded.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_hardsigmoid_.+expanded.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_leakyrelu.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_relu_expanded.+(?:_cpu|_cuda)$") +# Ops that fail on base tests. +backend_test.xfail(r"test_selu(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_thresholdedrelu(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_shrink(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_softplus(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_softsign(?:_.*)?(?:_cpu|_cuda)$") +# String comparison. +backend_test.xfail(r"test_equal_string.+(?:_cpu|_cuda)$") +# Various individual failures. +backend_test.xfail(r"test_expand_.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_eyelike.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_gather_elements_negative.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_gather_negative.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_gelu.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_gridsample_volumetric.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_identity_opt(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_identity_sequence(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_isinf_negative(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_isinf_positive(?:_.*)?(?:_cpu|_cuda)$") +backend_test.xfail(r"test_lppool.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_log_softmax.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_maxpool_with_argmax.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_maxunpool.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_nonmaxsuppression.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_onehot.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_optional.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_prelu.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_range_.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_roialign_mode_max.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_scatter_elements_with_reduction.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_scatter_elements_with_duplicate.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_softmax_functional.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_softmax_lastdim.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_upsample.+(?:_cpu|_cuda)$") +backend_test.xfail(r"test_squeezenet.+(?:_cpu|_cuda)$") + +globals().update(backend_test.test_cases) From a7b6c01a1514abd111fd0a1cb7dbdd69a77fbfee Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Wed, 6 May 2026 18:53:04 +0800 Subject: [PATCH 2/2] [Relax][ONNX] Add ONNX Backend Test runner for frontend coverage This PR introduces a test runner that reuses the official ONNX Backend Test suite to systematically cover relax.frontend.onnx. - Node-level test filtering via BackendTest._test_items - ONNX backend pytest marker - SKIP_SLOW_TESTS support - Documented xfails for known importer gaps --- .../relax/test_frontend_onnx_backend.py | 180 ++++++------------ 1 file changed, 57 insertions(+), 123 deletions(-) diff --git a/tests/python/relax/test_frontend_onnx_backend.py b/tests/python/relax/test_frontend_onnx_backend.py index 7aadeb712dbb..3eb63f153598 100644 --- a/tests/python/relax/test_frontend_onnx_backend.py +++ b/tests/python/relax/test_frontend_onnx_backend.py @@ -14,14 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name """ -ONNX Backend Tests for Relax ONNX Frontend -=========================================== +ONNX Backend Tests +=================== +Systematically verify the Relax ONNX importer using the official ONNX +Backend Test Suite (node-level tests only). Each test loads a small +ONNX model with protobuf reference inputs/outputs and checks that the +Relax-imported model produces numerically correct results. -Uses the official ONNX Backend Test Suite to systematically verify the -Relax ONNX importer against the ONNX specification. +Only ``onnx.backend.test.data.node`` tests are registered here; real, +simple, and PyTorch model tests are out of scope for importer-level +semantic verification. -Phase 1 (PoC): simple element-wise operators only. """ import numpy as np @@ -33,13 +38,16 @@ from tvm import relax from tvm.relax.frontend.onnx import from_onnx +# --------------------------------------------------------------------------- +# Backend adapter +# --------------------------------------------------------------------------- + class TVMRelaxBackendRep(BackendRep): - """Compiled Relax VM representation for running ONNX models.""" + """Compiled Relax VM representation for running an ONNX model.""" def __init__(self, mod, params, func_param_names, graph_input_names): super().__init__() - self._mod = mod self._params = params self._func_param_names = func_param_names self._graph_input_names = graph_input_names @@ -49,22 +57,19 @@ def __init__(self, mod, params, func_param_names, graph_input_names): self._vm = relax.VirtualMachine(ex, tvm.cpu()) def run(self, inputs, **kwargs): - # Build a name -> array mapping from the positional inputs list. - # The runner loads inputs matching model.graph.input order, but only - # for the number of .pb files found (non-initializer inputs). + # Map positional inputs to names. The runner loads one .pb per + # non-initializer input, aligned with model.graph.input order. input_map = {} for i, arr in enumerate(inputs): if i < len(self._graph_input_names): input_map[self._graph_input_names[i]] = arr - # Build input_list matching the Relax function's param order. - # User inputs come first (by func_param_names[:num_input]), - # then weight params from self._params. + # Build the argument list matching the Relax function's param order: + # user inputs first, then weight params from self._params. input_list = [] for name in self._func_param_names: if name in input_map: input_list.append(input_map[name]) - if self._params and "main" in self._params: input_list += self._params["main"] @@ -72,7 +77,6 @@ def run(self, inputs, **kwargs): self._vm.invoke_stateful("main") output = self._vm.get_outputs("main") - # Normalize output to tuple of numpy arrays. if isinstance(output, (tvm.runtime.Tensor, np.ndarray)): return (output.numpy() if hasattr(output, "numpy") else output,) if isinstance(output, (tuple, list)): @@ -91,24 +95,19 @@ def is_compatible(cls, model, device="CPU", **kwargs): @classmethod def prepare(cls, model, device="CPU", **kwargs): - # Extract opset version. opset = None for opset_import in model.opset_import: if opset_import.domain in ("", "ai.onnx"): opset = opset_import.version break - # Import ONNX model into Relax. tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True) tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) tvm_model = relax.transform.LegalizeOps()(tvm_model) tvm_model, params = relax.frontend.detach_params(tvm_model) - # Collect function parameter names (user inputs) and graph input names. func = tvm_model["main"] func_param_names = [p.name_hint for p in func.params] - - # Graph input names from the ONNX model (all inputs including initializers). graph_input_names = [inp.name for inp in model.graph.input] return TVMRelaxBackendRep( @@ -120,114 +119,49 @@ def supports_device(cls, device: str) -> bool: return device == "CPU" -# Register the backend test suite. +# --------------------------------------------------------------------------- +# Test registration +# --------------------------------------------------------------------------- + backend_test = onnx.backend.test.BackendTest(TVMRelaxBackend, __name__) -# Include node tests for all operators supported by the Relax ONNX frontend. -# The runner appends _cpu/_cuda to test names, so we match with device suffix. -# 116 operators have corresponding ONNX BackendTest node tests. +# Operators where ALL ONNX node tests pass on the Relax importer. +# Each prefix covers the base test and all its variants +# (e.g. test_add, test_add_bcast, test_add_uint8). +# +# Operators not listed here have known importer gaps or have not yet been +# validated against the ONNX Backend Test Suite. They can be added +# incrementally as the importer improves. _INCLUDE_OPS = [ - "abs", "acos", "acosh", "add", "and", "argmax", "argmin", "asin", - "asinh", "atan", "atanh", "attention", "averagepool", "bitshift", - "cast", "ceil", "clip", "compress", "concat", "constant", - "constantofshape", "conv", "convtranspose", "cos", "cosh", "cumsum", - "depthtospace", "dequantizelinear", "div", "dynamicquantizelinear", - "einsum", "elu", "equal", "erf", "exp", "expand", "eyelike", - "flatten", "floor", "gather", "gathernd", "gelu", "gemm", - "globalaveragepool", "globalmaxpool", "greater", "gridsample", - "hardmax", "hardsigmoid", "hardswish", "identity", "isinf", "isnan", - "leakyrelu", "less", "log", "logsoftmax", "lppool", "lrn", "matmul", - "matmulinteger", "max", "maxpool", "maxunpool", "mean", "min", "mish", - "mod", "mul", "neg", "nonmaxsuppression", "nonzero", "not", "onehot", - "optional", "or", "pow", "prelu", "quantizelinear", "range", - "reciprocal", "relu", "reshape", "resize", "roialign", "round", - "scatter", "scatternd", "selu", "shape", "shrink", "sigmoid", "sign", - "sin", "sinh", "size", "slice", "softmax", "softplus", "softsign", - "spacetodepth", "split", "sqrt", "squeeze", "sub", "sum", "tan", - "tanh", "thresholdedrelu", "tile", "transpose", "unique", "unsqueeze", - "upsample", "where", "xor", + "abs", "acos", "acosh", "add", "and", "argmax", "argmin", + "averagepool", "bitshift", + "bitwise_and", "bitwise_not", "bitwise_or", "bitwise_xor", + "ceil", "clip", "compress", "concat", + "conv", "cos", "cosh", + "depthtospace", "div", + "einsum", "erf", "exp", + "flatten", "floor", + "gathernd", "gemm", + "globalaveragepool", "globalmaxpool", "greater", "greater_equal", + "hardmax", "hardswish", + "isnan", + "less", "less_equal", "lrn", + "matmul", "matmulinteger", "mean", "min", "mod", "mul", "neg", + "nonzero", "not", + "or", + "reciprocal", + "round", + "scatternd", + "sigmoid", "sign", + "sin", "sinh", "size", "slice", + "spacetodepth", + "sqrt", "squeeze", "sub", "sum", + "tan", "tanh", "tile", "transpose", + "unique", "unsqueeze", + "where", "xor", ] for _op in _INCLUDE_OPS: - backend_test.include(rf"test_{_op}.*(?:_cpu|_cuda)$") - -# Known failures — xfail by category. -# Use (?:_.*)? to optionally match variant suffixes before _cpu/_cuda, -# avoiding greedy .* that would consume the device suffix. -# Trig function precision issues. -backend_test.xfail(r"test_asin(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_asinh(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_atan(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_atanh(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_mish(?:_.*)?(?:_cpu|_cuda)$") -# Output format mismatches. -backend_test.xfail(r"test_shape(?:_.*)?(?:_cpu|_cuda)$") -# Dynamic split not supported. -backend_test.xfail(r"test_split_variable_parts(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_split_zero_size_splits(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_split_to_sequence(?:_.*)?(?:_cpu|_cuda)$") -# All cast/castlike tests (exotic dtypes). -backend_test.xfail(r"test_cast(?:_e8m0)?_.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_castlike.+(?:_cpu|_cuda)$") -# Quantize/Dequantize edge cases. -backend_test.xfail(r"test_dequantizelinear_.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_quantizelinear_.+(?:_cpu|_cuda)$") -# Attention (complex op). -backend_test.xfail(r"test_attention.+(?:_cpu|_cuda)$") -# Resize (many interpolation edge cases). -backend_test.xfail(r"test_resize.+(?:_cpu|_cuda)$") -# Reshape edge cases. -backend_test.xfail(r"test_reshape_.+(?:_cpu|_cuda)$") -# cumsum edge cases. -backend_test.xfail(r"test_cumsum.+(?:_cpu|_cuda)$") -# Constant/ConstantOfShape edge cases. -backend_test.xfail(r"test_constant_pad(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_constantofshape(?:_.*)?(?:_cpu|_cuda)$") -# ConvInteger / ConvTranspose edge cases. -backend_test.xfail(r"test_convinteger.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_convtranspose_dilations(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_convtranspose_output_shape(?:_.*)?(?:_cpu|_cuda)$") -# Pow with mixed types. -backend_test.xfail(r"test_pow_types.+(?:_cpu|_cuda)$") -# Expanded versions of ops. -backend_test.xfail(r"test_elu_.+expanded.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_hardsigmoid_.+expanded.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_leakyrelu.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_relu_expanded.+(?:_cpu|_cuda)$") -# Ops that fail on base tests. -backend_test.xfail(r"test_selu(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_thresholdedrelu(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_shrink(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_softplus(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_softsign(?:_.*)?(?:_cpu|_cuda)$") -# String comparison. -backend_test.xfail(r"test_equal_string.+(?:_cpu|_cuda)$") -# Various individual failures. -backend_test.xfail(r"test_expand_.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_eyelike.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_gather_elements_negative.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_gather_negative.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_gelu.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_gridsample_volumetric.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_identity_opt(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_identity_sequence(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_isinf_negative(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_isinf_positive(?:_.*)?(?:_cpu|_cuda)$") -backend_test.xfail(r"test_lppool.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_log_softmax.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_maxpool_with_argmax.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_maxunpool.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_nonmaxsuppression.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_onehot.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_optional.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_prelu.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_range_.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_roialign_mode_max.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_scatter_elements_with_reduction.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_scatter_elements_with_duplicate.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_softmax_functional.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_softmax_lastdim.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_upsample.+(?:_cpu|_cuda)$") -backend_test.xfail(r"test_squeezenet.+(?:_cpu|_cuda)$") + backend_test.include(rf"^test_{_op}(?:_.*)?(?:_cpu|_cuda)$") globals().update(backend_test.test_cases)