Skip to content

tests, feat(attention): GQA/MQA and decode-phase support via IAttentionLayer with HLO tests#4197

Open
yizhuoz004 wants to merge 2 commits intopytorch:mainfrom
yizhuoz004:hlo-attention-tests
Open

tests, feat(attention): GQA/MQA and decode-phase support via IAttentionLayer with HLO tests#4197
yizhuoz004 wants to merge 2 commits intopytorch:mainfrom
yizhuoz004:hlo-attention-tests

Conversation

@yizhuoz004
Copy link
Copy Markdown
Contributor

Description

Extends the TRT attention converter to support GQA/MQA and decode-phase attention, and adds a comprehensive HLO-level test suite.

Converter changes (aten_ops_converters.py, force_causal_efficient_attention.py):

  • Lifts the enable_gqa=True rejection in all three SDPA validators (scaled_dot_product_attention, flash, efficient). IAttentionLayer natively handles GQA/MQA — the validator now verifies Hq % Hkv == 0 instead of blocking.
  • Relaxes the shape-equality check to allow decode-phase attention (seq_q != seq_k) and GQA head-count mismatches, while still rejecting incompatible batch/head-dim shapes.
  • Adds attn_bias_is_causal parameter to DispatchTestCase.run_test to control whether the force_causal_efficient_attention lowering pass strips attn_bias before reaching the converter.

New test suite (tests/py/dynamo/hlo/test_attention.py):

  • Covers all three SDPA kernel variants (standard, flash, efficient) across MHA/GQA/MQA, causal/non-causal, bool/float/broadcast masks, decode-phase (seq_q=1), non-power-of-2 head dims, LLM-realistic configs, and fp16/bf16/fp32.
  • Known bugs are documented inline (large causal sequences, fp32 GQA without decompose_attention).

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

… documentation

- Add tests/py/dynamo/hlo/test_attention.py covering all SDPA kernel variants
  (flash, efficient, GQA/MQA, decode-phase, bool/float masks) with LLM-realistic
  shapes and multiple dtypes
- Add tests/py/dynamo/hlo/__init__.py to make hlo/ a Python package
- Extend harness.py DispatchTestCase.run_test() with attn_bias_is_causal param
  to thread through CompilationSettings for efficient-attention bias tests
- File-level docstring documents three known IAttentionLayer bugs (large causal
  seq, GQA, decode-phase) and explains why affected tests use decompose_attention=True

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented Apr 20, 2026

Hi @yizhuoz004!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@github-actions github-actions Bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 20, 2026
return False


# this method is only used in our converter test to infer the module output dtypes via dummy inference
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a duplicate definition of line 42

- Extend flash attention validator to accept GQA shapes (Hq != Hkv):
  IAttentionLayer natively handles non-equal head counts without K/V
  expansion. Requires Hq divisible by Hkv and matching batch/head_dim.
- Add decode-phase support (seq_q != seq_k) to all three attention
  validators; only the seq dimension is skipped in shape checks.
- Document why GQA is not supported in the efficient attention validator:
  PyTorch's eager kernel rejects Hq != Hkv, so no reference output exists;
  GQA models dispatch to flash attention (FP16) or decompose via
  matmul+_safe_softmax (FP32) and never produce this op with GQA shapes.
- Restructure test_attention.py: merge five SDPA classes into TestSDPA,
  expand TestFlashAttention with test_decode and test_gqa methods,
  add TestEfficientAttention.test_with_bias_decode; trim redundant cases
  and remove BUG-1 inline annotations (kept only in module docstring).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@yizhuoz004 yizhuoz004 force-pushed the hlo-attention-tests branch from 95e767d to 0817c5a Compare April 20, 2026 19:21
@narendasan narendasan requested a review from zewenli98 April 21, 2026 15:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant