Ryan Zhu — transformers

transformers

Feburary 23, 2024


I’m just about 7 years late on the transformer. I’ve avoided it for a while, not for any particular reason, perhaps out of laziness (that’s irresponsible). I’ve written convnets like ResNet and AlexNet before, took a class proving convergence bounds of SGD under convex assumptions, and even gone as far as writing autograd from scratch. Each time, the fact that the model simply “fits the data” rings truer and truer. How do we know it’s a dog? It has similar features to other dog pictures we’ve looked at. How do I know what token comes next? Whatever token matches the distribution the best. It’s really just memorizing the data, so it’s pretty useless and uninteresting.

Anyways, here’s my take on transformers.

Prologue

GPT and friends are all next token prediction, so we’ll do the same. Let’s take our input as a sequence of tokens $$ \mathbf{x} = {x_1, x_2, \ldots, x_n}, $$ with $x_i \in \mathbb{R}^d,$ with $d$ being our internal embedding dimension. We want to find the conditional probability $\mathbb{P}(y | \mathbf{x}),$ where $y$ represents the next token. We’ll do this by using a transformer and some inexplicable hole of understanding that interpretability people are losing their minds over (circuits where?).

Transformers are mostly made up of blocks of self attention layers, which are (extremely) nonlinear functions from $\mathbb{R}^{n \times d} \to \mathbb{R}^{n \times d},$ where $n$ is the context length, $d$ is the embedding dimension from earlier. This gives us the hard limit of only caring about $n$ tokens at a time maximally, and the dimension probably controls the amount of information and complexity per token.

These blocks also include skip layers and forward feedlayers and also two layernorms, but those are normalizing inputs and you can go read attention is all you need for that. We’ll be focusing on the self attention part.

Self Attention

$$SA(\mathbf{x}, \theta) : \mathbb{R}^{n \times d} \to \mathbb{R}^{n \times d}$$

The attention mechanism is something like a weighted sum. I actually have no idea which thing is called the attention and what’s not, because I really don’t care either way (good news for you, this is definitely nonstandard notation). We’re going to calculate the vector that pops out of this layer and ignore the rest of the Transformer Block for now, so let’s just talk about $z$ where $z = SA(\mathbf{x}, \theta).$

We first will have three forwardfeed neural nets, $Q, K, V : \mathbb{R}^{d} \to \mathbb{R}^{d}.$ These will be used to spice up the attention calculation. In my interpretation, $Q(x_i)$ is the query of $x_i$ - what token $i$ wants to match against, while $K(x_j)$ is the key of $x_j$ - what token $j$ has to offer. If this sounds like someting Karpathy said, it’s because it is. $V(x_j)$ is the value of $x_j$ - the value of this token, which gets scaled by how much attention there is between two tokens. Our output is more or less (with heavy astericks) $$ z_i = \sum_{j=1}^n \langle Q(x_i), K(x_j) \rangle V(x_j), $$ where $\langle \cdot, \cdot \rangle : \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}$ is our attention mechanism, but the only relevant one (ignorance is bliss) is scaled dot product attention, with $$ \langle a_i, b_j \rangle = \text{softmax}\left(\frac{1}{\sqrt{d}} a_i^\top b_j\right). $$ Our careful readers note that this makes no sense, and is an extreme abuse of notation, since softmax returns a distribution, and we’re expecting a scalar. You’re absolutely right, but we for a single token, we will normalize its attention to 1, and represent it as a probability distribution. We basically want to enforce that for a fixed $k,$ $$ \sum_{i=1}^n \langle Q(x_k), K(x_i) \rangle = 1. $$ What this really should be, is we have an attention vector for token $x_k,$ let’s call it $a_k.$ We compute $$ a_k = \text{softmax}\left(\frac{1}{\sqrt{d}} Q(x_k)^\top K(x_1), \ldots, \frac{1}{\sqrt{d}} Q(x_k)^\top K(x_n)\right). $$ and our output $z_i$ is given by $$ z_i = \sum_{j=1}^n a_{i,j} V(x_j). $$

Note that we can just precompute $\langle Q(x_i), K(x_j) \rangle$ and put it into a matrix, which gives us the attention matrix $A \in \mathbb{R}^{n \times n},$ with $A_{i,j} = \langle Q(x_i), K(x_j) \rangle.$ Then, if we have a matrix $B \in \mathbb{R}^{n \times d}$ of the values, where the rows are the values $B_{i,j} = V(x_i)_j,$ we can do the matmul of $A \times B$ to get the output, which lets us use the GPU now.

To actually get $A,$ since everything is a dot product, I’m pretty sure we can use a dot product (yes we can). By our definition above, we want that (modulo the softmax) $$A_{i,j} = \frac{1}{\sqrt{d}}Q(x_i)^\top K(x_j).$$ This is easy enough, we’ll just make matrices $C \in \mathbb{R}^{n \times d}$ and $D \in \mathbb{R}^{n \times d}$ with the rows of $C$ and $D$ being $Q(x_i)$ and $K(x_j)$ respectively. We can then do the matmul $C \times D^\top$ (or the other way (this could be wrong i haven’t thought about it)) and scale by $\frac{1}{\sqrt{d}},$ while also taking care to take a softmax along each row. Thus, the full self attention is $$ SA(\mathbf{x}, \theta) = \text{softmax}\left( \frac{C D^\top}{\sqrt{d}}\right) B. $$

To write this in python notation, we can do

def self_attention(x, theta):
    Q, K, V = theta
    C = Q(x)
    D = K(x)
    A = torch.softmax(C @ D.T / np.sqrt(d), dim=1)
    return A @ V(x)

where x is a tensor of shape $\mathbb{R}^{n \times d},$ and theta is a list of three functions, each taking in a tensor of shape $\mathbb{R}^{d}$ and returning a tensor of shape $\mathbb{R}^{d}.$

Multihead Attention

This gives us single head attention. To get multihead attention, we run $h$ heads in parallel, where we typically split up the embedding size $d$ to $d/h,$ and at the end we concatenate the outputs and project using a linear layer. Apparently (Prince) this is important in giving comprehensible results, so I buy it. It can similarly be implemented as batched matmuls, except we now need to reshape the matrices a little bit more.

Going forward, we’re going to adapt pytorch notation, where we’ll represent lists of matrices as tensor like, so $\mathbb{R}^{k \times n \times d}$ now represnts a list of $k$ matrices of size $n \times d.$ When we do multiplication of tensors $A \in \mathbb{R}^{k \times n \times d}$ and $B \in \mathbb{R}^{k \times d \times m},$ we’ll get a tensor $C \in \mathbb{R}^{k \times n \times m},$ where $C_{i,j,k} = \sum_{l=1}^d A_{i,j,l} B_{i,l,k}.$ Copilot wrote that last line for me, so how I understand this better is that $C_i$ is a matrix in $\mathbb{R}^{n \times d},$ and $C_i = A_i \times B_i.$

Moving to multi head attention, we’ll now take $d’ = d/h,$ and now we need $3 \times h$ matrices for our $Q,K,V$ matrices (notice the number of parameters is still the same), now each living in $\mathbb{R}^{h \times n \times d’}.$ We’ll also need a linear layer to project the concatenated outputs back to $\mathbb{R}^{n \times d},$ but this is just a layer $f : \mathbb{R}^{n \times d} \to \mathbb{R}^{n \times d}.$ So, our calculation looks like given $Q,K,V \in \mathbb{R}^{h \times n \times d’},$ and $f \in \mathbb{R}^{n \times d},$ we have $$ SA(\mathbf{x}, \theta) = f\left(\text{concat}\left(\text{head}_1, \ldots, \text{head}_h\right)\right), $$ where $\text{head}_i \in \mathbb{R}^{n \times d’}$ and is given by $$ \text{head}_i = \text{softmax}\left(\frac{C_i D_i^\top}{\sqrt{d’}}\right) B_i, $$ where $C_i, D_i, B_i \in \mathbb{R}^{n \times d’}$ are analagous to the above definitions, except each one is $Q_i(\mathbf{x}_i),$ given that we split $\mathbf{x}$ into $h$ slices as well of size $n \times d’.$

In python (ew), this is

def multihead_attention(x, theta):
    Q, K, V, f = theta

    C = Q(x)
    D = K(x)
    B = V(x)
    heads = []
    for i in range(h):
        head_i = torch.softmax(C[i] @ D[i].T / np.sqrt(d’), dim=1) @ B[i]
        heads.append(head_i)
    return f(torch.cat(heads, dim=1))

This is ridiculous code, but I’m not bringing torch into this. But we’re basically taking our list of embeddings, and slicing each one into $h$ pieces, running the groups of pieces through their own attention layer, and then putting them back together.

Attention Masking

For decoder only next token prediction models like GPT, we are trying to predict the next token, given the previous tokens. But above, we totally cheated, and the self attention actually attends to all tokens in the sequence! No fear, we’ll just zero out the remaining tokens with an attention mask. Given our matrix $$ A = \frac{C D^\top}{\sqrt{d}},$$ (attention matrix before softmax), we will go through and for each $i < j,$ we’ll set $A_{i,j} = -\infty.$ This is equivalent to zeroing them out but for the softmax, since $\exp(-\infty) = 0.$ Problem solved. Does this actually train what we want it to do? The evidence says yes, but I haven’t actually looked at the conditional loss here, so I can’t help you with that.

Epilogue

Thank you for reading this far. This was a long winded road taking many liberties in notation, but this is how I understand transformers (for the worse). It is frightening how much of this was written by pressing tab and accepting a copilot suggestion. Some day (soon) I’ll write up more with batching and some more implementation and how to make this faster. I sure love catching up to SOTA in 2021 in 3 weeks, although with the state of open source models (Starcoder2, DeepSeek, Llama2), we’re just catching up to OpenAI in 2022 anyways, so maybe I’m not that behind.

Built with Pollen and Racket, inspired by
  Eric Zhang