You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _pages/dat450/assignment2.md
+4-4Lines changed: 4 additions & 4 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -84,7 +84,7 @@ The figure below shows what we will have to implement.
84
84
85
85
Continuing to work in `forward`, now compute query, key, and value representations; don't forget the normalizers after the query and key representations.
86
86
87
-
Now, we need to reshape the query, key, and value tensors so that the individual attention heads are stored separately. Assume your tensors have the shape $(b, m, d)$, where $b$ is the batch size, $m$ the text length, and $d$ the hidden layer size. We now need to reshape and transpose so that we get $(b, n_h, m, d_h)$ where $n_h$ is the number of attention heads and $d_h$ the attention head dimensionality. Your code could be something like the following (apply this to queries, keys, and values):
87
+
Now, we need to reshape the query, key, and value tensors so that the individual attention heads are stored separately. Assume your tensors have the shape \((b, m, d)\), where \(b\) is the batch size, \(m\) the text length, and \(d\) the hidden layer size. We now need to reshape and transpose so that we get \((b, n_h, m, d_h)\) where \(n_h\) is the number of attention heads and \(d_h\) the attention head dimensionality. Your code could be something like the following (apply this to queries, keys, and values):
88
88
89
89
```
90
90
q = q.view(b, m, n_h, d_h).transpose(1, 2)
@@ -103,15 +103,15 @@ We will explain the exact computations in the hint below, but conveniently enoug
In that case, the <ahref="https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html">documentation of the PyTorch implementation</a> includes a piece of code that can give you some inspiration and that you can simplify somewhat.
105
105
106
-
Assuming your query, key, and value tensors are called $q$, $k$, and $v$, then the computations you should carry out are the following. First, we compute the *attention pre-activations*, which are compute by multiplying query and key representations, and scaling:
106
+
Assuming your query, key, and value tensors are called \(q\), \(k\), and \(v\), then the computations you should carry out are the following. First, we compute the *attention pre-activations*, which are compute by multiplying query and key representations, and scaling:
The transposition of the key tensor can be carried out by calling <code>k.transpose(-2, -1)</code>.
113
113
114
-
Second, add a *causal mask* to the pre-activations. This mask is necessary for autoregressive (left-to-right) language models: this is so that the attention heads can only consider tokens before the current one. The mask should have the shape $(m, m)$; its lower triangle including the diagonal should be 0 and the upper triangle $-\infty$. Pytorch's <ahref="https://docs.pytorch.org/docs/stable/generated/torch.tril.html"><code>tril</code></a> or <ahref="https://docs.pytorch.org/docs/stable/generated/torch.triu.html"><code>triu</code></a> can be convenient here.
114
+
Second, add a *causal mask* to the pre-activations. This mask is necessary for autoregressive (left-to-right) language models: this is so that the attention heads can only consider tokens before the current one. The mask should have the shape \((m, m)\); its lower triangle including the diagonal should be 0 and the upper triangle \(-\infty\). Pytorch's <ahref="https://docs.pytorch.org/docs/stable/generated/torch.tril.html"><code>tril</code></a> or <ahref="https://docs.pytorch.org/docs/stable/generated/torch.triu.html"><code>triu</code></a> can be convenient here.
115
115
116
116
Then apply the softmax to get the attention weights.
117
117
@@ -131,7 +131,7 @@ $$
131
131
```
132
132
attn_out = attn_out.transpose(1, 2).reshape(b, m, d)
133
133
```
134
-
Then compute the final output representation (by applying the linear layer we called $W_O$ above) and return the result.
134
+
Then compute the final output representation (by applying the linear layer we called \(W_O\) above) and return the result.
135
135
136
136
**Sanity check steps 2 and 3.**
137
137
Once again create a MHA layer for testing and apply it to an input tensor of the same shape as before. Assuming you don't get any crashes here, the output should be of the same shape as the input. If it crashes or your output has the wrong shape, insert `print` statements along the way, or use an editor with step-by-step debugging, to check the shapes at each step.
0 commit comments