My DGX Spark was running slower MoE inference than it should have — because of a two-line bug that blocked every SM12x GPU
My DGX Spark was silently running slower MoE inference than it should have been — because a gate check in vLLM's Python layer had been blocking every SM120 and SM121 GPU from the CUTLASS grouped GEMM path since the hardware shipped. No error, no warning. Just quiet fallback to Triton on every single MoE forward pass.
This post is about finding that, fixing it, writing the missing CUTLASS kernel specialization from scratch, validating it on real hardware, and filing the upstream PRs. Start to finish.
Background: what SM121 is and why it matters for MoE
The DGX Spark runs on GB10, which is NVIDIA's consumer Blackwell. Compute capability SM_121. Not SM_90a (the datacenter Hopper), not SM_100 (datacenter Blackwell GB200), not SM_120 (RTX 5090). SM_121 is a discrete point in the Blackwell consumer line — 128 GB unified LPDDR5X, 20 SMs, 294 TFLOPS FP8.
MoE (mixture-of-experts) models like Gemma 4 26B route each token through a small subset of "expert" feed-forward layers. For Gemma 4 it's 26B total parameters but only ~4B active per forward pass. The routing and dispatch is implemented as a grouped GEMM: a single CUDA kernel that operates on a batch of small matrix multiplications, one per active expert.
CUTLASS is NVIDIA's template library for high-performance matrix math. vLLM uses it for the grouped GEMM dispatch. The specific kernel is in csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu — the "c3x" refers to CUTLASS v3 API, and "sm100" means it was written for datacenter Blackwell.
There is no equivalent grouped_mm_c3x_sm120.cu in upstream vLLM. I added it.
Bug 1: the Python gate that blocked every SM12x
In vllm/_custom_ops.py, around line 845:
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
if cuda_device_capability < 90 or cuda_device_capability >= 110:
return False
cuda_device_capability is an integer — it's 121 on my GB10. 121 >= 110 is True. So this function returned False for every SM120 and SM121 device, telling vLLM to skip CUTLASS entirely and use the Triton MoE backend instead.
The correct threshold for "future unsupported architecture" is >= 130. The fix is one character: 110 → 130.
I found this while tracing why my GPU, which clearly has FP8 tensor cores, was hitting the Triton path. A quick print(cutlass_group_gemm_supported(121)) from inside the vLLM container confirmed it: False.
Bug 2: the missing CUTLASS collective
After fixing the gate, vLLM tried to call the SM120 grouped GEMM path — and it didn't exist. The existing grouped_mm_c3x_sm120.cu in my local tree (a partial attempt from earlier development) had another bug: it referenced KernelPtrArrayTmaWarpSpecializedCooperativeSm120 without the required template argument <2>. That's a compile error.
But even before that: there was no CollectiveMma or CollectiveBuilder specialization in CUTLASS 4.5 for this kernel schedule. The dispatch policy type — KernelPtrArrayTmaWarpSpecializedCooperativeSm120<N> — was defined in CUTLASS's headers, but trying to instantiate it via CollectiveBuilder hit a "Could not build a collective for given parameters" static_assert. The symbol existed, the plumbing did not.
What does "collective" mean here? In CUTLASS v3 architecture:
- A CollectiveMma specifies how the mainloop runs: how data gets loaded from global memory → shared memory → registers, and how the tensor core instructions fire. It's the innermost kernel loop body.
- A CollectiveBuilder is the template dispatch layer that takes high-level parameters (element types, tile size, arch tag, schedule tag) and produces the right
CollectiveMma.
For SM_120, CUTLASS 4.5 had CollectiveMma specializations for dense GEMM and blockwise-scaled grouped GEMM (for NVFP4). The ptr-array (grouped) FP8 path — tensor/token-scaled FP8, the kind produced by tools like llm-compressor and used by Gemma 4, Mixtral FP8, Qwen FP8 — had no CollectiveMma at all.
Writing the SM120 collective
I used the SM100 dense GEMM collective as a reference. The implementation is ~400 lines of C++ templates in two files:
sm120_mma_array_tma.hpp — the CollectiveMma specialization. The mainloop body:
- Loads A and B operand tiles from global memory using TMA (Tensor Memory Accelerator)
- Uses pointer-array indirection:
params.ptr_A[problem_idx]gives the A matrix for the current problem in the grouped batch - Pipelines loads and MMA operations using
PipelineTmaUmmaAsync - Dispatches to UMMA (Unified Matrix Multiply-Accumulate) instructions via
rr_op_selector_sm120
The key difference from SM100: SM120 has no programmatic multicast and no 2-SM cluster support. Cluster shape is forced to 1×1×1. The UMMA atom layout is 4×2 for Cooperative schedule and 2×2 for Pingpong.
sm120_array_mma_builder.inl — the CollectiveBuilder dispatch layer. It plumbs the schedule tag's template parameter (SchedulerPipelineStageCount) through to the dispatch policy, computes stage count from shared memory capacity, and assembles the CollectiveOp.
One bug I hit during development: in the builder, I wrote:
using StrideA = TagToStrideA_t<GmemLayoutATag>*;
For ptr-array kernels, GmemLayoutATag is already RowMajor* — a pointer type. TagToStrideA_t<RowMajor*> already returns tuple<...>*. Adding * made it tuple<...>**. The compiler caught this with 95 errors about incompatible pointer types. The fix was removing the extra *. The comment I left explains the invariant:
// GmemLayoutATag is already a pointer type (e.g. RowMajor*) for ptr-array kernels;
// TagToStrideA_t propagates the pointer, so no extra * needed here.
using StrideA = TagToStrideA_t<GmemLayoutATag>;
Three other files needed single-line #include additions to wire the new files in: collective_mma.hpp, collective_builder.hpp, and sm120_mma_builder.inl (which also needed its enable_if condition updated to stop claiming ownership of the ptr-array schedules, which now route to the new builder).
Total: 5 files changed, ~940 lines added.
Building and testing
The development cycle was four Docker builds — each one either hit a cmake configure error, a compile error, or a wheel-size limit. The errors in order:
- cmake
toolsdirectory missing: CUTLASS's CMakeLists.txt unconditionally callsadd_subdirectory(tools)regardless ofCUTLASS_ENABLE_HEADERS_ONLY. Fixed by includingtools/,python/, andexamples/in the rsync to the vLLM build context. - Machete generation failure: vLLM's CMake adds
${CUTLASS_DIR}/python/to PYTHONPATH for a generation script.python/wasn't in the initial rsync. - 95 compile errors — double-pointer stride: The
StrideA*bug described above. - Wheel size check: vLLM's
check-wheel-size.pyenforces a 500 MB limit; the full wheel is ~678 MB. Added--build-arg RUN_WHEEL_CHECK=falseto the build command.
The environment variable VLLM_CUTLASS_SRC_DIR in vLLM's CMakeLists.txt overrides the CUTLASS FetchContent download — this is the clean injection mechanism. I injected a full rsync of CUTLASS 4.5.1 plus the 5 new/modified files as cutlass-patched/ in the build context, with a two-line Dockerfile addition:
COPY cutlass-patched /workspace/cutlass-patched
ENV VLLM_CUTLASS_SRC_DIR=/workspace/cutlass-patched
The resulting image vllm-gemma4-sm121:cutlass-fix is 26.2 GB. Gemma 4 loads, SM120 CUTLASS grouped GEMM activates, inference runs.
Benchmark results
I ran both images concurrently — baseline on port 11437, cutlass-fix on 11440 — with a bench script that sends short/medium/long prompts (128, 512, 1024 token outputs), 3 iterations each.
| Prompt size | Baseline (Triton) | SM120 CUTLASS | Delta |
|---|---|---|---|
| Short (128 tok output) | 76.3 tok/s | 81.9 tok/s | +7.3% |
| Medium (512 tok output) | 89.8 tok/s | 89.8 tok/s | ≈0% |
| Long (1024 tok output) | 87.6 tok/s | 85.5 tok/s | -2.4% |
The short-sequence result is the clearest signal. Decode-dominated workloads — where the MoE expert dispatch runs every forward pass without prefill diluting the measurement — show a real improvement. The medium and long numbers are within the noise floor of a 3-iteration speculative decoding benchmark; the accept-rate variance is larger than any real delta.
The honest read: this is a correctness fix that also happens to improve the common case. Short prompts with long thinking output (the typical interactive use case) get ~7% more throughput. Longer outputs are neutral. The SM120 CUTLASS kernel produces correct results; the Triton fallback was just leaving performance on the table.
Upstream path
The CUTLASS changes went to NVIDIA/cutlass as PR #3280, which closes issue #3263 (the feature request I filed the day before, after confirming the collective was genuinely missing). Once that merges into CUTLASS main and vLLM pins a revision containing it, the vLLM PR (#43814) will compile cleanly from the public repo.
The vLLM PR contains the two bug fixes plus the new kernel file. It references the CUTLASS dependency explicitly.
Why not HuggingFace Kernels
HuggingFace Kernels is for standalone Python-wrapped GPU kernels — the kind you can pip install. Think FlashAttention, xFormers, individual fused ops. What we built is a C++ template specialization inside CUTLASS's header-only library. It instantiates at vLLM build time via nvcc; there's no standalone .so to publish. The distribution channel is CUTLASS upstream, which benefits every downstream consumer — vLLM, Triton, custom frameworks — automatically once merged.
What it means for SM12x owners
Every GB10, RTX 5090, RTX 5080, and RTX 5070 running FP8 MoE models in vLLM has been silently using the slower Triton path. The CUTLASS path is faster and now exists. The vLLM PR is pending, and the CUTLASS PR is in review with NVIDIA. Once they merge and vLLM pins a compatible CUTLASS revision, the improvement will be automatic — rebuild the image, get the upgrade.
If you're building vLLM from source on SM12x hardware today, you can apply the patches manually and use VLLM_CUTLASS_SRC_DIR to inject them. The full patch is in the NVIDIA/cutlass PR.