Skip to content

Add fused epilogue support to preshuffle GEMM: bias + ReLU/SiLU/GeLU#404

Open
andyluo7 wants to merge 3 commits intoROCm:mainfrom
andyluo7:fused-epilogue
Open

Add fused epilogue support to preshuffle GEMM: bias + ReLU/SiLU/GeLU#404
andyluo7 wants to merge 3 commits intoROCm:mainfrom
andyluo7:fused-epilogue

Conversation

@andyluo7
Copy link
Copy Markdown

Summary

Add fused epilogue support to the preshuffle GEMM kernel (preshuffle_gemm.py), enabling bias addition and activation functions (ReLU, SiLU, GeLU) to be computed directly in the GEMM output store loop (body_row). This eliminates separate memory-bound epilogue kernels and their associated kernel launch overhead.

Motivation

In LLM inference (e.g., vLLM, SGLang), GEMM operations are frequently followed by bias addition and activation functions. These post-GEMM operations are typically executed as separate kernels, adding:

  • Kernel launch overhead (~3-5μs per launch)
  • Extra memory traffic (read C, write C for each epilogue op)

By fusing these into the GEMM body_row, the output values are processed while still in registers — zero additional memory traffic.

Changes

  • kernels/preshuffle_gemm.py: Added epilogue parameter to compile_preshuffle_gemm_a8() supporting "none", "bias", "bias_relu", "bias_silu", "bias_gelu". Added arg_bias tensor argument to the kernel function and launcher. Bias is loaded via buffer_ops.buffer_load and activations are computed in-place before the output store.

  • python/flydsl/_version.py: Bumped to dev version.

Supported Epilogues

Epilogue Operation Use Case
none No change (default) Standard GEMM
bias C = GEMM(A,B) + bias Linear layers with bias
bias_relu C = ReLU(GEMM(A,B) + bias) MLP layers
bias_silu C = SiLU(GEMM(A,B) + bias) FFN gate (LLaMA, DeepSeek)
bias_gelu C = GeLU(GEMM(A,B) + bias) GPT-style FFN

Performance

Benchmarked on MI300X (gfx942) and MI355X (gfx950) with real LLM model shapes:

MI300X — Fused SiLU vs hipBLAS + separate bias + SiLU

Shape hipBLAS+epilogue FlyDSL fused Speedup
MoE expert down (1×7168×2048) 0.050ms 0.015ms 3.33x
Average across 6 shapes 2.73x

MI355X — Fused SiLU

Shape hipBLAS+epilogue FlyDSL fused Speedup
Average across shapes 1.28x
Peak (MoE expert down) 1.80x

End-to-End vLLM Serving (DeepSeek-R1-0528, MoE)

Platform TP Baseline + FlyDSL Improvement
MI300X 8 64.82 tok/s 76.32 tok/s +17.7%

API

from flydsl.kernels.preshuffle_gemm import compile_preshuffle_gemm_a8

# Standard GEMM (no change to existing usage)
kernel = compile_preshuffle_gemm_a8(M, N, K, ...)

# Fused GEMM + bias + SiLU
kernel = compile_preshuffle_gemm_a8(M, N, K, ..., epilogue="bias_silu")

# Caller passes arg_bias tensor (shape [N], bf16/fp16) to the kernel
kernel(C, A, B, scale_a, scale_b, bias, M, N, stream=stream)

Backward Compatibility

  • Default epilogue="none" preserves existing behavior
  • arg_bias is always present in the kernel signature but unused when epilogue="none" — callers can pass a dummy tensor
  • No changes to other kernels or the compiler infrastructure

Testing

  • Correctness validated against PyTorch reference (F.linear + bias + act) on both gfx942 and gfx950
  • E2E validated through vLLM serving with DeepSeek-R1-0528 (671B MoE)

Modified body_row to apply bias + activation in registers before
the output store, eliminating separate epilogue kernel launches.

MI300X results: 2.73x avg speedup vs hipBLAS+bias+SiLU
- O-proj: 3.25x, MoE-dn: 3.33x, QKV: 2.92x
- Zero epilogue overhead (fused ops hidden by store latency)

New parameter: epilogue='none'|'bias'|'bias_relu'|'bias_silu'|'bias_gelu'
New kernel arg: arg_bias (N-element bias tensor)
@andyluo7 andyluo7 marked this pull request as ready for review April 15, 2026 17:47
@coderfeli
Copy link
Copy Markdown
Collaborator

@andyluo7 CI failed.

The fused epilogue commit added arg_bias to kernel_gemm and launch_gemm
unconditionally (needed for epilogue='none' to maintain a single kernel
definition). The test's _gemm_args and _w4_args functions need to pass
a dummy bias tensor to match the updated launch_gemm signature.

Without this fix, args shift: M goes to arg_bias slot, N to i32_m,
stream to i32_n, causing 'missing a required argument: stream' error.
@andyluo7
Copy link
Copy Markdown
Author

Fixed in 1b04c35 — the test's _gemm_args() and _w4_args() were missing the arg_bias parameter that was added to launch_gemm in the fused epilogue commit. Added dummy bias tensors to both. CI should pass now.

Verifies that FlyDSL kernels are correctly captured by torch.cuda.CUDAGraph
when torch.cuda.current_stream() is passed as the stream argument.

Test flow:
1. Regular execution → reference result
2. CUDAGraph capture on a dedicated stream
3. Graph replay → verify result matches reference
4. Assert non-zero output (kernel was captured)

Tests both BF16 and FP8 paths.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants