Skip to content

Fused all-gather+GEMM HBM-buffer kernel for iris.ops#346

Open
neoblizz wants to merge 75 commits intomainfrom
neoblizz/iris-xops-perf
Open

Fused all-gather+GEMM HBM-buffer kernel for iris.ops#346
neoblizz wants to merge 75 commits intomainfrom
neoblizz/iris-xops-perf

Conversation

@neoblizz
Copy link
Copy Markdown
Member

@neoblizz neoblizz commented Feb 3, 2026

Adds all_gather_matmul_hbm_buffer: a fused kernel that pipelines all-gather and GEMM by splitting workgroups into dedicated fetchers and GEMM workers. Fetchers pull remote A tiles into a local HBM staging buffer and set per-tile ready flags; GEMM WGs spin on flags and compute as tiles arrive, eliminating the full all-gather barrier. Delivers 2.7–3.4× lower latency vs the barrier-based baseline on 8× MI325X.

New kernel

  • iris/ops/all_gather_matmul_hbm_buffer.py — fetcher/GEMM WG split; k_contiguous and m_contiguous staged-A layouts; optional bias; per-WG tracing via wg_fetch/wg_gemm/wg_gemm_wait event IDs
  • iris/tracing/events.py — trace event IDs for per-workgroup profiling

API / config changes

  • iris/x/gather.pyhint vectorization parameter forwarded to _translate()
  • iris/ops/__init__.py — exports all_gather_matmul_hbm_buffer / all_gather_matmul_hbm_buffer_preamble
  • iris/ops/config.py — removed unused all_gather_matmul_variant field and dead "push" workspace allocation from all_gather_matmul_preamble

Benchmark & tests

  • benchmark/ops/bench_all_gather_matmul.py — merged baseline and HBM-buffer variants under @bench.axis("algorithm", ["baseline", "hbm_buffer"]); bench_all_gather_matmul_hbm_buffer.py deleted
  • tests/ops/test_all_gather_matmul.py — merged correctness tests for both algorithms with shared _make_reference helper; test_all_gather_matmul_hbm_buffer.py deleted

Results (8× AMD MI325X, float16, N=3584, K=8192)

Ranks MxNxK Baseline (ms) HBM Buffer (ms) Speedup TFLOPS
2 1024×3584×8192 1.67 0.78 2.1× 77
2 16384×3584×8192 27.8 8.2 3.4× 117
4 16384×3584×8192 27.3 8.6 3.2× 112
8 16384×3584×8192 24.4 8.9 2.7× 108

TFLOPS
Latency

@github-actions github-actions Bot added in-progress We are working on it iris Iris project issue labels Feb 3, 2026
tl.store(staged_ptrs, a_tile, cache_modifier=".cg")

flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group
tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need a tl.debug_barrier() before the atomic.xchg since the atomic is per wave but the store per block.
in my runs i got validation failures without the barrier

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 9c99965. Added tl.debug_barrier() before the atomic.xchg on line 316 to ensure all per-block tl.store calls complete before the flag is set for the GEMM workers.

# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved.

