tests, feat(attention): GQA/MQA and decode-phase support via IAttentionLayer with HLO tests#4197
tests, feat(attention): GQA/MQA and decode-phase support via IAttentionLayer with HLO tests#4197yizhuoz004 wants to merge 2 commits intopytorch:mainfrom
Conversation
… documentation - Add tests/py/dynamo/hlo/test_attention.py covering all SDPA kernel variants (flash, efficient, GQA/MQA, decode-phase, bool/float masks) with LLM-realistic shapes and multiple dtypes - Add tests/py/dynamo/hlo/__init__.py to make hlo/ a Python package - Extend harness.py DispatchTestCase.run_test() with attn_bias_is_causal param to thread through CompilationSettings for efficient-attention bias tests - File-level docstring documents three known IAttentionLayer bugs (large causal seq, GQA, decode-phase) and explains why affected tests use decompose_attention=True Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
Hi @yizhuoz004! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
| return False | ||
|
|
||
|
|
||
| # this method is only used in our converter test to infer the module output dtypes via dummy inference |
There was a problem hiding this comment.
This is a duplicate definition of line 42
- Extend flash attention validator to accept GQA shapes (Hq != Hkv): IAttentionLayer natively handles non-equal head counts without K/V expansion. Requires Hq divisible by Hkv and matching batch/head_dim. - Add decode-phase support (seq_q != seq_k) to all three attention validators; only the seq dimension is skipped in shape checks. - Document why GQA is not supported in the efficient attention validator: PyTorch's eager kernel rejects Hq != Hkv, so no reference output exists; GQA models dispatch to flash attention (FP16) or decompose via matmul+_safe_softmax (FP32) and never produce this op with GQA shapes. - Restructure test_attention.py: merge five SDPA classes into TestSDPA, expand TestFlashAttention with test_decode and test_gqa methods, add TestEfficientAttention.test_with_bias_decode; trim redundant cases and remove BUG-1 inline annotations (kept only in module docstring). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
95e767d to
0817c5a
Compare
Description
Extends the TRT attention converter to support GQA/MQA and decode-phase attention, and adds a comprehensive HLO-level test suite.
Converter changes (aten_ops_converters.py, force_causal_efficient_attention.py):
New test suite (tests/py/dynamo/hlo/test_attention.py):
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: