🚀 The feature, motivation and pitch
Good First Issue: Add MLX Op Handler for aten.isinf
Summary
Add support for aten.isinf in the MLX delegate. This op checks for infinite values element-wise and is needed for numerical stability checks and gradient clipping.
Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.
Approach: Decomposed handler (preferred)
aten.isinf can be decomposed using comparison with infinity:
# isinf(x) = abs(x) == inf
This uses existing AbsNode and EqualNode which are already supported.
Steps
-
Add handler in backends/mlx/ops.py
@REGISTRY.register(target=[torch.ops.aten.isinf.default])
def _isinf_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.isinf - check for infinite values element-wise.
isinf(x) is equivalent to abs(x) == inf.
"""
args = P.args(n)
require_args(args, 1, 1, "aten.isinf")
require_kwargs(P.kwargs(n), set(), "aten.isinf")
x = args[0]
# Create abs(x)
_, abs_tmp = P.make_tmp_slot()
P.emit(
AbsNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(abs_tmp),
)
)
# Create inf constant
inf_slot = emit_lifted_constant(P, float('inf'), torch.float32)
# Compare abs(x) == inf
out = P.make_or_get_slot(n)
P.emit(
EqualNode(
a=P.slot_to_tid(abs_tmp),
b=P.slot_to_tid(inf_slot),
out=P.slot_to_tid(out),
)
)
return out
-
Add test in backends/mlx/test/test_ops.py
Use the existing _make_unary_op_test infrastructure with a custom input function that includes inf values:
def _inf_input_fn():
"""Return a callable(shape, dtype) that generates inputs with some inf values."""
def fn(shape, dtype):
x = torch.randn(shape, dtype=dtype)
# Insert some inf values
mask_pos = torch.rand(shape) > 0.8
mask_neg = torch.rand(shape) > 0.9
x[mask_pos] = float('inf')
x[mask_neg] = float('-inf')
return (x,)
return fn
# Add to _UNARY_OP_TESTS list:
{"op_name": "isinf", "op_fn": torch.isinf, "shapes": _SHAPES_3, "input_fn": _inf_input_fn()},
Running tests
python -m executorch.backends.mlx.test.run_all_tests -k isinf
References
- MLX C++:
array isinf(const array &a, StreamOrDevice s = {})
- PyTorch signature:
isinf(Tensor self) -> Tensor
- Note: This catches both positive and negative infinity
- Test infrastructure: See
_make_unary_op_test and _UNARY_OP_TESTS in test_ops.py
Alternatives
No response
Additional context
No response
RFC (Optional)
No response
🚀 The feature, motivation and pitch
Good First Issue: Add MLX Op Handler for
aten.isinfSummary
Add support for
aten.isinfin the MLX delegate. This op checks for infinite values element-wise and is needed for numerical stability checks and gradient clipping.Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.
Approach: Decomposed handler (preferred)
aten.isinfcan be decomposed using comparison with infinity:# isinf(x) = abs(x) == infThis uses existing
AbsNodeandEqualNodewhich are already supported.Steps
Add handler in
backends/mlx/ops.pyAdd test in
backends/mlx/test/test_ops.pyUse the existing
_make_unary_op_testinfrastructure with a custom input function that includes inf values:Running tests
References
array isinf(const array &a, StreamOrDevice s = {})isinf(Tensor self) -> Tensor_make_unary_op_testand_UNARY_OP_TESTSintest_ops.pyAlternatives
No response
Additional context
No response
RFC (Optional)
No response