From 6dbcf440461bcca9c40491ba626293cd84ff3a2c Mon Sep 17 00:00:00 2001 From: albiol2004 Date: Tue, 14 Apr 2026 14:21:36 +0200 Subject: [PATCH] Add tile_m/tile_k/tile_n overrides to SwiGLUPrefill SwiGLUPrefill currently uses GEMM's default tile triple (64/64/64), which forces min_M = tile_m * num_aie_rows = 256. Real-world prefill batch sizes from decoder-model runtimes (llama.cpp ubatch=32/64/128) fall well below that threshold, leaving the fused SwiGLU path unreachable in practice. Add optional tile_m/tile_k/tile_n kwargs that pass through to both inner GEMMs. When None (default), each falls back to GEMM's native default, so existing callers and the existing (256, 2048, 2048) test are unchanged. Add a small-M test case (M=64, K=1024, N=3584, tile_m=16) that exercises the override path at the Qwen3.5-0.8B FFN shape. --- iron/operators/swiglu_prefill/op.py | 47 +++++++++++++++++++++++++-- iron/operators/swiglu_prefill/test.py | 19 +++++++++-- 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/iron/operators/swiglu_prefill/op.py b/iron/operators/swiglu_prefill/op.py index c6cb413b..d47e7145 100644 --- a/iron/operators/swiglu_prefill/op.py +++ b/iron/operators/swiglu_prefill/op.py @@ -28,8 +28,30 @@ def _call_matmul(self, matmul_callable, weight, input_buf, output_buf): class SwiGLUPrefill(CompositeOperator): def __init__( - self, seq_len, embedding_dim, hidden_dim, prio_accuracy=False, context=None + self, + seq_len, + embedding_dim, + hidden_dim, + prio_accuracy=False, + tile_m=None, + tile_k=None, + tile_n=None, + context=None, ): + """ + Args: + seq_len, embedding_dim, hidden_dim: SwiGLU FFN shape. + prio_accuracy: propagate the prio-accuracy flags to the inner GEMMs. + tile_m, tile_k, tile_n: optional overrides for the inner GEMMs' + tile sizes. If None, each falls back to GEMM's default (64). + Lowering tile_m reduces the minimum seq_len that SwiGLUPrefill + can be instantiated with (min_M = tile_m * 4), which unblocks + dispatch for decode-runtime batch sizes like M=32/64/128. + Both inner GEMMs share the same tile triple — this is the + common case; for asymmetric shapes, instantiate per-stage + GEMMs manually and compose. + context: AIEContext. + """ self.seq_len = seq_len self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim @@ -39,6 +61,9 @@ def __init__( self.weights_3 = None self.prio_accuracy = prio_accuracy + self.tile_m = tile_m + self.tile_k = tile_k + self.tile_n = tile_n super().__init__(context=context) def set_up_artifacts(self): @@ -53,8 +78,20 @@ def set_up_artifacts(self): "round_conv_even": True, } + tile_flags = {} + if self.tile_m is not None: + tile_flags["tile_m"] = self.tile_m + if self.tile_k is not None: + tile_flags["tile_k"] = self.tile_k + if self.tile_n is not None: + tile_flags["tile_n"] = self.tile_n + gemm_1 = GEMM( - M=self.seq_len, K=self.embedding_dim, N=self.hidden_dim, **accuracy_flags + M=self.seq_len, + K=self.embedding_dim, + N=self.hidden_dim, + **accuracy_flags, + **tile_flags, ) self.gemm_1 = gemm_1 self.seq_len_padded = gemm_1.M @@ -78,7 +115,11 @@ def set_up_artifacts(self): assert eltwise_mul.size == self.seq_len_padded * self.hidden_dim_padded gemm_2 = GEMM( - M=self.seq_len, K=self.hidden_dim, N=self.embedding_dim, **accuracy_flags + M=self.seq_len, + K=self.hidden_dim, + N=self.embedding_dim, + **accuracy_flags, + **tile_flags, ) self.gemm_2 = gemm_2 assert gemm_2.M == self.seq_len_padded diff --git a/iron/operators/swiglu_prefill/test.py b/iron/operators/swiglu_prefill/test.py index 6559571b..537c6889 100755 --- a/iron/operators/swiglu_prefill/test.py +++ b/iron/operators/swiglu_prefill/test.py @@ -17,7 +17,15 @@ def get_params(): - params_list = [(256, 2048, 2048, False)] + # (seq_len, embedding_dim, hidden_dim, prio_accuracy, tile_kwargs) + # The default case uses GEMM's native tile defaults (tile_m=tile_k=tile_n=64, + # so min_M = 256). The small-M case exercises the tile-override path: + # tile_m=16 drops min_M to 64, which is what real decode-runtime batch + # sizes (32/64/128) require to be dispatchable. + params_list = [ + (256, 2048, 2048, False, {}), + (64, 1024, 3584, False, {"tile_m": 16, "tile_k": 64, "tile_n": 64}), + ] params = [] for p in params_list: @@ -29,8 +37,12 @@ def get_params(): Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("seq_len,embedding_dim,hidden_dim,prio_accuracy", get_params()) -def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_context): +@pytest.mark.parametrize( + "seq_len,embedding_dim,hidden_dim,prio_accuracy,tile_kwargs", get_params() +) +def test_swiglu_prefill( + seq_len, embedding_dim, hidden_dim, prio_accuracy, tile_kwargs, aie_context +): golden_ref = generate_golden_reference(M=seq_len, K=embedding_dim, N=hidden_dim) operator = SwiGLUPrefill( @@ -39,6 +51,7 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c hidden_dim=hidden_dim, prio_accuracy=bool(prio_accuracy), context=aie_context, + **tile_kwargs, ) operator.weights_1 = golden_ref["w_gate"].T operator.weights_2 = golden_ref["w_up"].T