Validate symmetric heap placement in CCL collectives#526
Validate symmetric heap placement in CCL collectives#526
Conversation
Each collective now checks that remotely-accessed tensors are on the symmetric heap. Input tensors that are remote-read get auto-imported via as_symmetric(); output tensors that are remote-written must be pre-allocated on the heap (auto-import would silently discard results with the torch allocator). Per-collective rules based on kernel access patterns: - all_gather: output validated (remote write), input unchecked (local) - all_to_all: input auto-imported (remote read), output validated - all_reduce: variant-dependent (e.g. two_shot: both, ring: neither) - reduce_scatter: input auto-imported (remote read), output unchecked Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds symmetric-heap validation/import helpers and applies them across CCL collectives so tensors that are remotely accessed use the symmetric heap (auto-import for remote-reads, strict validation for remote-writes), plus new all_gather tests to enforce expected behavior.
Changes:
- Introduce
_ensure_symmetric()and_validate_output_symmetric()utilities iniris/ccl/utils.py. - Update CCL collectives to auto-import remote-read inputs and reject non-symmetric remote-write outputs.
- Add tests ensuring
all_gatherrejects non-symmetric outputs but allows non-symmetric inputs.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/ccl/test_all_gather.py | Adds test coverage for symmetric-heap enforcement semantics in all_gather. |
| iris/ccl/utils.py | Adds helper utilities to import/validate symmetric-heap placement. |
| iris/ccl/reduce_scatter.py | Auto-imports remote-read input tensor onto symmetric heap. |
| iris/ccl/all_to_all.py | Auto-imports remote-read input and validates remote-written output on symmetric heap. |
| iris/ccl/all_reduce.py | Variant-specific symmetric-heap import/validation based on kernel access patterns. |
| iris/ccl/all_gather.py | Validates output tensor is symmetric-heap allocated (remote-writes). |
| """Return tensor on symmetric heap, importing if needed. | ||
|
|
||
| For input tensors that are only read: the kernel reads from the heap | ||
| copy, so the caller doesn't need the returned tensor back. | ||
|
|
||
| Do NOT use this for output tensors — see _validate_output_symmetric. | ||
| """ | ||
| if ctx.is_symmetric(tensor): | ||
| return tensor | ||
| return ctx.as_symmetric(tensor) |
There was a problem hiding this comment.
name is currently unused in _ensure_symmetric, which can trigger lint warnings and is misleading. Either remove the parameter, or use it (e.g., in a debug/log message or in an error if as_symmetric can fail). Also consider rewording the docstring: while users may not need the returned tensor, the caller kernel launch must use the returned value—otherwise the import has no effect.
| """Return tensor on symmetric heap, importing if needed. | |
| For input tensors that are only read: the kernel reads from the heap | |
| copy, so the caller doesn't need the returned tensor back. | |
| Do NOT use this for output tensors — see _validate_output_symmetric. | |
| """ | |
| if ctx.is_symmetric(tensor): | |
| return tensor | |
| return ctx.as_symmetric(tensor) | |
| """Return ``tensor`` on the symmetric heap, importing if needed. | |
| If ``tensor`` is not already symmetric, this returns the imported heap | |
| copy. The caller must use the returned tensor for the subsequent kernel | |
| launch; otherwise the import has no effect. | |
| For read-only inputs, users may not need to keep the returned tensor after | |
| the kernel launch completes. | |
| Do NOT use this for output tensors — see _validate_output_symmetric. | |
| """ | |
| if ctx.is_symmetric(tensor): | |
| return tensor | |
| try: | |
| return ctx.as_symmetric(tensor) | |
| except Exception as exc: | |
| raise type(exc)(f"Failed to import {name} to the symmetric heap: {exc}") from exc |
| if not ctx.is_symmetric(tensor): | ||
| raise ValueError( | ||
| f"{name} must be on the symmetric heap. " | ||
| f"Allocate with ctx.zeros() or import with ctx.as_symmetric() before calling." |
There was a problem hiding this comment.
The message suggests ctx.as_symmetric() can be used as a pre-step, but it does not make it explicit that as_symmetric returns a new tensor that must be passed to the collective. Consider clarifying to avoid callers doing ctx.as_symmetric(output_tensor) without using the returned tensor (which would still fail validation or would lead to confusion).
| f"Allocate with ctx.zeros() or import with ctx.as_symmetric() before calling." | |
| f"Allocate it with ctx.zeros(), or assign the result of " | |
| f"ctx.as_symmetric({name}) to a new tensor and pass that returned " | |
| f"tensor to the collective." |
| from iris.ccl.utils import extract_group_info, _ensure_symmetric | ||
|
|
||
|
|
||
| def reduce_scatter(output_tensor, input_tensor, ctx, op=None, group=None, async_op=False, config=None): |
There was a problem hiding this comment.
The function docstring says input_tensor ... must be on symmetric heap, but the implementation now auto-imports via _ensure_symmetric. Update the docstring/comments to reflect the actual behavior (e.g., 'will be imported to the symmetric heap if needed') so callers understand non-symmetric inputs are accepted and will incur a copy.
| Args: | ||
| output_tensor: Shape (M, N) | ||
| input_tensor: Shape (M, N) | ||
| input_tensor: Shape (M, N) — must be on symmetric heap |
There was a problem hiding this comment.
The function docstring says input_tensor ... must be on symmetric heap, but the implementation now auto-imports via _ensure_symmetric. Update the docstring/comments to reflect the actual behavior (e.g., 'will be imported to the symmetric heap if needed') so callers understand non-symmetric inputs are accepted and will incur a copy.
| # Input is remote-read by all ranks — must be on symmetric heap | ||
| input_tensor = _ensure_symmetric(ctx, input_tensor, "input_tensor") |
There was a problem hiding this comment.
The function docstring says input_tensor ... must be on symmetric heap, but the implementation now auto-imports via _ensure_symmetric. Update the docstring/comments to reflect the actual behavior (e.g., 'will be imported to the symmetric heap if needed') so callers understand non-symmetric inputs are accepted and will incur a copy.
|
|
||
| heap_size = 2**30 | ||
| shmem = iris.iris(heap_size) | ||
| rank = shmem.get_rank() |
There was a problem hiding this comment.
These tests select the CUDA device using the global rank (cuda:{rank}), which will fail on multi-node runs (global rank can exceed the local GPU count on a node). Prefer using the local device mapping (e.g., LOCAL_RANK / torch.cuda.current_device() or a helper/fixture already used elsewhere in the test suite) to make the tests robust across single-node and multi-node distributed environments.
| # Input on symmetric heap (fine) | ||
| iris_input = shmem.zeros((M, N), dtype=torch.float32) | ||
| # Output on regular CUDA memory (NOT on symmetric heap) | ||
| bad_output = torch.zeros(world_size * M, N, dtype=torch.float32, device=f"cuda:{rank}") |
There was a problem hiding this comment.
These tests select the CUDA device using the global rank (cuda:{rank}), which will fail on multi-node runs (global rank can exceed the local GPU count on a node). Prefer using the local device mapping (e.g., LOCAL_RANK / torch.cuda.current_device() or a helper/fixture already used elsewhere in the test suite) to make the tests robust across single-node and multi-node distributed environments.
| M, N = 128, 64 | ||
|
|
||
| # Input on regular CUDA memory (NOT on symmetric heap — that's fine for all_gather) | ||
| external_input = torch.randn(M, N, dtype=torch.float32, device=f"cuda:{rank}") |
There was a problem hiding this comment.
These tests select the CUDA device using the global rank (cuda:{rank}), which will fail on multi-node runs (global rank can exceed the local GPU count on a node). Prefer using the local device mapping (e.g., LOCAL_RANK / torch.cuda.current_device() or a helper/fixture already used elsewhere in the test suite) to make the tests robust across single-node and multi-node distributed environments.
| iris_output = shmem.zeros((world_size * M, N), dtype=torch.float32) | ||
|
|
||
| # Reference via PyTorch | ||
| pytorch_output = torch.zeros(world_size * M, N, dtype=torch.float32, device=f"cuda:{rank}") |
There was a problem hiding this comment.
These tests select the CUDA device using the global rank (cuda:{rank}), which will fail on multi-node runs (global rank can exceed the local GPU count on a node). Prefer using the local device mapping (e.g., LOCAL_RANK / torch.cuda.current_device() or a helper/fixture already used elsewhere in the test suite) to make the tests robust across single-node and multi-node distributed environments.
| # Ensure/validate symmetric only for tensors that are remotely accessed per variant | ||
| if variant in ("one_shot", "two_shot"): | ||
| # Input is remote-read — auto-import if needed | ||
| input_tensor = _ensure_symmetric(ctx, input_tensor, "input_tensor") | ||
| if variant in ("atomic", "spinlock", "two_shot"): | ||
| # Output is remote-written — must be pre-allocated on heap | ||
| _validate_output_symmetric(ctx, output_tensor, "output_tensor") |
There was a problem hiding this comment.
New variant-dependent behavior was introduced (auto-import for one_shot/two_shot inputs; strict validation for atomic/spinlock/two_shot outputs). Consider adding focused tests for at least one variant in each category to prevent regressions and to verify the intended access-pattern policy.
Summary
_ensure_symmetric()and_validate_output_symmetric()helpers toiris/ccl/utils.pyiris.store), input unchecked (local read)one_shot/two_shotensure input,atomic/spinlock/two_shotvalidate output,ringchecks neither (workspace is internal)iris.load), output unchecked (local write)as_symmetric()(no-op if already on heap)Test plan
test_all_gather_rejects_non_symmetric_output— validatesValueErrorfor non-symmetric outputtest_all_gather_non_symmetric_input_ok— validates non-symmetric input works (local-only access)torchrun --nproc_per_node=4 tests/run_tests_distributed.py tests/ccl/ -v🤖 Generated with Claude Code