From 0ca3eb03ca1615d2261a092bab1d8dc7579eeb5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ai-chan-0411=20=28=E8=97=8D=29?= Date: Thu, 16 Apr 2026 13:56:21 +0900 Subject: [PATCH 1/2] Add MLX op handler for aten.isinf Decompose isinf(x) into abs(x) == inf using existing AbsNode and EqualNode, so the op runs on the Metal GPU via MLX instead of falling back to CPU execution. Closes #18922 Co-Authored-By: Claude Opus 4.6 --- backends/mlx/ops.py | 35 +++++++++++++++++++++++++++++++++++ backends/mlx/test/test_ops.py | 15 +++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 4dc891ee984..32b75fb21af 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -2697,6 +2697,41 @@ def _relu_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register(target=[torch.ops.aten.isinf.default]) +def _isinf_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.isinf - check for infinite values element-wise. + + isinf(x) is equivalent to abs(x) == inf. + """ + args = P.args(n) + require_args(args, 1, 1, "aten.isinf") + require_kwargs(P.kwargs(n), set(), "aten.isinf") + x = args[0] + + # abs(x) + _, abs_tmp = P.make_tmp_slot() + P.emit( + AbsNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(abs_tmp), + ) + ) + + # inf constant + inf_slot = emit_lifted_constant(P, float("inf"), torch.float32) + + # abs(x) == inf + out = P.make_or_get_slot(n) + P.emit( + EqualNode( + a=P.slot_to_tid(abs_tmp), + b=P.slot_to_tid(inf_slot), + out=P.slot_to_tid(out), + ) + ) + return out + + @REGISTRY.register(target=[torch.ops.aten._log_softmax.default]) def _log_softmax_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Handle aten._log_softmax.default - log of softmax. diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index e5ece4931b9..bcac8a7d085 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -4004,6 +4004,20 @@ def fn(shape, dtype): return fn +def _inf_input_fn(): + """Return a callable(shape, dtype) that generates inputs with some inf values.""" + + def fn(shape, dtype): + x = torch.randn(shape, dtype=dtype) + mask_pos = torch.rand(shape) > 0.8 + mask_neg = torch.rand(shape) > 0.9 + x[mask_pos] = float("inf") + x[mask_neg] = float("-inf") + return (x,) + + return fn + + # Standard shape and dtype configs used by unary tests. _SHAPES_3 = [(16,), (4, 4), (2, 3, 4)] _SHAPES_2 = [(16,), (4, 4)] @@ -4103,6 +4117,7 @@ def create_model(self) -> nn.Module: # math {"op_name": "rsqrt", "op_fn": torch.rsqrt, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(uniform=True, offset=0.1)}, {"op_name": "clone", "op_fn": torch.clone, "shapes": [(2, 3, 4), (8, 8), (16,)], "dtypes": [torch.float32]}, + {"op_name": "isinf", "op_fn": torch.isinf, "shapes": _SHAPES_3, "input_fn": _inf_input_fn()}, ] # fmt: on From 7647a7ee44579416ea373af155a13fc3b4cf53c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ai-chan-0411=20=28=E8=97=8D=29?= Date: Fri, 17 Apr 2026 03:23:52 +0900 Subject: [PATCH 2/2] Add NaN values to isinf test inputs Include NaN in the generated test data alongside inf/-inf to ensure isinf correctly returns False for NaN inputs. --- backends/mlx/test/test_ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index bcac8a7d085..8401eb8797c 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -4005,14 +4005,16 @@ def fn(shape, dtype): def _inf_input_fn(): - """Return a callable(shape, dtype) that generates inputs with some inf values.""" + """Return a callable(shape, dtype) that generates inputs with some inf/nan values.""" def fn(shape, dtype): x = torch.randn(shape, dtype=dtype) mask_pos = torch.rand(shape) > 0.8 mask_neg = torch.rand(shape) > 0.9 + mask_nan = torch.rand(shape) > 0.85 x[mask_pos] = float("inf") x[mask_neg] = float("-inf") + x[mask_nan] = float("nan") return (x,) return fn