Skip to content

Good First Issue: Add MLX Op Handler for aten.isinf #18922

@metascroy

Description

@metascroy

🚀 The feature, motivation and pitch

Good First Issue: Add MLX Op Handler for aten.isinf

Summary

Add support for aten.isinf in the MLX delegate. This op checks for infinite values element-wise and is needed for numerical stability checks and gradient clipping.

Background

The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. When an aten op has no handler, it falls back to CPU execution, breaking the GPU acceleration pipeline. Adding a handler lets the op run on the Metal GPU via MLX.

Approach: Decomposed handler (preferred)

aten.isinf can be decomposed using comparison with infinity:

# isinf(x) = abs(x) == inf

This uses existing AbsNode and EqualNode which are already supported.

Steps

  1. Add handler in backends/mlx/ops.py

    @REGISTRY.register(target=[torch.ops.aten.isinf.default])
    def _isinf_handler(P: MLXProgramBuilder, n: Node) -> Slot:
        """Handle aten.isinf - check for infinite values element-wise.
        
        isinf(x) is equivalent to abs(x) == inf.
        """
        args = P.args(n)
        require_args(args, 1, 1, "aten.isinf")
        require_kwargs(P.kwargs(n), set(), "aten.isinf")
        x = args[0]
        
        # Create abs(x)
        _, abs_tmp = P.make_tmp_slot()
        P.emit(
            AbsNode(
                x=P.slot_to_tid(x),
                out=P.slot_to_tid(abs_tmp),
            )
        )
        
        # Create inf constant
        inf_slot = emit_lifted_constant(P, float('inf'), torch.float32)
        
        # Compare abs(x) == inf
        out = P.make_or_get_slot(n)
        P.emit(
            EqualNode(
                a=P.slot_to_tid(abs_tmp),
                b=P.slot_to_tid(inf_slot),
                out=P.slot_to_tid(out),
            )
        )
        return out
  2. Add test in backends/mlx/test/test_ops.py

    Use the existing _make_unary_op_test infrastructure with a custom input function that includes inf values:

    def _inf_input_fn():
        """Return a callable(shape, dtype) that generates inputs with some inf values."""
        def fn(shape, dtype):
            x = torch.randn(shape, dtype=dtype)
            # Insert some inf values
            mask_pos = torch.rand(shape) > 0.8
            mask_neg = torch.rand(shape) > 0.9
            x[mask_pos] = float('inf')
            x[mask_neg] = float('-inf')
            return (x,)
        return fn
    
    # Add to _UNARY_OP_TESTS list:
    {"op_name": "isinf", "op_fn": torch.isinf, "shapes": _SHAPES_3, "input_fn": _inf_input_fn()},

Running tests

python -m executorch.backends.mlx.test.run_all_tests -k isinf

References

  • MLX C++: array isinf(const array &a, StreamOrDevice s = {})
  • PyTorch signature: isinf(Tensor self) -> Tensor
  • Note: This catches both positive and negative infinity
  • Test infrastructure: See _make_unary_op_test and _UNARY_OP_TESTS in test_ops.py

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions