Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions kernels/preshuffle_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from flydsl.expr import arith, vector
from flydsl.expr import gpu
from flydsl.expr import buffer_ops, rocdl
from flydsl.expr import math


from flydsl.expr.typing import T
Expand Down Expand Up @@ -142,7 +143,8 @@ def compile_preshuffle_gemm_a8(
waves_per_eu: Optional[int] = None,
use_async_copy: bool = False,
dsrd_preload: int = -1,
dvmem_preload: int = -1
dvmem_preload: int = -1,
epilogue: str = "none", # "none", "bias", "bias_relu", "bias_silu", "bias_gelu"
):
"""Compile the preshuffle GEMM kernel using the @flyc.kernel API.

Expand Down Expand Up @@ -306,13 +308,20 @@ def _out_elem():
allocator_pong.ptr = lds_alloc_offset + lds_total_elems * elem_bytes

# ── Kernel function ────────────────────────────────────────────────────
_has_epilogue = epilogue != "none"
_has_bias = epilogue in ("bias", "bias_relu", "bias_silu", "bias_gelu")
_has_relu = epilogue == "bias_relu"
_has_silu = epilogue == "bias_silu"
_has_gelu = epilogue == "bias_gelu"

@flyc.kernel
def kernel_gemm(
arg_c: fx.Tensor,
arg_a: fx.Tensor,
arg_b: fx.Tensor,
arg_scale_a: fx.Tensor,
arg_scale_b: fx.Tensor,
arg_bias: fx.Tensor,
i32_m: fx.Int32,
i32_n: fx.Int32,
):
Expand Down Expand Up @@ -395,6 +404,13 @@ def kernel_gemm(
_needs_per_token_scale = not is_f16_or_bf16 and not is_fp4
scale_a_rsrc = None if (is_f16_or_bf16) else buffer_ops.create_buffer_resource(
arg_scale_a, max_size=False)

# ---- Bias buffer resource (for fused epilogue) ----
bias_rsrc = None
if _has_bias:
_bias_nrec = arith.index_cast(T.i64, c_n * 2) # N elements * 2 bytes (bf16/fp16)
bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=False,
num_records_bytes=_bias_nrec)
b_rsrc = buffer_ops.create_buffer_resource(arg_b, max_size=True)
scale_b_rsrc = None if (is_f16_or_bf16) else buffer_ops.create_buffer_resource(
arg_scale_b, max_size=True)
Expand Down Expand Up @@ -985,6 +1001,39 @@ def body_row(*, mi, ii, row_in_tile, row):
val_s = (val * s_a) * s_b_vals[ni]
else:
val_s = val

# ── Fused epilogue: bias + activation ──
if _has_bias and bias_rsrc is not None:
col_idx = col_base + (ni * 16)
bias_val_f16 = buffer_ops.buffer_load(
bias_rsrc, col_idx, vec_width=1,
dtype=_out_elem())
bias_val_f32 = arith.extf(T.f32, bias_val_f16)
val_s = val_s + bias_val_f32

if _has_relu:
zero_f32 = arith.constant(0.0, type=T.f32)
cmp = arith.cmpf("ogt", val_s, zero_f32)
val_s = arith.select(cmp, val_s, zero_f32)
elif _has_silu:
# SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
neg_one = arith.constant(-1.0, type=T.f32)
neg_val = val_s * neg_one
exp_neg = math.exp(neg_val)
one_f32 = arith.constant(1.0, type=T.f32)
denom = one_f32 + exp_neg
val_s = val_s / denom
elif _has_gelu:
# GeLU approx: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
half_f32 = arith.constant(0.5, type=T.f32)
coeff_f32 = arith.constant(0.044715, type=T.f32)
sqrt2pi_f32 = arith.constant(0.7978845608, type=T.f32)
x3 = val_s * val_s * val_s
inner = sqrt2pi_f32 * (val_s + coeff_f32 * x3)
tanh_inner = math.tanh(inner)
one_f32 = arith.constant(1.0, type=T.f32)
val_s = half_f32 * val_s * (one_f32 + tanh_inner)

val_f16 = arith.trunc_f(_out_elem(), val_s)
idx_out = idx_base + (ni * 16)
buffer_ops.buffer_store(val_f16, c_rsrc, idx_out)
Expand Down Expand Up @@ -1384,6 +1433,7 @@ def launch_gemm(
arg_b: fx.Tensor,
arg_scale_a: fx.Tensor,
arg_scale_b: fx.Tensor,
arg_bias: fx.Tensor,
i32_m: fx.Int32,
i32_n: fx.Int32,
stream: fx.Stream,
Expand All @@ -1398,7 +1448,7 @@ def launch_gemm(
gx = (i32_m + (tile_m - 1)) // tile_m
gy = i32_n // tile_n

launcher = kernel_gemm(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, i32_m, i32_n)
launcher = kernel_gemm(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, arg_bias, i32_m, i32_n)
if waves_per_eu is not None:
_wpe = int(waves_per_eu)
if _wpe >= 1:
Expand Down
112 changes: 112 additions & 0 deletions tests/kernels/test_preshuffle_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,16 @@ def _pack_shuffled_int8_to_packed_int4_no_perm(x_shuf_i8):
def _as_i8(t):
return t.view(torch.int8) if "float8" in str(t.dtype) else t

# Create a dummy bias tensor (unused when epilogue="none")
_dummy_bias = torch.empty(0, dtype=out_dtype, device=a_q.device)

def _gemm_args(c, a, b, sa, sb):
return (c.contiguous().view(-1),
_as_i8(a.contiguous().view(-1)),
_as_i8(b.contiguous().view(-1)),
sa.contiguous().view(-1) if sa.numel() > 0 else sa,
sb.contiguous().view(-1) if sb.numel() > 0 else sb,
_dummy_bias,
M, N, torch.cuda.current_stream())

compiled_fn = flyc.compile(launch_fn, *_gemm_args(c_out_raw, a_q, b_input, sa_flat, sb_flat))
Expand Down Expand Up @@ -362,12 +366,16 @@ def _to_bytes(t):
return t
return t.view(torch.uint8)

# Create a dummy bias tensor (unused when epilogue="none")
_dummy_bias_w4 = torch.empty(0, dtype=torch.bfloat16, device=a_q.device)

def _w4_args(c, a, b, sa, sb):
return (c.contiguous().view(-1),
_to_bytes(a).contiguous().view(-1),
_to_bytes(b).contiguous().view(-1),
_to_bytes(sa).contiguous().view(-1),
_to_bytes(sb).contiguous().view(-1),
_dummy_bias_w4,
M, N, torch.cuda.current_stream())

compiled_fn = flyc.compile(launch_fn, *_w4_args(c_out, a_q, b_shuffled, scale_a, scale_b_shuffled))
Expand Down Expand Up @@ -464,3 +472,107 @@ def launch_kernel(c, a, b, sa, sb):
)
except pytest.skip.Exception as e:
print(f"Skipped: {e}")


# ── CUDAGraph Capture Test ────────────────────────────────────────────────

@pytest.mark.parametrize("in_dtype", ["bf16", "fp8"])
def test_cudagraph_capture_preshuffle(in_dtype):
"""Verify FlyDSL preshuffle GEMM kernels are captured by CUDAGraph.

This test ensures that passing torch.cuda.current_stream() correctly
routes the kernel launch to the capture stream during graph recording.
Without proper stream handling, CUDAGraph replay produces all-zeros.
"""
device = "cuda:0"
M, N, K = 1, 8192, 8192
tile_m, tile_n, tile_k = 16, 64, 256

arch = str(get_rocm_arch())
if not arch.startswith("gfx94") and not arch.startswith("gfx95"):
pytest.skip(f"Unsupported arch: {arch}")

# Prepare data
a_raw = torch.randn(M, K, dtype=torch.bfloat16, device=device)
b_raw = torch.randn(N, K, dtype=torch.bfloat16, device=device)

if in_dtype == "fp8":
a_q, scale_a = pertoken_quant(a_raw, dtype=torch.float8_e4m3fnuz)
b_q, scale_b = pertoken_quant(b_raw, dtype=torch.float8_e4m3fnuz)
a_q = a_q.view(torch.int8)
b_input = shuffle_weight(b_q.view(torch.int8), layout=(16, 16)).contiguous().view(-1)
sa_flat = scale_a.contiguous().view(-1)
sb_flat = scale_b.contiguous().view(-1)
else:
a_q = a_raw
b_input = shuffle_weight(b_raw.contiguous(), layout=(16, 16)).contiguous().view(-1)
sa_flat = torch.empty(0, dtype=torch.float32, device=device)
sb_flat = torch.empty(0, dtype=torch.float32, device=device)

c_out = torch.empty(M, N, dtype=torch.bfloat16, device=device)
_dummy_bias = torch.empty(0, dtype=torch.bfloat16, device=device)

# Compile kernel
launch_fn = compile_preshuffle_gemm_a8(
M=M, N=N, K=K,
tile_m=tile_m, tile_n=tile_n, tile_k=tile_k,
in_dtype=in_dtype,
epilogue="none",
)

def _args(c, a, b, sa, sb):
return (c.contiguous().view(-1),
a.contiguous().view(-1) if "int" not in str(a.dtype) else a.contiguous().view(-1),
b,
sa.contiguous().view(-1) if sa.numel() > 0 else sa,
sb.contiguous().view(-1) if sb.numel() > 0 else sb,
_dummy_bias,
M, N, torch.cuda.current_stream())

compiled_fn = flyc.compile(launch_fn, *_args(c_out, a_q, b_input, sa_flat, sb_flat))

# Warmup
compiled_fn(*_args(c_out, a_q, b_input, sa_flat, sb_flat))
torch.cuda.synchronize()

# ── Regular execution (reference) ──
c_out.zero_()
compiled_fn(*_args(c_out, a_q, b_input, sa_flat, sb_flat))
torch.cuda.synchronize()
ref = c_out.clone()
assert ref.abs().max().item() > 0, "Regular execution produced all zeros"

# ── CUDAGraph capture ──
g = torch.cuda.CUDAGraph()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())

# Warmup on capture stream
with torch.cuda.stream(s):
compiled_fn(*_args(c_out, a_q, b_input, sa_flat, sb_flat))
torch.cuda.current_stream().wait_stream(s)
torch.cuda.synchronize()

# Record
c_out.zero_()
with torch.cuda.graph(g, stream=s):
compiled_fn(*_args(c_out, a_q, b_input, sa_flat, sb_flat))
torch.cuda.synchronize()

# ── Replay ──
c_out.zero_()
g.replay()
torch.cuda.synchronize()
graph_result = c_out.clone()

# ── Verify ──
max_diff = (ref - graph_result).abs().max().item()
assert graph_result.abs().max().item() > 0, (
f"CUDAGraph replay produced all zeros — kernel was NOT captured! "
f"ref max={ref.abs().max().item():.4f}"
)
assert torch.allclose(ref, graph_result, atol=1e-2), (
f"CUDAGraph result mismatch: max_diff={max_diff:.6f}, "
f"ref max={ref.abs().max().item():.4f}, graph max={graph_result.abs().max().item():.4f}"
)
print(f"✓ CUDAGraph capture verified ({in_dtype}): max_diff={max_diff:.6f}")
Loading