Add head dim 256 support for SDPA on Blackwell#2906
Add head dim 256 support for SDPA on Blackwell#2906yaox12 wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR adds head_dim=256 training support for SDPA on Blackwell (SM100+) by wiring the Python-only cuDNN FE CuTe-DSL kernel into TE's dispatch path. The C++ backend filter is widened to emit Confidence Score: 5/5Safe to merge; all remaining findings are P2 style/robustness suggestions with no correctness impact. The core routing logic is sound: eligibility is gated in is_supported before the sentinel is set, attn_scale is always concretised before the Python kernel is called via the normal dispatch path, and env-var mutations in tests are correctly scoped with monkeypatch. The three inline comments are minor and do not affect the primary use path. No files require special attention beyond the P2 notes in cudnn_fe_sdpa.py and fused_attn.py. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant get_attention_backend
participant fused_attn_fwd
participant cudnn_fe_sdpa
participant cuDNN_FE_kernel
Caller->>get_attention_backend: head_dim=256, SM100+, F16
get_attention_backend->>get_attention_backend: C++ returns F16_arbitrary_seqlen
get_attention_backend->>cudnn_fe_sdpa: is_supported()?
cudnn_fe_sdpa-->>get_attention_backend: True
get_attention_backend-->>Caller: backend = F16_cudnn_fe_sdpa (sentinel)
Caller->>fused_attn_fwd: backend=F16_cudnn_fe_sdpa
fused_attn_fwd->>fused_attn_fwd: compute attn_scale, extract qkv_format
fused_attn_fwd->>cudnn_fe_sdpa: fused_attn_fwd(q,k,v,...)
cudnn_fe_sdpa->>cudnn_fe_sdpa: _to_kernel_shape (transpose 1-2 for bshd)
cudnn_fe_sdpa->>cuDNN_FE_kernel: sdpa_fwd_wrapper_sm100_d256(...)
cuDNN_FE_kernel-->>cudnn_fe_sdpa: o_tensor, lse_tensor
cudnn_fe_sdpa->>cudnn_fe_sdpa: _from_kernel_shape (transpose back)
cudnn_fe_sdpa-->>fused_attn_fwd: out, [lse, rng_placeholder]
fused_attn_fwd-->>Caller: out, aux_ctx_tensors
Reviews (4): Last reviewed commit: "resolve comments" | Re-trigger Greptile |
6b81cd1 to
87a39a2
Compare
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
87a39a2 to
2db2fee
Compare
Signed-off-by: Xin Yao <xiny@nvidia.com>
2db2fee to
4fdcf33
Compare
Signed-off-by: Xin Yao <xiny@nvidia.com>
There was a problem hiding this comment.
I don't think we should use the use_fused_attention=True code path for the CuTe DSL kernels.
- The C++ and Python APIs in the cuDNN frontend don't really have anything to do with each other, and they're only put together for organizational reasons.
- This doesn't generalize well if we want to add more backends. I could imagine TE eventually having some DPA kernels written in Triton or maybe some new package will come out with a better attention implementation.
- Putting the C++ and Python cuDNN APIs together means we will need to do redundant feature checks, once to decide
use_fused_attn=Trueand again to choose between the C++ or Python cuDNN APIs (and again in the C++ infrastructure to choose theNVTE_Fused_Attn_Backend). This extra layer is not so bad now since there's only one Python case, but it will scale poorly if we get closer to a 50/50 split between C++ and Python.
The right approach would be to make a new backend, but we're restricted by the current design of the backend selection logic. In particular, get_attention_backend is hard-coded to return use_flash_attention, use_fused_attention, use_unfused_attention, so we're kind of stuck. I propose changing this function to return a list of strings. Adding a new backend just involves adding logic to handle a new string.
CC @cyanguwa
There was a problem hiding this comment.
Yes, I agree. Though that's not an easy refactor.
| # Python-only sentinel (no C++ counterpart). Set by ``get_attention_backend`` | ||
| # when the cuDNN frontend CuTe-DSL d=256 SDPA path is applicable; checked | ||
| # in ``fused_attn_fwd`` / ``fused_attn_bwd`` to route to the Python kernel. | ||
| "F16_cudnn_fe_sdpa": "F16_cudnn_fe_sdpa", |
There was a problem hiding this comment.
Since FusedAttnBackend is no longer guaranteed to return an NVTE_Fused_Attn_Backend, the safe thing to do would be to check the return every thing we access this dict. That's really painful, so it's a hint we're doing something wrong.
| if device_compute_capability[0] < 10: | ||
| return False |
There was a problem hiding this comment.
Does the kernel work with SM 12 or beyond? To be safe:
| if device_compute_capability[0] < 10: | |
| return False | |
| if device_compute_capability[0] == 10: | |
| return False |
Description
Add head dim 256 support for SDPA on Blackwell via cuDNN FE cuteDSL kernels.
Type of change
Changes
The existing C++ cuDNN SDPA path cannot handle head_dim=256 training on SM100+; cuDNN frontend ≥1.22.1 ships a Python-only (CuTe DSL) kernel for this shape. This PR wires that kernel into TE's FusedAttention dispatch.
transformer_engine/common/fused_attn/fused_attn.cpptransformer_engine/pytorch/attention/dot_product_attention/cudnn_fe_sdpa.pytransformer_engine/pytorch/cpp_extensions/fused_attn.pytransformer_engine/pytorch/attention/dot_product_attention/utils.pytests/pytorch/attention/test_headdim256_cudnn_fe.pyChecklist: