Skip to content

Good First Issue: Add Full Integer Support for aten.bitwise_and #18925

@metascroy

Description

@metascroy

🚀 The feature, motivation and pitch

Good First Issue: Add Full Integer Support for aten.bitwise_and

Summary

Extend aten.bitwise_and support in the MLX delegate to handle integer tensors, not just boolean tensors. Currently the handler only works for bool dtype and falls back to CPU for integers.

Background

The MLX delegate currently has a bitwise_and handler that only supports boolean tensors (dispatching to LogicalAndNode). However, MLX has native support for bitwise operations on integers via mlx::core::bitwise_and.

Current limitation (in ops.py):

@REGISTRY.register(
    target=[torch.ops.aten.logical_and.default, torch.ops.aten.bitwise_and.Tensor]
)
def _logical_and_handler(P: MLXProgramBuilder, n: Node) -> Slot:
    ...
    # bitwise_and is only equivalent to logical_and for bool tensors.
    if n.target == torch.ops.aten.bitwise_and.Tensor:
        if dtype.dtype != torch.bool:
            raise NotImplementedError(
                f"aten.bitwise_and on non-bool dtype {dtype.dtype} is not supported"
            )

Approach: New schema node + runtime

Add a BitwiseAndNode to handle integer types via MLX's bitwise_and.

Steps

  1. Add node to backends/mlx/serialization/schema.fbs

    table BitwiseAndNode {
      a: Tid;
      b: Tid;
      out: Tid;
    }

    Add BitwiseAndNode to the OpNode union (append only, do not reorder).

  2. Regenerate serialization code

    python backends/mlx/serialization/generate.py
  3. Add C++ runtime exec function in backends/mlx/runtime/MLXInterpreter.h

    inline void exec_bitwise_and(
        const BitwiseAndNode& n, ExecutionState& st, StreamOrDevice s) {
      auto a = st.get_tensor(n.a());
      auto b = st.get_tensor(n.b());
      auto out = mlx::core::bitwise_and(a, b, s);
      st.set_tensor(n.out(), out);
    }
  4. Update handler in backends/mlx/ops.py

    Add to the _BINARY_OPS table:

    # In _BINARY_OPS list, add:
    ([torch.ops.aten.bitwise_and.Tensor, torch.ops.aten.bitwise_and.Scalar], BitwiseAndNode, "aten.bitwise_and", True),

    Note: You'll also need to update or remove the existing _logical_and_handler that currently handles bitwise_and.Tensor with bool-only restriction. The table entry will handle integer types; you may keep a separate handler for bool that uses LogicalAndNode for efficiency.

  5. Add test in backends/mlx/test/test_ops.py

    Use the _BINARY_OP_TESTS table with integer inputs:

    # Add to _BINARY_OP_TESTS list:
    {"op_name": "bitwise_and_int", "op_fn": torch.bitwise_and, 
     "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], 
     "input_fn_a": _int_input_fn(0, 256), "input_fn_b": _int_input_fn(0, 256)},

Running tests

python -m executorch.backends.mlx.test.run_all_tests -k bitwise_and

References

  • MLX C++: array bitwise_and(const array &a, const array &b, StreamOrDevice s = {})
  • PyTorch signature: bitwise_and(Tensor self, Tensor other) -> Tensor
  • Supported dtypes: int8, int16, int32, int64, uint8, bool

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    good first issueGood for newcomerstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions