diff --git a/iron/operators/swiglu_prefill/op.py b/iron/operators/swiglu_prefill/op.py index c6cb413b..d47e7145 100644 --- a/iron/operators/swiglu_prefill/op.py +++ b/iron/operators/swiglu_prefill/op.py @@ -28,8 +28,30 @@ def _call_matmul(self, matmul_callable, weight, input_buf, output_buf): class SwiGLUPrefill(CompositeOperator): def __init__( - self, seq_len, embedding_dim, hidden_dim, prio_accuracy=False, context=None + self, + seq_len, + embedding_dim, + hidden_dim, + prio_accuracy=False, + tile_m=None, + tile_k=None, + tile_n=None, + context=None, ): + """ + Args: + seq_len, embedding_dim, hidden_dim: SwiGLU FFN shape. + prio_accuracy: propagate the prio-accuracy flags to the inner GEMMs. + tile_m, tile_k, tile_n: optional overrides for the inner GEMMs' + tile sizes. If None, each falls back to GEMM's default (64). + Lowering tile_m reduces the minimum seq_len that SwiGLUPrefill + can be instantiated with (min_M = tile_m * 4), which unblocks + dispatch for decode-runtime batch sizes like M=32/64/128. + Both inner GEMMs share the same tile triple — this is the + common case; for asymmetric shapes, instantiate per-stage + GEMMs manually and compose. + context: AIEContext. + """ self.seq_len = seq_len self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim @@ -39,6 +61,9 @@ def __init__( self.weights_3 = None self.prio_accuracy = prio_accuracy + self.tile_m = tile_m + self.tile_k = tile_k + self.tile_n = tile_n super().__init__(context=context) def set_up_artifacts(self): @@ -53,8 +78,20 @@ def set_up_artifacts(self): "round_conv_even": True, } + tile_flags = {} + if self.tile_m is not None: + tile_flags["tile_m"] = self.tile_m + if self.tile_k is not None: + tile_flags["tile_k"] = self.tile_k + if self.tile_n is not None: + tile_flags["tile_n"] = self.tile_n + gemm_1 = GEMM( - M=self.seq_len, K=self.embedding_dim, N=self.hidden_dim, **accuracy_flags + M=self.seq_len, + K=self.embedding_dim, + N=self.hidden_dim, + **accuracy_flags, + **tile_flags, ) self.gemm_1 = gemm_1 self.seq_len_padded = gemm_1.M @@ -78,7 +115,11 @@ def set_up_artifacts(self): assert eltwise_mul.size == self.seq_len_padded * self.hidden_dim_padded gemm_2 = GEMM( - M=self.seq_len, K=self.hidden_dim, N=self.embedding_dim, **accuracy_flags + M=self.seq_len, + K=self.hidden_dim, + N=self.embedding_dim, + **accuracy_flags, + **tile_flags, ) self.gemm_2 = gemm_2 assert gemm_2.M == self.seq_len_padded diff --git a/iron/operators/swiglu_prefill/test.py b/iron/operators/swiglu_prefill/test.py index 6559571b..537c6889 100755 --- a/iron/operators/swiglu_prefill/test.py +++ b/iron/operators/swiglu_prefill/test.py @@ -17,7 +17,15 @@ def get_params(): - params_list = [(256, 2048, 2048, False)] + # (seq_len, embedding_dim, hidden_dim, prio_accuracy, tile_kwargs) + # The default case uses GEMM's native tile defaults (tile_m=tile_k=tile_n=64, + # so min_M = 256). The small-M case exercises the tile-override path: + # tile_m=16 drops min_M to 64, which is what real decode-runtime batch + # sizes (32/64/128) require to be dispatchable. + params_list = [ + (256, 2048, 2048, False, {}), + (64, 1024, 3584, False, {"tile_m": 16, "tile_k": 64, "tile_n": 64}), + ] params = [] for p in params_list: @@ -29,8 +37,12 @@ def get_params(): Latency=r"Latency \(us\): (?P[\d\.]+)", Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", ) -@pytest.mark.parametrize("seq_len,embedding_dim,hidden_dim,prio_accuracy", get_params()) -def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_context): +@pytest.mark.parametrize( + "seq_len,embedding_dim,hidden_dim,prio_accuracy,tile_kwargs", get_params() +) +def test_swiglu_prefill( + seq_len, embedding_dim, hidden_dim, prio_accuracy, tile_kwargs, aie_context +): golden_ref = generate_golden_reference(M=seq_len, K=embedding_dim, N=hidden_dim) operator = SwiGLUPrefill( @@ -39,6 +51,7 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c hidden_dim=hidden_dim, prio_accuracy=bool(prio_accuracy), context=aie_context, + **tile_kwargs, ) operator.weights_1 = golden_ref["w_gate"].T operator.weights_2 = golden_ref["w_up"].T