Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2697,6 +2697,41 @@ def _relu_handler(P: MLXProgramBuilder, n: Node) -> Slot:
return out


@REGISTRY.register(target=[torch.ops.aten.isinf.default])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any other variants of this torch op? E.g., isinf.Tensor, etc.

There may not be, just curious.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the PyTorch ATen operator registry and the executorch codebase — the only variant is aten.isinf.default. There is no isinf.Tensor or isinf.out overload defined. The other backends (MPS, Qualcomm) also only register aten.isinf.default, so this handler covers the complete surface.

def _isinf_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.isinf - check for infinite values element-wise.

isinf(x) is equivalent to abs(x) == inf.
"""
args = P.args(n)
require_args(args, 1, 1, "aten.isinf")
require_kwargs(P.kwargs(n), set(), "aten.isinf")
x = args[0]

# abs(x)
_, abs_tmp = P.make_tmp_slot()
P.emit(
AbsNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(abs_tmp),
)
)

# inf constant
inf_slot = emit_lifted_constant(P, float("inf"), torch.float32)

# abs(x) == inf
out = P.make_or_get_slot(n)
P.emit(
EqualNode(
a=P.slot_to_tid(abs_tmp),
b=P.slot_to_tid(inf_slot),
out=P.slot_to_tid(out),
)
)
return out


@REGISTRY.register(target=[torch.ops.aten._log_softmax.default])
def _log_softmax_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten._log_softmax.default - log of softmax.
Expand Down
17 changes: 17 additions & 0 deletions backends/mlx/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4004,6 +4004,22 @@ def fn(shape, dtype):
return fn


def _inf_input_fn():
"""Return a callable(shape, dtype) that generates inputs with some inf/nan values."""

def fn(shape, dtype):
x = torch.randn(shape, dtype=dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some nans to this generated test input as well?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — added NaN values to _inf_input_fn. The generated inputs now include float("nan") alongside inf/-inf, so the test verifies that isinf correctly returns False for NaN. Pushed in 7647a7e.

mask_pos = torch.rand(shape) > 0.8
mask_neg = torch.rand(shape) > 0.9
mask_nan = torch.rand(shape) > 0.85
x[mask_pos] = float("inf")
x[mask_neg] = float("-inf")
x[mask_nan] = float("nan")
return (x,)

return fn


# Standard shape and dtype configs used by unary tests.
_SHAPES_3 = [(16,), (4, 4), (2, 3, 4)]
_SHAPES_2 = [(16,), (4, 4)]
Expand Down Expand Up @@ -4103,6 +4119,7 @@ def create_model(self) -> nn.Module:
# math
{"op_name": "rsqrt", "op_fn": torch.rsqrt, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(uniform=True, offset=0.1)},
{"op_name": "clone", "op_fn": torch.clone, "shapes": [(2, 3, 4), (8, 8), (16,)], "dtypes": [torch.float32]},
{"op_name": "isinf", "op_fn": torch.isinf, "shapes": _SHAPES_3, "input_fn": _inf_input_fn()},
]
# fmt: on

Expand Down
Loading