Skip to content

Implement MLA decode fwd nh128 a8w8 kernel with FlyDSL.#403

Open
ruanjm wants to merge 6 commits intomainfrom
jruan/mla_h128_a8w8
Open

Implement MLA decode fwd nh128 a8w8 kernel with FlyDSL.#403
ruanjm wants to merge 6 commits intomainfrom
jruan/mla_h128_a8w8

Conversation

@ruanjm
Copy link
Copy Markdown
Contributor

@ruanjm ruanjm commented Apr 15, 2026

ATT.

Copilot AI review requested due to automatic review settings April 15, 2026 09:38
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a FlyDSL implementation of an MLA decode forward kernel specialized for nhead=128 with FP8 Q/KV inputs and BF16 output, plus a Python launcher and a corresponding correctness/perf test driver.

Changes:

  • Introduce kn_mla_fwd_decode_m16x8_fp8_fp8 FlyDSL kernel and JIT launcher for the nh=128 FP8/FP8 decode path.
  • Add a thin Python dispatcher (flydsl_mla_fwd_decode) that flattens inputs/outputs and launches the specialized kernel.
  • Add a new MLA decode test/benchmark script with a PyTorch reference + aiter metadata/reduce integration.

Reviewed changes

Copilot reviewed 2 out of 3 changed files in this pull request and generated 8 comments.

File Description
tests/kernels/test_mla_decode.py New MLA decode reference + kernel launch driver (currently not pytest-integrated and has a couple runtime issues).
kernels/mla_fwd_decode_m16x8_fp8_fp8.py New FlyDSL kernel implementation + JIT launcher for FP8/FP8 nh=128 decode.
kernels/mla_fwd_decode.py Public Python launcher/dispatcher for the new specialized kernel.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread kernels/mla_fwd_decode_m16x8_fp8_fp8.py Outdated
Comment thread kernels/mla_fwd_decode_m16x8_fp8_fp8.py
Comment thread tests/kernels/test_mla_decode.py Outdated
Comment thread tests/kernels/test_mla_decode.py Outdated
Comment thread tests/kernels/test_mla_decode.py
Comment thread kernels/mla_fwd_decode.py
Comment thread kernels/mla_fwd_decode.py
Comment thread kernels/mla_fwd_decode_m16x8_fp8_fp8.py
@coderfeli
Copy link
Copy Markdown
Collaborator

Align code style, type, IR usage with other kernels. Try to remove or reduce arith/std_arith and similar native mlir low level uses.

@ruanjm
Copy link
Copy Markdown
Contributor Author

ruanjm commented Apr 16, 2026

Align code style, type, IR usage with other kernels. Try to remove or reduce arith/std_arith and similar native mlir low level uses.

Done

@ruanjm ruanjm force-pushed the jruan/mla_h128_a8w8 branch from d5b5369 to ab46f7f Compare April 16, 2026 09:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants