Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,32 @@ class PassInsertions:
after_passes: list = field(default_factory=list)


_registered_pass_insertions: dict[type, PassInsertions] = {}


def register_pass_insertions_before(
target_pass_type: type, passes: list[ExportPass]
) -> None:
"""Register passes to be inserted before a target pass for all pipelines."""
if target_pass_type not in _registered_pass_insertions:
_registered_pass_insertions[target_pass_type] = PassInsertions()
_registered_pass_insertions[target_pass_type].before_passes.extend(passes)


def register_pass_insertions_after(
target_pass_type: type, passes: list[ExportPass]
) -> None:
"""Register passes to be inserted after a target pass for all pipelines."""
if target_pass_type not in _registered_pass_insertions:
_registered_pass_insertions[target_pass_type] = PassInsertions()
_registered_pass_insertions[target_pass_type].after_passes.extend(passes)

Comment thread
robell marked this conversation as resolved.

def clear_registered_pass_insertions() -> None:
"""Clear all globally registered pass insertions."""
_registered_pass_insertions.clear()
Comment thread
robell marked this conversation as resolved.


class ArmPassManager(PassManager):
def __init__(self, compile_spec: ArmCompileSpec) -> None:
self.compile_spec = compile_spec
Expand Down Expand Up @@ -319,13 +345,17 @@ def _configure_pass_insertions(self, exported_program: ExportedProgram) -> None:
"""Hook for subclasses to configure pass insertions. Called at the START
of pipeline construction, before any passes are added.

Subclasses should override this to call insert_passes_before/after.
Subclasses can override this to call insert_passes_before/after.

Args:
exported_program: The exported program being transformed

"""
pass
for pass_type, insertions in _registered_pass_insertions.items():
if insertions.before_passes:
self.insert_passes_before(pass_type, list(insertions.before_passes))
if insertions.after_passes:
self.insert_passes_after(pass_type, list(insertions.after_passes))

def add_passes(self, passes: Sequence[ExportPass | None]):
for p in passes:
Expand Down
6 changes: 6 additions & 0 deletions docs/source/backends/arm-ethos-u/arm-ethos-u-partitioner.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,9 @@ Args:
Returns:
- **PartitionResult**: The input program with nodes tagged for delegation
and a mapping of partition tags to delegation specs.

```python
def EthosUPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None:
```
Register a custom op to be considered supported by this
partitioner.
6 changes: 6 additions & 0 deletions docs/source/backends/arm-vgf/arm-vgf-partitioner.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,9 @@ Args:
Returns:
- **PartitionResult**: The input program with nodes tagged for delegation
and a mapping of partition tags to delegation specs.

```python
def VgfPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None:
```
Register a custom op to be considered supported by this
partitioner.
Loading