diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 31cb7a2e2c7..29946fb988b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -171,6 +171,32 @@ class PassInsertions: after_passes: list = field(default_factory=list) +_registered_pass_insertions: dict[type, PassInsertions] = {} + + +def register_pass_insertions_before( + target_pass_type: type, passes: list[ExportPass] +) -> None: + """Register passes to be inserted before a target pass for all pipelines.""" + if target_pass_type not in _registered_pass_insertions: + _registered_pass_insertions[target_pass_type] = PassInsertions() + _registered_pass_insertions[target_pass_type].before_passes.extend(passes) + + +def register_pass_insertions_after( + target_pass_type: type, passes: list[ExportPass] +) -> None: + """Register passes to be inserted after a target pass for all pipelines.""" + if target_pass_type not in _registered_pass_insertions: + _registered_pass_insertions[target_pass_type] = PassInsertions() + _registered_pass_insertions[target_pass_type].after_passes.extend(passes) + + +def clear_registered_pass_insertions() -> None: + """Clear all globally registered pass insertions.""" + _registered_pass_insertions.clear() + + class ArmPassManager(PassManager): def __init__(self, compile_spec: ArmCompileSpec) -> None: self.compile_spec = compile_spec @@ -319,13 +345,17 @@ def _configure_pass_insertions(self, exported_program: ExportedProgram) -> None: """Hook for subclasses to configure pass insertions. Called at the START of pipeline construction, before any passes are added. - Subclasses should override this to call insert_passes_before/after. + Subclasses can override this to call insert_passes_before/after. Args: exported_program: The exported program being transformed """ - pass + for pass_type, insertions in _registered_pass_insertions.items(): + if insertions.before_passes: + self.insert_passes_before(pass_type, list(insertions.before_passes)) + if insertions.after_passes: + self.insert_passes_after(pass_type, list(insertions.after_passes)) def add_passes(self, passes: Sequence[ExportPass | None]): for p in passes: diff --git a/docs/source/backends/arm-ethos-u/arm-ethos-u-partitioner.md b/docs/source/backends/arm-ethos-u/arm-ethos-u-partitioner.md index 10a28b8e785..ea315d567de 100644 --- a/docs/source/backends/arm-ethos-u/arm-ethos-u-partitioner.md +++ b/docs/source/backends/arm-ethos-u/arm-ethos-u-partitioner.md @@ -45,3 +45,9 @@ Args: Returns: - **PartitionResult**: The input program with nodes tagged for delegation and a mapping of partition tags to delegation specs. + +```python +def EthosUPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None: +``` +Register a custom op to be considered supported by this +partitioner. diff --git a/docs/source/backends/arm-vgf/arm-vgf-partitioner.md b/docs/source/backends/arm-vgf/arm-vgf-partitioner.md index e3cbd2f9d22..811696aa1b7 100644 --- a/docs/source/backends/arm-vgf/arm-vgf-partitioner.md +++ b/docs/source/backends/arm-vgf/arm-vgf-partitioner.md @@ -45,3 +45,9 @@ Args: Returns: - **PartitionResult**: The input program with nodes tagged for delegation and a mapping of partition tags to delegation specs. + +```python +def VgfPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None: +``` +Register a custom op to be considered supported by this +partitioner.