Building Large Language Models
Large Language ModelsI wanted a deep understanding of how LLMs work that I felt I was unable to acquire without writing all of the architecture, training, and inference code myself. Naturally, I gravitated to the very popular book by Sebastian Raschka that tackles this same set of tasks. I read through all of the code in the book, and implemented all of the relevant classes and functions myself. While I found the book to be an excellent resource for the architecture, I wanted to provide a more instructive set of implementations. I will explain the design decisions in my implementations as well as walk through some of the examples reflecting the benefits of my approach. I hope that this could be a useful write-up for others curious to develop a deep understanding of these architectures.
As a clarifying statement, I do not intend to offer the reader a proper explanation of the architecture or how it works. The curious reader unfamiliar with the building blocks of LLMs should read the book, since Sebastian Raschka has already created quality explanations that will surpass what I would create myself. So if I am directing readers to the book, and the book spells out the code (which also is organized into a nice repository), why would I spend my time writing this article or expect any reader to spend their time reading it? The answer comes down to my personal belief that the hardest part of writing deep learning models comes from the simple task of understanding dimensions.
With this last sentence, I was unusually precise in my wording. Dimension accounting is simple but not easy. Tracking dimensions is simple because, well, there is not much complexity here. There is no lack of advanced math in deep learning, but even transformations that may perform complex operations at the level of the elements of the relevant tensors rarely affect the dimensions of these tensors in strange ways. Multi-head attention, for example, contains some non-trivial operations that took researchers years to discover and may take hours to understand. But anyone can understand in a few seconds how the input tensor relates to the output tensor at the level of dimensionality. So I do not think I am making a contentious statement when I declare that dimension accounting in deep learning is simple.
The more provactive portion of my stance is likely that keeping track of these dimensions is not easy. If you focus on a single tensor operation, each change in dimensions is in fact easy. But very quickly in any neural network, many such operations are applied sequentially. In most deep learning code, whether pytorch or tensorflow, the developer is forced to keep all of these dimensions in their brain cache. Maybe this is just a personal limitation, but my cache quickly becomes too small to remember all of these dimension changes, which requires me to mentally re-derive the dimensions each time I look deeply at the code. Again, none of these mental traversals is challenging or slow, but these tasks need to be completed many times and lack reusability oof my previous mental computation. Pretty quickly, I am spending more time thinking about dimensions than the code execution. Or worse, I fail to think about the dimensions because I incorrectly think I understand how they are changing, producing bugs in my code.
Dimension Accounting
The first package that I used to help address these perceived shortcomings is jaxtyping, which provides type annotations and runtime type-checking to tensor shapes and dtypes. The mental model here is something like pylint to help developers both catch mistakes and offload some of the mental model of keeping track of variable information. As a motivating example, let’s look at this snippet of the forward pass from the implementation of MultiHeadAttention from Chapter 3:
def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
How long does it take you to come up with the dimension of x? What about the resulting keys, queries, or values? Answering any of these questions is very time-consuming if not nearly impossible without referencing other code in the class. And while this code is mathematically correct, I would argue that it is not maximally instructive for learning the line-by-line operations, which require careful dimension accounting to follow the operations being applied. Now, let’s look at a thoroughly annotated version of this snippet using jaxtyping. This might look something like:
def forward(self, x: Float[torch.Tensor, 'b t din']) -> (
Float[torch.Tensor, 'b t dout']):
keys: Float[torch.Tensor,'b t dout'] = self.W_key(x)
queries: Float[torch.Tensor, 'b t dout'] = self.W_query(x)
values: Float[torch.Tensor, 'b t dout'] = self.W_value(x)
Here we immediately understand the input x, which is a tensor of size b x t x din. At the expense of even more verbosity, we could equally well have represented this is batch x tokens x dimension_in. We also understand the returned object from the call signature, and there is no thinking or guesswork required to see the values of the keys, queries, or values formed from our inputs. Certainly for a pedagogical implmentation such as this, I would argue that this detailed annotation is very helpful. And personally, I would also argue that it improves developer velocity for any future modifications.
Readable Dimension Transforms
The second package that I want to praise is einops which, like the URL suggests, does indeed rock. Although I had not previously used this package, I was excited to find it after I realized that it helps remove confusion around dimension changes introduced from many common transformations, such as performing a transpose or view. To demonstrate the utility of einops, let’s continue building out the forward pass of an implementation of multi-head attention. From the official source code, the snippet required to form attention scores is:
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
I already complained enough about the opaqueness of the dimensions of all of the intermediate tensors, so I will not belabor the point. But note that this section becomes especially confusing due to common pattern of reassigning a variable to itself after changing its dimensions using view and transpose. Thankfully, Sebastian was kind enough to write comments to help readers follow along here, but the code is still cumbersome to read in my opinion. Now, let’s contrast this to an implementation leveraging both jaxtyping along with our new addition, einops:
from einops import rearrange
def forward(self, x: Float[torch.Tensor, 'b t din']) -> (
Float[torch.Tensor, 'b t dout']):
num_tokens = x.shape[1]
keys: Float[torch.Tensor,'b t dout'] = self.W_key(x)
queries: Float[torch.Tensor, 'b t dout'] = self.W_query(x)
values: Float[torch.Tensor, 'b t dout'] = self.W_value(x)
keysT = rearrange(
keys,
'b t (nh hd) -> b nh hd t',
nh=self.num_heads,
hd=self.head_dim)
values = rearrange(
values,
'b t (nh hd) -> b nh t hd',
nh=self.num_heads,
hd=self.head_dim)
queries = rearrange(
queries,
'b t (nh hd) -> b nh t hd',
nh=self.num_heads,
hd=self.head_dim)
# Compute scaled dot-product attention (aka self-attention) with a
# causal mask.
attn_scores: Float[Tensor, 'b nh t t'] = matmul.matmul(queries, keysT)
To my programming taste, this is an elegant implementation, thanks in large part to the power of rearrange. Let’s go out of order and start with changing the dimesionality of the original values variable. In the original implementation, we required a view and transpose to get the tensor into the desired shape, which can only be reasoned about due to comments. Despite the complexity of the original implementation, we are not doing anything complicated: we want to preserve the first dimension, move the second dimension to the last dimension, and split the last dimension into two intermediate dimensions using the number of heads and the head dimension. einops lets us express this into a similarly digestible operation that makese sense at a glance even without comments, and this essentially requires a single line (just split up to conform to the standard maximum terminal width). The formation of keysT is even more powerful: we have contracted a view and two transpose calls into a simple operation rearranging the dimensions.
Unintended Broadcasting
Finally, I introduce one final helper function to guard me against unintended dimension changes. I will have to work even harder to justify this choice, however. Broadcasting is a powerful technique to succintly perform tensor operations when the input tensors do not perfectly match. It also implicitly relies on vectorization, which leads to more efficient algorithms. If you are writing high performance code, broadcasting is a powerful ally. In my experience, though, broadcasting has a dark side as well. Broadcasting can be surprisingly subtle when a developer tries to reason about what will happen. The pytorch documentation does not even try to cover broadcasting in any depth, and instead redirects readers to the numpy documentation, which is quite lengthy.
After you become familiar with broadcasting, you do not think back to the documentation and reference the specific applicable broadcasting rule. Rather, you develop an intuition and just trust that pytorch does the thing that you hope. And most of the time you will be correct. Sometimes, you may have your tensors with invalid inputs, causing broadcasting to fail. Failing is not good, but this was not an issue with broadcasting, since the original tensors had invalid shapes to begin with. More nefarious, though, is the case where the malformed tensors get broadcasted to be compatible. The code does not even silently fail; the developer will have no idea that anything is awry. But every downstream tensor will not contain the correct information, even though it may have the expected shape. The hardest deep learning debugging that I have performed always comes from identifying these types of unintended broadcasts masking earlier mistakes.
This leads me to my argument: broadcasting is a powerful tool that should be utilized in production deep learning but avoided during development due to the difficulty of interpreting side effects. This philosophy leads me to defining a wrapper around the standard matmul:
def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Matmul forbidding any broadcasting in batch dimensions."""
if a.shape[:-2] != b.shape[:-2]:
raise RuntimeError(
f'Strict matmul disallows broadcasting: '
f'A batch dims = {a.shape[:-2]}, B batch dims = {b.shape[:-2]}')
return torch.matmul(a, b)
This will only catch broadcasting in the batch dimension, but this is the most frequent place where broadcasting is applied.
Taken in aggregation, this combination of jaxtyping for annotations, einops for readable rearranging, and restricted broadcasting make the code much easier to reason about while preventing any unintended consequences that can arise when the dimensions are incorrect.