[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907
[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907jing-4369 wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR fixes two related bugs in Two issues flagged in earlier review rounds are still unresolved in the post-PR code: the Confidence Score: 4/5The forward-path fixes are correct and well-structured, but two pre-existing P1 issues in the backward path (index narrowing, k*num_rows overflow) and the int32 guard arithmetic remain unresolved. The primary bugs (int32 overflow in pointer arithmetic and -1 sentinel mishandling) are correctly fixed. However, the int index = source_token narrowing in moe_permute_kernel (backward pass) and the k * num_rows int32 overflow in the inner-loop subscripts are still present and represent real correctness risks for large topK/num_rows configurations. The num_tokens * topK guard is also computed in int32. These were raised in earlier review rounds and should be resolved before merge to fully close the class of overflow bugs this PR targets. transformer_engine/common/permutation/permutation.cu — backward kernel index handling; transformer_engine/pytorch/csrc/extensions/permutation.cpp — int32 guard arithmetic Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["moe_permute_fwd entry"] --> B["cub::DeviceRadixSort signed\nSorts expert_ids and row_ids pairs"]
B --> C["sorted_row_id has sentinel prefix\nthen valid ids at tail"]
C --> D["Advance sorted_row_id_ptr past sentinels\nnum_minus_ones elements skipped"]
D --> E["Pre-fill row_id_map with -1\nvia torch::full dropped slots marked"]
E --> F["moe_permute_row_map kernel\ngrid sized to num_out_tokens\nprocesses only valid suffix"]
F --> G["row_id_map entries written\nDropped slots remain -1"]
G --> H["moe_permute_kernel\nsource_token as int64_t\ndest_row as int64_t\npointer math stays 64-bit"]
H --> I["permuted_output returned"]
Reviews (5): Last reviewed commit: "Guard against invalid num_out_tokens in ..." | Re-trigger Greptile |
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int); |
There was a problem hiding this comment.
Negative
num_minus_ones becomes enormous size_t offset
num_minus_ones is computed as int. If a caller passes num_out_tokens > num_tokens * topK (which the function does not validate), num_minus_ones is negative. The pointer advance expression:
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int);involves int * size_t, which promotes num_minus_ones to size_t (unsigned). A value like -4 becomes SIZE_MAX - 3, advancing the pointer far out of the allocation and causing a silent OOB read. A simple clamp or assert before this line would prevent this:
TORCH_CHECK(num_out_tokens <= num_tokens * topK,
"num_out_tokens (", num_out_tokens, ") cannot exceed num_tokens*topK (",
num_tokens * topK, ")");…permute Two independent bugs in transformer_engine/common/permutation/permutation.cu and the PyTorch extension caller reproduce on main (264da2b) and v2.13: 1. int32 overflow in moe_unpermute_kernel and moe_permute_kernel. `source_token * num_cols` and `source_row * num_cols` are computed with int, so for long-sequence MoE workloads where num_out_tokens * num_cols reaches 2**31 (e.g. 2**18 tokens x 2**13 hidden), the pointer offset wraps and the kernel either reads garbage or raises `an illegal memory access was encountered`. Widening source_token, source_row and dest_row to int64_t inside the kernels keeps the index arithmetic in 64 bits without changing any public types. 2. Incorrect handling of -1 sentinels in the routing indices. Libraries such as DeepEP (and any expert-parallel mask that sets non-local (token, slot) pairs to -1) feed a routing_map that contains -1 entries. `cub::DeviceRadixSort::SortPairs` is signed ascending, so those sentinels land at the HEAD of sorted_row_id, not the tail. moe_permute_row_map currently writes -1 only for idx >= num_out_tokens and reads the sentinel prefix as if it were a valid sorted id, producing bogus row_id_map writes (for instance `source_row / topK == 0, source_row % topK == -1`). The caller now advances sorted_row_id_ptr past the num_minus_ones prefix and pre-fills row_id_map with -1 via torch::full, so the kernel only processes the valid suffix and never dereferences a sentinel. The launcher's grid switches from num_rows*topK blocks to num_out_tokens blocks to match the new valid range. No behaviour change on happy-path routing_map (no -1, no overflow). Reproducers: - 8-token, topK=2 routing_map with -1 masking: max |TE - ref| = 4.5e0 on bf16 with current main; 0.0 with this patch. - num_tokens=2**18+1, num_cols=2**13, topK=1: current main raises CUDA illegal memory access at permutation.cu:252; with this patch it succeeds. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
a909254 to
ce34774
Compare
for more information, see https://pre-commit.ci
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | ||
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | ||
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
There was a problem hiding this comment.
num_tokens * topK still computed as int * int
num_tokens and topK are both int, so num_tokens * topK on line 61 is evaluated in 32-bit arithmetic before the result feeds the int64_t subtraction. The same expression appears twice in the NVTE_CHECK on lines 59–60. If num_tokens * topK wraps to a negative int (possible when, e.g., num_tokens ≥ 2^31 / topK), the NVTE_CHECK would either spuriously reject a valid num_out_tokens, or the error-message value would be wrong. Casting to int64_t before the multiplication closes this gap:
| const int num_minus_ones = num_tokens * topK - num_out_tokens; | |
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | |
| static_cast<size_t>(num_minus_ones) * sizeof(int); | |
| const int64_t total_tokens = static_cast<int64_t>(num_tokens) * topK; | |
| NVTE_CHECK(num_out_tokens <= total_tokens, "num_out_tokens (", num_out_tokens, | |
| ") must not exceed num_tokens*topK (", total_tokens, ")"); | |
| const int num_minus_ones = static_cast<int>(total_tokens - num_out_tokens); | |
| sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + | |
| static_cast<size_t>(num_minus_ones) * sizeof(int); |
ba39fba to
cd64aa5
Compare
Add an NVTE_CHECK that num_out_tokens <= num_tokens * topK and cast num_minus_ones to size_t before the pointer advance, so a negative num_minus_ones (from an invalid num_out_tokens) cannot silently wrap into a huge pointer offset. Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
cd64aa5 to
b73a1f9
Compare
Fixes #2908 — full description, repros, and DeepSeek-V3 context there.
Changes
permutation.cu— widensource_token,source_row,dest_rowtoint64_tinsidemoe_unpermute_kernelandmoe_permute_kernelsorow * num_colsstays 64-bit. Simplifymoe_permute_row_mapto only process the valid[0, num_out_tokens)range; launcher grid becomesnum_out_tokensblocks.permutation.cpp— advancesorted_row_id_ptrpast thenum_minus_onessentinel prefix left bycub::DeviceRadixSort(signed ascending), and pre-fillrow_id_mapwith-1viatorch::fullso dropped slots are marked without the kernel ever dereferencing a sentinel.No public API / dtype changes.
+17 / -18lines across the two files.Test plan
routing_map(no-1, offsets within int32) — unchanged.-1-sentinel repro from [Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling #2908 →max |TE - ref| = 0.0on bf16 (was4.56e0).int32-boundary repro from [Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling #2908 → no longer raisesillegal memory access; matches reference.tests/pytorch/test_permutation.pyvia CI.