Skip to content

Add CUDA graph capture probe for iris collectives#527

Draft
mawad-amd wants to merge 2 commits intomainfrom
muhaawad/graph-capture-probe
Draft

Add CUDA graph capture probe for iris collectives#527
mawad-amd wants to merge 2 commits intomainfrom
muhaawad/graph-capture-probe

Conversation

@mawad-amd
Copy link
Copy Markdown
Collaborator

Summary

  • Adds tests/graph_capture_probe.py — a probe script that tests which iris operations can be captured in a CUDA graph
  • Uses hipStreamBeginCapture detection (authoritative from HIP runtime) plus fresh-data replay validation to catch stale results
  • Run with: torchrun --nproc_per_node=2 --standalone tests/graph_capture_probe.py

Results (MI355X, 2 GPUs)

Operation Status Blocker
device_barrier CAPTURABLE
host_barrier NOT CAPTURABLE Uses NCCL
ccl.all_reduce(atomic) NOT CAPTURABLE refresh_peer_access CPU↔CUDA copy
ccl.all_reduce(two_shot) NOT CAPTURABLE same
ccl.all_reduce(one_shot) NOT CAPTURABLE same
ccl.all_gather NOT CAPTURABLE same
ccl.all_to_all NOT CAPTURABLE same
ccl.reduce_scatter NOT CAPTURABLE same
ops.matmul_all_reduce NOT CAPTURABLE same

Root cause

SymmetricHeap.allocate() calls refresh_peer_access() every time, which does:

self.heap_bases[self.cur_rank] = int(all_bases_arr[self.cur_rank])

This is a CPU↔CUDA tensor copy, which is illegal during hipStreamBeginCapture. It gets triggered when any ctx.zeros() allocation happens inside the CCL launch path (workspace creation in preamble).

Fix direction

To make CCL ops graph-capturable, we need to:

  1. Ensure all workspace allocation happens before graph capture (pre-allocate via preamble)
  2. Make get_heap_bases() return a pre-built CUDA tensor without any CPU interaction during the kernel launch path
  3. Add async_op=True to skip the trailing ctx.barrier() (already supported)

🤖 Generated with Claude Code

Probe script that tests which iris operations can be captured in a
CUDA graph. Uses hipStreamBeginCapture detection (authoritative from
HIP runtime) plus fresh-data replay validation to catch stale results.

Results on MI355X (2 GPUs):
- device_barrier: CAPTURABLE
- host_barrier: NOT CAPTURABLE (NCCL)
- All CCL ops (all_reduce, all_gather, all_to_all, reduce_scatter): NOT
  CAPTURABLE — refresh_peer_access does CPU-CUDA tensor copy during
  capture
- ops.matmul_all_reduce: NOT CAPTURABLE (same root cause)

Root cause: SymmetricHeap.allocate() calls refresh_peer_access() which
does self.heap_bases[rank] = int(all_bases_arr[rank]), a CPU-CUDA copy
illegal during graph capture.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings April 30, 2026 07:57
@mawad-amd mawad-amd requested review from BKP and neoblizz as code owners April 30, 2026 07:57
@mawad-amd mawad-amd marked this pull request as draft April 30, 2026 07:57
@github-actions github-actions Bot added in-progress We are working on it iris Iris project issue labels Apr 30, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds a standalone probe script to detect which iris collective/ops are CUDA-graph capturable by attempting torch.cuda.graph() capture and validating correctness on replay.

Changes:

  • Introduces a try_capture harness that warms up, captures, replays, and validates operations to catch stale-result “false positives”.
  • Adds probe cases for iris barriers, CCL collectives (all_reduce variants, all_gather, all_to_all, reduce_scatter), and ops.matmul_all_reduce.
  • Prints a rank-0 summary table of capturability outcomes and truncates error details for readability.

Comment on lines +413 to +420
def replay_setup():
# Keep same A, B (matmul is deterministic for same inputs)
A.copy_(A_ref)
B.copy_(B_ref)

def validate():
# Check output is non-zero and finite
return C.abs().max().item() > 0 and torch.isfinite(C).all().item()
# ---------------------------------------------------------------------------


def try_capture(name, warmup_fn, capture_fn, reset_fn, replay_setup_fn, validate_fn, ctx, rank):
Comment on lines +148 to +155
result_buf = ctx.zeros((64,), dtype=torch.float32)

def warmup():
buf.fill_(float(rank + 1))
ctx.device_barrier()
# Read neighbor to prove barrier works
neighbor = (rank + 1) % world_size
heap_bases = ctx.get_heap_bases()
Comment on lines +34 to +39
def setup():
"""Initialize distributed + iris."""
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="gloo")
ctx = iris.iris(heap_size=1 << 30)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants