Skip to content

Good First Issue: Add MLX Op Handler for aten.roll #18919

@metascroy

Description

@metascroy

Good First Issue: Add MLX Op Handler for aten.roll

Summary

Add support for aten.roll in the MLX delegate. This op shifts tensor elements along specified dimensions with wrap-around and is needed by Swin Transformer's shift-window attention mechanism.

Background

The MLX delegate converts PyTorch aten ops into MLX graph nodes during export. While aten.roll decomposes into index_select + arange + cat operations, a native MLX implementation using mlx::core::roll would be more efficient (single kernel vs multiple ops).

Approach: New schema node + runtime

This requires a new RollNode because MLX has a dedicated roll function that's more efficient than the decomposed representation.

Steps

  1. Add node to backends/mlx/serialization/schema.fbs

    table RollNode {
      x: Tid;
      out: Tid;
      shift: [IntOrVid];  // shift amounts per axis
      axes: [int];        // axes to roll
    }

    Add RollNode to the OpNode union (append only, do not reorder).

  2. Regenerate serialization code

    python backends/mlx/serialization/generate.py
  3. Add C++ runtime exec function in backends/mlx/runtime/MLXInterpreter.h

    void exec_RollNode(const RollNode& node) {
      auto x = get_tensor(node.x());
      
      std::vector<int> shifts;
      for (auto s : *node.shift()) {
        shifts.push_back(resolve_int_or_vid(s));
      }
      
      std::vector<int> axes;
      for (auto a : *node.axes()) {
        axes.push_back(a);
      }
      
      auto out = mlx::core::roll(x, shifts, axes, stream_);
      set_tensor(node.out(), out);
    }
  4. Add handler in backends/mlx/ops.py

    Since RollNode is a simple unary-style op (input tensor → output tensor), add it to the _UNARY_OPS table:

    # In the _UNARY_OPS list, add:
    (torch.ops.aten.roll.default, RollNode, "aten.roll"),

    Note: This only works if RollNode follows the standard unary signature (x, out). However, roll has additional parameters (shifts, axes), so you may need a custom handler instead:

    @REGISTRY.register(target=[torch.ops.aten.roll.default])
    def _roll_handler(P: MLXProgramBuilder, n: Node) -> Slot:
        args = P.args(n)
        require_args(args, 2, 3, "aten.roll")
        require_kwargs(P.kwargs(n), set(), "aten.roll")
        x = args[0]
        shifts = args[1]
        dims = args[2] if len(args) > 2 else []
        
        # Normalize shifts and dims to lists
        if isinstance(shifts, int):
            shifts = [shifts]
        if isinstance(dims, int):
            dims = [dims]
        if not dims:
            dims = list(range(len(n.args[0].meta["val"].shape)))
        
        out = P.make_or_get_slot(n)
        P.emit(
            RollNode(
                x=P.slot_to_tid(x),
                out=P.slot_to_tid(out),
                shift=[P.to_int_or_vid(s) for s in shifts],
                axes=list(dims),
            )
        )
        return out
  5. Add test in backends/mlx/test/test_ops.py

    This op doesn't fit the simple unary pattern, so create a custom test class:

    class RollModel(nn.Module):
        def __init__(self, shifts: int, dims: int):
            super().__init__()
            self.shifts = shifts
            self.dims = dims
        
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.roll(x, shifts=self.shifts, dims=self.dims)
    
    @register_test
    class RollTest(OpTestCase):
        name = "roll"
        
        def __init__(self, shape: Tuple[int, ...], shifts: int, dims: int):
            self.shape = shape
            self.shifts = shifts
            self.dims = dims
            self.name = f"roll_shift{shifts}_dim{dims}"
        
        @classmethod
        def get_test_configs(cls) -> List["RollTest"]:
            return [
                cls(shape=(8,), shifts=2, dims=0),
                cls(shape=(4, 5), shifts=1, dims=0),
                cls(shape=(4, 5), shifts=-2, dims=1),
                cls(shape=(3, 4, 5), shifts=3, dims=2),
            ]
        
        def create_model(self) -> nn.Module:
            return RollModel(self.shifts, self.dims)
        
        def create_inputs(self) -> Tuple[torch.Tensor, ...]:
            return (torch.randn(self.shape),)

Running tests

python -m executorch.backends.mlx.test.run_all_tests -k roll

References

  • MLX C++: array roll(const array &a, int shift, int axis, StreamOrDevice s = {})
  • PyTorch signature: roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor
  • Use case: Swin Transformer shift-window attention

Metadata

Metadata

Assignees

No one assigned

    Labels

    good first issueGood for newcomerstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions