Skip to content

Add head dim 256 support for SDPA on Blackwell#2906

Open
yaox12 wants to merge 4 commits intoNVIDIA:mainfrom
yaox12:xiny/headdim256_blackwell
Open

Add head dim 256 support for SDPA on Blackwell#2906
yaox12 wants to merge 4 commits intoNVIDIA:mainfrom
yaox12:xiny/headdim256_blackwell

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented Apr 21, 2026

Description

Add head dim 256 support for SDPA on Blackwell via cuDNN FE cuteDSL kernels.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

The existing C++ cuDNN SDPA path cannot handle head_dim=256 training on SM100+; cuDNN frontend ≥1.22.1 ships a Python-only (CuTe DSL) kernel for this shape. This PR wires that kernel into TE's FusedAttention dispatch.

  • C++ filter — transformer_engine/common/fused_attn/fused_attn.cpp
  • New Python wrapper — transformer_engine/pytorch/attention/dot_product_attention/cudnn_fe_sdpa.py
  • Dispatcher — transformer_engine/pytorch/cpp_extensions/fused_attn.py
  • Python-side filter — transformer_engine/pytorch/attention/dot_product_attention/utils.py
  • Tests — tests/pytorch/attention/test_headdim256_cudnn_fe.py

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 changed the title Add head dim 256 support on Blackwell Add head dim 256 support for SDPA on Blackwell Apr 21, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 21, 2026

Greptile Summary

This PR adds head_dim=256 training support for SDPA on Blackwell (SM100+) by wiring the Python-only cuDNN FE CuTe-DSL kernel into TE's dispatch path. The C++ backend filter is widened to emit F16_arbitrary_seqlen for this shape, get_attention_backend then promotes that to a Python-only "F16_cudnn_fe_sdpa" sentinel, and early-return branches in fused_attn_fwd/fused_attn_bwd route to the new cudnn_fe_sdpa module. Remaining findings are all P2 style/robustness suggestions.

Confidence Score: 5/5

Safe to merge; all remaining findings are P2 style/robustness suggestions with no correctness impact.

The core routing logic is sound: eligibility is gated in is_supported before the sentinel is set, attn_scale is always concretised before the Python kernel is called via the normal dispatch path, and env-var mutations in tests are correctly scoped with monkeypatch. The three inline comments are minor and do not affect the primary use path.

No files require special attention beyond the P2 notes in cudnn_fe_sdpa.py and fused_attn.py.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/cudnn_fe_sdpa.py New module wiring the cuDNN FE CuTe-DSL SDPA kernel; is_available/is_supported guards, layout transposition logic and stream-forwarding look correct; minor issue with attn_scale=None passthrough to kernel.
transformer_engine/pytorch/cpp_extensions/fused_attn.py Adds F16_cudnn_fe_sdpa string sentinel to FusedAttnBackend and early-return routing blocks in both fused_attn_fwd and fused_attn_bwd; type annotation mismatch and fragile qkv_format extraction are P2 concerns.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Adds backend-promotion block in get_attention_backend that upgrades F16_arbitrary_seqlen to the Python-only sentinel for head_dim=256 on SM100+; fallback disables fused attention cleanly, other backends can still serve the request.
transformer_engine/common/fused_attn/fused_attn.cpp Explicitly allows head_dim=256 training on SM100+ through the F16_arbitrary_seqlen C++ filter with a comment noting the Python SDPA intercept; logic is a narrowly-scoped extension of existing conditions.
tests/pytorch/attention/test_headdim256_cudnn_fe.py New smoke-test file covering fwd+bwd correctness against an FP32 reference and an integration test via DotProductAttention; uses monkeypatch.setenv for safe env-var scoping.
qa/L0_pytorch_unittest/test.sh Single line added to run the new test file in the CI suite.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant get_attention_backend
    participant fused_attn_fwd
    participant cudnn_fe_sdpa
    participant cuDNN_FE_kernel

    Caller->>get_attention_backend: head_dim=256, SM100+, F16
    get_attention_backend->>get_attention_backend: C++ returns F16_arbitrary_seqlen
    get_attention_backend->>cudnn_fe_sdpa: is_supported()?
    cudnn_fe_sdpa-->>get_attention_backend: True
    get_attention_backend-->>Caller: backend = F16_cudnn_fe_sdpa (sentinel)

    Caller->>fused_attn_fwd: backend=F16_cudnn_fe_sdpa
    fused_attn_fwd->>fused_attn_fwd: compute attn_scale, extract qkv_format
    fused_attn_fwd->>cudnn_fe_sdpa: fused_attn_fwd(q,k,v,...)
    cudnn_fe_sdpa->>cudnn_fe_sdpa: _to_kernel_shape (transpose 1-2 for bshd)
    cudnn_fe_sdpa->>cuDNN_FE_kernel: sdpa_fwd_wrapper_sm100_d256(...)
    cuDNN_FE_kernel-->>cudnn_fe_sdpa: o_tensor, lse_tensor
    cudnn_fe_sdpa->>cudnn_fe_sdpa: _from_kernel_shape (transpose back)
    cudnn_fe_sdpa-->>fused_attn_fwd: out, [lse, rng_placeholder]
    fused_attn_fwd-->>Caller: out, aux_ctx_tensors
Loading

Reviews (4): Last reviewed commit: "resolve comments" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/cpp_extensions/fused_attn.py Outdated
Comment thread transformer_engine/pytorch/cpp_extensions/fused_attn.py Outdated
Comment thread transformer_engine/pytorch/cpp_extensions/fused_attn.py
Comment thread tests/pytorch/attention/test_headdim256_cudnn_fe.py Outdated
@yaox12 yaox12 force-pushed the xiny/headdim256_blackwell branch from 6b81cd1 to 87a39a2 Compare April 21, 2026 05:59
yaox12 added 2 commits April 21, 2026 06:01
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/headdim256_blackwell branch from 87a39a2 to 2db2fee Compare April 21, 2026 06:01
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/headdim256_blackwell branch from 2db2fee to 4fdcf33 Compare April 21, 2026 06:02
Comment thread transformer_engine/pytorch/cpp_extensions/fused_attn.py Outdated
Signed-off-by: Xin Yao <xiny@nvidia.com>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should use the use_fused_attention=True code path for the CuTe DSL kernels.

  • The C++ and Python APIs in the cuDNN frontend don't really have anything to do with each other, and they're only put together for organizational reasons.
  • This doesn't generalize well if we want to add more backends. I could imagine TE eventually having some DPA kernels written in Triton or maybe some new package will come out with a better attention implementation.
  • Putting the C++ and Python cuDNN APIs together means we will need to do redundant feature checks, once to decide use_fused_attn=True and again to choose between the C++ or Python cuDNN APIs (and again in the C++ infrastructure to choose the NVTE_Fused_Attn_Backend). This extra layer is not so bad now since there's only one Python case, but it will scale poorly if we get closer to a 50/50 split between C++ and Python.

The right approach would be to make a new backend, but we're restricted by the current design of the backend selection logic. In particular, get_attention_backend is hard-coded to return use_flash_attention, use_fused_attention, use_unfused_attention, so we're kind of stuck. I propose changing this function to return a list of strings. Adding a new backend just involves adding logic to handle a new string.

CC @cyanguwa

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. Though that's not an easy refactor.

# Python-only sentinel (no C++ counterpart). Set by ``get_attention_backend``
# when the cuDNN frontend CuTe-DSL d=256 SDPA path is applicable; checked
# in ``fused_attn_fwd`` / ``fused_attn_bwd`` to route to the Python kernel.
"F16_cudnn_fe_sdpa": "F16_cudnn_fe_sdpa",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since FusedAttnBackend is no longer guaranteed to return an NVTE_Fused_Attn_Backend, the safe thing to do would be to check the return every thing we access this dict. That's really painful, so it's a hint we're doing something wrong.

Comment on lines +76 to +77
if device_compute_capability[0] < 10:
return False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the kernel work with SM 12 or beyond? To be safe:

Suggested change
if device_compute_capability[0] < 10:
return False
if device_compute_capability[0] == 10:
return False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants