Skip to content

Restore device-aware num_aie_columns in SwiGLU operators#104

Open
albiol2004 wants to merge 1 commit intoamd:develfrom
albiol2004:swiglu-num-aie-columns
Open

Restore device-aware num_aie_columns in SwiGLU operators#104
albiol2004 wants to merge 1 commit intoamd:develfrom
albiol2004:swiglu-num-aie-columns

Conversation

@albiol2004
Copy link
Copy Markdown

Restores device-aware column selection in SwiGLUDecode and SwiGLUPrefill so both composite operators run on Phoenix (NPU1, 4 columns) as well as Strix (NPU2, 8 columns). The branching was originally introduced in #89 and inadvertently dropped during the simplifying refactor in #88, which hardcoded num_aie_columns=8 across both SwiGLU variants. This PR replaces the literal 8 with aie_utils.get_current_device().cols, matching the pattern already used by rms_norm, gemm, gemv, mem_copy, etc.

While re-threading the column count, num_aie_columns=n_cols is also now passed to the two GEMM calls in SwiGLUPrefill. Previously those defaulted to GEMM's fallback, which meant SwiGLU prefill was implicitly under-parallelized on NPU2 and misaligned with the column count used by the surrounding SiLU and ElementwiseMul sub-ops.

A new rectangular FFN shape (seq_len=256, embedding_dim=1024, hidden_dim=3584) is added to swiglu_prefill/test.py so real decoder-model FFN dimensions (e.g. Qwen3.5-0.8B) are exercised in CI alongside the existing square 2048² smoke test.

Added

  • (256, 1024, 3584, False) rectangular FFN shape in iron/operators/swiglu_prefill/test.py, reflecting real decoder-model dims so non-square paths are covered.

Changed

  • iron/operators/swiglu_decode/op.py: derive n_cols = aie_utils.get_current_device().cols and pass it to the gemv_1 / silu / eltwise_mul / gemv_2 sub-ops in place of the hardcoded 8.
  • iron/operators/swiglu_prefill/op.py: same device-aware derivation, applied to silu / eltwise_mul and (newly) threaded through the gemm_1 and gemm_2 calls, which previously omitted num_aie_columns entirely.

Removed

  • Hardcoded num_aie_columns=8 and associated // 8, // 16 literals in both SwiGLU op files.

Testing

Verified on NPU2 (Strix, aie2p) with ironenv + XRT sourced:

  • pytest iron/operators/ -m "not extensive" --iterations 1 : all previously-passing tests still pass (pre-existing LeakyReLU skips unchanged).
  • pytest iron/operators/swiglu_decode/test.py -v --iterations 1 : square 2048² passes.
  • pytest iron/operators/swiglu_prefill/test.py -v --iterations 1 : both square 2048² and the new rectangular 256 × 1024 × 3584 pass.

NPU1 (Phoenix, aie2) hardware was not available for local validation; the column-selection logic is structurally identical to the previously-shipped #89 code path, so Phoenix behavior is expected to match that baseline and should be re-confirmed by a reviewer with Phoenix access.

During this work a separate, pre-existing numerical issue was observed in swiglu_decode at rectangular decode shapes (e.g. 1024 × 3584) that is unrelated to this change, the decode failure reproduces with num_aie_columns=8 either before or after this PR, so the rectangular case was not added to swiglu_decode/test.py in this PR. That issue is being investigated separately.

PR Merge Checklist

  1. The PR is rebased on the latest devel commit and pointing to devel.
  2. Your PR has been reviewed and approved.
  3. All checks are passing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant