Add fused epilogue support to preshuffle GEMM: bias + ReLU/SiLU/GeLU#404
Open
Add fused epilogue support to preshuffle GEMM: bias + ReLU/SiLU/GeLU#404
Conversation
Modified body_row to apply bias + activation in registers before the output store, eliminating separate epilogue kernel launches. MI300X results: 2.73x avg speedup vs hipBLAS+bias+SiLU - O-proj: 3.25x, MoE-dn: 3.33x, QKV: 2.92x - Zero epilogue overhead (fused ops hidden by store latency) New parameter: epilogue='none'|'bias'|'bias_relu'|'bias_silu'|'bias_gelu' New kernel arg: arg_bias (N-element bias tensor)
Collaborator
|
@andyluo7 CI failed. |
The fused epilogue commit added arg_bias to kernel_gemm and launch_gemm unconditionally (needed for epilogue='none' to maintain a single kernel definition). The test's _gemm_args and _w4_args functions need to pass a dummy bias tensor to match the updated launch_gemm signature. Without this fix, args shift: M goes to arg_bias slot, N to i32_m, stream to i32_n, causing 'missing a required argument: stream' error.
Author
|
Fixed in 1b04c35 — the test's |
Verifies that FlyDSL kernels are correctly captured by torch.cuda.CUDAGraph when torch.cuda.current_stream() is passed as the stream argument. Test flow: 1. Regular execution → reference result 2. CUDAGraph capture on a dedicated stream 3. Graph replay → verify result matches reference 4. Assert non-zero output (kernel was captured) Tests both BF16 and FP8 paths.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add fused epilogue support to the preshuffle GEMM kernel (
preshuffle_gemm.py), enabling bias addition and activation functions (ReLU, SiLU, GeLU) to be computed directly in the GEMM output store loop (body_row). This eliminates separate memory-bound epilogue kernels and their associated kernel launch overhead.Motivation
In LLM inference (e.g., vLLM, SGLang), GEMM operations are frequently followed by bias addition and activation functions. These post-GEMM operations are typically executed as separate kernels, adding:
By fusing these into the GEMM body_row, the output values are processed while still in registers — zero additional memory traffic.
Changes
kernels/preshuffle_gemm.py: Addedepilogueparameter tocompile_preshuffle_gemm_a8()supporting"none","bias","bias_relu","bias_silu","bias_gelu". Addedarg_biastensor argument to the kernel function and launcher. Bias is loaded viabuffer_ops.buffer_loadand activations are computed in-place before the output store.python/flydsl/_version.py: Bumped to dev version.Supported Epilogues
nonebiasC = GEMM(A,B) + biasbias_reluC = ReLU(GEMM(A,B) + bias)bias_siluC = SiLU(GEMM(A,B) + bias)bias_geluC = GeLU(GEMM(A,B) + bias)Performance
Benchmarked on MI300X (gfx942) and MI355X (gfx950) with real LLM model shapes:
MI300X — Fused SiLU vs hipBLAS + separate bias + SiLU
MI355X — Fused SiLU
End-to-End vLLM Serving (DeepSeek-R1-0528, MoE)
API
Backward Compatibility
epilogue="none"preserves existing behaviorarg_biasis always present in the kernel signature but unused whenepilogue="none"— callers can pass a dummy tensorTesting
F.linear + bias + act) on both gfx942 and gfx950