Skip to content

Good First Issue: Add MLX Op Handler for aten.isnan #18920

@metascroy

Description

@metascroy

🚀 The feature, motivation and pitch

Good First Issue: Add MLX Op Handler for aten.isnan

Summary

Add support for aten.isnan in the MLX delegate. This op checks for NaN values element-wise and is needed for numerical stability checks in training and inference pipelines.

Background

The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.

Approach: Decomposed handler (preferred)

aten.isnan can be decomposed using the property that NaN != NaN:

# isnan(x) = x != x

This uses the existing NotEqualNode which is already supported.

Steps

  1. Add handler in backends/mlx/ops.py

    @REGISTRY.register(target=[torch.ops.aten.isnan.default])
    def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot:
        """Handle aten.isnan - check for NaN values element-wise.
        
        isnan(x) is equivalent to x != x (NaN is the only value not equal to itself).
        """
        args = P.args(n)
        require_args(args, 1, 1, "aten.isnan")
        require_kwargs(P.kwargs(n), set(), "aten.isnan")
        x = args[0]
        
        out = P.make_or_get_slot(n)
        P.emit(
            NotEqualNode(
                a=P.slot_to_tid(x),
                b=P.slot_to_tid(x),
                out=P.slot_to_tid(out),
            )
        )
        return out
  2. Add test in backends/mlx/test/test_ops.py

    Use the existing _make_unary_op_test infrastructure with a custom input function that includes NaN values:

    def _nan_input_fn():
        """Return a callable(shape, dtype) that generates inputs with some NaN values."""
        def fn(shape, dtype):
            x = torch.randn(shape, dtype=dtype)
            # Insert some NaN values
            mask = torch.rand(shape) > 0.7
            x[mask] = float('nan')
            return (x,)
        return fn
    
    # Add to _UNARY_OP_TESTS list:
    {"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "input_fn": _nan_input_fn()},

Running tests

python -m executorch.backends.mlx.test.run_all_tests -k isnan

References

  • MLX C++: array isnan(const array &a, StreamOrDevice s = {})
  • PyTorch signature: isnan(Tensor self) -> Tensor
  • Mathematical property: NaN is the only float value where x != x is true
  • Test infrastructure: See _make_unary_op_test and _UNARY_OP_TESTS in test_ops.py

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    good first issueGood for newcomerstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions