From 8226f5033e418e7f85981d34f107ec94551e2c6f Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 15 Apr 2026 20:39:22 -0700 Subject: [PATCH 1/2] [aoti-cuda] remove op/kernel level benchmarking scripts --- backends/cuda/benchmarks/benchmark_moe.py | 423 --------------------- backends/cuda/benchmarks/benchmark_sdpa.py | 308 --------------- 2 files changed, 731 deletions(-) delete mode 100644 backends/cuda/benchmarks/benchmark_moe.py delete mode 100644 backends/cuda/benchmarks/benchmark_sdpa.py diff --git a/backends/cuda/benchmarks/benchmark_moe.py b/backends/cuda/benchmarks/benchmark_moe.py deleted file mode 100644 index 79484df0174..00000000000 --- a/backends/cuda/benchmarks/benchmark_moe.py +++ /dev/null @@ -1,423 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Benchmark the Triton fused MoE kernel against eager and torch.compile baselines. - -Measures latency across prompt lengths matching the Qwen3.5-35B-A3B model -(hidden_size=2048, num_experts=256, top_k=8, intermediate_size=512, -INT4 weight-only quantization with group_size=128). - -Usage: - python benchmark_moe.py - python benchmark_moe.py --prompt-lengths 1,8,64,512 --num_iters 200 -""" - -import argparse -from functools import partial - -import executorch.backends.cuda.triton.kernels # noqa: F401 — registers triton ops - -import torch -from triton.testing import do_bench - - -# -- Qwen3.5-35B-A3B defaults ------------------------------------------------ - -DEFAULTS = { - "num_experts": 256, - "top_k": 8, - "hidden_size": 2048, - "intermediate_size": 512, - "group_size": 128, -} - -PROMPT_LENGTHS = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4095] - - -# -- Weight / input generation ----------------------------------------------- - - -def _make_int4_weights(E, N, K, group_size, device="cuda"): - """Generate random packed INT4 weights and per-group scales. - - Returns: - w: [E, N, K//2] int8 — two INT4 values packed per byte - scale: [E, N, K//group_size] bf16 - """ - vals = torch.randint(0, 16, (E, N, K), dtype=torch.uint8, device=device) - low = vals[:, :, 0::2] - high = vals[:, :, 1::2] - packed = (high << 4) | low - w = packed.to(torch.int8) - - scale = ( - torch.randn(E, N, K // group_size, device=device, dtype=torch.bfloat16) * 0.01 - ) - return w, scale - - -# -- Dequantization ---------------------------------------------------------- - - -def _dequant_int4(w_packed, scale, group_size): - """Unpack INT4 weights and dequantize. - - w_packed: [E, N, K//2] int8 - scale: [E, N, K//group_size] bf16 - Returns: [E, N, K] bf16 - """ - w_uint8 = w_packed.to(torch.uint8) - low = (w_uint8 & 0xF).to(torch.float32) - high = ((w_uint8 >> 4) & 0xF).to(torch.float32) - E, N, Khalf = w_packed.shape - K = Khalf * 2 - vals = torch.empty(E, N, K, device=w_packed.device, dtype=torch.float32) - vals[:, :, 0::2] = low - vals[:, :, 1::2] = high - vals = vals - 8.0 - scale_expanded = scale.float().repeat_interleave(group_size, dim=2)[:, :, :K] - return (vals * scale_expanded).to(torch.bfloat16) - - -# -- Backends ----------------------------------------------------------------- - - -def _run_eager( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, -): - """Loop-based eager MoE — correctness reference only (not benchmarked).""" - M, K = hidden_states.shape - inter = w2.shape[2] * 2 - - w1_deq = _dequant_int4(w1, w1_scale, group_size) - w2_deq = _dequant_int4(w2, w2_scale, group_size) - - output = torch.zeros(M, K, device=hidden_states.device, dtype=torch.bfloat16) - for i in range(M): - for j in range(top_k): - expert_id = topk_ids[i, j].item() - weight = topk_weights[i, j] - x = hidden_states[i : i + 1] @ w1_deq[expert_id].T - gate = x[:, :inter] - up = x[:, inter:] - x = torch.nn.functional.silu(gate) * up - x = x @ w2_deq[expert_id].T - output[i] += weight * x.squeeze(0) - return output - - -def _run_eager_vectorized( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, -): - """Vectorized eager — gather + bmm, no Python loops.""" - M, K = hidden_states.shape - inter = w2.shape[2] * 2 - - w1_deq = _dequant_int4(w1, w1_scale, group_size) - w2_deq = _dequant_int4(w2, w2_scale, group_size) - - flat_ids = topk_ids.reshape(-1) - hs_rep = hidden_states.unsqueeze(1).expand(-1, top_k, -1).reshape(M * top_k, K) - gemm1_out = torch.bmm( - hs_rep.unsqueeze(1), w1_deq[flat_ids].transpose(1, 2) - ).squeeze(1) - - gate = gemm1_out[:, :inter] - up = gemm1_out[:, inter:] - act = torch.nn.functional.silu(gate) * up - - gemm2_out = torch.bmm(act.unsqueeze(1), w2_deq[flat_ids].transpose(1, 2)).squeeze(1) - - return (gemm2_out.view(M, top_k, K) * topk_weights.unsqueeze(-1)).sum(dim=1) - - -_compiled_fn = None - - -def _run_compiled( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, -): - global _compiled_fn - if _compiled_fn is None: - _compiled_fn = torch.compile(_run_eager_vectorized) - return _compiled_fn( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, - ) - - -def _run_triton( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, -): - return torch.ops.triton.fused_moe( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k=top_k, - num_experts=num_experts, - group_size=group_size, - ) - - -BACKENDS = { - "eager_vec": ("Eager (vec)", _run_eager_vectorized), - "compile": ("Compile", _run_compiled), - "triton": ("Triton fused", _run_triton), -} - -try: - from executorch.backends.cuda.triton.kernels.fused_moe import fused_moe_batched - - def _run_triton_batched( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k, - num_experts, - group_size, - ): - return fused_moe_batched( - hidden_states, - w1, - w1_scale, - w2, - w2_scale, - topk_weights, - topk_ids, - top_k=top_k, - num_experts=num_experts, - group_size=group_size, - ) - - BACKENDS["triton_batched"] = ("Triton batched", _run_triton_batched) -except ImportError: - pass - - -# -- Helpers ------------------------------------------------------------------ - - -def _max_abs_error(out, ref): - return (out.float() - ref.float()).abs().max().item() - - -def _bench_ms(fn, num_warmup, num_iters): - return do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median") - - -def _try_bench(run_fn, args, num_warmup, num_iters): - fn = partial(run_fn, **args) - try: - fn() - return _bench_ms(fn, num_warmup, num_iters) - except torch.OutOfMemoryError: - torch.cuda.empty_cache() - return None - - -# -- Main --------------------------------------------------------------------- - - -@torch.inference_mode() -def run_benchmark( - prompt_lengths, - num_experts, - top_k, - hidden_size, - intermediate_size, - group_size, - num_warmup, - num_iters, -): - backends = [(name, *BACKENDS[name]) for name in BACKENDS] - - device_name = torch.cuda.get_device_name() - print() - print("=" * 100) - print("Fused MoE Benchmark — Qwen3.5-35B-A3B (W4A16)") - print(f" Device: {device_name}") - print( - f" Experts: {num_experts}, Top-K: {top_k}, Hidden: {hidden_size}, " - f"Intermediate: {intermediate_size}, Group: {group_size}" - ) - print(f" Warmup: {num_warmup}, Iters: {num_iters}") - print(f" Backends: {', '.join(label for _, label, _ in backends)}") - print("=" * 100) - - # Generate weights once (shared across prompt lengths) - w1, w1_scale = _make_int4_weights( - num_experts, 2 * intermediate_size, hidden_size, group_size - ) - w2, w2_scale = _make_int4_weights( - num_experts, hidden_size, intermediate_size, group_size - ) - - # Column layout: Shape | backend1 | backend2 | ... (dynamic widths) - col_specs = [("M (tokens)", "", 10)] - for _, label, _ in backends: - col_specs.append((label, "(ms)", max(8, len(label)))) - - col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] - - header = " | ".join( - f"{h:<{w}}" if i == 0 else f"{h:>{w}}" - for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths)) - ) - units = " | ".join( - f"{'':>{w}}" if i == 0 else f"{u:>{w}}" - for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths)) - ) - print(header) - print(units) - print("-" * len(header)) - - for M in prompt_lengths: - hidden_states = torch.randn(M, hidden_size, device="cuda", dtype=torch.bfloat16) - router_logits = torch.randn(M, num_experts, device="cuda", dtype=torch.float32) - topk_w, topk_i = torch.topk(router_logits, top_k, dim=-1) - topk_w = torch.softmax(topk_w, dim=-1) - topk_i = topk_i.to(torch.int64) - - common_args = { - "hidden_states": hidden_states, - "w1": w1, - "w1_scale": w1_scale, - "w2": w2, - "w2_scale": w2_scale, - "topk_weights": topk_w, - "topk_ids": topk_i, - "top_k": top_k, - "num_experts": num_experts, - "group_size": group_size, - } - - # Correctness: triton vs loop-based eager reference. - # Only check at small M to avoid slow eager loop + OOM on large M. - if M <= 64: - ref_out = _run_eager(**common_args) - tri_out = _run_triton(**common_args) - err = _max_abs_error(tri_out, ref_out) - assert err < 2.0e-1, ( - f"Triton vs eager mismatch at M={M}: " - f"max abs error {err:.3e} >= 2.0e-1" - ) - del ref_out, tri_out - - # Benchmark - times = {} - for name, _label, run_fn in backends: - times[name] = _try_bench(run_fn, common_args, num_warmup, num_iters) - - ci = 0 - row_parts = [f"{f'M={M}':<{col_widths[ci]}}"] - ci += 1 - for name, _, _ in backends: - t = times[name] - w = col_widths[ci] - row_parts.append(f"{t:>{w}.3f}" if t is not None else f"{'OOM':>{w}}") - ci += 1 - print(" | ".join(row_parts)) - - del hidden_states, topk_w, topk_i - torch.cuda.empty_cache() - - print("-" * len(header)) - print() - - -def main(): - parser = argparse.ArgumentParser( - description="Benchmark Triton fused MoE vs eager/compile baselines" - ) - parser.add_argument("--num-experts", type=int, default=DEFAULTS["num_experts"]) - parser.add_argument("--top-k", type=int, default=DEFAULTS["top_k"]) - parser.add_argument("--hidden-size", type=int, default=DEFAULTS["hidden_size"]) - parser.add_argument( - "--intermediate-size", type=int, default=DEFAULTS["intermediate_size"] - ) - parser.add_argument("--group-size", type=int, default=DEFAULTS["group_size"]) - parser.add_argument("--num_warmup", type=int, default=25) - parser.add_argument("--num_iters", type=int, default=100) - parser.add_argument( - "--prompt-lengths", - type=str, - default=None, - help="Comma-separated list of prompt lengths (default: standard sweep)", - ) - args = parser.parse_args() - - prompt_lengths = PROMPT_LENGTHS - if args.prompt_lengths: - prompt_lengths = [int(x.strip()) for x in args.prompt_lengths.split(",")] - - run_benchmark( - prompt_lengths=prompt_lengths, - num_experts=args.num_experts, - top_k=args.top_k, - hidden_size=args.hidden_size, - intermediate_size=args.intermediate_size, - group_size=args.group_size, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - ) - - -if __name__ == "__main__": - main() diff --git a/backends/cuda/benchmarks/benchmark_sdpa.py b/backends/cuda/benchmarks/benchmark_sdpa.py deleted file mode 100644 index 3c117f4574f..00000000000 --- a/backends/cuda/benchmarks/benchmark_sdpa.py +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Benchmark the Triton SDPA kernel against PyTorch SDPA backends. - -Measures latency across decode shapes matching the Qwen3.5 MoE model -(B=1, H_q=16, H_kv=2, D=256). The ET Triton kernel uses native GQA -(2 KV heads), while Flash/Efficient/Math require pre-expanded KV -(16 heads) since they lack native GQA support. - -""" - -import argparse -import warnings -from functools import partial - -import torch -import torch.nn.functional as F - -from executorch.backends.cuda.triton.kernels.sdpa import ( - sdpa as triton_sdpa, - sdpa_decode_splitk as triton_splitk, -) -from torch.nn.attention import sdpa_kernel, SDPBackend -from triton.testing import do_bench - - -# PyTorch's Flash/Efficient backends don't support GQA (H_q != H_kv) directly. -# We expand KV heads via repeat_interleave so they can run, matching what -# the test reference does. This is fair: it measures the kernel itself, not -# the GQA dispatch overhead. - - -def _expand_kv(k, v, num_groups): - if num_groups > 1: - k = k.repeat_interleave(num_groups, dim=1) - v = v.repeat_interleave(num_groups, dim=1) - return k, v - - -def _expand_mask(mask, H_q): - if mask is not None and mask.shape[1] == 1 and H_q > 1: - mask = mask.expand(-1, H_q, -1, -1) - return mask - - -def _run_triton(q, k, v, attn_mask, enable_gqa): - return triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) - - -def _run_splitk(q, k, v, attn_mask, enable_gqa): - return triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) - - -def _run_pytorch_default(q, k, v, attn_mask, enable_gqa): - return F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attn_mask, - enable_gqa=enable_gqa, - ) - - -def _make_pytorch_runner(backend: SDPBackend): - def run(q, k, v, attn_mask, enable_gqa): - with sdpa_kernel(backend): - return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) - - return run - - -# Flash doesn't support attn_mask at all, only is_causal. -# Our benchmark mask is all-ones, so no mask is equivalent. -def _run_flash(q, k, v, attn_mask, enable_gqa): - with sdpa_kernel(SDPBackend.FLASH_ATTENTION): - return F.scaled_dot_product_attention(q, k, v) - - -BACKENDS = { - "triton": ("ET Triton (GQA)", _run_triton), - "splitk": ("ET Split-K (GQA)", _run_splitk), - "pytorch": ("PyTorch", _run_pytorch_default), - "flash": ("Flash (expanded KV)", _run_flash), - "efficient": ( - "Efficient (expanded KV)", - _make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION), - ), - "math": ("Math (expanded KV)", _make_pytorch_runner(SDPBackend.MATH)), -} - -# Backends that need KV heads expanded before calling (no native GQA support) -_NEEDS_KV_EXPAND = {"flash", "efficient", "math"} - -# -- Shapes ------------------------------------------------------------------ - -# Qwen3.5 MoE: B=1, H_q=16, H_kv=2, D=256 -QWEN35_BASE = {"B": 1, "H_q": 16, "H_kv": 2, "D": 256} - -DECODE_SHAPES = [ - dict(**QWEN35_BASE, Lq=1, Lk=64), - dict(**QWEN35_BASE, Lq=1, Lk=128), - dict(**QWEN35_BASE, Lq=1, Lk=256), - dict(**QWEN35_BASE, Lq=1, Lk=512), - dict(**QWEN35_BASE, Lq=1, Lk=1024), - dict(**QWEN35_BASE, Lq=1, Lk=2048), - dict(**QWEN35_BASE, Lq=1, Lk=4096), - dict(**QWEN35_BASE, Lq=1, Lk=8192), - dict(**QWEN35_BASE, Lq=1, Lk=16384), -] - -SCENARIOS = { - "decode": DECODE_SHAPES, -} - -# -- Helpers ----------------------------------------------------------------- - - -def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16): - q = torch.randn(B, H_q, Lq, D, device=device, dtype=dtype) - k = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype) - v = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype) - mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device=device) - enable_gqa = H_q != H_kv - num_groups = H_q // H_kv - # Pre-expanded versions for backends without native GQA - k_exp, v_exp = _expand_kv(k, v, num_groups) - mask_exp = _expand_mask(mask, H_q) - return q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa - - -def _max_abs_error(out, ref): - return (out.float() - ref.float()).abs().max().item() - - -# Cross-backend validation tolerance (bf16 vs bf16). -MAX_ABS_TOL = 1e-2 - - -def _bench_us(fn, num_warmup, num_iters): - """Return median latency in microseconds using triton.testing.do_bench.""" - ms = do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median") - return ms * 1000.0 - - -def _try_run(run_fn, q, k, v, mask, enable_gqa): - """Run a backend, returning output or None on failure.""" - try: - return run_fn(q, k, v, mask, enable_gqa) - except RuntimeError: - return None - - -def _try_bench(run_fn, q, k, v, mask, enable_gqa, num_warmup, num_iters): - """Benchmark a backend, returning median us or None on failure.""" - fn = partial(run_fn, q, k, v, mask, enable_gqa) - try: - run_fn(q, k, v, mask, enable_gqa) - return _bench_us(fn, num_warmup, num_iters) - except RuntimeError: - return None - - -# -- Main -------------------------------------------------------------------- - - -def _shape_label(shape): - return ( - f"B={shape['B']} Hq={shape['H_q']} Hkv={shape['H_kv']} " - f"D={shape['D']} Lq={shape['Lq']} Lk={shape['Lk']}" - ) - - -def _short_label(shape, scenario="decode"): - return f"Lq={shape['Lq']},Lk={shape['Lk']}" - - -@torch.inference_mode() -def run_benchmark( - scenario: str = "decode", - num_warmup: int = 25, - num_iters: int = 100, -): - shapes = SCENARIOS[scenario] - backends = [(name, *BACKENDS[name]) for name in BACKENDS] - - device_name = torch.cuda.get_device_name() - print() - print("=" * 100) - print(f"SDPA Benchmark Qwen3.5-35B-A3B — {scenario}") - print(f" Device: {device_name}") - print(f" Warmup: {num_warmup}, Iters: {num_iters}") - print(f" Backends: {', '.join(label for _, label, _ in backends)}") - print("=" * 100) - - # Build column specs: (header_text, unit_text, min_width) - # Each column gets width = max(len(header), len(unit), min_width) - max_label = max(len(_short_label(s, scenario)) for s in shapes) - col_specs = [("Shape", "", max(8, max_label))] - for _, label, _ in backends: - col_specs.append((label, "(us)", 8)) - - col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] - - header = " | ".join( - f"{h:<{w}}" if i == 0 else f"{h:>{w}}" - for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths)) - ) - units = " | ".join( - f"{'':>{w}}" if i == 0 else f"{u:>{w}}" - for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths)) - ) - print(header) - print(units) - print("-" * len(header)) - - for shape in shapes: - q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors(**shape) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - # Validate outputs across backends before benchmarking - outputs = {} - for name, _label, run_fn in backends: - if name in _NEEDS_KV_EXPAND: - bk, bv, bmask = k_exp, v_exp, mask_exp - else: - bk, bv, bmask = k, v, mask - outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa) - - # Use PyTorch F.sdpa as the trusted reference — never validate - # against our own Triton kernels. - ref_name, ref_out = None, None - if outputs.get("pytorch") is not None: - ref_name, ref_out = "pytorch", outputs["pytorch"] - - if ref_out is not None: - for name, label, _ in backends: - if name == ref_name or outputs[name] is None: - continue - err = _max_abs_error(outputs[name], ref_out) - assert err < MAX_ABS_TOL, ( - f"Output mismatch for {_shape_label(shape)}: " - f"{label} vs {BACKENDS[ref_name][0]}, " - f"max abs error {err:.3e} >= 1e-2" - ) - del outputs - - # Benchmark all backends - times = {} - for name, _label, run_fn in backends: - if name in _NEEDS_KV_EXPAND: - bk, bv, bmask = k_exp, v_exp, mask_exp - else: - bk, bv, bmask = k, v, mask - times[name] = _try_bench( - run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters - ) - - # Format row using col_widths - ci = 0 - row_parts = [f"{_short_label(shape, scenario):<{col_widths[ci]}}"] - ci += 1 - for name, _, _ in backends: - t = times[name] - w = col_widths[ci] - row_parts.append(f"{t:>{w}.1f}" if t is not None else f"{'N/A':>{w}}") - ci += 1 - print(" | ".join(row_parts)) - - del q, k, v, k_exp, v_exp, mask, mask_exp - torch.cuda.empty_cache() - - print("-" * len(header)) - print() - - -def main(): - parser = argparse.ArgumentParser( - description="Benchmark Triton SDPA vs PyTorch backends" - ) - parser.add_argument( - "--scenario", - choices=list(SCENARIOS.keys()) + ["all"], - default="all", - help="Which shape set to benchmark (default: all)", - ) - parser.add_argument("--num_warmup", type=int, default=25) - parser.add_argument("--num_iters", type=int, default=100) - args = parser.parse_args() - - scenarios = list(SCENARIOS.keys()) if args.scenario == "all" else [args.scenario] - for s in scenarios: - run_benchmark( - scenario=s, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - ) - - -if __name__ == "__main__": - main() From 75a2348ad9f8e0fb73035c2846764202c87accf9 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 15 Apr 2026 21:20:22 -0700 Subject: [PATCH 2/2] Tighten fused MoE test tolerances from 5% to 2% test_eager_correctness, test_single_expert, and test_batched_correctness used 5% relative tolerance for INT4 kernel-vs-dequant comparison. Tighten to 2% to match the e2e runner bar (fe71bd48). --- backends/cuda/tests/test_fused_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/cuda/tests/test_fused_moe.py b/backends/cuda/tests/test_fused_moe.py index e23832b89ea..63865dc2302 100644 --- a/backends/cuda/tests/test_fused_moe.py +++ b/backends/cuda/tests/test_fused_moe.py @@ -302,7 +302,7 @@ def test_eager_correctness(self): rel = diff / (ref.float().abs().max().item() + 1e-10) self.assertLess( rel, - 0.05, + 0.02, f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})", ) @@ -332,7 +332,7 @@ def test_single_expert(self): ref = w2_dq[1] @ activated diff = (out[t].float() - ref.float()).abs().max().item() rel = diff / (ref.float().abs().max().item() + 1e-10) - self.assertLess(rel, 0.05, f"token {t}: relative diff {rel:.4f}") + self.assertLess(rel, 0.02, f"token {t}: relative diff {rel:.4f}") def test_batched_correctness(self): """Batched kernel matches reference across M values.""" @@ -390,7 +390,7 @@ def test_batched_correctness(self): rel = diff / (ref.float().abs().max().item() + 1e-10) self.assertLess( rel, - 0.05, + 0.02, f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})", )