Skip to content

Fix hardcoded lock value in fused GEMM+CCL operations#529

Open
mawad-amd wants to merge 3 commits intomainfrom
muhaawad/fix-hardcoded-lock
Open

Fix hardcoded lock value in fused GEMM+CCL operations#529
mawad-amd wants to merge 3 commits intomainfrom
muhaawad/fix-hardcoded-lock

Conversation

@mawad-amd
Copy link
Copy Markdown
Collaborator

Summary

  • Replace hardcoded lock signal value 1 in fused GEMM+CCL producer-consumer signaling with a monotonically increasing call_counter on FusedWorkspace
  • Each call to matmul_all_reduce or matmul_reduce_scatter increments the counter and passes it to both producer (atomic_xchg) and consumer (spin loop) sides
  • Eliminates the need to zero locks between calls for one_shot, two_shot, and reduce_scatter variants
  • spinlock variant is unaffected (uses CAS mutex pattern that self-resets)
  • signal_value parameter defaults to 1 for backward compatibility with existing test kernels

Test plan

  • 25/25 fused ops tests pass on MI300X (8 GPUs)
  • 54/54 context CCL tests pass on MI300X (8 GPUs)
  • Custom workspace reuse test: 5 consecutive calls without lock zeroing, all correct (max_diff=0.125, call_counter=1→5)

Closes #465

🤖 Generated with Claude Code

mawad-amd and others added 2 commits May 2, 2026 08:14
Replace hardcoded lock signal value `1` with a monotonically increasing
call_counter on FusedWorkspace. Each call to matmul_all_reduce or
matmul_reduce_scatter increments the counter and passes it as the signal
value to both producer (atomic_xchg) and consumer (spin loop) sides.

This eliminates the need to zero locks between calls for one_shot,
two_shot, and reduce_scatter variants, since each call uses a unique
signal value that won't collide with previous calls.

The spinlock variant still uses CAS(0→1)/release(0) mutex semantics and
continues to require zeroed locks.

The signal_value parameter defaults to 1 for backward compatibility with
existing test kernels and examples that zero locks manually.

Closes #465

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings May 2, 2026 15:33
@mawad-amd mawad-amd requested review from BKP and neoblizz as code owners May 2, 2026 15:33
@github-actions github-actions Bot added in-progress We are working on it iris Iris project issue labels May 2, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR replaces the hardcoded lock “ready” value (1) used in fused GEMM+CCL producer/consumer signaling with a monotonically increasing per-workspace call_counter, so locks don’t need to be zeroed between calls for one_shot/two_shot/reduce_scatter.

Changes:

  • Add call_counter to FusedWorkspace and reset it on clear().
  • Pass a per-call signal_value through fused matmul+collective kernels and into Triton context ops.
  • Remove lock zeroing for reduce_scatter and for all_reduce one_shot/two_shot (keep it for spinlock).

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
iris/ops/workspace.py Adds call_counter state to track per-call signal values and resets it in clear().
iris/ops/matmul_reduce_scatter.py Uses signal_value for producer atomic_xchg and consumer wait; removes lock zeroing.
iris/ops/matmul_all_reduce.py Uses per-call signal_value for one_shot/two_shot and conditionally zeroes locks only for spinlock.
iris/mem/triton/context.py Extends Triton context collectives to wait on signal_value instead of hardcoded 1.

even_k = K % config.block_size_k == 0

# Increment call counter for producer-consumer signal value.
# Each call uses a unique value so consumers don't see stale signals.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 1044149 — call_counter now wraps at INT32_MAX (0x7FFFFFFF).

Comment on lines +260 to +262
# Increment call counter for producer-consumer signal value.
workspace.call_counter += 1
signal_value = workspace.call_counter
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workspaces are per-process objects — in distributed training, each rank runs its own Python process with its own workspace instance. The call_counter increments identically on all ranks because fused ops are collective (all ranks must call them together). Divergence would indicate a program bug (one rank skipping a collective call), which would deadlock regardless of the signal value.

Comment thread iris/ops/matmul_all_reduce.py Outdated
BLOCK_SIZE_K: tl.constexpr,
EVEN_K: tl.constexpr,
VARIANT: tl.constexpr,
SIGNAL_VALUE=1,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 1044149 — renamed to lowercase signal_value.

… INT32_MAX

- Rename kernel parameter from SIGNAL_VALUE (constexpr style) to signal_value
  (runtime parameter style) to avoid confusion with compile-time constants
- Wrap call_counter at INT32_MAX since lock tensors are int32

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

all_reduce_one_shot / all_reduce_two_shot use hardcoded lock value 1

2 participants