Skip to content
Go back

Optimizing Inference for Router Looped Transformers

Updated: · 17 min read

Looped transformers change the inference problem in a subtle way.

In a normal decoder-only transformer, layer 17 is always layer 17. During generation, its KV cache can be indexed by request, layer, and token position. A serving system can safely say:

cache[request_id][layer_id][token_position]

In a looped transformer, the same physical block can be reused many times. In a router looped transformer, the route can also depend on the input or context. The physical block id is no longer enough.

For example:

route step 0 -> physical block 0
route step 1 -> physical block 1
route step 2 -> physical block 0
route step 3 -> physical block 1

The two visits to physical block 0 share weights, but they do not share hidden states. If they share the same KV slot, the cache is wrong.

The correct serving abstraction is closer to:

cache[request_id][route_step][physical_block_id][token_position]

or, in serving-engine terms:

virtual_layer_id = route_step * num_physical_blocks + physical_block_id

That is the core idea of this note: a looped transformer has fewer unique parameters, but the serving system still needs a virtual layer identity for cache correctness.

What We Are Testing

The current research model is intentionally small. It uses two reusable transformer blocks and runs them recurrently:

n_unique_layers = 2
n_loops = 4
effective fixed depth = 8 block calls
router_max_steps = 8
d_model = 50
n_heads = 2

The router observes a sequence summary:

mean hidden state + last-token hidden state + route-step signal

At each route step it chooses among:

physical block 0, physical block 1, exit

During training, the router is soft. It evaluates all candidate blocks and mixes outputs by probability. That is useful for gradients, but bad for deployment: soft routing does not actually skip the expensive block executions.

So the inference hypothesis is:

train with soft routing
convert to a hard route at inference
batch requests by route template
store KV cache by virtual route step, not just physical block

The First Result: Soft Routing Is Not an Inference Optimization

On a 300-step single-seed Modal probe, the soft router was slightly more accurate than the fixed loop, but it was slower because it executed roughly twice as many core block calls.

Model pathEval accuracyLatency per batchThroughputCore block calls
fixed 2x40.16210927.568 ms2321.5 examples/s8.000
soft router 2x40.16601645.420 ms1409.1 examples/s16.000
hard router, threshold 0.500.16406234.589 ms1850.3 examples/s7.982
hard router, threshold 0.300.13476634.253 ms1868.4 examples/s7.646

Source artifact:

runs/modal-downloads/modal_inference_fixed_router_300s_seed0_20260629/inference_report.md

The important result is not that the router is already faster. It is not.

The important result is that hard routing kept nearly the same accuracy while reducing executed block calls back toward the fixed-loop budget. This suggests the router should be treated as a training-time search policy and converted into a simpler serving-time plan.

Multi-Seed Modal Gate: Hard Router Can Beat Fixed Accuracy

The next run used 600 steps and three seeds on the easier transfer setting:

num_nodes = 16
train_max_hops = 3
eval_max_hops = 4
seeds = 0, 1, 2
Inference pathMean eval accAcc stdLatency per batchThroughputMean block calls
fixed 2x40.1666670.01566617.433 ms3673.7 examples/s8.000
soft router0.1477860.02263728.148 ms2274.2 examples/s16.000
hard router t=0.450.1477860.00739423.985 ms2674.6 examples/s7.930
hard router t=0.500.1725260.00406623.276 ms2750.2 examples/s7.965
hard router t=0.550.1738280.01188023.381 ms2737.3 examples/s7.973

Source artifact:

runs/modal-downloads/modal_inference_fixed_router_thresholds_600s_seeds012_20260629/inference_summary.json

This is the first useful candidate signal:

hard router t=0.55:
  accuracy beats fixed by about 4.3 percent relative
  block calls stay near 8
  wall-clock latency is still worse than fixed

So the architecture signal and systems signal disagree. The router can find a slightly better path, but the current implementation has too much control-flow overhead.

That is exactly where serving optimization matters.

Harder Transfer Gate

The harder transfer setting used more nodes and a longer evaluation horizon:

