Skip to content

Commit c0b7397

Browse files
authored
Merge pull request #114 from ricj/master
attention preliminary
2 parents 88989e5 + 66c7412 commit c0b7397

File tree

1 file changed

+57
-7
lines changed

1 file changed

+57
-7
lines changed

_pages/dat450/assignment2.md

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,26 +56,76 @@ Create an untrained MLP layer. Create some 3-dimensional tensor where the last d
5656
### Normalization
5757

5858
To stabilize gradients during training, deep learning models with many layers often include some *normalization* (such as batch normalization or layer normalization). Transformers typically includes normalization layers at several places in the stack.
59-
6059
OLMo 2 uses a type of normalization called [Root Mean Square layer normalization](https://arxiv.org/pdf/1910.07467).
6160

62-
Here, you can either implement your own normalization layer, or use the built-in [`RMSNorm`](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html) from PyTorch. In the PyTorch implementation, `eps` corresponds to `rms_norm_eps` from our model configuration, while `normalized_shape` should be equal to the hidden layer size. The hyperparameter `elementwise_affine` should be set to `True`, meaning that we include some learnable weights in this layer instead of a pure normalization.
61+
You can either implement your own normalization layer, or use the built-in [`RMSNorm`](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html) from PyTorch. In the PyTorch implementation, `eps` corresponds to `rms_norm_eps` from our model configuration, while `normalized_shape` should be equal to the hidden layer size. The hyperparameter `elementwise_affine` should be set to `True`, meaning that we include some learnable weights in this layer instead of a pure normalization.
6362

64-
If you want to make your own layer, the PyTorch documentation shows the formula you will have to implement. (The $\gamma_i$ parameters are the learnable weights.)
63+
If you want to make your own layer, the PyTorch documentation shows the formula you should implement. (The $\gamma_i$ parameters are the learnable weights.)
6564

6665
**Sanity check.**
6766

6867
You can test this in the same way as you tested the MLP previously.
6968

7069
### Multi-head attention
7170

72-
Let's take the trickiest part first!
71+
Now, let's turn to the tricky part!
7372

74-
It is OK to use PyTorch's [`scaled_dot_product_attention`](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) to compute the final step. (In that case, set `is_causal=True`.)
73+
The smaller versions of the OLMo 2 model, which we will follow here, use the same implementation of *multi-head attention* as the original Transformer, plus a couple of additional normalizers. (The bigger OLMo 2 models use [grouped-query attention](https://sebastianraschka.com/llms-from-scratch/ch04/04_gqa/) rather than standard MHA; GQA is also used in various Llama, Qwen and some other popular LLMs.)
7574

76-
If you want to use your own implementation, the [documentation of the PyTorch implementation](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) includes a piece of code that you can start from.
75+
The figure below shows what we will have to implement.
7776

78-
**Sanity check.**
77+
**Hyperparameters:** The hyperparameters you will need to consider when implementing the MHA are
78+
`hidden_size` which defines the input dimensionality as in the MLP and normalizer above, and
79+
`num_attention_heads` which defines the number of attention heads. **Note** that `hidden_size` has to be evenly divisible by `num_attention_heads`. (Below, we will refer to `hidden_size // num_attention_heads` as the head dimensionality $d_h$.)
80+
81+
**Defining MHA components.** In `__init__`, define the `nn.Linear` components (square matrices) that compute query, key, and value representations, and the final outputs. (They correspond to what we called $W_Q$, $W_K$, $W_V$, and $W_O$ in [the lecture on Transformers](https://www.cse.chalmers.se/~richajo/dat450/lectures/l4/m4_2.pdf).) OLMo 2 also applies layer normalizers after the query and key representations.
82+
83+
**MHA computation, step 1.** The `forward` method takes two inputs `hidden_states` and `position_embedding`.
84+
85+
Continuing to work in `forward`, now compute query, key, and value representations; don't forget the normalizers after the query and key representations.
86+
87+
Now, we need to reshape the query, key, and value tensors so that the individual attention heads are stored separately. Assume your tensors have the shape $(b, m, d)$, where $b$ is the batch size, $m$ the text length, and $d$ the hidden layer size. We now need to reshape and transpose so that we get $(b, n_h, m, d_h)$ where $n_h$ is the number of attention heads and $d_h$ the attention head dimensionality. Your code could be something like the following (apply this to queries, keys, and values):
88+
89+
```
90+
q = q.view(b, m, n_h, d_h).transpose(1, 2)
91+
```
92+
93+
Now apply the RoPE rotations to the query and key representations. Use the utility function `apply_rotary_pos_emb` provided in the code skeleton and just provide the `position_embedding` that you received as an input to `forward`. The utility function returns the modified query and key representations.
94+
95+
**Sanity check step 1.**
96+
Create an untrained MHA layer. Create some 3-dimensional tensor where the last dimension has the same size as `hidden_size`, as you did in the previous sanity checks. Apply the MHA layer with what you have implemented so far and make sure it does not crash. (It is common to see errors related to tensor shapes here.)
97+
98+
**MHA computation, step 2.** Now, implement the attention mechanism itself.
99+
We will explain the exact computations in the hint below, but conveniently enough PyTorch's [`scaled_dot_product_attention`](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (with `is_causal=True`) implements everything that we have to do here. Optionally, implement your own solution.
100+
101+
<details>
102+
<summary><b>Hint</b>: Some advice if you want to implement your own attention.</summary>
103+
<div style="margin-left: 10px; border-radius: 4px; background: #ddfff0; border: 1px solid black; padding: 5px;">
104+
In that case, the <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html">documentation of the PyTorch implementation</a> includes a piece of code that can give you some inspiration and that you can simplify somewhat.
105+
106+
Assuming your query, key, and value tensors are called $q$, $k$, and $v$, then the computations you should carry out are the following. First, we compute the *attention pre-activations*, which are compute by multiplying query and key representations, and scaling:
107+
108+
$$
109+
\alpha(q, k) = \frac{q \cdot k^{\top}}{\sqrt{d_h}}
110+
$$
111+
112+
Second, add a *causal mask* to the pre-activations. This mask is necessary for autoregressive (left-to-right) language models: this is so that the attention heads can only consider tokens before the current one. The mask should have the shape $(m, m)$; its lower triangle including the diagonal should be 0 and the upper triangle $-\infty$. Pytorch's <a href="https://docs.pytorch.org/docs/stable/generated/torch.tril.html"><code>tril</code></a> can be convenient here.
113+
114+
Then apply the softmax to get the attention weights.
115+
116+
$$
117+
A(q, k) = \text{softmax}(\alpha(q, k) + \text{mask})
118+
$$
119+
120+
Finally, multiply the attention weights by the value tensor and return the result.
121+
122+
$$
123+
\text{Attention}(q, k, v) = A(q, k) \cdot v
124+
$$
125+
</div>
126+
</details>
127+
128+
**Sanity check step 2.**
79129

80130
### The full Transformer block
81131

0 commit comments

Comments
 (0)