Type to search posts and projects to navigate

Beating CUDA with Triton: A Fused MoE Dispatch Kernel for Mixtral and DeepSeek

In my last post on Triton kernels, I optimized individual operations: RMSNorm, SwiGLU, INT8 GEMM. Single kernels, single operations. That was useful for learning Triton, but the real bottleneck in modern LLM inference isn’t any single operation. It’s the expert routing in Mixture-of-Experts models.

Over 60% of open-source model releases in 2025-2026 use MoE architectures: Mixtral, DeepSeek-V3, Qwen2-MoE, Grok. And MoE inference is hard. Not because the math is complicated, but because the memory access patterns are terrible: tokens scatter to different experts, each expert gets a different number of tokens, and you need to gather everything back together afterward.

So I tried something more ambitious: a fused MoE dispatch kernel that handles the entire forward pass (router scoring, token permutation, expert GEMMs, and output combination) in pure Triton. No CUDA, no vendor-specific code.

The result surprised me. At inference-relevant batch sizes, it’s faster than Megablocks, Stanford’s CUDA-optimized MoE library. And it runs on AMD GPUs without any changes.

Code: github.com/bassrehab/triton-kernels

Why MoE Dispatch is the Hard Part

A standard MoE forward pass looks simple on paper:

For each token:
    1. Compute router scores (which experts should handle this token?)
    2. Select top-k experts
    3. Send token to selected experts
    4. Run expert FFN
    5. Combine outputs weighted by router scores

The problem is step 3-5. In a Mixtral model with 8 experts and top-2 routing, each token goes to 2 of 8 experts. But which 2 varies per token. So you can’t batch the expert GEMMs naively — each expert gets a different-sized batch.

The naive PyTorch implementation loops over experts in Python:

for expert_id in range(num_experts):
    tokens_for_this_expert = permuted_tokens[start:end]  # variable size
    output[start:end] = expert_ffn(tokens_for_this_expert)  # separate cuBLAS call

For Mixtral, that’s 8 experts × 3 matmuls each = 24 separate kernel launches per MoE layer. For DeepSeek-V3 with 256 experts, it’s 768 launches. Each one underutilizes the GPU because the per-expert batch is small.

The Design

I ended up with a pipeline of 5 Triton kernel launches (down from 24+ in the naive approach):

  1. Router kernel: fused softmax + top-k selection
  2. Permute kernel: scatter tokens to expert-contiguous layout
  3. Fused gate+up GEMM: both projections from shared A-tile loads, SiLU in registers
  4. Down GEMM: grouped GEMM with block scheduling
  5. Unpermute kernel: gather + weighted combine

Let me walk through the two most interesting parts.

Block-Scheduled Grouped GEMM

The central problem is: how do you run a matmul where different “groups” (experts) have different batch sizes, in a single kernel launch?

My approach: precompute a mapping from Triton program blocks to (expert, token_offset) pairs. Each block looks up which expert it serves and where its tokens start:

@triton.jit
def _grouped_gemm_kernel(A, B, C, ExpertOffsets, BlockToExpert, BlockToM, ...):
    pid = tl.program_id(0)

    # Which expert am I working on?
    expert_id = tl.load(BlockToExpert + pid)
    m_start = tl.load(BlockToM + pid)
    expert_token_start = tl.load(ExpertOffsets + expert_id)

    # Standard tiled GEMM from here, just with offset pointers
    global_m_start = expert_token_start + m_start
    # ... load A tile, load B tile for this expert, accumulate, store

The schedule is built on CPU in ~0.1ms (trivial loop over experts). The key constraint I learned the hard way: BLOCK_M must be fixed, not autotuned. If you autotune BLOCK_M independently of the schedule, the kernel and schedule disagree on how many rows each block covers. I spent an hour debugging 30-45% element mismatches before realizing autotune had picked BLOCK_M=128 while the schedule used 64.

Fused Gate+Up Projection

This is where the real memory savings come from. In a SwiGLU FFN, you compute:

\[\text{output} = (\text{SiLU}(x W_\text{gate}^T) \odot x W_\text{up}^T) \cdot W_\text{down}^T\]

The unfused version does two separate grouped GEMMs (gate and up), writes both results to global memory, reads them back for SiLU + multiply, writes the intermediate, then does the down projection. That’s a lot of memory traffic.

The fused kernel computes both projections in the same tile loop. The trick is that both GEMMs share the same input tile — we load A once from L2 cache and compute two dot products:

