🚀 The feature, motivation and pitch
Good First Issue: Add MLX Op Handler for aten.isnan
Summary
Add support for aten.isnan in the MLX delegate. This op checks for NaN values element-wise and is needed for numerical stability checks in training and inference pipelines.
Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.
Approach: Decomposed handler (preferred)
aten.isnan can be decomposed using the property that NaN != NaN:
This uses the existing NotEqualNode which is already supported.
Steps
-
Add handler in backends/mlx/ops.py
@REGISTRY.register(target=[torch.ops.aten.isnan.default])
def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.isnan - check for NaN values element-wise.
isnan(x) is equivalent to x != x (NaN is the only value not equal to itself).
"""
args = P.args(n)
require_args(args, 1, 1, "aten.isnan")
require_kwargs(P.kwargs(n), set(), "aten.isnan")
x = args[0]
out = P.make_or_get_slot(n)
P.emit(
NotEqualNode(
a=P.slot_to_tid(x),
b=P.slot_to_tid(x),
out=P.slot_to_tid(out),
)
)
return out
-
Add test in backends/mlx/test/test_ops.py
Use the existing _make_unary_op_test infrastructure with a custom input function that includes NaN values:
def _nan_input_fn():
"""Return a callable(shape, dtype) that generates inputs with some NaN values."""
def fn(shape, dtype):
x = torch.randn(shape, dtype=dtype)
# Insert some NaN values
mask = torch.rand(shape) > 0.7
x[mask] = float('nan')
return (x,)
return fn
# Add to _UNARY_OP_TESTS list:
{"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "input_fn": _nan_input_fn()},
Running tests
python -m executorch.backends.mlx.test.run_all_tests -k isnan
References
- MLX C++:
array isnan(const array &a, StreamOrDevice s = {})
- PyTorch signature:
isnan(Tensor self) -> Tensor
- Mathematical property: NaN is the only float value where
x != x is true
- Test infrastructure: See
_make_unary_op_test and _UNARY_OP_TESTS in test_ops.py
Alternatives
No response
Additional context
No response
RFC (Optional)
No response
🚀 The feature, motivation and pitch
Good First Issue: Add MLX Op Handler for
aten.isnanSummary
Add support for
aten.isnanin the MLX delegate. This op checks for NaN values element-wise and is needed for numerical stability checks in training and inference pipelines.Background
The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.
Approach: Decomposed handler (preferred)
aten.isnancan be decomposed using the property that NaN != NaN:# isnan(x) = x != xThis uses the existing
NotEqualNodewhich is already supported.Steps
Add handler in
backends/mlx/ops.pyAdd test in
backends/mlx/test/test_ops.pyUse the existing
_make_unary_op_testinfrastructure with a custom input function that includes NaN values:Running tests
References
array isnan(const array &a, StreamOrDevice s = {})isnan(Tensor self) -> Tensorx != xis true_make_unary_op_testand_UNARY_OP_TESTSintest_ops.pyAlternatives
No response
Additional context
No response
RFC (Optional)
No response