Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 298 additions & 0 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -19,6 +19,7 @@
from torch._ops import OpOverloadPacket
from torch.fx import Node
from torch.fx.node import Argument
from typing import Hashable

T = TypeVar("T")

Expand Down Expand Up @@ -356,3 +357,300 @@
continue
changed |= self.maybe_remove_or_replace(node)
return changed

class SwapOnCostModelPassInterface(HierarchicalInplacePassInterface):
"""
A base class for passes that reduce op count by moving wrapper operations
(e.g., dequant/quant, permute) from the majority side of a target op's
inputs/outputs to the minority side.

Given a target op with some inputs wrapped by ``input_to_swap`` and some
outputs wrapped by ``output_to_swap``, this pass checks whether it is
cheaper to instead wrap the *other* inputs/outputs. The swap is performed
only when:

1. All input wrappers share the same hash (via ``input_hash``).
2. All output wrappers share the same hash (via ``output_hash``).
3. The input and output hashes are compatible (via ``hashes_are_compatible``).
4. The number of wrappers after swapping is strictly less than before.

Subclasses must implement:
- ``targets``: the ops to scan (e.g., cat, add, mul).
- ``input_to_swap`` / ``output_to_swap``: the wrapper op targets.
- ``input_hash`` / ``output_hash``: extract a hashable identity from a
wrapper node for equality checking.
- ``hashes_are_compatible``: define the relationship between input and
output hashes (equality for quant/dequant, inverse for permute).

Example 1 — Dequant/quant around cat::

# Before: 3 dequants on inputs, 1 fp32 input, 2 quant outputs, 1 fp32 output.
# dequant(A, s, zp) ──┐
# dequant(B, s, zp) ──┤ ┌── quant(out1, s, zp)
# dequant(C, s, zp) ──┼── cat(fp32) ─┼── quant(out2, s, zp)
# D(fp32) ┘ └── consumer(fp32)
# Cost: 5 wrapper ops (3 dequants + 2 quants)
#
# After: A, B, C feed cat directly (already quantized). D gets a quant.
# The cat now runs in quantized domain. The fp32 consumer gets a dequant.
# A ──┐
# B ──┤ ┌── out1 (quantized, no wrapper needed)
# C ──┼── cat(q) ──┼── out2 (quantized, no wrapper needed)
# quant(D, s, zp) ─┘ └── dequant(out3, s, zp) ── consumer
# Cost: 2 wrapper ops (1 quant + 1 dequant)
#
# Delta: 5 → 2, net savings of 3 ops.
#
# input_to_swap = dequantize_per_tensor
# output_to_swap = quantize_per_tensor
# input_hash = (scale, zero_point, quant_min, quant_max, dtype)
# output_hash = (scale, zero_point, quant_min, quant_max, dtype)
# hashes_are_compatible: input_hash == output_hash

Example 2 — Permutes around a binary op::

# Before: permute(A) and B feed add, output is inverse-permuted.
# permute(A, [0,3,1,2]) ──┐
# ├── add ── permute(out, [0,2,3,1])
# B ───┘
# Cost: 2 wrapper ops
#
# After: A and inverse-permute(B) feed add, no output permute.
# A ──────────────────┐
# ├── add ── out
# permute(B, [0,2,3,1]) ─────────┘
# Cost: 1 wrapper op
#
# input_to_swap = permute_copy
# output_to_swap = permute_copy
# input_hash = tuple(dims)
# output_hash = tuple(dims)
# hashes_are_compatible: applying input_perm then output_perm = identity

Example 3 — Chained elimination across multiple ops::

When targets share edges, a single pass run can cascade reductions.
The pass visits each target in graph order; wrappers injected by
earlier swaps become inputs/outputs for later targets.

# Before (5 permutes total):
# permute(A, dims) ──┐
# ├── add_1 ── permute(out1, inv_dims) ── consumer_1
# permute(B, dims) ──┤
# └── (fp32 edge to add_2)
# │
# permute(C, dims) ─────────────┐ │
# ├── add_2 ── permute(out2, inv_dims)
# (from add_1)┘
#
# After pass visits add_1 (3 wrappers → 1):
# Removes 2 input permutes + 1 output permute, injects 1 inverse
# permute on the edge flowing to add_2.
# A ──┐
# ├── add_1 ── consumer_1
# B ──┤
# └── permute(edge, inv_dims) ── add_2
# permute(C, dims) ─────────────────────────┘ │
# └── permute(out2, inv_dims)
#
# After pass visits add_2 (3 wrappers → 0):
# add_2 now has 2 permuted inputs + 1 permuted output = 3, 0 unwrapped.
# All eliminated.
# A ──┐
# ├── add_1 ── consumer_1
# B ──┤
# └── add_2 ── out2
# C ──┘
#
# Result: all 5 original permutes eliminated in a single pass run.
"""
@property
@abstractmethod
def targets(self) -> list[EdgeOpOverload]:
"""
The list of targets that we will potentially swap inputs and outputs.
"""
raise NotImplementedError("`targets` must be implemented")

@property
@abstractmethod
def input_to_swap(self) -> EdgeOpOverload:
"""
The wrapper op target to match on inputs (e.g., dequantize_per_tensor,
permute_copy). Inputs to the target op whose target matches this will
be candidates for removal during the swap.
"""
raise NotImplementedError("You must specify the input we are trying to swap")

@property
@abstractmethod
def output_to_swap(self) -> EdgeOpOverload:
"""
The wrapper op target to match on outputs (e.g., quantize_per_tensor,
permute_copy). Users of the target op whose target matches this will
be candidates for removal during the swap.
"""
...

@abstractmethod
def input_hash(self, node: Node) -> Hashable:
"""
Extract a hashable identity from an input wrapper node. All input
wrappers must produce the same hash for the swap to be valid.
E.g., for dequant: (scale, zero_point, quant_min, quant_max, dtype).
E.g., for permute: tuple(dims).
"""
...

@abstractmethod
def output_hash(self, node: Node) -> Hashable:
"""
Extract a hashable identity from an output wrapper node. All output
wrappers must produce the same hash for the swap to be valid.
"""
...

@abstractmethod
def hashes_are_compatible(self, input_hash: Hashable, output_hash: Hashable) -> bool:
"""
Check whether the input and output wrapper hashes are compatible,
meaning the swap is semantically legal.
For quant/dequant: hashes must be equal (same scale, zp, dtype).
For permute: output dims must be the inverse permutation of input dims.
"""
...

@abstractmethod
def create_inverse_wrapper_args(
self, template: Node,
) -> tuple[EdgeOpOverload, tuple, dict]:
"""
Given a wrapper node from one side, return (target, args_tail, kwargs)
for the inverse wrapper on the other side. args_tail excludes the data
input (first arg), which is supplied by the caller.

For quant/dequant: swap the op target, keep the same params.
For permute: same op target, compute inverse dims.
"""
...

def _apply_flat_inplace(self, graph_module: torch.fx.GraphModule) -> bool:

Check warning on line 537 in backends/cadence/aot/pass_utils.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 C901

'SwapOnCostModelPassInterface._apply_flat_inplace' is too complex (27) See https://www.flake8rules.com/rules/C901.html.
changed = False
for target in self.targets:
for node in graph_module.graph.find_nodes(
op="call_function", target=target
):
# Find all nodes that we can potentially swap (before_swap)
# and nodes we will insert wrappers around (after_swap)
valid_swap = True
input_hash_value: Optional[Hashable] = None
input_nodes_before_swap: list[Node] = []
input_nodes_after_swap: list[Node] = []
for input_node in node._input_nodes.keys():
if input_node.target == self.input_to_swap:
if len(input_node.users) != 1:
valid_swap = False
break
if input_hash_value is None:
input_hash_value = self.input_hash(input_node)
elif input_hash_value != self.input_hash(input_node):
# ex. Scale and zero point for one quant/dequant doesn't match another
valid_swap = False
break
input_nodes_before_swap.append(input_node)
else:
input_nodes_after_swap.append(input_node)

if not valid_swap:
continue

output_hash_value: Optional[Hashable] = None
output_nodes_before_swap: list[Node] = []
output_nodes_after_swap: list[Node] = []
# Same idea as above for outputs
for output_node in node.users:
if output_node.target == self.output_to_swap:
if len(output_node.users) != 1:
valid_swap = False
break

if output_hash_value is None:
output_hash_value = self.output_hash(output_node)
elif output_hash_value != self.output_hash(output_node):
valid_swap = False
break
output_nodes_before_swap.append(output_node)
else:
output_nodes_after_swap.append(output_node)

if not valid_swap:
continue

# Check cross-compatibility between input and output hashes
if input_hash_value is not None and output_hash_value is not None:
if not self.hashes_are_compatible(input_hash_value, output_hash_value):
continue

# We need at least one side to have wrappers to perform a swap.
if input_hash_value is None and output_hash_value is None:
continue

# If we got here, it means it is valid to perform a swap, but we will
# only perform the swap if we have fewer nodes after swapping than
# nodes before swapping.
if len(input_nodes_after_swap) + len(output_nodes_after_swap) >= len(input_nodes_before_swap) + len(output_nodes_before_swap):
continue

graph = graph_module.graph

