Skip to content

[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907

Open
jing-4369 wants to merge 3 commits intoNVIDIA:mainfrom
jing-4369:fix/moe-permute-int-overflow-and-minus-one
Open

[Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute#2907
jing-4369 wants to merge 3 commits intoNVIDIA:mainfrom
jing-4369:fix/moe-permute-int-overflow-and-minus-one

Conversation

@jing-4369
Copy link
Copy Markdown

@jing-4369 jing-4369 commented Apr 21, 2026

Fixes #2908 — full description, repros, and DeepSeek-V3 context there.

Changes

  • permutation.cu — widen source_token, source_row, dest_row to int64_t inside moe_unpermute_kernel and moe_permute_kernel so row * num_cols stays 64-bit. Simplify moe_permute_row_map to only process the valid [0, num_out_tokens) range; launcher grid becomes num_out_tokens blocks.
  • permutation.cpp — advance sorted_row_id_ptr past the num_minus_ones sentinel prefix left by cub::DeviceRadixSort (signed ascending), and pre-fill row_id_map with -1 via torch::full so dropped slots are marked without the kernel ever dereferencing a sentinel.

No public API / dtype changes. +17 / -18 lines across the two files.

Test plan

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 21, 2026

Greptile Summary

This PR fixes two related bugs in moe_permute/moe_unpermute: (1) widens source_token, source_row, and dest_row to int64_t in the CUDA kernels so that row * num_cols pointer arithmetic stays 64-bit for large tensors, and (2) properly handles the -1 sentinel prefix that cub::DeviceRadixSort (signed ascending) places at the head of sorted_row_id by advancing the pointer past it in the host code and pre-filling row_id_map with -1 via torch::full.

Two issues flagged in earlier review rounds are still unresolved in the post-PR code: the int index = source_token narrowing in moe_permute_kernel (backward pass) and the k * num_rows int32 intermediate in the inner-loop index computations — both can overflow for large topK × num_rows configurations. Additionally, the guard NVTE_CHECK(num_out_tokens <= num_tokens * topK …) and the num_minus_ones derivation still evaluate num_tokens * topK as a 32-bit multiplication before the int64_t comparison, so the check is itself unsound for extreme sizes.

Confidence Score: 4/5

The forward-path fixes are correct and well-structured, but two pre-existing P1 issues in the backward path (index narrowing, k*num_rows overflow) and the int32 guard arithmetic remain unresolved.

The primary bugs (int32 overflow in pointer arithmetic and -1 sentinel mishandling) are correctly fixed. However, the int index = source_token narrowing in moe_permute_kernel (backward pass) and the k * num_rows int32 overflow in the inner-loop subscripts are still present and represent real correctness risks for large topK/num_rows configurations. The num_tokens * topK guard is also computed in int32. These were raised in earlier review rounds and should be resolved before merge to fully close the class of overflow bugs this PR targets.

transformer_engine/common/permutation/permutation.cu — backward kernel index handling; transformer_engine/pytorch/csrc/extensions/permutation.cpp — int32 guard arithmetic

Important Files Changed

Filename Overview
transformer_engine/common/permutation/permutation.cu Widens source_token/source_row/dest_row to int64_t in forward and backward kernels (fixes the primary int32 overflow). moe_permute_row_map is simplified to only iterate valid [0, num_out_tokens) range. Remaining issues: int index = source_token narrowing (backward path) and k * num_rows computed as int32 in inner-loop subscripts — both pre-existing but still not addressed.
transformer_engine/pytorch/csrc/extensions/permutation.cpp Adds sentinel-prefix skipping logic: advances sorted_row_id_ptr by num_minus_ones elements, pre-fills row_id_map with -1 via torch::full, and narrows the sorted_row_id TE tensor to num_out_tokens. The new NVTE_CHECK guards against negative num_minus_ones but is computed with num_tokens * topK as int * int = int32, which may overflow before the comparison.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["moe_permute_fwd entry"] --> B["cub::DeviceRadixSort signed\nSorts expert_ids and row_ids pairs"]
    B --> C["sorted_row_id has sentinel prefix\nthen valid ids at tail"]
    C --> D["Advance sorted_row_id_ptr past sentinels\nnum_minus_ones elements skipped"]
    D --> E["Pre-fill row_id_map with -1\nvia torch::full dropped slots marked"]
    E --> F["moe_permute_row_map kernel\ngrid sized to num_out_tokens\nprocesses only valid suffix"]
    F --> G["row_id_map entries written\nDropped slots remain -1"]
    G --> H["moe_permute_kernel\nsource_token as int64_t\ndest_row as int64_t\npointer math stays 64-bit"]
    H --> I["permuted_output returned"]
Loading

Reviews (5): Last reviewed commit: "Guard against invalid num_out_tokens in ..." | Re-trigger Greptile

Comment on lines +59 to +60
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Negative num_minus_ones becomes enormous size_t offset

num_minus_ones is computed as int. If a caller passes num_out_tokens > num_tokens * topK (which the function does not validate), num_minus_ones is negative. The pointer advance expression:

sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) + num_minus_ones * sizeof(int);

involves int * size_t, which promotes num_minus_ones to size_t (unsigned). A value like -4 becomes SIZE_MAX - 3, advancing the pointer far out of the allocation and causing a silent OOB read. A simple clamp or assert before this line would prevent this:

TORCH_CHECK(num_out_tokens <= num_tokens * topK,
            "num_out_tokens (", num_out_tokens, ") cannot exceed num_tokens*topK (",
            num_tokens * topK, ")");

…permute

Two independent bugs in transformer_engine/common/permutation/permutation.cu
and the PyTorch extension caller reproduce on main (264da2b) and v2.13:

1. int32 overflow in moe_unpermute_kernel and moe_permute_kernel.
   `source_token * num_cols` and `source_row * num_cols` are computed with
   int, so for long-sequence MoE workloads where num_out_tokens * num_cols
   reaches 2**31 (e.g. 2**18 tokens x 2**13 hidden), the pointer offset
   wraps and the kernel either reads garbage or raises
   `an illegal memory access was encountered`.
   Widening source_token, source_row and dest_row to int64_t inside the
   kernels keeps the index arithmetic in 64 bits without changing any
   public types.

2. Incorrect handling of -1 sentinels in the routing indices.
   Libraries such as DeepEP (and any expert-parallel mask that sets
   non-local (token, slot) pairs to -1) feed a routing_map that contains
   -1 entries. `cub::DeviceRadixSort::SortPairs` is signed ascending, so
   those sentinels land at the HEAD of sorted_row_id, not the tail.
   moe_permute_row_map currently writes -1 only for idx >= num_out_tokens
   and reads the sentinel prefix as if it were a valid sorted id,
   producing bogus row_id_map writes (for instance
   `source_row / topK == 0, source_row % topK == -1`).

   The caller now advances sorted_row_id_ptr past the num_minus_ones
   prefix and pre-fills row_id_map with -1 via torch::full, so the
   kernel only processes the valid suffix and never dereferences a
   sentinel.  The launcher's grid switches from num_rows*topK blocks
   to num_out_tokens blocks to match the new valid range.

No behaviour change on happy-path routing_map (no -1, no overflow).
Reproducers:

- 8-token, topK=2 routing_map with -1 masking: max |TE - ref| = 4.5e0
  on bf16 with current main; 0.0 with this patch.
- num_tokens=2**18+1, num_cols=2**13, topK=1: current main raises
  CUDA illegal memory access at permutation.cu:252; with this patch
  it succeeds.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
@jing-4369 jing-4369 force-pushed the fix/moe-permute-int-overflow-and-minus-one branch from a909254 to ce34774 Compare April 21, 2026 07:58
Comment on lines +61 to +63
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 num_tokens * topK still computed as int * int

num_tokens and topK are both int, so num_tokens * topK on line 61 is evaluated in 32-bit arithmetic before the result feeds the int64_t subtraction. The same expression appears twice in the NVTE_CHECK on lines 59–60. If num_tokens * topK wraps to a negative int (possible when, e.g., num_tokens ≥ 2^31 / topK), the NVTE_CHECK would either spuriously reject a valid num_out_tokens, or the error-message value would be wrong. Casting to int64_t before the multiplication closes this gap:

Suggested change
const int num_minus_ones = num_tokens * topK - num_out_tokens;
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);
const int64_t total_tokens = static_cast<int64_t>(num_tokens) * topK;
NVTE_CHECK(num_out_tokens <= total_tokens, "num_out_tokens (", num_out_tokens,
") must not exceed num_tokens*topK (", total_tokens, ")");
const int num_minus_ones = static_cast<int>(total_tokens - num_out_tokens);
sorted_row_id_ptr = reinterpret_cast<char *>(sorted_row_id_ptr) +
static_cast<size_t>(num_minus_ones) * sizeof(int);

@jing-4369 jing-4369 force-pushed the fix/moe-permute-int-overflow-and-minus-one branch from ba39fba to cd64aa5 Compare April 21, 2026 08:14
Add an NVTE_CHECK that num_out_tokens <= num_tokens * topK and cast
num_minus_ones to size_t before the pointer advance, so a negative
num_minus_ones (from an invalid num_out_tokens) cannot silently wrap
into a huge pointer offset.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
@jing-4369 jing-4369 force-pushed the fix/moe-permute-int-overflow-and-minus-one branch from cd64aa5 to b73a1f9 Compare April 21, 2026 08:22
@ptrendx ptrendx added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Apr 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] moe_permute CUDA kernel: int32 overflow and incorrect -1 sentinel handling

3 participants