🚀 The feature, motivation and pitch
Good First Issue: Add Full Integer Support for aten.bitwise_and
Summary
Extend aten.bitwise_and support in the MLX delegate to handle integer tensors, not just boolean tensors. Currently the handler only works for bool dtype and falls back to CPU for integers.
Background
The MLX delegate currently has a bitwise_and handler that only supports boolean tensors (dispatching to LogicalAndNode). However, MLX has native support for bitwise operations on integers via mlx::core::bitwise_and.
Current limitation (in ops.py):
@REGISTRY.register(
target=[torch.ops.aten.logical_and.default, torch.ops.aten.bitwise_and.Tensor]
)
def _logical_and_handler(P: MLXProgramBuilder, n: Node) -> Slot:
...
# bitwise_and is only equivalent to logical_and for bool tensors.
if n.target == torch.ops.aten.bitwise_and.Tensor:
if dtype.dtype != torch.bool:
raise NotImplementedError(
f"aten.bitwise_and on non-bool dtype {dtype.dtype} is not supported"
)
Approach: New schema node + runtime
Add a BitwiseAndNode to handle integer types via MLX's bitwise_and.
Steps
-
Add node to backends/mlx/serialization/schema.fbs
table BitwiseAndNode {
a: Tid;
b: Tid;
out: Tid;
}
Add BitwiseAndNode to the OpNode union (append only, do not reorder).
-
Regenerate serialization code
python backends/mlx/serialization/generate.py
-
Add C++ runtime exec function in backends/mlx/runtime/MLXInterpreter.h
inline void exec_bitwise_and(
const BitwiseAndNode& n, ExecutionState& st, StreamOrDevice s) {
auto a = st.get_tensor(n.a());
auto b = st.get_tensor(n.b());
auto out = mlx::core::bitwise_and(a, b, s);
st.set_tensor(n.out(), out);
}
-
Update handler in backends/mlx/ops.py
Add to the _BINARY_OPS table:
# In _BINARY_OPS list, add:
([torch.ops.aten.bitwise_and.Tensor, torch.ops.aten.bitwise_and.Scalar], BitwiseAndNode, "aten.bitwise_and", True),
Note: You'll also need to update or remove the existing _logical_and_handler that currently handles bitwise_and.Tensor with bool-only restriction. The table entry will handle integer types; you may keep a separate handler for bool that uses LogicalAndNode for efficiency.
-
Add test in backends/mlx/test/test_ops.py
Use the _BINARY_OP_TESTS table with integer inputs:
# Add to _BINARY_OP_TESTS list:
{"op_name": "bitwise_and_int", "op_fn": torch.bitwise_and,
"shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64],
"input_fn_a": _int_input_fn(0, 256), "input_fn_b": _int_input_fn(0, 256)},
Running tests
python -m executorch.backends.mlx.test.run_all_tests -k bitwise_and
References
- MLX C++:
array bitwise_and(const array &a, const array &b, StreamOrDevice s = {})
- PyTorch signature:
bitwise_and(Tensor self, Tensor other) -> Tensor
- Supported dtypes:
int8, int16, int32, int64, uint8, bool
Alternatives
No response
Additional context
No response
RFC (Optional)
No response
🚀 The feature, motivation and pitch
Good First Issue: Add Full Integer Support for
aten.bitwise_andSummary
Extend
aten.bitwise_andsupport in the MLX delegate to handle integer tensors, not just boolean tensors. Currently the handler only works for bool dtype and falls back to CPU for integers.Background
The MLX delegate currently has a
bitwise_andhandler that only supports boolean tensors (dispatching toLogicalAndNode). However, MLX has native support for bitwise operations on integers viamlx::core::bitwise_and.Current limitation (in
ops.py):Approach: New schema node + runtime
Add a
BitwiseAndNodeto handle integer types via MLX'sbitwise_and.Steps
Add node to
backends/mlx/serialization/schema.fbsAdd
BitwiseAndNodeto theOpNodeunion (append only, do not reorder).Regenerate serialization code
Add C++ runtime exec function in
backends/mlx/runtime/MLXInterpreter.hUpdate handler in
backends/mlx/ops.pyAdd to the
_BINARY_OPStable:Note: You'll also need to update or remove the existing
_logical_and_handlerthat currently handlesbitwise_and.Tensorwith bool-only restriction. The table entry will handle integer types; you may keep a separate handler for bool that usesLogicalAndNodefor efficiency.Add test in
backends/mlx/test/test_ops.pyUse the
_BINARY_OP_TESTStable with integer inputs:Running tests
References
array bitwise_and(const array &a, const array &b, StreamOrDevice s = {})bitwise_and(Tensor self, Tensor other) -> Tensorint8,int16,int32,int64,uint8,boolAlternatives
No response
Additional context
No response
RFC (Optional)
No response