Skip to content

Fix flash attention version check.#2910

Merged
ptrendx merged 1 commit intoNVIDIA:mainfrom
bbuschkaemper:fix-fa-version-check
Apr 22, 2026
Merged

Fix flash attention version check.#2910
ptrendx merged 1 commit intoNVIDIA:mainfrom
bbuschkaemper:fix-fa-version-check

Conversation

@bbuschkaemper
Copy link
Copy Markdown
Contributor

Description

The current FA4 beta release 4.0.0b9 sorts below version 4.0.0 according to PEP 440, causing 4.0.0b9 to incorrectly enable both use_flash_attn_3 and use_flash_attn_4.

Related to #2432

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • backends.py checking flash attention major version directly now

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Björn Buschkämper <bjoern.buschkaemper@gmail.com>
Copilot AI review requested due to automatic review settings April 21, 2026 09:50
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 21, 2026

Greptile Summary

This PR fixes a bug where FA4 pre-release versions like 4.0.0b9 — which sort below 4.0.0 under PEP 440 — would satisfy both the use_flash_attn_4 and use_flash_attn_3 range checks simultaneously. The fix replaces both range-based PkgVersion comparisons with .major == 4 / .major == 3 equality checks, which correctly dispatch pre-release and stable builds of the same major line to a single API path.

Confidence Score: 5/5

Safe to merge — targeted one-liner fix with no side effects on stable or pre-release FA2/FA3/FA4 paths.

The change is minimal and correct: packaging.version.Version.major is a well-defined property that returns the first component of the release tuple, so 4.0.0b9.major == 4 is True while the old upper-bound check < PkgVersion("4.0.0") was False. Both flags are now mutually exclusive by construction. All remaining observations are P2 or lower.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Replaces PEP 440 range-based version checks with .major equality checks to correctly identify FA3/FA4 pre-release packages like 4.0.0b9.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[flash_attention_backend version] --> B{is not None?}
    B -- No --> C[use_flash_attn_4 = False\nuse_flash_attn_3 = False]
    B -- Yes --> D{.major == 4?}
    D -- Yes --> E[use_flash_attn_4 = True]
    D -- No --> F{.major == 3?}
    F -- Yes --> G[use_flash_attn_3 = True]
    F -- No --> H[FA2 or unknown path]

    style E fill:#90EE90
    style G fill:#90EE90
Loading

Reviews (1): Last reviewed commit: "Fix flash attention version check." | Re-trigger Greptile

@bbuschkaemper
Copy link
Copy Markdown
Contributor Author

@yaox12 Could you take a look here as well? Its a separate issue to #2909 so I created two PRs in accordance with contributing guidelines.

@yaox12
Copy link
Copy Markdown
Member

yaox12 commented Apr 21, 2026

LGTM

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes FlashAttention backend selection when using FA4 prereleases that sort below 4.0.0 under PEP 440, preventing FA3/FA4 API selection from becoming inconsistent.

Changes:

  • Switch FA3/FA4 selection logic to use the package major version rather than a < 4.0.0 range check.
  • Add clarifying inline comments explaining the prerelease sorting pitfall.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +941 to +946
use_flash_attn_4 = (
flash_attention_backend is not None and flash_attention_backend.major == 4
)
use_flash_attn_3 = (
flash_attention_backend is not None and flash_attention_backend.major == 3
)
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flash_attention_backend is a packaging.version.Version (PkgVersion) instance. That type is not guaranteed to expose a .major attribute across packaging versions; relying on it can raise AttributeError at runtime. Prefer using the stable API (flash_attention_backend.release[0] / flash_attention_backend.release[:1]) to derive the major version (with a safe default when release is empty).

Suggested change
use_flash_attn_4 = (
flash_attention_backend is not None and flash_attention_backend.major == 4
)
use_flash_attn_3 = (
flash_attention_backend is not None and flash_attention_backend.major == 3
)
flash_attention_backend_major = (
flash_attention_backend.release[0]
if flash_attention_backend is not None and flash_attention_backend.release
else None
)
use_flash_attn_4 = flash_attention_backend_major == 4
use_flash_attn_3 = flash_attention_backend_major == 3

Copilot uses AI. Check for mistakes.
Comment on lines +939 to +946
# FA4 prereleases such as 4.0.0b8 sort below 4.0.0, so key off the major
# version instead of a stable-version range check when selecting the API.
use_flash_attn_4 = (
flash_attention_backend is not None and flash_attention_backend.major == 4
)
use_flash_attn_3 = (
flash_attention_backend is not None and flash_attention_backend.major == 3
)
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change fixes a subtle prerelease ordering issue (e.g., 4.0.0b9 sorting below 4.0.0) that previously enabled both FA3 and FA4 code paths. There should be a regression test that exercises this selection logic with a prerelease FA4 version string to prevent future regressions (ideally without requiring flash-attn-4 to be installed, e.g., by unit-testing a small helper that maps PkgVersion -> selected API).

Copilot uses AI. Check for mistakes.
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Apr 21, 2026

/te-ci pytorch

@ptrendx ptrendx added 2.15.0 community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. labels Apr 21, 2026
@ptrendx ptrendx self-assigned this Apr 21, 2026
@ptrendx ptrendx merged commit 4014f7f into NVIDIA:main Apr 22, 2026
25 of 28 checks passed
KshitijLakhani pushed a commit that referenced this pull request Apr 22, 2026
Signed-off-by: Björn Buschkämper <bjoern.buschkaemper@gmail.com>
YigongQin pushed a commit to YigongQin/TransformerEngine that referenced this pull request Apr 23, 2026
Signed-off-by: Björn Buschkämper <bjoern.buschkaemper@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.15.0 community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants