From a737ea5235affd6be159e3d4bb9a9c8f5188d707 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 16 Apr 2026 12:04:24 -0700 Subject: [PATCH 1/2] EdgeProgramManager passes (#16986) Summary: - Adds support to run to run passes on ExportedPrograms and EdgeProgramManager - EdgeProgramManager transform behaves basically like a pass manager Reviewed By: larryliu0820, ethansfng Differential Revision: D91725222 --- exir/BUCK | 12 + exir/edge_program_manager_pass_base.py | 313 +++++++++++++++++++ exir/program/BUCK | 2 +- exir/program/_program.py | 126 ++++---- exir/tests/TARGETS | 1 + exir/tests/test_pass_infra.py | 406 ++++++++++++++++++++++++- 6 files changed, 802 insertions(+), 58 deletions(-) create mode 100644 exir/edge_program_manager_pass_base.py diff --git a/exir/BUCK b/exir/BUCK index f00b3f1c787..9c9064390ea 100644 --- a/exir/BUCK +++ b/exir/BUCK @@ -259,6 +259,17 @@ fbcode_target(_kind = runtime.python_library, ], ) +fbcode_target(_kind = runtime.python_library, + name = "edge_program_manager_pass_base", + srcs = [ + "edge_program_manager_pass_base.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], +) + fbcode_target(_kind = runtime.python_library, name = "pass_manager", srcs = [ @@ -267,6 +278,7 @@ fbcode_target(_kind = runtime.python_library, deps = [ "fbsource//third-party/pypi/typing-extensions:typing-extensions", ":error", + ":pass_base", "//caffe2:torch", ], ) diff --git a/exir/edge_program_manager_pass_base.py b/exir/edge_program_manager_pass_base.py new file mode 100644 index 00000000000..af5d8942222 --- /dev/null +++ b/exir/edge_program_manager_pass_base.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import copy +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Sequence, TYPE_CHECKING, Union + +import torch +from torch.export import ExportedProgram +from torch.fx.passes.infra.pass_base import PassResult + +if TYPE_CHECKING: + from executorch.exir.program._program import EdgeProgramManager + + +@dataclass(frozen=True) +class ExportedProgramPassResult: + """Result of running a pass on an ExportedProgram.""" + + exported_program: ExportedProgram + modified: bool + + +class ExportedProgramPassBase(ABC): + """ + Base interface for implementing passes that operate on ExportedProgram. + """ + + def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult: + """ + Runs the precondition check, the pass itself, and the postcondition check. + """ + + self.requires(exported_program) + res = self.call(exported_program) + self.ensures(exported_program) + return res + + @abstractmethod + def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult: + """ + The pass that is run through the given exported program. To implement a + pass, it is required to implement this function. + + Args: + exported_program: The exported program we will run a pass on + """ + + def requires(self, exported_program: ExportedProgram) -> None: # noqa: B027 + """ + This function will be called before the pass is run and will check that + the given exported program contains the preconditions needed to run the + pass. It is not required to implement this function. + + Args: + exported_program: The exported program we will run checks on + """ + + def ensures(self, exported_program: ExportedProgram) -> None: # noqa: B027 + """ + This function will be called after the pass is run and will check that + the given exported program contains the postconditions needed to run the + pass. It is not required to implement this function. + + Args: + exported_program: The exported program we will run checks on + """ + + +@dataclass(frozen=True) +class EdgeProgramManagerPassResult: + """Result of running a pass on an EdgeProgramManager.""" + + edge_program_manager: "EdgeProgramManager" + modified: bool + + +class EdgeProgramManagerPassBase(ABC): + """ + Base interface for implementing passes that operate on EdgeProgramManager. + + This is the highest-level pass abstraction. Passes at this level can: + - Transform individual ExportedPrograms within the manager + - Modify constant methods + - Split one program into multiple programs + - Add or remove programs from the manager + + Lower-level passes (ExportedProgramPassBase, GraphModule callables) can be + lifted to this level using the provided wrapper classes. + """ + + def __call__( + self, epm: "EdgeProgramManager" + ) -> EdgeProgramManagerPassResult: + """ + Runs the precondition check, the pass itself, and the postcondition check. + """ + self.requires(epm) + res = self.call(epm) + self.ensures(res.edge_program_manager) + return res + + @abstractmethod + def call( + self, epm: "EdgeProgramManager" + ) -> EdgeProgramManagerPassResult: + """ + The pass that is run on the given EdgeProgramManager. To implement a + pass, it is required to implement this function. + + Args: + epm: The EdgeProgramManager to transform + """ + + def requires(self, epm: "EdgeProgramManager") -> None: # noqa: B027 + """ + This function will be called before the pass is run and will check that + the given EdgeProgramManager contains the preconditions needed to run the + pass. It is not required to implement this function. + + Args: + epm: The EdgeProgramManager we will run checks on + """ + + def ensures(self, epm: "EdgeProgramManager") -> None: # noqa: B027 + """ + This function will be called after the pass is run and will check that + the given EdgeProgramManager contains the postconditions needed to run the + pass. It is not required to implement this function. + + Args: + epm: The EdgeProgramManager we will run checks on + """ + + +class GraphModuleBackedExportedProgramPassWrapper(ExportedProgramPassBase): + """ + Wrapper that adapts a GraphModule pass to work as an ExportedProgramPassBase. + + This wrapper takes a pass that operates on GraphModule and makes it compatible + with ExportedProgramPassBase by extracting the graph module, running the pass, + and updating the ExportedProgram in-place. + """ + + def __init__( + self, + graph_module_pass: Callable[[torch.fx.GraphModule], PassResult], + ) -> None: + super().__init__() + self._pass = graph_module_pass + + def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult: + from executorch.exir.program._program import ( + _get_updated_graph_signature, + _get_updated_range_constraints, + ) + + result = self._pass(exported_program.graph_module) + + if result.modified: + # Cannot use _update_exported_program_graph_module because it + # runs verification, and it is not the responsibility of the + # pass to run verification. EdgeProgram manager can + # optionally run verification after a pass. + result.graph_module.recompile() + exported_program = copy.copy(exported_program) # bypasses __init__ and _validate() + + exported_program._graph_module = result.graph_module + exported_program._graph_signature = _get_updated_graph_signature( + exported_program.graph_signature, result.graph_module + ) + exported_program._range_constraints = _get_updated_range_constraints( + result.graph_module + ) + exported_program._module_call_graph = copy.deepcopy( + exported_program._module_call_graph + ) + exported_program._graph_module.meta.update(exported_program.graph_module.meta) + + + return ExportedProgramPassResult(exported_program, result.modified) + + +class ExportedProgramToEdgeProgramManagerPassWrapper(EdgeProgramManagerPassBase): + """ + Adapts an ExportedProgramPassBase to run on every method in an EdgeProgramManager. + + This wrapper takes a pass that operates on a single ExportedProgram and applies it + to every method in the EdgeProgramManager, collecting results into a new EPM. + This is where the iteration over methods lives -- not in the pass manager, and not + in EdgeProgramManager.transform(). + """ + + def __init__(self, ep_pass: ExportedProgramPassBase) -> None: + super().__init__() + self._pass = ep_pass + + def call( + self, epm: "EdgeProgramManager" + ) -> EdgeProgramManagerPassResult: + new_epm = copy.copy(epm) + new_epm._edge_programs = dict(epm._edge_programs) + + overall_modified = False + for name, program in epm._edge_programs.items(): + result = self._pass(program) + new_epm._edge_programs[name] = result.exported_program + overall_modified = overall_modified or result.modified + + new_epm._config_methods = epm._config_methods + return EdgeProgramManagerPassResult(new_epm, overall_modified) + + +PassType = Union[ + EdgeProgramManagerPassBase, + ExportedProgramPassBase, + Callable[[torch.fx.GraphModule], Optional[PassResult]], +] + +# Passes that operate on a single method (ExportedProgram or GraphModule level). +# Excludes EdgeProgramManagerPassBase, which operates on the whole EdgeProgramManager. +# Use this for per-method pass specifications (e.g. Dict[str, Sequence[MethodPassType]]). +MethodPassType = Union[ + ExportedProgramPassBase, + Callable[[torch.fx.GraphModule], Optional[PassResult]], +] + + +def _get_pass_name(fn: PassType) -> str: + """Unwraps wrapper chain to get the underlying pass name.""" + import inspect + + if isinstance(fn, ExportedProgramToEdgeProgramManagerPassWrapper): + return _get_pass_name(fn._pass) + if isinstance(fn, GraphModuleBackedExportedProgramPassWrapper): + return _get_pass_name(fn._pass) + return fn.__name__ if inspect.isfunction(fn) else type(fn).__name__ + + +def wrap_passes( + passes: Sequence[PassType], +) -> list[EdgeProgramManagerPassBase]: + """ + Wraps a list of mixed-level passes up to the EdgeProgramManager level. + + Accepts passes at three levels: + - EdgeProgramManagerPassBase: used as-is + - ExportedProgramPassBase: wrapped with ExportedProgramToEdgeProgramManagerPassWrapper + - GraphModule callables: wrapped with GraphModuleBackedExportedProgramPassWrapper + then ExportedProgramToEdgeProgramManagerPassWrapper + + Args: + passes: A sequence of passes at any level. + + Returns: + A list of EdgeProgramManagerPassBase passes. + """ + from torch.fx.passes.infra.pass_manager import pass_result_wrapper + + wrapped: list[EdgeProgramManagerPassBase] = [] + for fn in passes: + if isinstance(fn, EdgeProgramManagerPassBase): + wrapped.append(fn) + elif isinstance(fn, ExportedProgramPassBase): + wrapped.append( + ExportedProgramToEdgeProgramManagerPassWrapper(fn) + ) + else: + assert callable(fn) + ep_pass = GraphModuleBackedExportedProgramPassWrapper( + pass_result_wrapper(fn) + ) + wrapped.append( + ExportedProgramToEdgeProgramManagerPassWrapper(ep_pass) + ) + return wrapped + + +class MethodFilteredEdgeProgramManagerPass(EdgeProgramManagerPassBase): + """ + Applies different passes to different methods in an EdgeProgramManager. + + Converts the Dict[str, Sequence[MethodPassType]] pattern (previously handled inline + in EdgeProgramManager.transform) into a proper pass. Used by + to_edge_transform_and_lower to handle the dict case. + """ + + def __init__(self, passes_dict: Dict[str, Sequence[MethodPassType]]) -> None: + super().__init__() + self._passes_dict = passes_dict + + def call( + self, epm: "EdgeProgramManager" + ) -> EdgeProgramManagerPassResult: + from executorch.exir.program._program import _transform + + new_epm = copy.copy(epm) + new_epm._edge_programs = dict(epm._edge_programs) + + overall_modified = False + for name, program in epm._edge_programs.items(): + if name in self._passes_dict: + new_program = _transform(program, *self._passes_dict[name]) + new_epm._edge_programs[name] = new_program + overall_modified = True + + return EdgeProgramManagerPassResult(new_epm, overall_modified) diff --git a/exir/program/BUCK b/exir/program/BUCK index 7d9642efdb7..19e778f1274 100644 --- a/exir/program/BUCK +++ b/exir/program/BUCK @@ -1,6 +1,5 @@ load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") -load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") oncall("executorch") @@ -48,6 +47,7 @@ fbcode_target(_kind = runtime.python_library, "//executorch/exir/passes:spec_prop_pass", "//executorch/exir/passes:weights_to_outputs_pass", "//executorch/exir/passes:convert_constant_dim_order_pass", + "//executorch/exir:edge_program_manager_pass_base", "//executorch/exir/verification:verifier", "//executorch/extension/flat_tensor/serialize:serialize", ] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else []) diff --git a/exir/program/_program.py b/exir/program/_program.py index c68d0eed945..f50b80058bf 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -5,8 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe - +# pyre-strict import copy import io import logging @@ -38,7 +37,14 @@ from executorch.exir.operator.convert import _pybind_schema_to_native_schema from executorch.exir.operator.util import _QUANT_PRIMITIVES from executorch.exir.pass_base import PassBase -from executorch.exir.pass_manager import PassType +from executorch.exir.edge_program_manager_pass_base import ( + _get_pass_name, + MethodFilteredEdgeProgramManagerPass, + MethodPassType, + PassType, + wrap_passes, +) + from executorch.exir.passes import ( base_post_op_replace_passes, base_pre_op_replace_passes, @@ -98,6 +104,7 @@ ) from torch.fx import _pytree as fx_pytree from torch.fx._compatibility import compatibility + from torch.fx.passes.infra.pass_manager import PassManager from torch.utils import _pytree as pytree @@ -248,7 +255,7 @@ def _transform( def _transform_with_pass_manager( - self, + self: ExportedProgram, pass_manager: PassManager, override_verifiers: None | list[Type[Verifier]] = None, ) -> "ExportedProgram": @@ -276,6 +283,7 @@ def _transform_with_pass_manager( ) + def _update_exported_program_graph_module( exported_program: ExportedProgram, gm: torch.fx.GraphModule, @@ -1314,7 +1322,7 @@ def collect_named_data_store_outputs( def to_edge_transform_and_lower( # noqa: C901 programs: Union[ExportedProgram, Dict[str, ExportedProgram]], transform_passes: Optional[ - Union[Sequence[PassType], Dict[str, Sequence[PassType]], PassManager] + Union[Sequence[PassType], Dict[str, Sequence[MethodPassType]], PassManager] ] = None, partitioner: Optional[ Union[List[Partitioner], Dict[str, List[Partitioner]]] @@ -1594,8 +1602,13 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram: @et_logger("transform") def transform( self, - passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]], PassManager], + passes: Union[ + Sequence[PassType], + PassManager, + Dict[str, Sequence[MethodPassType]] + ], compile_config: Optional[EdgeCompileConfig] = None, + run_checks_after_each_pass: bool = False, ) -> "EdgeProgramManager": """ Transforms the program according to the provided passes. @@ -1605,66 +1618,77 @@ def transform( 1) a list of passes - all methods in the given EdgeProgramManager will be transformed with the provided passes. - 2) a dictionary mapping method names to lists of passes - - only method names specified in the dictionary will be - transformed with their corresponding passes. - 3) a PassManager instance - - all methods in the given EdgeProgramManager will be - transformed with the given PassManager instance. - compile_config: Compile config to use for veriy the correctness of model + Passes can be EdgeProgramManagerPassBase, + ExportedProgramPassBase, or GraphModule callables. + 2) a PassManager instance (deprecated) - + for backwards compatibility. + compile_config: Compile config to use for verifying the correctness of model graph after each pass. If not specified, the compile config of the - calling EdgeProgramManager will be used. It will be used in as compile - config of returned EdgeProgramManager. + calling EdgeProgramManager will be used. It will be used as compile + config of the returned EdgeProgramManager. + run_checks_after_each_pass: If True, run validation checks after each + pass is applied. Returns: EdgeProgramManager: A copy of the calling EdgeProgramManager with the transformations applied. """ - compile_config = compile_config or self.compile_config - new_programs: Dict[str, ExportedProgram] = {} - - # Cast passes parameter upfront. - passes_seq: Optional[Sequence[PassType]] = None - passes_dict: Optional[Dict[str, Sequence[PassType]]] = None - pass_manager: Optional[PassManager] = None - if isinstance(passes, Sequence): - passes_seq = passes - if isinstance(passes, dict): - passes_dict = passes if isinstance(passes, PassManager): - pass_manager = passes - - for name, program in self._edge_programs.items(): - # If the method name is enforced, but not matched, we skip transformation. - if ( - isinstance(passes, dict) - and passes_dict - and name not in passes_dict.keys() - ): - new_programs[name] = copy.deepcopy(program) - continue + # For backwards compatibility, extract the passes from the + # deprecated PassManager and wrap them. + wrapped_passes = wrap_passes([passes]) + elif isinstance(passes, dict): + wrapped_passes = [MethodFilteredEdgeProgramManagerPass(passes)] + else: + wrapped_passes = wrap_passes(list(passes)) - # Depending on the passes parameter, call the corresponding transform function. - if passes_seq is not None: - new_programs[name] = _transform(program, *passes_seq) - elif passes_dict is not None: - new_programs[name] = _transform(program, *passes_dict[name]) - elif pass_manager is not None: - new_programs[name] = _transform_with_pass_manager(program, pass_manager) + epm = self + for i, fn in enumerate(wrapped_passes): + try: + result = fn(epm) + epm = result.edge_program_manager + + if run_checks_after_each_pass: + self._check_edge_programs(epm) + + except Exception as e: + prev_names = [_get_pass_name(p) for p in wrapped_passes[:i]] + msg = ( + f"An error occurred when running the \'{_get_pass_name(fn)}\' pass " + f"after the following passes: {prev_names}\n" + f"Original error: {e}" + ) + raise Exception(msg) from e - # Verify the correctness of model graph after each transformation. + for name, program in epm._edge_programs.items(): EXIREdgeDialectVerifier(edge_compile_config=compile_config)( - new_programs[name].graph_module + program.graph_module ) - epm = EdgeProgramManager( - new_programs, copy.deepcopy(self._config_methods), compile_config - ) + new_epm = epm + new_epm.compile_config = compile_config + new_epm._etrecord = self._etrecord + return new_epm - epm._etrecord = self._etrecord - return epm + def _check_edge_programs(self, epm: "EdgeProgramManager") -> None: + """ + Runs validation checks on each ExportedProgram in the EdgeProgramManager. + """ + from executorch.exir.error import ExportError, ExportErrorType + + for name, program in epm._edge_programs.items(): # noqa: B007 + module = program.graph_module + module.recompile() + module.graph.lint() + + for node in module.graph.nodes: + if node.op == "call_method": + raise ExportError( + ExportErrorType.NOT_SUPPORTED, + f"call_method `{node}` is not supported except for backend delegate.", + ) @et_logger("to_backend") def to_backend( diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 322f72c870a..cce984a32f7 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -238,6 +238,7 @@ python_unittest( "//executorch/exir:pass_manager", "//executorch/exir/passes:lib", "//executorch/exir/passes:pass_registry", + "//executorch/exir:edge_program_manager_pass_base", ], ) diff --git a/exir/tests/test_pass_infra.py b/exir/tests/test_pass_infra.py index c3788a5a38e..ad2071dffb3 100644 --- a/exir/tests/test_pass_infra.py +++ b/exir/tests/test_pass_infra.py @@ -8,13 +8,24 @@ import unittest +import executorch.exir as exir import torch -from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_manager import PassManager from executorch.exir.passes import ScalarToTensorPass from executorch.exir.passes.pass_registry import PassRegistry -from torch.export import export -from torch.fx.passes.infra.pass_base import PassBase +from executorch.exir.program import to_edge +from executorch.exir.edge_program_manager_pass_base import ( + EdgeProgramManagerPassBase, + ExportedProgramPassBase, + ExportedProgramPassResult, + EdgeProgramManagerPassResult, + ExportedProgramToEdgeProgramManagerPassWrapper, + MethodFilteredEdgeProgramManagerPass, +) +from torch.export import ExportedProgram, export +from torch.export.graph_signature import InputKind, InputSpec, TensorArgument +from torch.fx.passes.infra.pass_base import PassBase, PassResult class TestPassInfra(unittest.TestCase): @@ -90,9 +101,10 @@ def test_pass_manager(self) -> None: """ def replace_add_with_mul(gm: torch.fx.GraphModule) -> None: - for node in gm.graph.nodes: - if node.op == "call_function" and "aten.add.Tensor" in str(node.target): - node.target = torch.mul + for node in gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ): + node.target = torch.mul def replace_mul_with_div(gm: torch.fx.GraphModule) -> None: for node in gm.graph.nodes: @@ -178,3 +190,385 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: for node in new_gm.graph.nodes: if node.target != "output": self.assertIn("val", node.meta) + + +class TestExportedProgramPassManager(unittest.TestCase): + """Tests for EdgeProgramManager.transform() pass infrastructure. + + These tests validate that the pass manager correctly operates on EdgeProgramManagers, + preserving the original test objectives from when it operated on ExportedPrograms directly. + """ + + def test_raises_spec_violation_error(self) -> None: + """ + Ensures that transform() raises a SpecViolationError after running + a pass which places a non-Edge operator in the graph. + """ + + def replace_add_with_torch_aten_mul(gm: torch.fx.GraphModule) -> PassResult: + modified = False + for node in gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ): + node.target = torch.ops.aten.mul.Tensor + modified = True + return PassResult(gm, modified) + + def f(x: torch.Tensor) -> torch.Tensor: + y = torch.add(x, x) + z = torch.add(y, x) + return z + + epm = to_edge( + exir.capture(f, (torch.randn(10),), exir.CaptureConfig()) + .to_edge() + .exported_program + ) + + with self.assertRaisesRegex( + torch._export.verifier.SpecViolationError, + "Operator torch._ops.aten.mul.Tensor is not an Edge operator.", + ): + epm.transform([replace_add_with_torch_aten_mul]) + + def test_runs_graph_module_passes_on_exported_program(self) -> None: + """ + Tests that transform() runs GraphModule passes + on an EdgeProgramManager and the graph is correctly modified. + """ + + def replace_add_with_mul(gm: torch.fx.GraphModule) -> PassResult: + modified = False + for node in gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ): + node.target = exir_ops.edge.aten.mul.Tensor + modified = True + return PassResult(gm, modified) + + def f(x: torch.Tensor) -> torch.Tensor: + y = torch.add(x, x) + z = torch.add(y, x) + return z + + epm = to_edge( + exir.capture(f, (torch.randn(10),), exir.CaptureConfig()) + .to_edge() + .exported_program + ) + + result_epm = epm.transform([replace_add_with_mul]) + + # Check that all add ops were replaced with mul + add_nodes = result_epm.exported_program().graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ) + self.assertEqual(len(add_nodes), 0) + + def test_updates_constants_on_exported_program(self) -> None: + """ + Tests that transform() can update constants + in the ExportedProgram using an ExportedProgram-aware pass. + """ + + class DoubleConstantsPass(ExportedProgramPassBase): + """Pass that doubles all constant tensor values in the ExportedProgram.""" + + def call(self, ep: ExportedProgram) -> ExportedProgramPassResult: + modified = False + for key, const in ep.constants.items(): + if isinstance(const, torch.Tensor): + ep.constants[key] = const * 2 + modified = True + return ExportedProgramPassResult(ep, modified) + + class ModuleWithConstant(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.ones(3) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.weight + + module = ModuleWithConstant() + epm = to_edge( + torch.export.export(module, (torch.randn(3),)) + ) + exported_program = epm.exported_program() + + # Verify there are constants in the ExportedProgram + self.assertGreater( + len(exported_program.constants), 0, "Expected constants in ExportedProgram" + ) + + # Store original constant values + original_values = { + key: const.clone() + for key, const in exported_program.constants.items() + if isinstance(const, torch.Tensor) + } + + result_epm = epm.transform([DoubleConstantsPass()]) + + # Verify constants were doubled + new_ep = result_epm.exported_program() + for key, original_const in original_values.items(): + new_const = new_ep.constants[key] + self.assertTrue( + torch.allclose(new_const, original_const * 2), + f"Constant {key} was not doubled correctly", + ) + + def test_adds_constant_to_exported_program(self) -> None: + """ + Tests that transform() can add a new constant + to the ExportedProgram, including updating the graph and input specs. + """ + + class AddConstantPass(ExportedProgramPassBase): + """Pass that adds a new constant tensor to the ExportedProgram.""" + + def call(self, ep: ExportedProgram) -> ExportedProgramPassResult: + graph = ep.graph_module.graph + sig = ep.graph_signature + + # Find the first user input to insert before it + placeholders = graph.find_nodes(op="placeholder") + assert len(placeholders) == 1 + user_input_node = placeholders[0] + + # Create a new constant tensor + new_constant_name = "_test_added_constant" + new_constant_tensor = torch.tensor([1.0, 2.0, 3.0]) + + # Add placeholder node for the new constant + with graph.inserting_before(user_input_node): + new_placeholder = graph.placeholder(new_constant_name) + # Set up meta for the new placeholder + new_placeholder.meta["val"] = new_constant_tensor + + # Add the constant to the constants dict + ep.constants[new_constant_name] = new_constant_tensor + + # Update input specs to include the new constant + new_input_spec = InputSpec( + kind=InputKind.CONSTANT_TENSOR, + arg=TensorArgument(name=new_placeholder.name), + target=new_constant_name, + persistent=False, + ) + sig.input_specs = (new_input_spec, sig.input_specs[0]) + + return ExportedProgramPassResult(ep, modified=True) + + class IdentityModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + epm = to_edge( + torch.export.export(IdentityModule(), (torch.randn(3),)) + ) + exported_program = epm.exported_program() + assert len(exported_program.constants) == 0 + assert len(exported_program.graph_signature.input_specs) == 1 + + result_epm = epm.transform([AddConstantPass()]) + + new_ep = result_epm.exported_program() + + # Verify the new constant was added to constants dict + self.assertEqual(len(new_ep.constants), 1) + self.assertIn("_test_added_constant", new_ep.constants) + self.assertTrue( + torch.allclose( + new_ep.constants["_test_added_constant"], + torch.tensor([1.0, 2.0, 3.0]), + ) + ) + + # Verify input_specs was updated + self.assertEqual( + len(new_ep.graph_signature.input_specs), + 2, + ) + + # Verify the new placeholder exists in the graph + placeholder_names = [ + node.target + for node in new_ep.graph_module.graph.find_nodes( + op="placeholder" + ) + ] + self.assertTrue(len(placeholder_names) == 2) + + # Verify the new input spec has the correct kind + new_spec = None + for spec in new_ep.graph_signature.input_specs: + if spec.target == "_test_added_constant": + new_spec = spec + break + self.assertIsNotNone(new_spec) + self.assertEqual(new_spec.kind, InputKind.CONSTANT_TENSOR) + + def test_invalid_pass_creates_call_method(self) -> None: + """ + Tests that transform() detects invalid passes + that introduce call_method nodes. + """ + + def introduce_call_method(gm: torch.fx.GraphModule) -> PassResult: + node = list(gm.graph.nodes)[-2] + with gm.graph.inserting_after(node): + gm.graph.call_method("torch.ops.relu", (torch.randn(2),)) + return PassResult(gm, True) + + def f(x: torch.Tensor) -> torch.Tensor: + y = torch.add(x, x) + return y + + epm = to_edge( + exir.capture(f, (torch.randn(10),), exir.CaptureConfig()) + .to_edge() + .exported_program + ) + + with self.assertRaisesRegex(Exception, "call_method"): + epm.transform( + [introduce_call_method], run_checks_after_each_pass=True + ) + + +class TestEdgeProgramManagerWrappers(unittest.TestCase): + """Tests for the new EPM-level pass wrappers and MethodFilteredEdgeProgramManagerPass.""" + + def _make_simple_epm(self): + """Helper to create a simple EdgeProgramManager with a single 'forward' method.""" + + class SimpleModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.add(x, x) + + return to_edge(torch.export.export(SimpleModule(), (torch.randn(10),))) + + def test_runs_epm_pass_directly(self) -> None: + """ + Tests that EdgeProgramManagerPassBase subclasses can operate + directly on the EPM (e.g., modifying config methods). + """ + + class AddConfigMethodPass(EdgeProgramManagerPassBase): + def call(self, epm): + import copy + + new_epm = copy.copy(epm) + new_epm._config_methods = dict(epm._config_methods or {}) + new_epm._config_methods["new_config"] = "test_value" + return EdgeProgramManagerPassResult(new_epm, modified=True) + + epm = self._make_simple_epm() + + result_epm = epm.transform([AddConfigMethodPass()]) + + self.assertIn("new_config", result_epm.config_methods) + + def test_exported_program_to_epm_wrapper(self) -> None: + """ + Tests that ExportedProgramToEdgeProgramManagerPassWrapper correctly + iterates over all methods in the EPM. + """ + + class NoOpPass(ExportedProgramPassBase): + def __init__(self): + super().__init__() + self.call_count = 0 + + def call(self, ep: ExportedProgram) -> ExportedProgramPassResult: + self.call_count += 1 + return ExportedProgramPassResult(ep, modified=False) + + epm = self._make_simple_epm() + inner_pass = NoOpPass() + wrapper = ExportedProgramToEdgeProgramManagerPassWrapper(inner_pass) + + result = wrapper(epm) + self.assertIsInstance(result, EdgeProgramManagerPassResult) + self.assertFalse(result.modified) + self.assertEqual(inner_pass.call_count, len(epm.methods)) + + def test_graph_module_to_epm_two_step_wrapping(self) -> None: + """ + Tests that wrapping a GraphModule pass with + GraphModuleBackedExportedProgramPassWrapper and then + ExportedProgramToEdgeProgramManagerPassWrapper correctly + applies it to all methods. + """ + from executorch.exir.edge_program_manager_pass_base import ( + GraphModuleBackedExportedProgramPassWrapper, + ) + from torch.fx.passes.infra.pass_manager import pass_result_wrapper + + call_count = 0 + + def counting_pass(gm: torch.fx.GraphModule) -> PassResult: + nonlocal call_count + call_count += 1 + return PassResult(gm, False) + + epm = self._make_simple_epm() + ep_pass = GraphModuleBackedExportedProgramPassWrapper( + pass_result_wrapper(counting_pass) + ) + wrapper = ExportedProgramToEdgeProgramManagerPassWrapper(ep_pass) + + result = wrapper(epm) + self.assertIsInstance(result, EdgeProgramManagerPassResult) + self.assertFalse(result.modified) + self.assertEqual(call_count, len(epm.methods)) + + def test_method_filtered_pass(self) -> None: + """ + Tests that MethodFilteredEdgeProgramManagerPass applies passes + only to specified methods. + """ + + def replace_add_with_mul(gm: torch.fx.GraphModule) -> PassResult: + modified = False + for node in gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ): + node.target = exir_ops.edge.aten.mul.Tensor + modified = True + return PassResult(gm, modified) + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.add(x, x) + + epm = to_edge( + { + "forward": torch.export.export(AddModule(), (torch.randn(10),)), + "other": torch.export.export(AddModule(), (torch.randn(10),)), + } + ) + + filtered_pass = MethodFilteredEdgeProgramManagerPass( + {"forward": [replace_add_with_mul]} + ) + + result = filtered_pass(epm) + self.assertTrue(result.modified) + + # 'forward' should have mul ops (no add remaining) + add_nodes = result.edge_program_manager.exported_program("forward").graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ) + self.assertEqual(len(add_nodes), 0) + + # 'other' should still have add ops + add_nodes = result.edge_program_manager.exported_program("other").graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ) + self.assertGreater(len(add_nodes), 0, "Expected 'other' method to still have add ops") From f65b936015a888323867ea7fad271328871cead0 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 16 Apr 2026 12:04:24 -0700 Subject: [PATCH 2/2] Interface for swapping inputs and outputs based on cost model + working permute example (#18929) Summary: I have seen the following pattern quite a bit so far where we have wrapper(x0), wrapper(x1), x2 -> op -> wrapper(o0), o1 for example, and it is semantically equivalent to instead run x0, x1, wrapper(x2) -> op -> o0, wrapper(o1) In case 1, we had 3 wrapper ops, and in case 2, we have 2, so it's better to run 2. This interface formalizes this idea. I also implemented the pass which uses this interface for the permute case. What is really nice about this interface is that in a single pass, we can do a LOT of cleanup. See the docs in the diff for a case for adds. Differential Revision: D100917820 --- backends/cadence/aot/pass_utils.py | 444 +++++++++++++++++- backends/cadence/aot/replace_ops.py | 59 ++- .../aot/tests/test_replace_ops_passes.py | 376 +++++++++++++++ 3 files changed, 877 insertions(+), 2 deletions(-) diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index d03862d44fa..4a1d2bef252 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -9,7 +9,17 @@ import dataclasses from abc import abstractmethod from dataclasses import dataclass -from typing import Callable, List, Optional, override, Set, Type, TypeVar, Union +from typing import ( + Callable, + Hashable, + List, + Optional, + override, + Set, + Type, + TypeVar, + Union, +) import torch from beartype.door import die_if_unbearable @@ -356,3 +366,435 @@ def _apply_flat_inplace(self, graph_module: torch.fx.GraphModule) -> bool: continue changed |= self.maybe_remove_or_replace(node) return changed + + +class SwapOnCostModelPassInterface(HierarchicalInplacePassInterface): + """ + A base class for passes that reduce op count by moving wrapper operations + (e.g., dequant/quant, permute) from the majority side of a target op's + inputs/outputs to the minority side. + + Given a target op with some inputs wrapped by ``input_to_swap`` and some + outputs wrapped by ``output_to_swap``, this pass checks whether it is + cheaper to instead wrap the *other* inputs/outputs. The swap is performed + only when: + + 1. All input wrappers share the same hash (via ``input_hash``). + 2. All output wrappers share the same hash (via ``output_hash``). + 3. The input and output hashes are compatible (via ``hashes_are_compatible``). + 4. The cost after swapping is strictly less than before (via ``should_swap``). + + Subclasses must implement: + - ``targets``: the ops to scan (e.g., cat, add, mul). + - ``input_to_swap`` / ``output_to_swap``: the wrapper op targets. + - ``input_hash`` / ``output_hash``: extract a hashable identity from a + wrapper node for equality checking. + - ``hashes_are_compatible``: define the relationship between input and + output hashes (equality for quant/dequant, inverse for permute). + + Optionally override: + - ``should_swap``: custom cost model for the swap decision (default: + total bytes moved through wrapper ops, i.e. bytes read + written). + + Example 1 — Dequant/quant around cat:: + + # Before: 3 dequants on inputs, 1 fp32 input, 2 quant outputs, 1 fp32 output. + # dequant(A, s, zp) ──┐ + # dequant(B, s, zp) ──┤ ┌── quant(out1, s, zp) + # dequant(C, s, zp) ──┼── cat(fp32) ─┼── quant(out2, s, zp) + # D(fp32) ┘ └── consumer(fp32) + # Cost: 5 wrapper ops (3 dequants + 2 quants) + # + # After: A, B, C feed cat directly (already quantized). D gets a quant. + # The cat now runs in quantized domain. The fp32 consumer gets a dequant. + # A ──┐ + # B ──┤ ┌── out1 (quantized, no wrapper needed) + # C ──┼── cat(q) ──┼── out2 (quantized, no wrapper needed) + # quant(D, s, zp) ─┘ └── dequant(out3, s, zp) ── consumer + # Cost: 2 wrapper ops (1 quant + 1 dequant) + # + # Delta: 5 → 2, net savings of 3 ops. + # + # input_to_swap = dequantize_per_tensor + # output_to_swap = quantize_per_tensor + # input_hash = (scale, zero_point, quant_min, quant_max, dtype) + # output_hash = (scale, zero_point, quant_min, quant_max, dtype) + # hashes_are_compatible: input_hash == output_hash + + Example 2 — Permutes around a binary op:: + + # Before: permute(A) and B feed add, output is inverse-permuted. + # permute(A, [0,3,1,2]) ──┐ + # ├── add ── permute(out, [0,2,3,1]) + # B ───┘ + # Cost: 2 wrapper ops + # + # After: A and inverse-permute(B) feed add, no output permute. + # A ──────────────────┐ + # ├── add ── out + # permute(B, [0,2,3,1]) ─────────┘ + # Cost: 1 wrapper op + # + # input_to_swap = permute_copy + # output_to_swap = permute_copy + # input_hash = tuple(dims) + # output_hash = tuple(dims) + # hashes_are_compatible: applying input_perm then output_perm = identity + + Example 3 — Chained elimination across multiple ops:: + + When targets share edges, a single pass run can cascade reductions. + The pass visits each target in graph order; wrappers injected by + earlier swaps become inputs/outputs for later targets. + + # Before (5 permutes total): + # permute(A, dims) ──┐ + # ├── add_1 ── permute(out1, inv_dims) ── consumer_1 + # permute(B, dims) ──┤ + # └── (fp32 edge to add_2) + # │ + # permute(C, dims) ─────────────┐ │ + # ├── add_2 ── permute(out2, inv_dims) + # (from add_1)┘ + # + # After pass visits add_1 (3 wrappers → 1): + # Removes 2 input permutes + 1 output permute, injects 1 inverse + # permute on the edge flowing to add_2. + # A ──┐ + # ├── add_1 ── consumer_1 + # B ──┤ + # └── permute(edge, inv_dims) ── add_2 + # permute(C, dims) ─────────────────────────┘ │ + # └── permute(out2, inv_dims) + # + # After pass visits add_2 (3 wrappers → 0): + # add_2 now has 2 permuted inputs + 1 permuted output = 3, 0 unwrapped. + # All eliminated. + # A ──┐ + # ├── add_1 ── consumer_1 + # B ──┤ + # └── add_2 ── out2 + # C ──┘ + # + # Result: all 5 original permutes eliminated in a single pass run. + """ + + @property + @abstractmethod + def targets(self) -> list[EdgeOpOverload]: + """ + The list of targets that we will potentially swap inputs and outputs. + """ + raise NotImplementedError("`targets` must be implemented") + + @property + @abstractmethod + def input_to_swap(self) -> EdgeOpOverload: + """ + The wrapper op target to match on inputs (e.g., dequantize_per_tensor, + permute_copy). Inputs to the target op whose target matches this will + be candidates for removal during the swap. + """ + raise NotImplementedError("You must specify the input we are trying to swap") + + @property + @abstractmethod + def output_to_swap(self) -> EdgeOpOverload: + """ + The wrapper op target to match on outputs (e.g., quantize_per_tensor, + permute_copy). Users of the target op whose target matches this will + be candidates for removal during the swap. + """ + ... + + @abstractmethod + def input_hash(self, node: Node) -> Hashable: + """ + Extract a hashable identity from an input wrapper node. All input + wrappers must produce the same hash for the swap to be valid. + E.g., for dequant: (scale, zero_point, quant_min, quant_max, dtype). + E.g., for permute: tuple(dims). + """ + ... + + @abstractmethod + def output_hash(self, node: Node) -> Hashable: + """ + Extract a hashable identity from an output wrapper node. All output + wrappers must produce the same hash for the swap to be valid. + """ + ... + + @abstractmethod + def hashes_are_compatible( + self, input_hash: Hashable, output_hash: Hashable + ) -> bool: + """ + Check whether the input and output wrapper hashes are compatible, + meaning the swap is semantically legal. + For quant/dequant: hashes must be equal (same scale, zp, dtype). + For permute: output dims must be the inverse permutation of input dims. + """ + ... + + @abstractmethod + def create_inverse_wrapper_args( + self, + template: Node, + ) -> tuple[EdgeOpOverload, tuple, dict]: + """ + Given a wrapper node from one side, return (target, args_tail, kwargs) + for the inverse wrapper on the other side. args_tail excludes the data + input (first arg), which is supplied by the caller. + + For quant/dequant: swap the op target, keep the same params. + For permute: same op target, compute inverse dims. + """ + ... + + def should_swap( + self, + input_nodes_before_swap: list[Node], + input_nodes_after_swap: list[Node], + output_nodes_before_swap: list[Node], + output_nodes_after_swap: list[Node], + ) -> bool: + """ + Return True if the swap reduces total data movement (bytes read + + bytes written across all wrapper ops). Override to use a different + cost model. + + The default implementation computes, for every wrapper op (existing + or hypothetical), the sum of its input tensor bytes and output + tensor bytes. For not-yet-existing inverse wrappers the output + metadata is obtained by running the op's meta kernel on FakeTensors. + """ + + def _node_bytes(node: Node) -> int: + if node.op == "output": + return 0 + val = node.meta["val"] + if not isinstance(val, torch.Tensor): + raise ValueError( + f"should_swap: expected torch.Tensor in meta['val'], got {type(val)}" + ) + return val.nelement() * val.element_size() + + def _wrapper_cost(node: Node) -> int: + """bytes_in + bytes_out for an existing wrapper node.""" + return _node_bytes(node.args[0]) + _node_bytes(node) + + def _inverse_wrapper_cost(node: Node, template: Node) -> int: + """bytes_in + bytes_out for a wrapper that would be created on node.""" + if node.op == "output": + return 0 + inv_target, inv_args_tail, inv_kwargs = self.create_inverse_wrapper_args( + template + ) + resolved_args = tuple( + a.meta["val"] if isinstance(a, Node) else a for a in inv_args_tail + ) + resolved_kwargs = { + k: v.meta["val"] if isinstance(v, Node) else v + for k, v in inv_kwargs.items() + } + fake_output = inv_target( + node.meta["val"], *resolved_args, **resolved_kwargs + ) + bytes_in = _node_bytes(node) + bytes_out = fake_output.nelement() * fake_output.element_size() + return bytes_in + bytes_out + + cost_before = sum(_wrapper_cost(n) for n in input_nodes_before_swap) + sum( + _wrapper_cost(n) for n in output_nodes_before_swap + ) + + cost_after = 0 + if input_nodes_after_swap: + template = ( + output_nodes_before_swap[0] + if output_nodes_before_swap + else input_nodes_before_swap[0] + ) + cost_after += sum( + _inverse_wrapper_cost(n, template) for n in input_nodes_after_swap + ) + if output_nodes_after_swap: + template = ( + input_nodes_before_swap[0] + if input_nodes_before_swap + else output_nodes_before_swap[0] + ) + cost_after += sum( + _inverse_wrapper_cost(n, template) for n in output_nodes_after_swap + ) + + return cost_after < cost_before + + def _partition_neighbors( + self, + neighbors: list[Node], + wrapper_target: EdgeOpOverload, + hash_fn: Callable[[Node], Hashable], + ) -> tuple[list[Node], list[Node], Hashable | None] | None: + """Partition neighbor nodes into (before_swap, after_swap, hash_value). + + Returns None if the partition is invalid (wrapper with multiple users + or inconsistent hashes across wrappers). + """ + hash_value: Hashable | None = None + before_swap: list[Node] = [] + after_swap: list[Node] = [] + for neighbor in neighbors: + if neighbor.target == wrapper_target: + if len(neighbor.users) != 1: + return None + h = hash_fn(neighbor) + if hash_value is None: + hash_value = h + elif hash_value != h: + return None + before_swap.append(neighbor) + else: + after_swap.append(neighbor) + return before_swap, after_swap, hash_value + + def _resolve_inverse_params( + self, + opposite_before_swap: list[Node], + same_side_before_swap: list[Node], + ) -> tuple[EdgeOpOverload, tuple, dict, dict]: + """Determine (target, args_tail, kwargs, meta) for the inverse wrapper. + + Uses the opposite side's existing wrappers if available, otherwise + falls back to create_inverse_wrapper_args on the same side. + """ + if opposite_before_swap: + template = opposite_before_swap[0] + return ( + template.target, + template.args[1:], + dict(template.kwargs), + template.meta.copy(), + ) + template = same_side_before_swap[0] + inv_target, inv_args_tail, inv_kwargs = self.create_inverse_wrapper_args( + template + ) + return inv_target, inv_args_tail, inv_kwargs, template.meta.copy() + + def _inject_inverse_wrappers( + self, + graph: torch.fx.Graph, + unwrapped_nodes: list[Node], + target_node: Node, + inv_target: EdgeOpOverload, + inv_args_tail: tuple, + inv_kwargs: dict, + inv_meta: dict, + on_inputs: bool, + ) -> None: + """Inject inverse wrapper nodes between target_node and unwrapped_nodes.""" + for unwrapped in list(unwrapped_nodes): + if on_inputs: + source, dest = unwrapped, target_node + else: + source, dest = target_node, unwrapped + with graph.inserting_before(dest): + new_wrapper = graph.call_function( + inv_target, + args=(source,) + inv_args_tail, + kwargs=inv_kwargs, + ) + new_wrapper.meta = inv_meta + dest.replace_input_with(source, new_wrapper) + + def _apply_flat_inplace(self, graph_module: torch.fx.GraphModule) -> bool: + changed = False + for target in self.targets: + for node in graph_module.graph.find_nodes( + op="call_function", target=target + ): + # Partition inputs into wrapped (before_swap) and unwrapped (after_swap) + input_result = self._partition_neighbors( + list(node._input_nodes.keys()), + self.input_to_swap, + self.input_hash, + ) + if input_result is None: + continue + input_before, input_after, input_hash_value = input_result + + output_result = self._partition_neighbors( + list(node.users), + self.output_to_swap, + self.output_hash, + ) + if output_result is None: + continue + output_before, output_after, output_hash_value = output_result + + # We need at least one side to have wrappers to perform a swap. + if input_hash_value is None and output_hash_value is None: + continue + + # Check cross-compatibility between input and output hashes + if ( + input_hash_value is not None + and output_hash_value is not None + and not self.hashes_are_compatible( + input_hash_value, output_hash_value + ) + ): + continue + + if not self.should_swap( + input_before, input_after, output_before, output_after + ): + continue + + graph = graph_module.graph + + # Inject inverse wrappers on unwrapped inputs + if input_after: + inv_target, inv_args_tail, inv_kwargs, inv_meta = ( + self._resolve_inverse_params(output_before, input_before) + ) + self._inject_inverse_wrappers( + graph, + input_after, + node, + inv_target, + inv_args_tail, + inv_kwargs, + inv_meta, + on_inputs=True, + ) + + # Inject inverse wrappers on unwrapped outputs + if output_after: + inv_target, inv_args_tail, inv_kwargs, inv_meta = ( + self._resolve_inverse_params(input_before, output_before) + ) + self._inject_inverse_wrappers( + graph, + output_after, + node, + inv_target, + inv_args_tail, + inv_kwargs, + inv_meta, + on_inputs=False, + ) + + # Bypass existing input wrappers (e.g., remove dequants). + for wrapper in input_before: + node.replace_input_with(wrapper, wrapper.args[0]) + + # Bypass existing output wrappers (e.g., remove quants). + for wrapper in output_before: + wrapper.replace_all_uses_with(node) + + changed = True + + return changed diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index e09a6589e76..caaf4406cf4 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -15,7 +15,7 @@ import math import operator from operator import neg -from typing import cast, Dict, Optional, Sequence +from typing import cast, Dict, Hashable, Optional, Sequence import torch import torch.fx @@ -26,6 +26,7 @@ get_arg, register_cadence_pass, RemoveOrReplacePassInterface, + SwapOnCostModelPassInterface, ) from executorch.backends.cadence.aot.utils import is_depthwise_conv from executorch.backends.transforms.replace_scalar_with_tensor import ( @@ -2613,6 +2614,61 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class SwapDequantQuantAroundDataMovementOps(SwapOnCostModelPassInterface): + """ + For data movement ops (cat, slice_copy) surrounded by dequant inputs and + quant outputs with matching quantization parameters, swap the wrappers to + the minority side to reduce total quant/dequant op count. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.slice_copy.Tensor, + ] + + @property + def input_to_swap(self) -> EdgeOpOverload: + return exir_ops.edge.cadence.dequantize_per_tensor.default + + @property + def output_to_swap(self) -> EdgeOpOverload: + return exir_ops.edge.cadence.quantize_per_tensor.default + + def _quant_key(self, node: torch.fx.Node) -> Hashable: + return ( + get_arg(node, "scale", float), + get_arg(node, "zero_point", int), + get_arg(node, "quant_min", int), + get_arg(node, "quant_max", int), + get_arg(node, "dtype", torch.dtype), + ) + + def input_hash(self, node: torch.fx.Node) -> Hashable: + return self._quant_key(node) + + def output_hash(self, node: torch.fx.Node) -> Hashable: + return self._quant_key(node) + + def hashes_are_compatible( + self, input_hash: Hashable, output_hash: Hashable + ) -> bool: + return input_hash == output_hash + + def create_inverse_wrapper_args( + self, + template: torch.fx.Node, + ) -> tuple[EdgeOpOverload, tuple, dict]: + # dequant template → need a quant; quant template → need a dequant. + if template.target == self.input_to_swap: + target = self.output_to_swap + else: + target = self.input_to_swap + return (target, template.args[1:], dict(template.kwargs)) + + class CommonReplacePasses: passes = [ ReplaceScalarWithTensorArgPass, @@ -2682,4 +2738,5 @@ class CadenceReplaceOpsInGraph: ReplaceAtenAvgPoolWithCadenceAvgPoolPass, ReplaceWhereWithFullArgsWithWhereScalar, ReplaceMulTensorWithMulAndFullOpsPass, + SwapDequantQuantAroundDataMovementOps, ] diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 56a17b73f88..93baf9d15c4 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -49,6 +49,7 @@ ReplaceTransposedConvWithLinearPass, ReplaceTrivialConvWithLinear, ReplaceWhereWithFullArgsWithWhereScalar, + SwapDequantQuantAroundDataMovementOps, ) from executorch.backends.cadence.aot.typing_stubs import expand @@ -3596,3 +3597,378 @@ def test_no_replacement_without_logical_not(self) -> None: self.assertEqual(node.args[0].name, "bool_cond") self.assertEqual(node.args[1].name, "x") self.assertEqual(node.args[2].name, "y") + + +class SwapDequantQuantAroundDataMovementOpsTest(unittest.TestCase): + """Tests for the SwapDequantQuantAroundDataMovementOps pass.""" + + SCALE = 0.01 + ZERO_POINT = 128 + QUANT_MIN = 0 + QUANT_MAX = 255 + DTYPE: torch.dtype = torch.uint8 + + def _quant_args( + self, + scale: Optional[float] = None, + zero_point: Optional[int] = None, + ) -> tuple[float, int, int, int, torch.dtype]: + return ( + scale if scale is not None else self.SCALE, + zero_point if zero_point is not None else self.ZERO_POINT, + self.QUANT_MIN, + self.QUANT_MAX, + self.DTYPE, + ) + + def test_no_swap_input_hashes_mismatch(self) -> None: + """Two dequant inputs with different scales → swap is illegal.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + b = builder.placeholder("b", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + + dq_a = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(a,) + self._quant_args(scale=0.01), + ) + dq_b = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(b,) + self._quant_args(scale=0.02), + ) + cat = builder.call_operator( + op=exir_ops.edge.aten.cat.default, + args=([dq_a, dq_b], 1), + ) + q_out = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(cat,) + self._quant_args(scale=0.01), + ) + builder.output([q_out]) + gm = builder.get_graph_module() + + p = SwapDequantQuantAroundDataMovementOps() + result = cast(PassResult, p(gm)) + + self.assertFalse(result.modified) + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default + ), + 2, + ) + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.quantize_per_tensor.default + ), + 1, + ) + + def test_no_swap_output_hashes_mismatch(self) -> None: + """Two quant outputs with different zero_points → swap is illegal.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + + dq_a = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(a,) + self._quant_args(), + ) + cat = builder.call_operator( + op=exir_ops.edge.aten.cat.default, + args=([dq_a], 0), + ) + q_out1 = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(cat,) + self._quant_args(zero_point=128), + ) + q_out2 = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(cat,) + self._quant_args(zero_point=64), + ) + builder.output([q_out1, q_out2]) + gm = builder.get_graph_module() + + p = SwapDequantQuantAroundDataMovementOps() + result = cast(PassResult, p(gm)) + + self.assertFalse(result.modified) + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default + ), + 1, + ) + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.quantize_per_tensor.default + ), + 2, + ) + + def test_no_swap_input_output_hashes_incompatible(self) -> None: + """Input dequants all match, output quants all match, but input != output → illegal.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + b = builder.placeholder("b", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + + dq_a = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(a,) + self._quant_args(scale=0.01), + ) + dq_b = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(b,) + self._quant_args(scale=0.01), + ) + cat = builder.call_operator( + op=exir_ops.edge.aten.cat.default, + args=([dq_a, dq_b], 1), + ) + q_out = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(cat,) + self._quant_args(scale=0.05), + ) + builder.output([q_out]) + gm = builder.get_graph_module() + + p = SwapDequantQuantAroundDataMovementOps() + result = cast(PassResult, p(gm)) + + self.assertFalse(result.modified) + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default + ), + 2, + ) + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.quantize_per_tensor.default + ), + 1, + ) + + def test_no_swap_equal_cost(self) -> None: + """One dequant input + one fp32 input, no quant outputs → cost 1 before and 1 after.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + b = builder.placeholder("b", torch.randn(4, 8)) + + dq_a = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(a,) + self._quant_args(), + ) + cat = builder.call_operator( + op=exir_ops.edge.aten.cat.default, + args=([dq_a, b], 1), + ) + builder.output([cat]) + gm = builder.get_graph_module() + + p = SwapDequantQuantAroundDataMovementOps() + result = cast(PassResult, p(gm)) + + # No swap: before = 1 (1 dequant), after = 1 (1 quant on b). Equal cost. + self.assertFalse(result.modified) + + def test_no_swap_higher_cost(self) -> None: + """One dequant input + two fp32 inputs, no quant outputs → swap would increase cost.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + b = builder.placeholder("b", torch.randn(4, 8)) + c = builder.placeholder("c", torch.randn(4, 8)) + + dq_a = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(a,) + self._quant_args(), + ) + cat = builder.call_operator( + op=exir_ops.edge.aten.cat.default, + args=([dq_a, b, c], 1), + ) + builder.output([cat]) + gm = builder.get_graph_module() + + p = SwapDequantQuantAroundDataMovementOps() + result = cast(PassResult, p(gm)) + + self.assertFalse(result.modified) + + def test_swap_reduces_cost(self) -> None: + """Three dequant inputs + one fp32 input, one quant output → 4 before, 1 after.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + b = builder.placeholder("b", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + c = builder.placeholder("c", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + d = builder.placeholder("d", torch.randn(4, 8)) + + dq_a = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(a,) + self._quant_args(), + ) + dq_b = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(b,) + self._quant_args(), + ) + dq_c = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(c,) + self._quant_args(), + ) + cat = builder.call_operator( + op=exir_ops.edge.aten.cat.default, + args=([dq_a, dq_b, dq_c, d], 1), + ) + q_out = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(cat,) + self._quant_args(), + ) + builder.output([q_out]) + gm = builder.get_graph_module() + + p = SwapDequantQuantAroundDataMovementOps() + result = cast(PassResult, p(gm)) + + self.assertTrue(result.modified) + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default + ), + 0, + ) + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.quantize_per_tensor.default + ), + 1, + ) + + def test_swap_all_matched_eliminates_all(self) -> None: + """All inputs dequanted, all outputs quanted, same params → eliminate everything.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + b = builder.placeholder("b", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + + dq_a = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(a,) + self._quant_args(), + ) + dq_b = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(b,) + self._quant_args(), + ) + cat = builder.call_operator( + op=exir_ops.edge.aten.cat.default, + args=([dq_a, dq_b], 1), + ) + q_out = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(cat,) + self._quant_args(), + ) + builder.output([q_out]) + gm = builder.get_graph_module() + + p = SwapDequantQuantAroundDataMovementOps() + result = cast(PassResult, p(gm)) + + self.assertTrue(result.modified) + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default + ), + 0, + ) + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.quantize_per_tensor.default + ), + 0, + ) + + def test_swap_output_only_wrappers(self) -> None: + """No dequant inputs, 4 quant outputs + 1 fp32 output → push quants to inputs.""" + builder = GraphBuilder() + a = builder.placeholder("a", torch.randn(4, 8)) + b = builder.placeholder("b", torch.randn(4, 8)) + + cat = builder.call_operator( + op=exir_ops.edge.aten.cat.default, + args=([a, b], 1), + ) + q_out1 = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(cat,) + self._quant_args(), + ) + q_out2 = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(cat,) + self._quant_args(), + ) + q_out3 = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(cat,) + self._quant_args(), + ) + q_out4 = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(cat,) + self._quant_args(), + ) + builder.output([q_out1, q_out2, q_out3, q_out4, cat]) + gm = builder.get_graph_module() + + p = SwapDequantQuantAroundDataMovementOps() + result = cast(PassResult, p(gm)) + + # before = 4 (4 quants), after = 3 (quant a + quant b + dequant for fp32 out). + self.assertTrue(result.modified) + + # All inputs should now be quantized (2 quants added). + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.quantize_per_tensor.default + ), + 2, + ) + # One dequant for the fp32 consumer. + self.assertEqual( + count_node( + result.graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default + ), + 1, + ) + + def test_swap_fewer_nodes_but_more_bytes(self) -> None: + """ + Before: 2 dequants (small [4, 8] uint8 → fp32) + After would be: 1 quant (large [4, 2048] fp32 → uint8) + + Byte cost: + before = 2 x (4x8x1 + 4x8x4) = 2 x 160 = 320 + after = 1 x (4x2048x4 + 4x2048x1) = 40960 + Byte cost says don't swap. + """ + builder = GraphBuilder() + a = builder.placeholder("a", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + b = builder.placeholder("b", torch.randint(0, 255, (4, 8), dtype=torch.uint8)) + c = builder.placeholder("c", torch.randn(4, 2048)) + + dq_a = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(a,) + self._quant_args(), + ) + dq_b = builder.call_operator( + op=exir_ops.edge.cadence.dequantize_per_tensor.default, + args=(b,) + self._quant_args(), + ) + cat = builder.call_operator( + op=exir_ops.edge.aten.cat.default, + args=([dq_a, dq_b, c], 1), + ) + q_out = builder.call_operator( + op=exir_ops.edge.cadence.quantize_per_tensor.default, + args=(cat,) + self._quant_args(), + ) + builder.output([q_out]) + gm = builder.get_graph_module() + + p = SwapDequantQuantAroundDataMovementOps() + result = cast(PassResult, p(gm)) + + # Should NOT swap: wrapping the large [4, 2048] input costs far more + # bytes than the 2 small dequants save. + self.assertFalse(result.modified)