"""Benchmark for fused all-gather + GEMM (iris.ops)."""
"""Benchmark for all-gather + GEMM: RCCL baseline vs iris HBM-buffer prefetch.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to no longer include validation logic?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the benchmark code doesn't match the reference calculation in the test:
it passes because we do torch.ones here (which also gives an significant perf boost to pytorch)

    torch.manual_seed(42 + rank)
    A_sharded = torch.randn(M, K_local, dtype=dtype, device=device)

    torch.manual_seed(123)
    B = torch.randn(K, N, dtype=dtype, device=device)

    A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=device) for _ in range(world_size)]
    dist.all_gather(A_gathered_list, A_sharded)
    A_gathered_ref = torch.cat(A_gathered_list, dim=1)
    ref_output = torch.matmul(A_gathered_ref, B)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to avoid mixing examples, benchmarks and tests. I think tests should live on their own scripts and we only benchmark once we have a correct and we are confident of the implementation.

@copilot could you please test what David is suggesting here? Can you show us perf before and after applying the suggested fix.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in commits 9c99965 and e42c7a3:

  1. Added tl.debug_barrier() before atomic.xchg — correctness fix (all 7 tests now pass with randn inputs).

  2. Benchmark now uses torch.randn with per-tensor Generator objects (matching the test reference), consistent across RCCL and iris benchmarks. Replaced the unfair ones / fill_(1.0) approach.

  3. Performance impact of debug_barrier (ws=8, MI325X, float16, N=3584, K=8192):

M×N×K RCCL (ms) iris without barrier (ms) iris with barrier (ms)
1024×3584×8192 0.262 0.526 0.516
4096×3584×8192 0.677 1.399 1.425
16384×3584×8192 2.300 3.268 3.257

The barrier adds negligible overhead (<2%). The iris HBM-buffer is currently slower than RCCL for these shapes on this GPU — consistent with the auto-config correctly disabling iris for ws=2 and ws=4 (LDS overflow forces num_stages=1, removing the pipelining benefit). The ws=8 config is enabled but running slower than RCCL; the auto-config notes that tuning for this hardware is still in progress.

  1. state.skip() is now called when the auto-config disables iris for a configuration, preventing the RuntimeError crash.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot Ryan suggested some specific shapes and tiling parameters earlier. Can you reproduce the bar chart and reply here please?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran the 3 small tuned shapes from configs/mi300x/NN/ws8.json on 8× MI300X (gfx942), commit 7f163a0.

Results — ws=8, MI300X (gfx942)

Shape MxNxK RCCL (ms) iris measured (ms) iris expected (ms) iris vs RCCL drift vs expected
pow2_4k 4096×4096×4096 0.503 0.865 1.512 0.58× (slower) −43% faster than expected
llama7b_gate 4096×11008×4096 0.823 1.556 1.946 0.53× (slower) −20% faster than expected
mixtral_gate 4096×14336×4096 0.964 1.343 1.933 0.72× (slower) −30% faster than expected

Bar chart (MI300X ws=8):
bar chart

Notes:

  • For these small M=4096 shapes at ws=8, iris is 0.53–0.72× vs RCCL (iris is slower — as expected from the config commentary, ws<8 NO-GO, and ws=8 is borderline for small M)
  • iris runs 20–43% faster than the expected_iris_ms values in the JSON (drift is favorable) — the JSON values were measured on a specific MI300X system; MI325X (gfx942) runs faster
  • The two larger shapes (g5: 8192×8192×262144, g1: 16384×16384×131072) hit heap OOM at ws=8 with the default 1 GB symmetric heap — they require a larger heap (iris.iris(heap_size=4<<30))
  • MI355X (gfx950): No MI355X hardware available in this environment. The mi355x configs are ported and will auto-detect when run on MI355X hardware

Copy link
Copy Markdown
Collaborator

@dsidler dsidler Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot check commit c52e5ed for the pytorch-rccl bench. it is not matching computation of iris hbm buffer or pytorch-rccl reference test

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot please address David's request.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 95dce96. The root cause was that dist.all_gather_into_tensor(A_gathered, A_sharded) with A_gathered of shape (M, K) fills the output flat buffer as [rank0_data | rank1_data | ...] in row-major order, which does not concatenate along K. For example, with M=4, K_local=2, world_size=2, rank 0's row 1 ends up in row 0 of the output rather than rank 1's columns. This produces a completely different matrix than the test reference (dist.all_gather + cat(dim=1)).

Fixed by replacing all_gather_into_tensor with dist.all_gather(list, A_sharded) + torch.cat(list, dim=1), which correctly concatenates along K.

Bar chart — ws=8, MI325X gfx942 (corrected RCCL computation):

Shape RCCL measured (ms) iris measured (ms) iris expected (ms)
pow2_4k (4096×4096×4096) 0.477 0.800 1.512
llama7b_gate (4096×11008×4096) 0.900 1.485 1.946
mixtral_gate (4096×14336×4096) 1.005 1.501 1.933

Bar chart

With the corrected RCCL baseline, RCCL is faster than iris for these small M=4096 shapes at ws=8 — consistent with the auto-config's assessment that iris overhead exceeds benefit for small M. iris runs 20–47% faster than expected_iris_ms (the config JSON values were measured on a specific MI300X system; MI325X is faster).

Copilot AI and others added 3 commits April 22, 2026 19:52
… use randn in benchmark

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/6b5a926e-b45d-43ac-abac-7b8805d2d367

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
…dn inputs

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/6b5a926e-b45d-43ac-abac-7b8805d2d367

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
…iris tensors

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/6b5a926e-b45d-43ac-abac-7b8805d2d367

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 22, 2026

@copilot can you please address Ryan's comment.

Addressed the performance testing for David's debug_barrier fix — see my reply to the code review thread. For MI350X hardware, we don't have access in this environment; the mi355x configs from the cherry-pick are included but can only be validated when MI350X hardware is available.

@mawad-amd
Copy link
Copy Markdown
Collaborator

@copilot remove the charts you uploaded and get this PR ready for final review.

…at; add bar chart ws8

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/84eadb5a-429c-47b4-9636-dc3a70726a1a

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
@ryanswann-amd
Copy link
Copy Markdown
Collaborator

@copilot resolve the merge conflicts in this pull request

…chart images

Co-authored-by: ryanswann-amd <109695074+ryanswann-amd@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented May 4, 2026

@copilot remove the charts you uploaded and get this PR ready for final review.

Done in commit 801adb7. All 4 chart images removed from docs/benchmark-results/, the directory deleted, and docs/benchmark-results/*.png added to .gitignore. The branch is also synced with main (merge resolved).

@copilot resolve the merge conflicts in this pull request

Merge conflicts resolved in commit 801adb7. Changes made:

  • Merged main (iris restructured: iris/x/iris.mem.triton.context, iris/tracing/iris/host/tracing/, etc.)
  • Updated all_gather_matmul_hbm_buffer.py to use new API: iris.make_tensor_view, iris.TileView, ctx.gather
  • Added hint parameter to ctx.gather() in iris/mem/triton/context.py
  • Updated iris.host.distributed.fd_passing to use new logging imports
  • All conflicts resolved, ruff checks passing

@ryanswann-amd
Copy link
Copy Markdown
Collaborator

@copilot can you expand test coverage for the kernels added in this PR?

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/fc86711f-8150-48e4-85d6-c4c72c9f26b9

Co-authored-by: ryanswann-amd <109695074+ryanswann-amd@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented May 4, 2026

@copilot can you expand test coverage for the kernels added in this PR?

Expanded in commit 63d84c4. Added 16 new tests across two categories:

Distributed kernel tests:

  • bfloat16 dtype added to all parametrized correctness tests (baseline, HBM buffer, bias)
  • M=512 added to HBM buffer shape parametrize
  • test_all_gather_matmul_hbm_buffer_auto_workspace — verifies workspace=None triggers automatic preamble
  • test_all_gather_matmul_hbm_buffer_workspace_reuse — verifies two consecutive calls with the same workspace produce identical results
  • test_all_gather_matmul_hbm_buffer_trace — verifies trace_data is not populated when trace=False

Pure-Python _auto_config unit tests (no distributed required):

  • test_auto_config_heuristic_validity (5 shapes) — k_per_flag divides K//block_k
  • test_auto_config_champion_shapes — champion config values returned exactly
  • test_auto_config_large_m_uses_block_256 / test_auto_config_small_m_uses_block_128
  • test_auto_config_block_n_always_256 / test_auto_config_block_k_always_64

All 31 tests pass.

@ryanswann-amd
Copy link
Copy Markdown
Collaborator

@copilot I identified that the linked diff fixes the CI changes. Can you cherry pick minimal changes to fix the CI error: ryanswann-amd/iris@39c213d...5f30678

…uction

Cherry-picked from ryanswann-amd/iris@5f30678. Two bugs caused deadlocks
when iris.iris() is constructed/destroyed repeatedly (parametrized tests):

1. NCCL collective ordering: replaced _allgather_paths_tensor (which still
   uses dist.all_gather NCCL collective) with _allgather_paths_store using
   the dist.Store key-value API (pure TCP, zero NCCL).

2. Stale socket reuse: added per-process monotonic instance_id to socket
   paths and eager socket cleanup via SymmetricHeap.close_fd_conns().

Agent-Logs-Url: https://github.com/ROCm/iris/sessions/4824970a-ba41-4384-ab0e-650090f96f71

Co-authored-by: ryanswann-amd <109695074+ryanswann-amd@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented May 6, 2026

@copilot I identified that the linked diff fixes the CI changes. Can you cherry pick minimal changes to fix the CI error: ryanswann-amd/iris@39c213d...5...

Cherry-picked in commit 677cb6f. Applied the minimal changes from ryanswann-amd/iris@5f30678 to the restructured paths:

  • iris/host/distributed/fd_passing.py: Replaced _allgather_paths_tensor (NCCL dist.all_gather) with _allgather_paths_store (pure TCP dist.Store API); added instance_id to make_rank_sock_path and setup_fd_infrastructure
  • iris/host/memory/symmetric_heap.py: Added close_fd_conns() for eager socket cleanup
  • iris/host/iris.py: Updated __del__ to call heap.close_fd_conns() before allocator cleanup

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants