Two days, six hours of Triton tuning, one GB10, and a whole lot of nothing
Or how I learned that a pure-kernel benchmark doesn’t equal serving gains
This is not going to be an “I achieved an x-fold speed-up” article. It’s the opposite: a two-day tuning marathon whose final production config ended up 5-7% worse than the version I’d been running for the past two weeks — and I’m going to tell you about it, because I think this is more useful than yet another triumphant blog post.
If you just want the punchline: on DGX Spark, the Triton MoE tuning via benchmark_moe.py --tune against Qwen3.5-35B-A3B-FP8 runs for 6 hours and regresses prefill TTFT by ~10%. The remaining 9,000 words are optional.
The setup
I have a project called DocAI — a document processing SaaS for Hungarian SMBs, accounting firms, and law offices. The hardware is an NVIDIA DGX Spark with a GB10 chip (Blackwell, SM 12.1, 128 GB unified LPDDR5x), and the model is Qwen/Qwen3.5-35B-A3B-FP8 served via vLLM.
My config had been running fine for a few weeks — --max-num-seqs 32, --gpu-memory-utilization 0.45, --max-model-len 131072, plus the usual FP8 KV cache. The vLLM image was cu130-nightly, pulled about 3 weeks ago. The workload is mixed: chat agent (long system prompt + tool schemas), KIE batch (OCR → JSON document processing), plus a Marker-based PDF pipeline with a section header processor.
Then I reread my own notes on an incident from earlier this month where a one-page PDF got stuck for 15 minutes around an LLM call due to queue contention. That prompted me to collect a list of things that could be sped up. At the top of the list:
Triton fused_moe tuning for GB10 (medium effort, estimated 1.5-2x speed-up)
Because the vLLM log was whispering this at every startup:
Using default MoE config. Performance might be sub-optimal!
Config file not found: E=256,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json
Alright, let’s do this, I thought. Two days and 6 hours of compute later, here we are.
Baseline measurement, because I didn’t want an anecdotal article
First I needed a proper benchmark. Using vllm bench serve, I measured three scenarios so I wouldn’t be waving a single number around:
- A: single decode — 512 input, 512 output, batch=1. The classic “one user generating” measurement
- B: prefill-bound — 8192 input, 256 output, 4 concurrent. Simulating the KIE workload
- C: concurrent contention — 2048 input, 512 output, 16 concurrent. The chat agent + KIE race
I turned on --ignore-eos everywhere (otherwise output-len is only an upper bound and the model might stop early, skewing throughput). --num-warmups 3 before each test, seed=42 fixed. If we’re benchmarking, let’s benchmark properly.
Baseline results (3-week-old nightly image)
| Test | Output tok/s | Mean TTFT | Mean TPOT |
|---|---|---|---|
| A: single decode | 48.92 | 190 ms | 20.11 ms |
| B: prefill (8k) | 75.88 | 4855 ms | 33.87 ms |
| C: concurrent (16) | 208.04 | 4828 ms | 67.61 ms |
This is the reference point. 48.92 tok/s single decode, 4855 ms Mean TTFT on the KIE workload, 208 tok/s total throughput with 16 concurrent. Remember these numbers — we’ll be seeing them a lot.
First surprise: the nightly update made things worse
Before starting the tune, I figured let’s update the image. Maybe something new matured in those 3 weeks — FlashInfer MoE backend support, perhaps the CUTLASS improvements. This will be another data point for the article: “how much does a pure nightly pull buy you?”
docker pull vllm/vllm-openai:cu130-nightly
New hash, new version: 0.19.1rc1.dev231+g9dd5ee011. Boot up, check the log, three new lines jump out that weren’t there before:
Selected CutlassFp8BlockScaledMMKernel for Fp8LinearMethod
Using Triton/FLA GDN prefill kernel
Asynchronous scheduling is enabled.
Mamba cache mode is set to 'align' for Qwen3_5MoeForConditionalGeneration
by default when prefix caching is enabled
Nice — new kernels for the dense FP8 linear layers, a new prefill kernel for linear attention layers (because the Qwen3.5-A3B hybrid arch has 30 linear-attention layers out of 40). Async scheduling as a new default. And a Mamba cache mode is 'align' labelled experimental, saying “Please report any issues you may observe.” Noted.
Running the benchmark:
| Test | Baseline | Phase 2 (new image) | Δ |
|---|---|---|---|
| A: tok/s | 48.92 | 49.76 | +1.7% |
| B: Mean TTFT | 4855 | 5941 | +22% ⚠️ |
| B: tok/s | 75.88 | 70.06 | −7.7% |
| C: Mean TTFT | 4828 | 6893 | +43% ⚠️ |
| C: tok/s | 208.04 | 197.67 | −5.0% |
Wait. +22% TTFT? +43%? The new nightly got slower. That’s not what I expected.
Investigation 1: async scheduling
Digging through the vLLM source, I found two new flags:
--async-scheduling, --no-async-scheduling
--mamba-cache-mode {align,all,none}
I tried turning async off first. It was just a sed, but of course I was the one who couldn’t get the command right — it took two tries, since the first didn’t match the intended line (line breaks inside a command: > folded scalar behave strangely). Eventually I edited it by hand to add --no-async-scheduling.
| Test | Phase 2 (async ON) | Phase 2.5a (async OFF) | Δ |
|---|---|---|---|
| B: Mean TTFT | 5941 | 5249 | −12% ✅ |
| C: Mean TTFT | 6893 | 5153 | −25% ✅ |
OK, that helped. Async scheduling is a new default in v0.19, and it causes a regression on prefill-heavy concurrent workloads. I don’t know exactly why — in theory async improves GPU utilisation, but if the workload doesn’t exploit it, all that’s left is synchronisation overhead.
Throughput didn’t fully recover (71.73 vs baseline 75.88). A ~5% deficit remained, whose cause I identified but realised I can’t do anything about.
Investigation 2: the mamba_cache_mode=align mystery
Qwen3.5-A3B is a hybrid attention arch: 10 full-attention layers + 30 linear attention layers. The linear attention layers carry recurrent state (mamba-like), and vLLM caches that separately alongside prefix caching.
New in v0.19: there’s an automatic override to mamba_cache_mode="align" for Qwen3.5-MoE when prefix caching is on. I tried disabling it with --mamba-cache-mode none, but it didn’t take — the model config hook (in vLLM at /usr/local/lib/.../models/config.py:310-348) overrides whatever the CLI flag provides. The boot log kept showing:
WARNING Mamba cache mode is set to 'align' for Qwen3_5MoeForConditionalGeneration
by default when prefix caching is enabled
So the remaining ~5% regression is probably from this, and it’s not CLI-overridable. Fully turning off prefix caching would eliminate it (no prefix cache means no mamba align), but our chat agent needs the prefix cache. So we lose 5% here with nowhere to move. Noted, but staying put.
The main event: Triton MoE tuning
OK, now for the adventure. benchmark_moe.py --tune tunes the Triton fused_moe kernel parameters (block_m, block_n, block_k, num_stages, num_warps) for the Qwen3.5-A3B architecture (E=256 experts, moe_intermediate_size=512, topk=8). Expected result: a 1.5-2× kernel-level speed-up.
First attempt:
ModuleNotFoundError: No module named 'ray'
Easy fix: pip install --quiet ray && python3 benchmark_moe.py --tune. Second attempt. My first estimate for tuning time: 25-40 minutes, based on the default sweep running for batch sizes 1-4096 (I went with a medium sweep: 1, 4, 8, 16, 32, 64, 128, since with max-num-seqs 32 the realistic serving batch size doesn’t go above 128).
Start tuning over 640 configurations...
That’s 640 kernel config permutations per batch. The first batch took 10 minutes. The second, 20. At minute nine I looked up and asked Claude whether I had to wait it out or could go sleep.
Claude said: don’t interrupt it, restart as docker run -d detached with --name and a persistent log file. I killed the 45-minute run (brutal), restarted in detached mode, accepted my fate, and went to bed.
In the morning: the tuning ran for 5 hours 48 minutes. Nowhere near the original 25-40 minute estimate. But the JSON output looks nice:
Writing best config to /save/E=256,N=512,device_name=NVIDIA_GB10,
dtype=fp8_w8a8,block_shape=[128,128].json...
Tuning took 20905.83 seconds
Tuned configs for all 7 batch sizes. The JSON contains BLOCK_SIZE_M=16/32, BLOCK_SIZE_N=128/256, num_warps=8, num_stages=2-4 values with workload-dependent progression. A properly tuned config.
Bind-mount into the vLLM image’s fused_moe/configs/ directory (just the single file, so I don’t mask anything else), restart, and the boot log says:
Using configuration from /usr/local/lib/python3.12/dist-packages/vllm/
model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_GB10,
dtype=fp8_w8a8,block_shape=[128,128].json for MoE layer.
No more “Config file not found” warning. The tuned config is live. Time to measure.
Phase 3 results — after 6 hours of tuning
| Test | Phase 2.5a (untuned) | Phase 3 (tuned) | Δ |
|---|---|---|---|
| A: tok/s | 49.01 | 48.94 | ≈0% |
| B: Mean TTFT | 5249 | 5857 | +11.6% 😬 |
| B: tok/s | 71.73 | 69.32 | −3.4% |
| C: Mean TTFT | 5153 | 5302 | +2.9% |
| C: tok/s | 204.39 | 204.46 | ≈0% |
B TTFT got WORSE. A and C barely moved.
Suspecting variance, I ran two extra replications (seed=123, seed=456):
| Seed | Phase 3 B Mean TTFT |
|---|---|
| 42 | 5857 ms |
| 123 | 5678 ms |
| 456 | 5793 ms |
| Average | 5776 ms (±90 ms) |
Nope, not variance. Standard deviation ±90 ms, the mean is 527 ms higher than phase 2.5a. Statistically reliably worse.
Fun fact: throughput is also consistently worse, 69.96 vs 71.73 tok/s (−2.5%).
6 hours of tuning, and it hurts a little.
Hey — why does it hurt at all?
This is the moment I sat down and honestly asked myself the question. benchmark_moe.py is a pure-kernel benchmark — it measures only the MoE forward pass, in isolation. But in the vLLM serving context this kernel doesn’t run in isolation:
- Chunked prefill slices the prefill into various batch sizes, of which the tuned JSON has only a few discrete points. Between two batch sizes (say 8 and 16) the default heuristic may be more useful than the tuned config, which is optimal only at those fixed points.
- CUDA graph capture captures several batch sizes, and kernel choice “baked into” the graph may behave differently than a freshly launched kernel.
- Because of mamba cache align, the MoE forward pass doesn’t run on its own — a complex scheduler decides which request is in which step and which batch.
So benchmark_moe.py is measuring something that isn’t relevant. On GB10 the default Triton heuristic makes a better choice over the actual workload distribution than a workload-agnostic tuning.
This is the core lesson of the article, which I’ll try to phrase as sharply as possible: pure-kernel benchmark does not equal serving benchmark. benchmark_moe.py --tune is a micro-benchmark that optimises kernel execution time — but in a real serving workload the kernel latency is only one piece of the total latency, and the 5% kernel speed-up from tuning can evaporate in scheduler interactions.
A few more experiments, because I didn’t give up
gpu-memory-utilization 0.55
“Maybe the KV pool wasn’t enough.” Raised it 0.45 → 0.55. The KV pool grew 174k → 467k tokens (2.7×). Let’s see:
| Test | Phase 3 (0.45) | Phase 4 (0.55) | Δ |
|---|---|---|---|
| A: tok/s | 48.94 | 48.67 | ≈0% |
| B: Mean TTFT | 5776 avg | 5744 avg | ≈0% |
| C: Mean TTFT | 5302 | 5571 | +5% 🤔 |
| C: tok/s | 204.46 | 204.01 | ≈0% |
Nothing. The 174k-token KV pool was never the bottleneck (16 concurrent × 2.5k = 40k active max). 0.55 just locks memory away from the rest of the doc pipeline. If I ship this, the Marker/Surya models could get squeezed.
enforce-eager (CUDA graph OFF)
“Maybe the graph compile is doing something stupid.” I tried --enforce-eager.
A: tok/s: 30.02 (vs 48.94)
B: Mean TTFT: 7460 (vs 5776)
−33% decode throughput. CUDA graphs matter materially. Reverted immediately. (Bonus: boot is ~70 seconds faster in eager mode because there’s no torch.compile, but that’s not worth 33% perf.)
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1
The log itself recommended it. This gives more accurate CUDA graph memory estimation. Expected effect: slightly more KV pool. A free win, I thought.
Reality: not a win, but an accounting fix. The log is explicit:
The current --gpu-memory-utilization=0.4500 is equivalent to
--gpu-memory-utilization=0.4434 without CUDA graph memory profiling.
To maintain the same effective KV cache size as before, increase
--gpu-memory-utilization to 0.4566.
The env var is more accurate, yes. But at the same 0.45 you get a smaller KV pool, because CUDA graph memory is accounted for more precisely. If you want the same as before, raise it to 0.4566.
There is no free lunch. The log was honest; I was the one hoping for something else.
The final production config
Two days, a lot of coffee, 6 hours of tuning compute, nine measurement phases. In the end:
services:
qwen35_vllm:
image: vllm/vllm-openai:cu130-nightly
environment:
HF_TOKEN: ${HF_TOKEN}
TORCHINDUCTOR_COMPILE_THREADS: "1"
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: "1" # more accurate accounting
volumes:
- ~/.cache/huggingface:/root/.cache/huggingface
- triton-cache:/root/.cache/torch/triton
# NO MoE tuned config bind-mount — it didn't pay off
command: >
Qwen/Qwen3.5-35B-A3B-FP8
--max-model-len 131072
--max-num-batched-tokens 65536
--gpu-memory-utilization 0.45
--max-num-seqs 32
--kv-cache-dtype fp8_e4m3
--enable-chunked-prefill
--enable-prefix-caching
--no-async-scheduling # the only real win
--reasoning-parser qwen3
--enable-auto-tool-choice
--tool-call-parser qwen3_coder
--limit-mm-per-prompt '{"image":1}'
--trust-remote-code
Final numbers (phase 6) vs. baseline
| Test | Baseline (3 weeks old) | Phase 6 (production v1) | Δ |
|---|---|---|---|
| A: tok/s | 48.92 | 49.31 | +0.8% |
| B: Mean TTFT | 4855 | 5170 | +6.5% |
| B: tok/s | 75.88 | 73.33 | −3.4% |
| C: Mean TTFT | 4828 | 4943 | +2.4% |
| C: tok/s | 208.04 | 193.24 | −7.1% |
Yes, phase 6 is slightly worse than the 3-week-old image was. This was the first “end” of the two days of tuning. But it didn’t end up being the final production config — see the phase 7 section below.
So what did I actually learn?
Quite a lot, actually. Here are six lessons, in order of importance:
1. Pure-kernel benchmark ≠ serving benchmark
benchmark_moe.py --tune measures how fast an isolated MoE kernel runs at a given batch size. In a real serving workload that kernel doesn’t run in isolation: the scheduler, chunked prefill, CUDA graph capture, prefix caching, mamba cache align — each of them affects at what batch size the kernel is called, and when. The pure-kernel optimum and the serving optimum are not the same point.
Concrete takeaway: if you run benchmark_moe.py, take its output with a grain of salt. Measure on the actual serving workload before and after. The pure-kernel win might be 1.5×; the serving win might be negative.
2. vLLM nightly’s new defaults are sneaky
v0.19 has two new defaults, both of which hurt certain workloads:
async_scheduling=Truecauses ~12% TTFT regression on prefill-heavy concurrent workloads. Override:--no-async-scheduling.mamba_cache_mode="align"(Qwen3.5-MoE + prefix caching) adds another ~5-10% regression. Not overridable on the CLI, only by disabling prefix caching.
When you upgrade, measure both directions, especially TTFT.
3. CUDA graphs matter materially
--enforce-eager chops decode throughput by 33%. The 70-second torch.compile overhead at startup is more than worth it. If you want a smaller graph pool, narrow the cudagraph_capture_sizes list; don’t turn eager mode on.
4. More accurate accounting ≠ more resources
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 improves accounting but does not grow the KV pool. Want the same as before? Raise gpu-memory-utilization. The log doesn’t lie.
5. gpu-memory-utilization only matters if the KV pool is tight
I couldn’t saturate the 174k-token KV pool with 16 concurrent requests (max 40-80k active). The 467k-token pool (0.55 mem util) behaved the same — it just walled memory off from the rest of the pipeline. If num_preemptions_total is 0 and kv_cache_usage_perc is low, don’t raise it.
6. A 6-hour negative result is still a result
Nobody else wants to spend 6 hours of tuning compute to find out it doesn’t pay off. I’ve just saved other people those 6 hours. That’s the value of the article.
Post-hoc update: phase 7 — the max-num-batched-tokens experiment
After sitting down to write this article, a thought nagged at me. In the phase 6 config, --max-num-batched-tokens 65536 stayed at its original value — this parameter controls the number of tokens chunked prefill can process in a single forward pass. Simple hypothesis:
Smaller chunks → finer-grained interleaving with decode. Large 65k chunks suppress prefill/decode interleaving, which can cause TTFT tail variance. Smaller chunks yield finer scheduling, shorter preemption windows, and better tail latency.
Single change from phase 6: 65536 → 16384. Everything else stayed.
Phase 7 results — a surprising amount of improvement
| Test | Phase 6 | Phase 7 (16k chunks) | Δ |
|---|---|---|---|
| A: Output tok/s | 49.31 | 49.15 | noise |
| A: P99 TTFT | 668 ms | 176 ms | −74% 🎯 |
| B: Mean TTFT avg | 5170 | 4935 | −4.5% |
| B: P99 TPOT avg | ~39 ms | ~50 ms | +28% ⚠️ |
| C: Output tok/s | 193.24 | 207.92 | +7.6% ✅ |
| C: Mean TTFT avg (2 seed) | 4943 | 4146 | −16% ✅ |
| C: P99 TPOT | 97.07 ms | 77.34 ms | −20% ✅ |
Test C at seed=42 threw an outlier P99 TTFT (8269 ms) — on its own that would have cast doubt on the result. But at seed=123 the P99 TTFT is 5283 ms, better than phase 6’s 5917 ms. So the seed=42 spike was noise, not a refutation.
Clean evidence for the chat workload — test D
C is too extreme for a realistic chat workload (16 concurrent is extreme). I added a test D — 2048 input, 256 output, 2 concurrent — which simulates the actual DocAI chat agent profile:
| Metric | Phase 7 D avg (2 seeds) |
|---|---|
| Mean TTFT | 719 ms |
| P99 TTFT | 902 ms |
| Mean TPOT | 24.97 ms |
| P99 TPOT | 25.67 ms (variance ±1 ms!) |
| Output tok/s | 72.26 |
These are production-quality numbers: even in the worst case the chat user gets their first token in 1 second, and the streaming cadence is a linear ~40 tok/s — TPOT variance is negligible.
Takeaway — the “mundane” parameter
max-num-batched-tokens is a basic vLLM scheduler flag, no Triton kernel tuning, no 6-hour compute. Just a single value change. And it delivered what the entire MoE tuning marathon could not:
- Single-user P99 TTFT −74%
- Concurrent stress throughput +7.6%
- Tail TPOT variance −20%
The lesson: the biggest wins often don’t come from exotic kernel tuning but from a well-chosen scheduler parameter. Chunked prefill granularity, batch size, KV pool sizing — these are all “mundane” config values that the vLLM docs don’t advertise flashily, but their impact on actual serving performance is bigger than kernel optimisations.
Phase 7 is the real production. The phase 6 config file is archived for reference.
New final config (phase 7)
--max-model-len 131072
--max-num-batched-tokens 16384 # <-- single change from phase 6
--gpu-memory-utilization 0.45
--max-num-seqs 32
--kv-cache-dtype fp8_e4m3
--enable-chunked-prefill
--enable-prefix-caching
--no-async-scheduling
The full results matrix, extended
| Phase | Config | A tok/s | A P99 TTFT | B TTFT avg | C TTFT | C tok/s | C P99 TPOT |
|---|---|---|---|---|---|---|---|
| Baseline | old image | 48.92 | — | 4855 | 4828 | 208.04 | — |
| Phase 2 | new image | 49.76 | — | 5941 | 6893 | 197.67 | — |
| Phase 2.5a | async OFF | 49.01 | — | 5249 | 5153 | 204.39 | 68.34 |
| Phase 3 | + MoE tuned | 48.94 | — | 5776 avg | 5302 | 204.46 | 68.02 |
| Phase 4 | + 0.55 util | 48.67 | — | 5744 avg | 5571 | 204.01 | 67.67 |
| Phase 5a | + eager | 30.02 | — | 7460 | — | — | — |
| Phase 6 | v1 production | 49.31 | 668 | 5170 | 4943 | 193.24 | 97.07 |
| Phase 7 | v2 production | 49.15 | 176 | 4935 | 4146 avg | 207.92 avg | 77.34 |
What else could be done?
A few ideas that came up but I didn’t chase, because at some point the article has to close:
- Prefix caching OFF experiment: if I split the chat agent and KIE workload across two vLLM instances, I could disable prefix caching on the KIE side (there’s no hit anyway) → the mamba align overhead disappears. Probably recovers the 5-10% regression.
- Speculative decoding (EAGLE / Qwen MTP): potential 2-3× decode throughput on the single-request path. There’s an MTP drafter for Qwen3.5-MoE (
qwen3_5_mtp.py). A separate research project. - Waiting for FlashInfer MoE SM12.1 support:
--use-flashinfer-moe-fp8still errors out on A3B block-wise FP8, but if it lands in the coming months, that will be the real jump. Into the 48 → 100-150 tok/s range.
Those are topics for another article. I’m closing this one.
Acknowledgements
Throughout the investigation, Claude was my assistant in the terminal in a log-parser and experiment-proposer role — the benchmark script parameterisation, the sed tricks (that sometimes didn’t run), the statistical interpretations — either suggested or validated by it. If you need a two-day tuning marathon and don’t want to be alone staring at logs, it’s a good companion.
System: NVIDIA DGX Spark, GB10 (SM 12.1), 128 GB LPDDR5x unified memory
Model: Qwen/Qwen3.5-35B-A3B-FP8 (hybrid attention, 40 layers, 10 full-attn + 30 linear-attn, E=256 MoE experts)
vLLM: 0.19.1rc1.dev231+g9dd5ee011 (cu130-nightly image, 2026-04-13)
Benchmark tool: vllm bench serve (built-in)
All JSON results, docker-compose files, and the tuned MoE config are available — if you’re interested in reproduction or the detailed percentile distributions, get in touch.