Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions iron/operators/swiglu_prefill/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,30 @@ def _call_matmul(self, matmul_callable, weight, input_buf, output_buf):

class SwiGLUPrefill(CompositeOperator):
def __init__(
self, seq_len, embedding_dim, hidden_dim, prio_accuracy=False, context=None
self,
seq_len,
embedding_dim,
hidden_dim,
prio_accuracy=False,
tile_m=None,
tile_k=None,
tile_n=None,
context=None,
):
"""
Args:
seq_len, embedding_dim, hidden_dim: SwiGLU FFN shape.
prio_accuracy: propagate the prio-accuracy flags to the inner GEMMs.
tile_m, tile_k, tile_n: optional overrides for the inner GEMMs'
tile sizes. If None, each falls back to GEMM's default (64).
Lowering tile_m reduces the minimum seq_len that SwiGLUPrefill
can be instantiated with (min_M = tile_m * 4), which unblocks
dispatch for decode-runtime batch sizes like M=32/64/128.
Both inner GEMMs share the same tile triple — this is the
common case; for asymmetric shapes, instantiate per-stage
GEMMs manually and compose.
context: AIEContext.
"""
self.seq_len = seq_len
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
Expand All @@ -39,6 +61,9 @@ def __init__(
self.weights_3 = None

self.prio_accuracy = prio_accuracy
self.tile_m = tile_m
self.tile_k = tile_k
self.tile_n = tile_n
super().__init__(context=context)

def set_up_artifacts(self):
Expand All @@ -53,8 +78,20 @@ def set_up_artifacts(self):
"round_conv_even": True,
}

tile_flags = {}
if self.tile_m is not None:
tile_flags["tile_m"] = self.tile_m
if self.tile_k is not None:
tile_flags["tile_k"] = self.tile_k
if self.tile_n is not None:
tile_flags["tile_n"] = self.tile_n

gemm_1 = GEMM(
M=self.seq_len, K=self.embedding_dim, N=self.hidden_dim, **accuracy_flags
M=self.seq_len,
K=self.embedding_dim,
N=self.hidden_dim,
**accuracy_flags,
**tile_flags,
)
self.gemm_1 = gemm_1
self.seq_len_padded = gemm_1.M
Expand All @@ -78,7 +115,11 @@ def set_up_artifacts(self):
assert eltwise_mul.size == self.seq_len_padded * self.hidden_dim_padded

gemm_2 = GEMM(
M=self.seq_len, K=self.hidden_dim, N=self.embedding_dim, **accuracy_flags
M=self.seq_len,
K=self.hidden_dim,
N=self.embedding_dim,
**accuracy_flags,
**tile_flags,
)
self.gemm_2 = gemm_2
assert gemm_2.M == self.seq_len_padded
Expand Down
19 changes: 16 additions & 3 deletions iron/operators/swiglu_prefill/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@


def get_params():
params_list = [(256, 2048, 2048, False)]
# (seq_len, embedding_dim, hidden_dim, prio_accuracy, tile_kwargs)
# The default case uses GEMM's native tile defaults (tile_m=tile_k=tile_n=64,
# so min_M = 256). The small-M case exercises the tile-override path:
# tile_m=16 drops min_M to 64, which is what real decode-runtime batch
# sizes (32/64/128) require to be dispatchable.
params_list = [
(256, 2048, 2048, False, {}),
(64, 1024, 3584, False, {"tile_m": 16, "tile_k": 64, "tile_n": 64}),
]

params = []
for p in params_list:
Expand All @@ -29,8 +37,12 @@ def get_params():
Latency=r"Latency \(us\): (?P<value>[\d\.]+)",
Bandwidth=r"Effective Bandwidth: (?P<value>[\d\.e\+-]+) GB/s",
)
@pytest.mark.parametrize("seq_len,embedding_dim,hidden_dim,prio_accuracy", get_params())
def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_context):
@pytest.mark.parametrize(
"seq_len,embedding_dim,hidden_dim,prio_accuracy,tile_kwargs", get_params()
)
def test_swiglu_prefill(
seq_len, embedding_dim, hidden_dim, prio_accuracy, tile_kwargs, aie_context
):
golden_ref = generate_golden_reference(M=seq_len, K=embedding_dim, N=hidden_dim)

operator = SwiGLUPrefill(
Expand All @@ -39,6 +51,7 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c
hidden_dim=hidden_dim,
prio_accuracy=bool(prio_accuracy),
context=aie_context,
**tile_kwargs,
)
operator.weights_1 = golden_ref["w_gate"].T
operator.weights_2 = golden_ref["w_up"].T
Expand Down