From bcbf705c22ab755202601af1b1bb414b1b77e194 Mon Sep 17 00:00:00 2001 From: Ryan Monroe Date: Tue, 14 Apr 2026 15:05:01 -0700 Subject: [PATCH 1/2] Add FuseConcatPass to eliminate redundant concat ops (#18827) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Concat (torch.cat) in the Gen2 Executorch ARM/Ethos-U stack is lowered to TOSA CONCAT, which Vela then converts to N x MemoryCopy operations — real DMA data movement on the NPU. This pass eliminates concat operations that can be proven unnecessary at the FX graph level, preventing Vela from generating MemoryCopy ops entirely. Inspired by Espresso's concat elimination techniques (bolt/nn/espresso/transforms/remove_nops.py), three patterns are handled: 1. Single-input concat: cat([x]) is a no-op, replaced with x. 2. Concat-then-slice: if every consumer of cat([a, b, ...]) is a slice_copy that extracts exactly one original input, bypass both. 3. Slice-then-concat: if contiguous slices of the same tensor are concatenated back, the result is the original tensor. Differential Revision: D97667069 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/_passes/fuse_concat_pass.py | 341 ++++++++++++++++++ .../arm/test/passes/test_fuse_concat_pass.py | 193 ++++++++++ 4 files changed, 537 insertions(+) create mode 100644 backends/arm/_passes/fuse_concat_pass.py create mode 100644 backends/arm/test/passes/test_fuse_concat_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 1bd18de581d..b6689658f64 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -102,6 +102,7 @@ QuantizeClampArgumentsPass, ) from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa +from .fuse_concat_pass import FuseConcatPass # noqa from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa from .fuse_constant_ops_pass import ( # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 31cb7a2e2c7..1af04b71084 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -98,6 +98,7 @@ DecorateFp32toInt32CastingPass, FoldAndAnnotateQParamsPass, FuseBatchNorm2dPass, + FuseConcatPass, FuseConsecutiveConcatShapesPass, FuseConsecutiveRescalesPass, FuseConstantArgsPass, @@ -486,6 +487,7 @@ def _tosa_pipeline( # Aten -> TOSA transformation passes self.add_passes( [ + FuseConcatPass(), RewriteUpsamplePass(), RewriteConvPass(exported_program), RewriteMatmulPass(), diff --git a/backends/arm/_passes/fuse_concat_pass.py b/backends/arm/_passes/fuse_concat_pass.py new file mode 100644 index 00000000000..237cf85deff --- /dev/null +++ b/backends/arm/_passes/fuse_concat_pass.py @@ -0,0 +1,341 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Set, Type + +import torch.fx +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +logger = logging.getLogger(__name__) + + +def _int_arg(node: torch.fx.Node, index: int, default: int) -> int: + """Get an integer argument from a node, with a default if missing.""" + val = node.args[index] if len(node.args) > index else default + assert isinstance(val, int) + return val + + +def _slice_params(node: torch.fx.Node, dim_size: int) -> tuple[int, int, int, int]: + """Extract (dim, start, end, step) from a slice_copy node. + + ``dim`` is normalized to a positive index. ``end`` is clamped to + ``dim_size`` (the size of the source tensor along the slice dimension). + + """ + rank = len(get_first_fake_tensor(node).shape) + dim = _int_arg(node, 1, 0) + dim = (dim + rank) % rank + start = _int_arg(node, 2, 0) + end = min(_int_arg(node, 3, dim_size), dim_size) + step = _int_arg(node, 4, 1) + return dim, start, end, step + + +_SLICE_OP = exir_ops.edge.aten.slice_copy.Tensor + + +def _is_valid_slice(node: torch.fx.Node, cat_dim: int, dim_size: int) -> bool: + """Check that node is a slice_copy on cat_dim with step=1.""" + if node.target != _SLICE_OP: + return False + s_dim, _, _, s_step = _slice_params(node, dim_size) + return s_dim == cat_dim and s_step == 1 + + +def _find_slice_replacement( + slice_op: torch.fx.Node, + cat_node: torch.fx.Node, + cat_dim: int, + s_start: int, + s_end: int, + offsets: list[tuple[int, int, torch.fx.Node]], +) -> torch.fx.Node | None: + """Find a replacement for a slice that consumes a cat output. + + ``offsets`` maps each concat input to its range in the concatenated + output: [(start, end, input_node), ...] along ``cat_dim``. + + Returns the replacement node (exact input match or adjusted sub-slice), + or None if the slice crosses input boundaries. + + """ + for o_start, o_end, inp in offsets: + if s_start == o_start and s_end == o_end: + return inp + if s_start >= o_start and s_end <= o_end: + graph = cat_node.graph + with graph.inserting_before(slice_op): + new_slice = graph.call_function( + _SLICE_OP, + (inp, cat_dim, s_start - o_start, s_end - o_start), + ) + new_slice.meta = slice_op.meta.copy() + return new_slice + return None + + +def _find_common_slice_source( + cat_inputs: list | tuple, + cat_dim: int, + dim_size: int, +) -> torch.fx.Node | None: + """Check all inputs are valid slices of the same source. + + Returns the source. + + """ + source_node = None + for inp in cat_inputs: + if not isinstance(inp, torch.fx.Node): + return None + if not _is_valid_slice(inp, cat_dim, dim_size): + return None + slice_source = inp.args[0] + if source_node is None: + source_node = slice_source + elif slice_source is not source_node: + return None + assert isinstance(source_node, torch.fx.Node) + return source_node + + +def _check_contiguous_slices( + cat_inputs: list | tuple, + source_dim_size: int, +) -> tuple[int, int] | None: + """Check slices are contiguous. + + Returns (first_start, last_end) or None. + + """ + _, first_start, _, _ = _slice_params(cat_inputs[0], source_dim_size) + expected_start = first_start + for inp in cat_inputs: + _, s_start, s_end, _ = _slice_params(inp, source_dim_size) + if s_start != expected_start: + return None + expected_start = s_end + + # expected_start is now the end of the last slice + return first_start, expected_start + + +class FuseConcatPass(ArmPass): + """Eliminate redundant concat (cat) operations via graph pattern matching. + + Inspired by Espresso's concat elimination techniques + (bolt/nn/espresso/transforms/remove_nops.py), this pass recognizes and + removes concat operations that can be proven to produce no useful data + movement. Eliminating these at the FX/TOSA level prevents Vela from + generating MemoryCopy operations on the Ethos-U NPU. + + Five patterns are handled: + + 1. Single-input concat: cat([x], dim) is a no-op; replace with x. + 2. Concat-then-slice (exact): if a consumer of cat([a, b, ...], dim) is + a slice_copy that extracts exactly one original input, replace it + with the corresponding concat input directly. + 3. Slice-then-concat (full): if cat([slice(x, d, s0, e0), + slice(x, d, s1, e1), ...], dim) reconstructs x exactly (contiguous + slices covering the full source dimension), replace with x. + 4. Concat-then-sub-slice: if a consumer of cat([a, b, ...], dim) is a + slice_copy whose range falls entirely within one original input, + replace it with an adjusted slice on that input directly. + 5. Slice-then-concat (partial): if contiguous slices of the same tensor + are concatenated but cover only a sub-range of the source dimension, + replace with a single slice on the source. + + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + cat_ops = { + exir_ops.edge.aten.cat.default, + } + slice_op = _SLICE_OP + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + graph = graph_module.graph + + for node in list(graph.nodes): + if node.op != "call_function" or node.target not in self.cat_ops: + continue + if node.graph is None: + continue + + if self._eliminate_single_input_cat(node): + modified = True + continue + + if self._eliminate_cat_then_slice(node): + modified = True + continue + + if self._eliminate_slice_then_cat(node): + modified = True + continue + + if modified: + graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) + + # ------------------------------------------------------------------ + # Pattern 1: single-input cat + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_single_input_cat(cat_node: torch.fx.Node) -> bool: + inputs = cat_node.args[0] + if not isinstance(inputs, (list, tuple)) or len(inputs) != 1: + return False + sole_input = inputs[0] + assert isinstance(sole_input, torch.fx.Node) + cat_node.replace_all_uses_with(sole_input) + logger.debug("Eliminated single-input cat: %s", cat_node.name) + return True + + # ------------------------------------------------------------------ + # Patterns 2 & 4: cat -> slice (exact input or sub-range of input) + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_cat_then_slice( + cat_node: torch.fx.Node, + ) -> bool: + cat_inputs = cat_node.args[0] + if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2: + return False + + # if the dim does not exist as an arg, it defaults to '0' + cat_dim = _int_arg(cat_node, 1, 0) + output_rank = len(get_first_fake_tensor(cat_node).shape) + cat_dim = (cat_dim + output_rank) % output_rank + + users = list(cat_node.users.keys()) + if not users: + return False + + # Build the offset map for each concat input along cat_dim. + offsets = [] + offset = 0 + for inp in cat_inputs: + assert isinstance(inp, torch.fx.Node) + inp_shape = get_first_fake_tensor(inp).shape + size = inp_shape[cat_dim] + offsets.append((offset, offset + size, inp)) + offset += size + + # Every user must be a slice_copy on the same dim with step=1. + # Collect validated (node, start, end) for replacement below. + validated_slices: list[tuple[torch.fx.Node, int, int]] = [] + for slice_op in users: + if not _is_valid_slice(slice_op, cat_dim, offset): + return False + if slice_op.args[0] is not cat_node: + return False + _, s_start, s_end, _ = _slice_params(slice_op, offset) + validated_slices.append((slice_op, s_start, s_end)) + + # For each user, try exact match (Pattern 2) then sub-range (Pattern 4). + # Users that cross input boundaries are skipped. + replacements: list[tuple[torch.fx.Node, torch.fx.Node]] = [] + + for slice_op, s_start, s_end in validated_slices: + replacement = _find_slice_replacement( + slice_op, cat_node, cat_dim, s_start, s_end, offsets + ) + if replacement is not None: + replacements.append((slice_op, replacement)) + + if not replacements: + return False + + for old_node, new_node in replacements: + old_node.replace_all_uses_with(new_node) + + logger.debug( + "Eliminated cat-then-slice pattern: %s (%d slices redirected)", + cat_node.name, + len(replacements), + ) + return True + + # ------------------------------------------------------------------ + # Patterns 3 & 5: slice -> cat (contiguous slices, full or partial) + # ------------------------------------------------------------------ + @staticmethod + def _eliminate_slice_then_cat( + cat_node: torch.fx.Node, + ) -> bool: + cat_inputs = cat_node.args[0] + if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2: + return False + + cat_dim = _int_arg(cat_node, 1, 0) + output_rank = len(get_first_fake_tensor(cat_node).shape) + cat_dim = (cat_dim + output_rank) % output_rank + + # All inputs must be slice_copy on the same source tensor and dim, + # with step=1. + source_node = _find_common_slice_source(cat_inputs, cat_dim, output_rank) + if source_node is None: + return False + + source_shape = get_first_fake_tensor(source_node).shape + source_dim_size = source_shape[cat_dim] + + # Verify slices are contiguous (but not necessarily starting at 0). + bounds = _check_contiguous_slices(cat_inputs, source_dim_size) + if bounds is None: + return False + first_start, last_end = bounds + + # Verify output shape matches expectations. + cat_shape = get_first_fake_tensor(cat_node).shape + + if first_start == 0 and last_end == source_dim_size: + # Pattern 3: full coverage — replace with source tensor. + if list(cat_shape) != list(source_shape): + return False + cat_node.replace_all_uses_with(source_node) + logger.debug( + "Eliminated slice-then-cat (full): %s -> %s", + cat_node.name, + source_node.name, + ) + else: + # Pattern 5: partial coverage — replace with single slice. + expected_dim_size = last_end - first_start + if cat_shape[cat_dim] != expected_dim_size: + return False + for i, (cs, ss) in enumerate(zip(cat_shape, source_shape)): + if i != cat_dim and cs != ss: # dims must match except for cat_dim + return False + graph = cat_node.graph + with graph.inserting_before(cat_node): + new_slice = graph.call_function( + _SLICE_OP, + (source_node, cat_dim, first_start, last_end), + ) + new_slice.meta = cat_node.meta.copy() + cat_node.replace_all_uses_with(new_slice) + logger.debug( + "Eliminated slice-then-cat (partial): %s -> slice(%s, %d, %d:%d)", + cat_node.name, + source_node.name, + cat_dim, + first_start, + last_end, + ) + return True diff --git a/backends/arm/test/passes/test_fuse_concat_pass.py b/backends/arm/test/passes/test_fuse_concat_pass.py new file mode 100644 index 00000000000..1f8f77c17a5 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_concat_pass.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.fuse_concat_pass import FuseConcatPass +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + + +cat_op = "executorch_exir_dialects_edge__ops_aten_cat_default" +slice_op = "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor" + + +class SingleInputCat(torch.nn.Module): + """Pattern 1: cat with a single input is a no-op.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.cat([x], dim=0) + + data = (torch.randn(2, 3, 4),) + ops_before_pass = {cat_op: 1} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op] + + +class CatThenSlice(torch.nn.Module): + """Pattern 2: cat followed by slices that extract exactly the inputs.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) + # Extract exactly a and b back out + part_a = combined[:, :3, :] + part_b = combined[:, 3:, :] + return part_a + 1, part_b + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op, slice_op] + + +class SliceThenCat(torch.nn.Module): + """Pattern 3: contiguous slices of the same tensor concatenated back.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + b = x[:, 3:, :] + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 8, 4),) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass: dict = {} + ops_not_after_pass = [cat_op, slice_op] + + +class CatNotEliminated(torch.nn.Module): + """Negative test: cat of different tensors should NOT be eliminated.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class SliceThenCatPartial(torch.nn.Module): + """Negative test: non-contiguous slices should NOT be eliminated.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, :3, :] + b = x[:, 4:, :] # Gap at index 3 + return torch.cat([a, b], dim=1) + + data = (torch.randn(1, 8, 4),) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSliceMismatch(torch.nn.Module): + """Negative test: slices that don't match original inputs.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) + return combined[:, 1:5, :] # Crosses the boundary + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSliceWithStep(torch.nn.Module): + """Negative test: slices with step != 1 should NOT be eliminated.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, ...]: + combined = torch.cat([a, b], dim=1) + part_a = combined[:, :3:2, :] # step=2, output shape differs from a + part_b = combined[:, 3::1, :] + return part_a + 1, part_b + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 5, 4)) + ops_before_pass = {cat_op: 1} + ops_after_pass = {cat_op: 1} + + +class CatThenSubSlice(torch.nn.Module): + """Pattern 4: slice extracts a sub-range within one concat input.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) # a dim1=6, b dim1=4 + # Range [1,5) falls entirely within a's range [0,6) + return combined[:, 1:5, :] + 1 + + data = (torch.randn(1, 6, 4), torch.randn(1, 4, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class CatThenSubSliceSecondInput(torch.nn.Module): + """Pattern 4: sub-slice within second concat input (tests offset adjust).""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + combined = torch.cat([a, b], dim=1) # a dim1=3, b dim1=8 + # Range [5,9) falls within b's range [3,11), adjusted to [2,6) on b + return combined[:, 5:9, :] + 1 + + data = (torch.randn(1, 3, 4), torch.randn(1, 8, 4)) + ops_before_pass = {cat_op: 1, slice_op: 1} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +class SliceThenCatPartialContiguous(torch.nn.Module): + """Pattern 5: contiguous slices covering a sub-range of the dimension.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = x[:, 2:5, :] + b = x[:, 5:8, :] + return torch.cat([a, b], dim=1) # Equivalent to x[:, 2:8, :] + + data = (torch.randn(1, 10, 4),) + ops_before_pass = {cat_op: 1, slice_op: 2} + ops_after_pass = {slice_op: 1} + ops_not_after_pass = [cat_op] + + +positive_tests = { + "single_input_cat": SingleInputCat(), + "cat_then_slice": CatThenSlice(), + "slice_then_cat": SliceThenCat(), + "cat_then_sub_slice": CatThenSubSlice(), + "cat_then_sub_slice_second_input": CatThenSubSliceSecondInput(), + "slice_then_cat_partial_contiguous": SliceThenCatPartialContiguous(), +} + +negative_tests = { + "cat_not_eliminated": CatNotEliminated(), + "slice_then_cat_partial": SliceThenCatPartial(), + "cat_then_slice_mismatch": CatThenSliceMismatch(), + "cat_then_slice_with_step": CatThenSliceWithStep(), +} + + +@common.parametrize("model", positive_tests) +def test_fuse_concat_eliminates(model): + pipeline = PassPipeline( + model, + model.data, + quantize=False, + ops_before_pass=model.ops_before_pass, + ops_after_pass=model.ops_after_pass, + ops_not_after_pass=getattr(model, "ops_not_after_pass", []), + pass_list=[FuseConcatPass], + ) + pipeline.run() + + +@common.parametrize("model", negative_tests) +def test_fuse_concat_preserves(model): + pipeline = PassPipeline( + model, + model.data, + quantize=False, + ops_before_pass=model.ops_before_pass, + ops_after_pass=model.ops_after_pass, + pass_list=[FuseConcatPass], + ) + pipeline.run() From 7eae5dec48bd65f3faa8fe4648b5322eb9854fe1 Mon Sep 17 00:00:00 2001 From: Ryan Monroe Date: Tue, 14 Apr 2026 15:05:01 -0700 Subject: [PATCH 2/2] jk don't apply passes Differential Revision: D100874753 --- backends/arm/_passes/arm_pass_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 1af04b71084..fb3638e3989 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -487,7 +487,7 @@ def _tosa_pipeline( # Aten -> TOSA transformation passes self.add_passes( [ - FuseConcatPass(), + # FuseConcatPass(), RewriteUpsamplePass(), RewriteConvPass(exported_program), RewriteMatmulPass(),