Restore device-aware num_aie_columns in SwiGLU operators#104
Open
albiol2004 wants to merge 1 commit intoamd:develfrom
Open
Restore device-aware num_aie_columns in SwiGLU operators#104albiol2004 wants to merge 1 commit intoamd:develfrom
albiol2004 wants to merge 1 commit intoamd:develfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Restores device-aware column selection in
SwiGLUDecodeandSwiGLUPrefillso both composite operators run on Phoenix (NPU1, 4 columns) as well as Strix (NPU2, 8 columns). The branching was originally introduced in #89 and inadvertently dropped during the simplifying refactor in #88, which hardcodednum_aie_columns=8across both SwiGLU variants. This PR replaces the literal8withaie_utils.get_current_device().cols, matching the pattern already used byrms_norm,gemm,gemv,mem_copy, etc.While re-threading the column count,
num_aie_columns=n_colsis also now passed to the twoGEMMcalls inSwiGLUPrefill. Previously those defaulted toGEMM's fallback, which meant SwiGLU prefill was implicitly under-parallelized on NPU2 and misaligned with the column count used by the surrounding SiLU and ElementwiseMul sub-ops.A new rectangular FFN shape (
seq_len=256, embedding_dim=1024, hidden_dim=3584) is added toswiglu_prefill/test.pyso real decoder-model FFN dimensions (e.g. Qwen3.5-0.8B) are exercised in CI alongside the existing square2048²smoke test.Added
(256, 1024, 3584, False)rectangular FFN shape iniron/operators/swiglu_prefill/test.py, reflecting real decoder-model dims so non-square paths are covered.Changed
iron/operators/swiglu_decode/op.py: deriven_cols = aie_utils.get_current_device().colsand pass it to the gemv_1 / silu / eltwise_mul / gemv_2 sub-ops in place of the hardcoded8.iron/operators/swiglu_prefill/op.py: same device-aware derivation, applied to silu / eltwise_mul and (newly) threaded through thegemm_1andgemm_2calls, which previously omittednum_aie_columnsentirely.Removed
num_aie_columns=8and associated// 8,// 16literals in both SwiGLU op files.Testing
Verified on NPU2 (Strix,
aie2p) withironenv+ XRT sourced:pytest iron/operators/ -m "not extensive" --iterations 1: all previously-passing tests still pass (pre-existing LeakyReLU skips unchanged).pytest iron/operators/swiglu_decode/test.py -v --iterations 1: square2048²passes.pytest iron/operators/swiglu_prefill/test.py -v --iterations 1: both square2048²and the new rectangular256 × 1024 × 3584pass.NPU1 (Phoenix,
aie2) hardware was not available for local validation; the column-selection logic is structurally identical to the previously-shipped #89 code path, so Phoenix behavior is expected to match that baseline and should be re-confirmed by a reviewer with Phoenix access.During this work a separate, pre-existing numerical issue was observed in
swiglu_decodeat rectangular decode shapes (e.g.1024 × 3584) that is unrelated to this change, the decode failure reproduces withnum_aie_columns=8either before or after this PR, so the rectangular case was not added toswiglu_decode/test.pyin this PR. That issue is being investigated separately.PR Merge Checklist
develcommit and pointing todevel.