# Inject wrappers on unwrapped inputs.
if input_nodes_after_swap:
if output_nodes_before_swap:
inv_target = output_nodes_before_swap[0].target
inv_args_tail = output_nodes_before_swap[0].args[1:]
inv_kwargs = dict(output_nodes_before_swap[0].kwargs)
inv_meta = output_nodes_before_swap[0].meta.copy()
else:
inv_target, inv_args_tail, inv_kwargs = self.create_inverse_wrapper_args(input_nodes_before_swap[0])
inv_meta = input_nodes_before_swap[0].meta.copy()
for unwrapped_input in input_nodes_after_swap:
with graph.inserting_before(node):
new_wrapper = graph.call_function(
inv_target,
args=(unwrapped_input,) + inv_args_tail,
kwargs=inv_kwargs,
)
new_wrapper.meta = inv_meta
node.replace_input_with(unwrapped_input, new_wrapper)

# Inject wrappers on unwrapped outputs.
if output_nodes_after_swap:
if input_nodes_before_swap:
inv_target = input_nodes_before_swap[0].target
inv_args_tail = input_nodes_before_swap[0].args[1:]
inv_kwargs = dict(input_nodes_before_swap[0].kwargs)
inv_meta = input_nodes_before_swap[0].meta.copy()
else:
inv_target, inv_args_tail, inv_kwargs = self.create_inverse_wrapper_args(output_nodes_before_swap[0])
inv_meta = output_nodes_before_swap[0].meta.copy()
for unwrapped_output in list(output_nodes_after_swap):
with graph.inserting_before(unwrapped_output):
new_wrapper = graph.call_function(
inv_target,
args=(node,) + inv_args_tail,
kwargs=inv_kwargs,
)
new_wrapper.meta = inv_meta
unwrapped_output.replace_input_with(node, new_wrapper)

# Bypass input wrappers (e.g., remove dequants).
for wrapper in input_nodes_before_swap:
node.replace_input_with(wrapper, wrapper.args[0])

# Bypass output wrappers (e.g., remove quants).
for wrapper in output_nodes_before_swap:
wrapper.replace_all_uses_with(node)

changed = True

return changed
58 changes: 57 additions & 1 deletion backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
Expand All @@ -15,7 +15,7 @@
import math
import operator
from operator import neg
from typing import cast, Dict, Optional, Sequence
from typing import cast, Dict, Hashable, Optional, Sequence

import torch
import torch.fx
Expand All @@ -26,6 +26,7 @@
get_arg,
register_cadence_pass,
RemoveOrReplacePassInterface,
SwapOnCostModelPassInterface,
)
from executorch.backends.cadence.aot.utils import is_depthwise_conv
from executorch.backends.transforms.replace_scalar_with_tensor import (
Expand Down Expand Up @@ -2597,6 +2598,60 @@
return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class SwapDequantQuantAroundDataMovementOps(SwapOnCostModelPassInterface):
"""
For data movement ops (cat, slice_copy) surrounded by dequant inputs and
quant outputs with matching quantization parameters, swap the wrappers to
the minority side to reduce total quant/dequant op count.
"""

@property
def targets(self) -> list[EdgeOpOverload]:
return [
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.slice_copy.Tensor,
]

@property
def input_to_swap(self) -> EdgeOpOverload:
return exir_ops.edge.cadence.dequantize_per_tensor.default

@property
def output_to_swap(self) -> EdgeOpOverload:
return exir_ops.edge.cadence.quantize_per_tensor.default

def _quant_key(self, node: torch.fx.Node) -> Hashable:
return (
get_arg(node, "scale", float),
get_arg(node, "zero_point", int),
get_arg(node, "quant_min", int),
get_arg(node, "quant_max", int),
get_arg(node, "dtype", torch.dtype),
)

def input_hash(self, node: torch.fx.Node) -> Hashable:
return self._quant_key(node)

def output_hash(self, node: torch.fx.Node) -> Hashable:
return self._quant_key(node)

def hashes_are_compatible(
self, input_hash: Hashable, output_hash: Hashable
) -> bool:
return input_hash == output_hash

def create_inverse_wrapper_args(
self, template: torch.fx.Node,
) -> tuple[EdgeOpOverload, tuple, dict]:
# dequant template → need a quant; quant template → need a dequant.
if template.target == self.input_to_swap:
target = self.output_to_swap
else:
target = self.input_to_swap
return (target, template.args[1:], dict(template.kwargs))


class CommonReplacePasses:
passes = [
ReplaceScalarWithTensorArgPass,
Expand Down Expand Up @@ -2666,4 +2721,5 @@
ReplaceAtenAvgPoolWithCadenceAvgPoolPass,
ReplaceWhereWithFullArgsWithWhereScalar,
ReplaceMulTensorWithMulAndFullOpsPass,
SwapDequantQuantAroundDataMovementOps,
]
Loading
Loading