# Two accumulators in registers
acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

for k_start in range(0, K, BLOCK_K):
    # Load A tile ONCE (shared between gate and up)
    a = tl.load(a_ptrs, mask=a_mask, other=0.0)

    # Load both weight tiles
    b_gate = tl.load(bg_ptrs, mask=b_mask, other=0.0)
    b_up = tl.load(bu_ptrs, mask=b_mask, other=0.0)

    # Two matmuls from the same A tile
    acc_gate += tl.dot(a, b_gate, out_dtype=tl.float32)
    acc_up += tl.dot(a, b_up, out_dtype=tl.float32)

# SiLU + multiply IN REGISTERS — never written to global memory
silu_gate = acc_gate * tl.sigmoid(acc_gate)
result = silu_gate * acc_up

This eliminates gate_out and up_out from global memory entirely. For Mixtral (ffn_dim=14336, 4096 tokens × top-2), that’s ~470 MB of memory traffic saved per forward pass. Overall about 35% reduction in global memory traffic.

I tried to also fuse the down projection with the output scatter (writing directly to the final token positions with gating weights applied via tl.atomic_add), but Triton doesn’t support scalar indexing into 2D accumulators (acc[m, :] fails to compile). The fused gate+up alone gets most of the win.

Results

All benchmarks on NVIDIA A100-SXM4-80GB (2039 GB/s bandwidth, 312 FP16 TFLOPS). PyTorch 2.4.1, Triton 3.0.0.

Mixtral-8x7B (8 experts, top-2, hidden=4096, ffn=14336)

Tokens PyTorch Ref Megablocks Triton Fused vs PyTorch vs Megablocks
1 9.32 ms - 1.02 ms 9.1x -
32 10.44 ms 2.78 ms 2.13 ms 4.9x 131%
128 13.14 ms 2.77 ms 2.27 ms 5.8x 124%
512 25.92 ms 3.57 ms 3.99 ms 6.5x 89%
2048 66.22 ms 9.08 ms 16.48 ms 4.0x 56%
4096 122.82 ms - 32.31 ms 3.8x -

At 32 and 128 tokens — which is where most inference happens (single-user or small-batch serving), we’re actually faster than Megablocks. This probably comes from lower kernel launch overhead (5 launches vs Megablocks’ more complex dispatch).

At 512 tokens we’re at 89% of Megablocks, well above the 70% target I set at the start. At 2048+ tokens, Megablocks pulls ahead because its hand-tuned CUDA block-sparse matmul better saturates tensor cores at scale.

DeepSeek-V3 (256 experts, top-8, hidden=7168, ffn=2048)

Tokens Triton Unfused Triton Fused Fused Speedup
1 4.56 ms 3.27 ms 1.40x
32 13.65 ms 11.53 ms 1.18x
128 19.46 ms 16.74 ms 1.16x
512 25.66 ms 20.16 ms 1.27x

DeepSeek-V3 is the hardest configuration. 256 experts means each expert gets ~2 tokens on average at batch size 512. The per-expert GEMMs are tiny (2 × 2048), too small to fill tensor cores efficiently. This is fundamentally a memory-bound regime regardless of implementation.

Roofline Analysis

The roofline for Mixtral at 512 tokens shows the expected picture: the expert FFN stages are compute-bound (high arithmetic intensity, near the compute ceiling), while the permute/unpermute stages are memory-bound (low arithmetic intensity, limited by bandwidth). The fused kernel pushes the expert FFN from 38% to 43% of the compute ceiling — modest but real.

DeepSeek-V3 tells a different story. With 256 experts and tiny per-expert batches, even the expert FFN is memory-bound, sitting on the bandwidth slope, not the compute plateau. The unpermute kernel actually hits 54% of peak bandwidth, which is decent for an irregular scatter operation.

The AMD Surprise

One of my design goals was cross-platform portability: use only Triton primitives, no inline CUDA. So I spun up an AMD MI300X pod on RunPod to test.

162 out of 162 tests passed. Zero code changes.

No #ifdef, no platform-specific paths, no vendor intrinsics. The same .py files that run on A100 run on MI300X. Triton’s ROCm backend handled the compilation transparently.

I didn’t benchmark performance on AMD (that’s future work), but correctness across all four model configurations (Mixtral, DeepSeek-V3, Qwen2-MoE) validated cleanly. This is the promise of Triton over CUDA: write once, run on both vendors.

Things I Got Wrong Along the Way

The -1.0 masking bug. The top-k kernel selects experts iteratively: find the max, store it, mask it out, repeat. I initially masked selected experts with 0.0. This works fine for 8 experts where softmax scores are spread out. But with 256 experts, most softmax scores are ~0.0 anyway. Masking to 0.0 doesn’t differentiate the selected expert from the unselected ones, so argmax kept returning the same index. Took me a while to figure out. The fix: mask with -1.0 instead.

The BLOCK_M autotune disaster. I mentioned this above, but it’s worth emphasizing. If you’re building a block-scheduled grouped GEMM, the schedule’s tile size and the kernel’s tile size must agree. I autotuned BLOCK_M thinking “let Triton pick the best tile size.” But the schedule was pre-built with BLOCK_M=64. When autotune picked 128, blocks overlapped. When it picked 32, rows were skipped. The output looked plausible (most elements correct) but ~30-45% of values were wrong. Fix: don’t autotune BLOCK_M, fix it to match the schedule.

Triton doesn’t support continue. My first attempt at a fused down+scatter kernel had a for m in range(BLOCK_M): if invalid: continue loop. Triton doesn’t support continue statements — compilation fails with “unsupported AST node type.” Rewrote with conditional masks instead.

Triton doesn’t support 2D scalar indexing. acc[m, :] where m is a loop variable doesn’t compile: “unsupported tensor index: int32[].” This killed my fused down+scatter design, which is why the down projection uses a separate grouped GEMM kernel.

What’s Next

I’m planning to write this up as an arXiv technical report. The gaps to fill before then:

  • vLLM FusedMoE comparison: it’s also Triton-based, so it’s the most apples-to-apples baseline
  • AMD performance benchmarks: not just correctness
  • End-to-end integration: benchmark inside an actual serving framework, measure time-to-first-token
  • Full single-kernel fusion: persistent kernel approach to eliminate all intermediate buffers

Code

Everything is on GitHub: github.com/bassrehab/triton-kernels

The MoE-specific files:

Takeaways

  1. Triton can compete with CUDA for real workloads. Not just toy kernels: a full MoE dispatch pipeline that beats the CUDA-optimized baseline at inference batch sizes.

  2. Fusion is about eliminating buffers, not reducing kernel launches. The biggest win (35% memory savings) came from keeping the gate+up intermediate in registers. Reducing from 7 to 5 kernel launches helped too, but it’s secondary.

  3. Cross-platform is real but unfinished. The code runs on AMD with no changes, which is a strong validation of the Triton-only approach. But “runs correctly” and “runs fast” are different things. AMD performance optimization is future work.

  4. Block scheduling is the key abstraction for grouped GEMM. Triton doesn’t have native grouped GEMM. The block_id → (expert_id, offset) mapping is simple but powerful: it lets you handle variable-sized expert batches in a single kernel launch without padding waste.

  5. MoE inference at small batch sizes is surprisingly tractable. The conventional wisdom is that MoE is hard because of irregular access patterns. But at inference batch sizes (1-128 tokens), the overhead is dominated by weight loading, not routing. A clean Triton implementation can match or beat CUDA here because the simpler dispatch has less overhead.


_This is Part 3 of my LLM inference series. Part 1: speculative decoding Part 2: custom Triton kernels. The code, benchmarks, and technical writeup are all in the repo._

Cite this article

Mitra, Subhadip. (2026, March). Beating CUDA with Triton: A Fused MoE Dispatch Kernel for Mixtral and DeepSeek. Subhadip Mitra. Retrieved from https://subhadipmitra.com/blog/2026/fused-moe-dispatch-triton/

@article{mitra2026beating-cuda-with-triton-a-fused-moe-dispatch-kernel-for-mixtral-and-deepseek,
  title   = {Beating CUDA with Triton: A Fused MoE Dispatch Kernel for Mixtral and DeepSeek},
  author  = {Mitra, Subhadip},
  journal = {Subhadip Mitra},
  year    = {2026},
  month   = {Mar},
  url     = {https://subhadipmitra.com/blog/2026/fused-moe-dispatch-triton/}
}
Share this article

Get More Like This

Strategic insights on Data, AI, and Cloud transformation delivered to your inbox.

Free insights. No spam. Unsubscribe anytime.

Subhadip Mitra