MoE prefill bf16 perf improvement for qwen-3.5-35B-A3B#18829
MoE prefill bf16 perf improvement for qwen-3.5-35B-A3B#18829digantdesai merged 5 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18829
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 40c6cc2 with merge base 87e65ac ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
a0d199a to
63548f5
Compare
63548f5 to
5055971
Compare
|
@claude review this PR. |
|
Claude finished @digantdesai's task in 2m 33s —— View job PR Review: MoE prefill bf16 perf improvement for qwen-3.5-35B-A3B
This PR adds a batched tensor-core MoE kernel for prefill ( Overall ArchitectureThe approach is sound: use the existing
|
There was a problem hiding this comment.
Pull request overview
This PR improves Qwen3.5 MoE prefill throughput by adding a token-sorted, tensor-core (tl.dot) batched MoE Triton kernel and wiring export to use it for the prefill method while keeping the existing vec-mat kernel for decode. It also adds a CUDA AOTI sort_stable shim (Thrust-based) to support the argsort/sort needed by token sorting during export/lowering.
Changes:
- Add a new Triton MoE variant (
fused_moe_batched_gemm) with token sorting (moe_align_block_size) for compute-bound prefill. - Toggle decode vs prefill MoE kernel selection during Qwen3.5 export (
use_batched_moeflag + export-side toggling). - Add a CUDA runtime shim for
aten::sort.stablefallback and new CUDA tests + a benchmark script.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/models/qwen3_5_moe/model.py | Adds use_batched_moe switch to route MoE to the new batched kernel. |
| examples/models/qwen3_5_moe/export.py | Exports decode with old kernel and prefill with batched kernel; updates prefill example shapes. |
| backends/cuda/triton/kernels/fused_moe.py | Implements token sorting helper and new batched tensor-core MoE kernels/wrappers; tweaks scale-load logic. |
| backends/cuda/triton/kernels/init.py | Exposes new MoE helpers/ops in kernel package exports. |
| backends/cuda/tests/test_sort_shim.py | Adds export-level tests ensuring sort/argsort can lower with the shim. |
| backends/cuda/tests/test_fused_moe.py | Adds correctness tests for batched MoE and for moe_align_block_size. |
| backends/cuda/runtime/shims/sort.h | Declares the AOTI CUDA sort stable shim API. |
| backends/cuda/runtime/shims/sort.cu | Implements Thrust-based stable sort shim (plus optional transpose-permute path). |
| backends/cuda/runtime/TARGETS | Adds sort shim sources/headers to the runtime build. |
| backends/cuda/cuda_backend.py | Registers at::_ops::sort_stable::call as a supported CUDA fallback kernel. |
| backends/cuda/benchmarks/benchmark_moe.py | Adds a benchmark harness for fused MoE kernels. |
| backends/cuda/CMakeLists.txt | Ensures the sort shim is built into aoti_cuda_shims when CUDA is available. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| #include <cuda.h> | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
| #include <cuda_runtime.h> | ||
| #include <thrust/device_ptr.h> | ||
| #include <thrust/execution_policy.h> | ||
| #include <thrust/sort.h> | ||
|
|
||
| #include <executorch/backends/aoti/utils.h> | ||
| #include <executorch/backends/cuda/runtime/shims/memory.h> | ||
| #include <executorch/backends/cuda/runtime/shims/sort.h> | ||
| #include <executorch/backends/aoti/slim/cuda/guard.h> | ||
| #include <executorch/runtime/platform/log.h> | ||
|
|
||
| namespace executorch::backends::cuda { |
There was a problem hiding this comment.
sort.cu uses std::vector later (e.g., contig_strides) but doesn’t include , which will fail to compile on toolchains that don’t indirectly include it. Add the missing standard header include explicitly near the top of the file.
| # Load per-group scales and dequantize | ||
| if BLOCK_SIZE_K <= group_size: | ||
| # All K values in this tile share one scale group — load [1, N] | ||
| group_idx = (BLOCK_SIZE_K * k_step) // group_size | ||
| scale_ptrs = ( | ||
| B_scale | ||
| + expert_id * stride_bse | ||
| + offs_n[None, :] * stride_bsn | ||
| + group_idx * stride_bsk | ||
| ) | ||
| b_scale = tl.load(scale_ptrs, mask=n_mask[None, :], other=0.0).to( | ||
| tl.float32 | ||
| ) | ||
| else: |
There was a problem hiding this comment.
The fast-path if BLOCK_SIZE_K <= group_size assumes the entire K tile falls within a single quantization group, but that’s only guaranteed when the tile does not straddle a group_size boundary (e.g., when group_size % BLOCK_SIZE_K == 0, or when (BLOCK_SIZE_K*k_step) % group_size + BLOCK_SIZE_K <= group_size). As written, some valid (K, group_size, BLOCK_SIZE_K) combos will load the wrong scale for part of the tile and silently produce incorrect results. Please tighten the condition or fall back to the per-element scale indexing in those cases.
| if BLOCK_SIZE_K <= group_size: | ||
| group_idx = (BLOCK_SIZE_K * k_step) // group_size | ||
| scale_ptrs = ( | ||
| B_scale | ||
| + expert_id * stride_bse | ||
| + offs_n[None, :] * stride_bsn | ||
| + group_idx * stride_bsk | ||
| ) | ||
| b_scale = tl.load(scale_ptrs, mask=n_mask[None, :], other=0.0).to( | ||
| tl.float32 | ||
| ) | ||
| else: |
There was a problem hiding this comment.
Same issue as above: the if BLOCK_SIZE_K <= group_size scale-loading fast-path is only correct if the K tile cannot cross a quantization-group boundary. For non-multiple group sizes (or misaligned tiles), this can load an incorrect scale for some K lanes and produce wrong outputs. Please either enforce the alignment constraint (with a clear error) or use the general per-lane scale indexing path.
| # To avoid argsort (needs sort_stable fallback in AOTI), we use a | ||
| # scatter-based approach: | ||
| # For each pair i, within_expert_rank[i] = sum_{j<i} (flat_ids[j]==e) | ||
| # This is a segmented exclusive prefix sum, computed as: | ||
| # within_expert_rank[i] = cumcount_of_expert[flat_ids[i]] before i | ||
| # We iterate in pair order, incrementing per-expert counters. | ||
| # Since this is a sequential scan, we implement it as a cumsum trick: | ||
| # For each expert e, the pairs assigned to it appear at positions | ||
| # where flat_ids == e. Their within-expert ranks are 0, 1, 2, ... | ||
| # in the order they appear. | ||
| # | ||
| # Vectorized: use argsort on (flat_ids * num_pairs + arange) which | ||
| # gives stable expert-grouped ordering — same result as argsort(flat_ids). | ||
| # The multiplication ensures expert grouping; adding arange breaks ties | ||
| # by original order (equivalent to stable sort). | ||
| sort_keys = flat_ids * num_pairs + torch.arange( | ||
| num_pairs, device=device, dtype=torch.int64 | ||
| ) | ||
| sorted_order = sort_keys.argsort() # no `stable` kwarg needed — keys are unique | ||
|
|
There was a problem hiding this comment.
The comment says this avoids argsort due to sort_stable/AOTI constraints, but the implementation still calls argsort() (on sort_keys). If the intent is “stable sort not required because keys are unique”, please rewrite the comment to match the code; otherwise this is misleading for future maintenance/debugging (especially now that a sort shim is being introduced).
| topk_ids: torch.Tensor, | ||
| block_size: int, | ||
| num_experts: int, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, int]: |
There was a problem hiding this comment.
The return type annotation says the third return value is an int, but moe_align_block_size() actually returns num_tokens_post_padded as a scalar int64 tensor (expert_offsets[num_experts]). Please update the type hint to torch.Tensor (or convert to Python int and adjust the implementation accordingly) to avoid confusing callers and static analysis.
| ) -> tuple[torch.Tensor, torch.Tensor, int]: | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
This seems like a valid comment, any reason why? How can you be sure about the correctness of the new batched version? |
mergennachin
left a comment
There was a problem hiding this comment.
Please see comments
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| Benchmark the Triton fused MoE kernel against eager and torch.compile baselines. |
There was a problem hiding this comment.
do we need this file to be checked in? i'd rather not check in, if it's not being exercised by CI.
There was a problem hiding this comment.
yeah this can be run by CI or removed.
|
Also it would be good to consistently reproduce this number in CI. We can change the prompt to be 512 so that it is comparable to other frameworks. |
Inductor emits aten::sort.stable for ops like argsort, but lacks a native c-shim for it. This adds a thrust-based implementation (aoti_torch_cuda_sort_stable) that handles int64, int32, and float32 dtypes on contiguous innermost-dim tensors. Registered as a supported fallback kernel in CudaBackend so AOTI-compiled models can use sort. This PR was authored with the assistance of Claude.
Sweeps prompt lengths [1..4095] with Qwen3.5-35B-A3B shapes (256 experts, top-8, INT4 W4A16). Validates correctness against loop-based eager reference at small M, benchmarks vectorized eager, torch.compile, and Triton fused_moe. Handles OOM gracefully at large M where eager/compile dequantize all experts. This PR was authored with the assistance of Claude.
When the Triton tile size fits within a single quantization group, load one scale per N-element instead of per (K, N) element. Reduces scale memory traffic in both GEMM1 and GEMM2 vec-mat kernels. This PR was authored with the assistance of Claude.
Adds a batched (M>1) Triton fused MoE kernel using tensor-core mma instructions for prefill workloads. Includes moe_align_block_size for token-expert sorting and scale broadcast optimization in the batched GEMM inner loops. Weight layout: [E, N, K//2] (packed INT4). This PR was authored with the assistance of Claude.
Add use_batched_moe flag on FusedMoEExperts, toggled by _set_batched_moe in export.py before each method's torch.export call. Decode (T=1) uses the vec-mat fused_moe kernel; prefill (T>=2) uses fused_moe_batched_gemm. This PR was authored with the assistance of Claude.
5055971 to
40c6cc2
Compare
Songhao has a PR for updating the CI setup. |
Just an oversight. I did update the benchmark script but not this - let me update. |
|
Fixes in #18934 |
Uh oh!
There was an error while loading. Please reload this page.