num_nodes = 24
train_max_hops = 4
eval_max_hops = 5
seeds = 0, 1, 2
Inference pathMean eval accAcc stdLatency per batchThroughputMean block calls
fixed 2x40.1041670.01300517.966 ms3563.0 examples/s8.000
soft router0.1009110.02255329.026 ms2206.4 examples/s16.000
hard router t=0.500.0917970.03197424.325 ms2631.5 examples/s7.928
hard router t=0.550.1015620.00195324.506 ms2613.1 examples/s7.955
hard router t=0.600.1080730.00298323.699 ms2700.7 examples/s7.971

Source artifact:

runs/modal-downloads/modal_inference_harder_transfer_thresholds_600s_seeds012_20260629/inference_summary.json

The current harder-transfer candidate is t=0.60. It beats fixed accuracy slightly and keeps fixed-like block calls. It is still slower in wall-clock time, which means the next bottleneck is not model compute. It is routing overhead, dynamic grouping, and cache layout.

Local Serving Benchmark: Route-Template Replay

The first local benchmark does not implement full KV cache. It measures a simpler serving optimization: freeze the hard route plan and replay it as virtual recurrent layers.

Command:

uv run python scripts/benchmark_inference_cache.py \
  --repeats 30 \
  --warmup 8 \
  --batch-size 16 \
  --prompt-len 128 \
  --decode-tokens 32 \
  --d-model 64 \
  --n-heads 4 \
  --route-steps 8 \
  --output runs/local-cache-benchmarks/inference_cache_benchmark_template_20260629.json

Result:

PathMean latencyStdSpeedup vs generic hard routerMax logit diff
generic hard router8.996 ms0.1381.00xn/a
route-plan replay7.527 ms0.1141.20x0.0
route-template replay6.379 ms0.0891.41x0.0

Route template in this run:

0,1,0,1,0,1,0,1

Virtual cache slots:

(route_step=0, block_id=0)
(route_step=1, block_id=1)
(route_step=2, block_id=0)
(route_step=3, block_id=1)
(route_step=4, block_id=0)
(route_step=5, block_id=1)
(route_step=6, block_id=0)
(route_step=7, block_id=1)

The exact-logit match matters. It means route-template replay is not an approximation. It is the same hard-router computation with less runtime overhead.

This suggests a practical serving path:

1. run router once during prefill, or during an early planning stage
2. assign each request a route_template_id
3. group active requests by route_template_id
4. run the corresponding fixed virtual block sequence

Local KV Benchmark: Virtual-Step Cache

The second local benchmark is a synthetic causal decoder microbench. It compares:

no cache:
  for every generated token, rerun virtual blocks over the whole prefix

virtual-step KV cache:
  prefill creates KV for each (route_step, physical_block_id)
  decode computes only the new token and appends to each virtual-step slot

Single setting:

BatchPrompt lenDecode tokensRoute stepsNo cacheVirtual-step KVSpeedup
16128328195.321 ms116.726 ms1.67x

Prompt-length matrix:

Prompt lenRoute stepsNo cacheVirtual-step KVSpeedup
64441.401 ms32.638 ms1.27x
64880.746 ms64.338 ms1.26x
128449.207 ms35.050 ms1.40x
128899.023 ms70.293 ms1.41x
256467.680 ms40.351 ms1.68x
2568132.918 ms78.141 ms1.70x

Source artifacts:

runs/local-cache-benchmarks/inference_cache_benchmark_template_20260629.json
runs/local-cache-benchmarks/decoder_kv_matrix_20260629.json

This is CPU synthetic data, not deployment latency. But it supports two design decisions:

After the first local benchmark, I ran a small A10G Modal grid to separate two claims:

1. does route-template replay help on GPU?
2. does this simple Python virtual-step KV microbench help on GPU?

Command shape:

batch sizes: 1, 8, 16
prompt lengths: 64, 128, 256
decode tokens: 16, 32
route steps: 8
repeats: 15
warmup: 5
GPU: A10G

Source artifact:

runs/modal-downloads/modal_cache_benchmark_grid_20260629/cache_benchmark_summary.json

The result is mixed, and useful.

Route-template replay is clearly useful on GPU:

Batch sizeMean template speedup vs generic hard routerMinMax
13.02x2.92x3.07x
82.20x2.12x2.30x
161.80x1.61x2.04x

The logits still matched exactly:

template_max_logit_diff = 0.0

So this is no longer just a CPU-local effect. The serving system should batch hard-router requests by route template.

The virtual-step KV result is more cautious:

Batch sizeMean speedup vs no-cacheMinMax
10.93x0.89x0.98x
80.99x0.90x1.03x
161.01x0.97x1.04x

This does not disprove route-aware KV cache. It says the toy Python implementation is not a GPU serving implementation. The next KV experiment needs fused attention or paged KV blocks. In other words:

route-aware KV identity is required for correctness;
fused/paged implementation is required for speed.

Why vLLM Needs Route Identity

vLLM’s PagedAttention design stores KV cache in fixed-size blocks rather than one large contiguous tensor (vLLM PagedAttention). Its prefix caching design hashes full KV blocks using parent hash, block tokens, and extra hashes such as LoRA ids, multimodal hashes, or cache salts (vLLM prefix caching).

That “extra hashes” concept is exactly where route identity belongs.

For a looped transformer, the cache hash should include:

model_id
request prefix hash
route_template_id
virtual_layer_id
physical_block_id
route_step
token block ids

The minimal vLLM adaptation would be:

1. add looped-transformer metadata to model config
2. expose a route planner that returns route_template_id and virtual_layer_ids
3. include route_template_id / virtual_layer_id in prefix-cache extra hashes
4. map virtual_layer_id to the same physical weights during model execution
5. batch scheduler groups requests by next virtual_layer_id, not only by decode step

The likely code areas to study first are:

vllm/v1/core/kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/kv_cache_utils.py

The key rule is simple:

same tokens + same physical block is not enough
same tokens + same virtual route step is required

Why SGLang Needs Route-Aware Radix Keys

SGLang’s RadixAttention keeps KV cache for prompts and generation results in a radix tree, so later requests can reuse matching prefixes. The public SGLang writeup describes this as a mapping from token sequences to KV tensors, with LRU eviction and cache-aware scheduling (SGLang RadixAttention).

For normal transformers, a token-prefix key is enough because layer order is fixed.

For router looped transformers, the radix key should become route-aware:

radix_key = [
  route_template_id,
  virtual_layer_id,
  token_0,
  token_1,
  ...
]

or, if token-level routing is enabled:

radix_key = [
  route_signature_for_token_span,
  virtual_layer_id,
  token_span
]

The likely SGLang code areas to study first are:

python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/managers/schedule_batch.py

The serving scheduler should prefer batches with the same next virtual route step. Otherwise, each request may need a different physical block at the same decode position, which destroys batching efficiency.

Step-by-Step Optimization Plan

The current data suggests this order:

Step 1: Do not serve the soft router

Soft routing is useful for training, but it doubles block calls in the current implementation:

fixed loop: 8 block calls
soft router: 16 block calls
hard router: about 8 block calls

The serving path should use hard routing, top-k routing, or a distilled route policy.

Step 2: Freeze a request-level route template

The route-template replay benchmark shows 1.41x speedup over generic hard routing with identical logits. This is the cheapest systems win.

Request-level routing is also easier to implement than token-level routing because every token in the sequence follows the same virtual block plan.

Step 3: Add virtual-layer KV cache

Use:

virtual_layer_id = route_step * num_physical_blocks + physical_block_id

This preserves cache correctness while still reusing the same physical weights.

Step 4: Batch by next virtual layer

The scheduler should group requests by:

next_virtual_layer_id
route_template_id
decode phase

This turns a dynamic router model back into a small set of dense batched kernels.

Step 5: Only then try token-level routing

Token-level routing is more powerful, but it complicates the cache. Different tokens in the same request can take different routes, so the runtime needs per-token route signatures and may need to compact tokens by route step.

I would not start there. The current data says request-level route-template batching already has measurable upside.

What Is Still Missing

The current evidence is enough to justify the optimization direction. It is not enough to claim that router looped transformers are already faster in a real serving stack.

Three things are still missing:

Missing evidenceWhy it mattersCurrent status
GPU decode latency with real KV cacheCPU microbenchmarks can overstate or understate the serving winA10G route-template replay wins; Python KV microbench is not enough
Route-template diversity under varied promptsTemplate batching only works if many requests share a small number of route templatesCurrent synthetic run still has one template: 0,1,0,1,0,1,0,1
Accuracy/latency Pareto across thresholdsA router candidate should not be chosen by accuracy aloneModal thresholds show t=0.55 and t=0.60 are promising

The most important missing measurement is:

fixed loop vs hard router vs hard router + template batching + virtual-step KV cache

measured on the same GPU, same batch sizes, same prompt lengths, same decode lengths, and same model checkpoint.

Experiments To Run Next

I would run the next experiments in this order.

PriorityExperimentWhat to measurePass condition
P0GPU decode benchmarkTTFT, TPOT, end-to-end latency, tokens/shard-router optimized path is faster than generic hard router and approaches or beats fixed loop
P0Route-template histogramnumber of unique templates, top-k coverage, template entropytop 8 to 16 templates cover most requests
P1Threshold Pareto sweepaccuracy, latency, block calls for t=0.45 to 0.70one threshold beats fixed accuracy without exceeding fixed block-call budget
P1Prompt/decode length matrixprompt lengths 64/128/256/512, decode lengths 16/32/64cache speedup grows with prompt length and remains positive at longer decode
P2Token-level router scoutper-token route diversity, batching fragmentation, accuracytoken routing improves accuracy enough to justify scheduler complexity
P2Template distillationsmall route predictor vs original routersame template choices with lower planning overhead

The P0 experiments are the ones needed before making a strong inference claim. The P1 experiments choose a usable router candidate. The P2 experiments are research extensions.

Low-Cost Modal Plan

Because the goal is to control spend, the next Modal run should be a bounded inference-only benchmark, not a long training run.

I would use:

one small GPU
one checkpoint
three seeds only if the first seed is promising
short prompt/decode grid first
hard timeout
write JSON after every benchmark cell

Suggested first grid:

SettingValues
batch size1, 8, 16
prompt length64, 128, 256
decode tokens16, 32
route steps8
inference pathsfixed, soft router, hard router, hard + template, hard + template + KV

The first stopping rule should be simple:

stop if hard + template is not at least 1.20x faster than generic hard router
stop if virtual-step KV does not beat no-cache at prompt_len >= 128
stop if accuracy drops below fixed by more than one seed-level std

That keeps the experiment cheap and prevents a long run from chasing a weak systems signal.

What Would Make This Publishable As A Research Result

For a stronger public claim, I would want one table like this:

ModelAccuracyTTFTTPOTtokens/sKV memoryBlock callsNotes
fixed loopbaselinebaselinebaselinebaselinebaseline8simple serving path
soft routermaybe higher/lowerworseworseworsehigher16training path only
hard routercandidateworse todayworse todayworse todaysimilarabout 8needs systems work
hard + template batchingsame logitsbetterbetterbettersimilarabout 8removes routing overhead
hard + template + KVsame logitsbest expectedbest expectedbest expectedhigher virtual slotsabout 8correct serving target

The headline should not be “router is faster” until the last row beats fixed loop on at least one realistic GPU decode setting.

The safer headline today is:

Router looped transformers need route-aware serving.
The model signal is promising, and the cache/scheduler design is clear.

Current Conclusion

The current router looped transformer is not yet an inference win out of the box.

The best evidence so far is more precise:

The next experiment should be a small GPU serving benchmark that combines all three pieces:

hard router
route-template batching
virtual-step KV cache

That is the real candidate inference path for router looped transformers.


Share this post on:

Next Post
How to Test Pretraining Ideas at Small Scale Before Betting on a Large Model