Home

I spent the morning trying to understand Self-attention. Here's what I got:

Carlos Gruss

October 2024

Let's say we have an input consisting of a certain sequence of vectors. We can organize these into a matrix \( X \):

$$ X = \begin{pmatrix} \vec{E}_1 \\ \vec{E}_2 \\ \vdots \\ \vec{E}_N \end{pmatrix} $$

where \( N \) is the length of the input sequence and each \( \vec{E}_i \in \mathbb{R}^{d_{\text{model}}} \) is the embedding for the \( i \)th input. The embeddings are just a vector representation of whatever input we originally had (for example, in an NLP application, a token embedding obtained with word2vec, or in a Vision Transformer the output of a convolutional layer). Note that \( d_{\text{model}} \) is generally large, for example in GPT-3, \( d_{\text{model}} = 12,288 \).

We then want to linearly project each of these embeddings into a smaller space of dimension \( d_k \) (for GPT-3, \( d_k = 128 \)). We can do this by performing a set of matrix multiplications:

$$ \begin{align*} Q &= X W_Q \\ K &= X W_K \\ V &= X W_V \end{align*} $$

where each of these projection matrices \( W_* \) are of dimensions \( d_{\text{model}} \times d_k \), and their values are learned throughout the training process. \( Q, K, V \) stand for Query, Key, and Value respectively. We thus obtain matrices:

$$ Q = \begin{pmatrix} \vec{Q}_1 \\ \vec{Q}_2 \\ \vdots \\ \vec{Q}_N \end{pmatrix} $$ $$ K = \begin{pmatrix} \vec{K}_1 \\ \vec{K}_2 \\ \vdots \\ \vec{K}_N \end{pmatrix} $$

where for instance we have \( \vec{Q}_i = \vec{E}_i W_Q \).

Very simply put, we then want to, given a Query vector, recover the single or multiple Key vectors that most resemble the Query vector. To do this, we perform a dot product between each pair of vectors, which gives a measure of similarity between them. A larger dot product indicates that the vectors are more closely aligned (i.e., pointing in a similar direction) in the \( d_k \)-dimensional space.

$$ QK^\top = \begin{pmatrix} \vec{Q}_1 \cdot \vec{K}_1 & \vec{Q}_1 \cdot \vec{K}_2 & \cdots & \vec{Q}_1 \cdot \vec{K}_N \\ \vec{Q}_2 \cdot \vec{K}_1 & \vec{Q}_2 \cdot \vec{K}_2 & \cdots & \vec{Q}_2 \cdot \vec{K}_N \\ \vdots & \vdots & \ddots & \vdots \\ \vec{Q}_N \cdot \vec{K}_1 & \vec{Q}_N \cdot \vec{K}_2 & \cdots & \vec{Q}_N \cdot \vec{K}_N \end{pmatrix} $$

We can observe that in the \( i \)th row we're trying to "match" \( \vec{Q}_i \) to each of the Key vectors we have, by performing the dot products between \( \vec{Q}_i \) and each of the \( \vec{K}_j \)s. We'll have large values where \( \vec{Q}_i \) and \( \vec{K}_j \) are closely aligned and small or negative values where they're not.

For numerical/stability reasons, we divide the whole matrix by the scalar \( \sqrt{d_k} \), and then perform a row-wise soft-max:

$$ A = \text{Softmax}\left( \frac{QK^\top}{\sqrt{d_k}} \right) $$

This matrix \( A \) is a square \( N \times N \) matrix where the \( i \)th row is a probability distribution expressing the importance of each of the \( N \) keys with respect to the \( i \)th query. In other words, \( A_{ij} \) represents how much attention the \( i \)th position pays to the \( j \)th position.

The final step is integrating \( V \), which we haven't used yet. Conceptually, we can imagine the value vectors hold the actual information we want to pay attention to when computing the output. So in \( V \) we have a sequence of encoded input information, while in \( A \) we have computed the importance of each input \( j \) when predicting the output \( i \). To do this, we perform the product:

$$ Z = A V = \text{Softmax}\left( \frac{QK^\top}{\sqrt{d_k}} \right) V $$

To break down this matrix product, picture that each row is performing the following:

$$ \vec{Z}_i = \sum_{j=1}^N A_{ij} \vec{V}_j $$

So each output \( \vec{Z}_i \) is incorporating information from each input \( \vec{V}_j \) weighted by how much the output \( i \) should pay attention to the input \( j \).

Observe that, starting from an input \( X \in \mathbb{R}^{N \times d_{\text{model}}} \), we're obtaining a matrix \( Z \in \mathbb{R}^{N \times d_k} \). Each row of this matrix is a vector \( \vec{Z}_i \) which is a contextually enriched representation of the input embedding \( \vec{E}_i \). This means that \( \vec{Z}_i \) not only contains information about the \( i \)th input but also incorporates relevant information from other positions in the sequence, as determined by the attention weights \( A_{ij} \).

These vectors can then be used for further processing within the model or directly for tasks such as classification, translation, or sequence generation, depending on the application.

In practice, the self-attention mechanism is often extended to multi-head attention. Instead of computing attention once, the model computes it multiple times in parallel, each with different learned projection matrices. This allows the model to capture information from different representation subspaces.

For each of the \( H \) heads, we perform the self-attention operations:

$$ \begin{align*} Q^{(h)} &= X W_Q^{(h)} \\ K^{(h)} &= X W_K^{(h)} \\ V^{(h)} &= X W_V^{(h)} \\ \end{align*} $$ $$ \begin{align*} Z^{(h)} &= \text{Softmax}\left( \frac{Q^{(h)} {K^{(h)}}^\top}{\sqrt{d_k}} \right) V^{(h)} \end{align*} $$

We then concatenate the outputs from each head:

$$ Z = \text{Concat}\left( Z^{(1)},\ Z^{(2)},\ \dots,\ Z^{(H)} \right) $$

This approach allows the model to attend to multiple aspects of the input simultaneously, enhancing its ability to capture complex patterns and dependencies.