A systems level understanding of LLM inference

A user types in a prompt and hits send. The prompt leaves the device over HTTPS as a payload, and stays encrypted till it's on the provider’s API gateway, where the TLS terminates. Here, the payload is inspected for authenticity, rate limits, quota, abuse filters, etc. Then, the request is routed to an inference serving engine. The job of this engine is to manage the queue of thousands of incoming requests and allocate them to hardware efficiently.

A single replica of a LLM is a distributed program running on a group of accelerators (GPUs, TPUs, etc). The distribution of the model is called parallelism, and there’s a bunch of strategies for that (tensor, pipeline, expert, and even the sequence of tokens itself - sequence parallelism). The weights of the model are a set of checkpoint files (TBs of data), generated from training, and are stored in persistent object storage. There could be a quantized copy of the weights to save on space and speed. For example, by using an FP4 or INT8 quantized model, the system reads 4x or 2x as many weights from HBM in the same amount of time as an FP16 model.

At some point, the model is loaded into the accelerators. The orchestrator first determines the sharding strategy based on the parallelism approach, essentially deciding which GPU gets what part of the model. The serving engine reads the checkpoint files from the storage and moves them into the ultra-fast HBM of the accelerator cluster. There’s usually a caching hierarchy here, loading weights into a local storage first instead of repeatedly pulling TBs from remote storage. This loading process can take several minutes for large models. For example, a 405B parameter model in FP16 means moving 810GB of data into HBM across the cluster. Once loaded, the weights do not move. They’re burned into the HBM of this specific cluster of GPUs, creating a stateful replica that can only serve this one model until it's explicitly unloaded. This replica of the model is now ready.

The model router (which maintains a list of all the model replicas that are ready) takes a look at the payload and forwards it to the correct cluster’s leader machine. The router uses strategies for load balancing, and even predictive routing, to send new jobs to the least busy replicas. The leader machine runs a serving engine software and will have a CPU side process to tokenize the text and generate a list of token IDs for the model. The list of token IDs is placed into the scheduler’s queue.

The job of the scheduler is to maximize GPU efficiency, while minimizing latency. When a new request arrives, it enters the prefill phase. The scheduler will wait for a few ms and batch this request with many others that are also in the prefill phase queue. This batch is copied from the CPU’s RAM into the GPU’s HBM. This kind of continuous batching is quite useful. In the older static batching system, if the request of 50 tokens is batched with a request of 5000 tokens, the 50-token job has to wait for the 5000 token job to finish. That’s terrible for latency, and also increases the GPU idle time. But with continuous batching, as soon as a request finishes the prefill phase, it moves to the decode phase, and new requests can take its place in the batch.

Now, for a given batch, the GPU does a parallel forward pass, processing all tokens in the prompt at once. This data flows through the static model weights that are already on the HBM. For each attention layer, the model computes key and value vectors for every token in the prompt, and these need to be stored because every future generated token will need to attend to all previous tokens. This output is called the KV cache and it’s a compressed context representing the model’s understanding of the prompt and is needed to generate the future tokens. KV cache can be quite massive. The cache is of course distributed across the cluster in the exact same pattern that the model is.

The inference engine now allocates pages of HBM on each GPU to store this cache. A page (or a page table) is a map of the request, basically a list of pointers. This map represents the request’s state. These pages are the classic OS pages concept applied to memory. It’s quite useful because otherwise, for a 2000 token request, the engine would’ve had to find a single continuous block of HBM large enough for a 2000 token KV cache. If the HBM was free in smaller token chunks, it’ll lead to internal fragmentation. By using pages, there can be a near 100% memory utilization, and enables things like sharing the pages for a common prefix (ex: 100 users sharing the same long system prompt, or the developer prompt).

Next, the scheduler builds a new batch, called a decode batch, for requests that have already had their cache computed. The decode kernel is called. The GPU again reads the static model weights from HBM. The attention kernel uses the page table pointers to read the KV cache pages. Notice here that the cache does not move. It’s read in-place from whichever GPUs it lives on. For tensor-parallelism, this means there’ll be heavy interconnect traffic. The interconnect has to handle all this communication. A compiler like TensorRT-LLM fuses these above operations into a single kernel, and can perform them at once instead of writing intermediate results back to the HBM, which improves performance a lot.

The attention mechanism here is itself quite optimized. Flash attention is the standard. It restructures the computation to keep data in the GPU’s fast but tiny SRAM to avoid the constant read/writes to the HBM. It also processes the attention in blocks which reduces HBM bandwidth usage by 10x or more. Many models also have variants like Grouped Query Attention (GQA) or Multi-Query attention (MQA) which share K and V projections across multiple queries, and are used to shrink the size of the KV cache itself (useful for handling long context!).

Through these optimized kernels, the GPU computes logits for every token in its vocabulary, which are then converted into a probability distribution using something like softmax. A temperature is applied to scale the probabilities, followed by a filtering strategy (tok-k, nucleus sampling, etc) to get a small set of tokens, and one single token is sampled from this set. If you want the model to do structured output generation like JSON, you can mask the invalid tokens before sampling them at this stage, allowing only tokens that can form a valid string. Now we have generated one new token randomly sampled from the filtered distribution for each request in the batch. This batch of new token IDs (a few kBs) is copied from the GPU’s HBM back to the CPU’s RAM.

Well, generating one token at a time is quite slow and memory-bandwidth bound, as we have to read the entire model weights from HBM for each token generation. Instead, speculative decoding is often used. The system maintains two models - a target model (what the user requested) and a draft model (10-50x smaller version of the target model). The draft model gets the request, and generates a draft of the next few tokens quickly. This entire draft is fed to the target model in a single parallel forward pass. The target model accepts the draft tokens that it would’ve generated itself (i.e. the draft token was in the target model’s most likely tokens), and if it disagrees with a token, it rejects that token and all tokens thereafter, and generates its own token instead. This new token then goes to the draft model, and the whole process happens again.

Since the target model processes the draft tokens in parallel, we’ve saved a lot of sequential computation here. It’s quite useful because for generating 10 tokens the old way, the target model would’ve had to read the entire model weights from the HBM 10 times. If the draft model is good, the target model can verify all 10 tokens in a single parallel pass. That costs only 1 HBM read of the target model. This is quite useful for common phrases where we can have a high predictability. You can have multiple draft models of different sizes for even better results, and that’s called cascade speculative decoding.

Now, the CPU-side engine gets the new token IDs and detokenizes them back into text. It immediately streams these tokens back through the model router to the user via Server-Sent Events (SSE) which is a one-way protocol over HTTP. This gives us our Time to First Token (TTFT).

The engine might buffer a few tokens before sending to reduce network overhead which is why sometimes we get tokens back in small bursts when using ChatGPT for instance.

At the same time as the token is sent, the GPU has computed the K and V vectors for these new tokens, and the memory manager allocates new pages to append this data to the distributed KV cache. Thus, the “state” aka the cache has grown just slightly.

The request is now re-queued in the scheduler, ready to be batched for its next token generation.

This loop repeats, generating one token at a time, until the model generates a stop token. The CPU maintains a list of stop token IDs and checks each generated token against this list. When a match is found, the request is marked as complete, and the streaming connection is closed.

At this point, the memory manager frees all the HBM pages storing this requests’ KV cache. It’s a huge amount of memory -GBs for a conversation. Each generated token was adding MBs to the cache. Some systems implement a warm cache where instead of destroying the KV cache immediately, it's offloaded to cheap CPU RAM or NVMe storage for some time before destruction. This is useful because if the user sends in a follow-up request quickly, the engine can hot-swap this state back in and avoid the recomputation of the cache.

When I have a conversation with a model, I leave it and sometimes come back later (hours or even days) to continue it. It feels like resuming an active session but it’s not at all so. The KV cache from the chat was likely destroyed milliseconds after the last response. Continuing the chat means triggering a full re-computation of the entire conversation. The text for our conversation is just stored in a persistent database linked to my user account (It’s useful to think of chats as logs, not sessions!).

When I continue a chat, the entire above chat history + my new prompt functions as a single massive prompt which is sent to the API gateway, tokenized, and routed to a model replica. The KV cache representing the state of the whole conversation has to be rebuilt from scratch. That’s why TTFT for continuing a long chat is slower than starting a fresh chat.

The context window is decoupled from the chat history storage. If the chat history is 200k words, but the model’s limit is 128k tokens, the oldest 72k tokens are simply truncated. The model “forgets” the beginning of the chat. Some systems use a sliding window or compress/summarize older messages before dropping them, or use a scratchpad to note down important things in the context (claude code).

This is where context caching comes in. You can process a large block of text like a 100 page PDF and cache the KV representation. This can be persisted in a disaggregated memory pool of CPU RAM, HBM, etc. The system generates a handle pointing to this state. When I send a question about this doc, the prompt will include the handle, and the serving engine can skip the entire prefill. It’ll simply locate the cached KV state, load it into the HBM, and begin generating tokens. That’s quite useful for quickly working with large documents. TPUs kind of have an advantage here because they were designed from scratch for distributed stateful workloads with unified memory addressing.