Skip to content

Good First Issue: Add Full Integer Support for aten.bitwise_not #18924

@metascroy

Description

@metascroy

🚀 The feature, motivation and pitch

Good First Issue: Add Full Integer Support for aten.bitwise_not

Summary

Extend aten.bitwise_not support in the MLX delegate to handle integer tensors, not just boolean tensors. Currently the handler only works for bool dtype and falls back to CPU for integers.

Background

The MLX delegate currently has a bitwise_not handler that only supports boolean tensors (dispatching to LogicalNotNode). However, MLX has native support for bitwise operations on integers via mlx::core::bitwise_invert.

Current limitation (in ops.py):

@REGISTRY.register(target=[torch.ops.aten.bitwise_not.default])
def _bitwise_not_handler(P: MLXProgramBuilder, n: Node) -> Slot:
    ...
    if dtype.dtype == torch.bool:
        # For boolean tensors, bitwise_not is equivalent to logical_not
        P.emit(LogicalNotNode(...))
    else:
        raise NotImplementedError(
            f"aten.bitwise_not is only supported for boolean tensors. "
        )

Approach: New schema node + runtime

Add a BitwiseInvertNode to handle integer types via MLX's bitwise_invert.

Steps

  1. Add node to backends/mlx/serialization/schema.fbs

    table BitwiseInvertNode {
      x: Tid;
      out: Tid;
    }

    Add BitwiseInvertNode to the OpNode union (append only, do not reorder).

  2. Regenerate serialization code

    python backends/mlx/serialization/generate.py
  3. Add C++ runtime exec function in backends/mlx/runtime/MLXInterpreter.h

    inline void exec_bitwise_invert(
        const BitwiseInvertNode& n, ExecutionState& st, StreamOrDevice s) {
      auto x = st.get_tensor(n.x());
      auto out = mlx::core::bitwise_invert(x, s);
      st.set_tensor(n.out(), out);
    }
  4. Update handler in backends/mlx/ops.py

    Since BitwiseInvertNode is a unary op, add it to the _UNARY_OPS table:

    # In _UNARY_OPS list, add:
    (torch.ops.aten.bitwise_not.default, BitwiseInvertNode, "aten.bitwise_not"),

    Note: You'll also need to update the existing _bitwise_not_handler to remove the bool-only restriction, OR keep the custom handler that dispatches to LogicalNotNode for bool and BitwiseInvertNode for integers (shown in step 4 above).

  5. Add test in backends/mlx/test/test_ops.py

    Use the _UNARY_OP_TESTS table with integer inputs:

    # Add to _UNARY_OP_TESTS list:
    {"op_name": "bitwise_not_int", "op_fn": torch.bitwise_not, 
     "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], 
     "input_fn": _int_input_fn()},

Running tests

python -m executorch.backends.mlx.test.run_all_tests -k bitwise_not

References

  • MLX C++: array bitwise_invert(const array &a, StreamOrDevice s = {})
  • PyTorch signature: bitwise_not(Tensor self) -> Tensor
  • Supported dtypes: int8, int16, int32, int64, uint8, bool

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions