Fix hardcoded lock value in fused GEMM+CCL operations#529
Fix hardcoded lock value in fused GEMM+CCL operations#529
Conversation
Replace hardcoded lock signal value `1` with a monotonically increasing call_counter on FusedWorkspace. Each call to matmul_all_reduce or matmul_reduce_scatter increments the counter and passes it as the signal value to both producer (atomic_xchg) and consumer (spin loop) sides. This eliminates the need to zero locks between calls for one_shot, two_shot, and reduce_scatter variants, since each call uses a unique signal value that won't collide with previous calls. The spinlock variant still uses CAS(0→1)/release(0) mutex semantics and continues to require zeroed locks. The signal_value parameter defaults to 1 for backward compatibility with existing test kernels and examples that zero locks manually. Closes #465 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR replaces the hardcoded lock “ready” value (1) used in fused GEMM+CCL producer/consumer signaling with a monotonically increasing per-workspace call_counter, so locks don’t need to be zeroed between calls for one_shot/two_shot/reduce_scatter.
Changes:
- Add
call_countertoFusedWorkspaceand reset it onclear(). - Pass a per-call
signal_valuethrough fused matmul+collective kernels and into Triton context ops. - Remove lock zeroing for reduce_scatter and for all_reduce one_shot/two_shot (keep it for spinlock).
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| iris/ops/workspace.py | Adds call_counter state to track per-call signal values and resets it in clear(). |
| iris/ops/matmul_reduce_scatter.py | Uses signal_value for producer atomic_xchg and consumer wait; removes lock zeroing. |
| iris/ops/matmul_all_reduce.py | Uses per-call signal_value for one_shot/two_shot and conditionally zeroes locks only for spinlock. |
| iris/mem/triton/context.py | Extends Triton context collectives to wait on signal_value instead of hardcoded 1. |
| even_k = K % config.block_size_k == 0 | ||
|
|
||
| # Increment call counter for producer-consumer signal value. | ||
| # Each call uses a unique value so consumers don't see stale signals. |
There was a problem hiding this comment.
Fixed in 1044149 — call_counter now wraps at INT32_MAX (0x7FFFFFFF).
| # Increment call counter for producer-consumer signal value. | ||
| workspace.call_counter += 1 | ||
| signal_value = workspace.call_counter |
There was a problem hiding this comment.
Workspaces are per-process objects — in distributed training, each rank runs its own Python process with its own workspace instance. The call_counter increments identically on all ranks because fused ops are collective (all ranks must call them together). Divergence would indicate a program bug (one rank skipping a collective call), which would deadlock regardless of the signal value.
| BLOCK_SIZE_K: tl.constexpr, | ||
| EVEN_K: tl.constexpr, | ||
| VARIANT: tl.constexpr, | ||
| SIGNAL_VALUE=1, |
There was a problem hiding this comment.
Fixed in 1044149 — renamed to lowercase signal_value.
… INT32_MAX - Rename kernel parameter from SIGNAL_VALUE (constexpr style) to signal_value (runtime parameter style) to avoid confusion with compile-time constants - Wrap call_counter at INT32_MAX since lock tensors are int32 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
1in fused GEMM+CCL producer-consumer signaling with a monotonically increasingcall_counteronFusedWorkspacematmul_all_reduceormatmul_reduce_scatterincrements the counter and passes it to both producer (atomic_xchg) and consumer (spin loop) sidesone_shot,two_shot, andreduce_scattervariantsspinlockvariant is unaffected (uses CAS mutex pattern that self-resets)signal_valueparameter defaults to1for backward compatibility with existing test kernelsTest plan
Closes #465
🤖 Generated with Claude Code