Skip to content

Commit 51c6381

Browse files
authored
Merge pull request #119 from ricj/master
trying to fix bugs
2 parents c745ac7 + 70c649f commit 51c6381

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

_pages/dat450/assignment2.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ We will explain the exact computations in the hint below, but conveniently enoug
103103
<div style="margin-left: 10px; border-radius: 4px; background: #ddfff0; border: 1px solid black; padding: 5px;">
104104
In that case, the <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html">documentation of the PyTorch implementation</a> includes a piece of code that can give you some inspiration and that you can simplify somewhat.
105105

106-
Assuming your query, key, and value tensors are called $$q$$, $$k$$, and $$v$$, then the computations you should carry out are the following. First, we compute the <em>attention pre-activations</em>, which are compute by multiplying query and key representations, and scaling:
106+
Assuming your query, key, and value tensors are called \(q\), \(k\), and \(v\), then the computations you should carry out are the following. First, we compute the <em>attention pre-activations</em>, which are compute by multiplying query and key representations, and scaling:
107107

108108
$$
109109
\alpha(q, k) = \frac{q \cdot k^{\top}}{\sqrt{d_h}}
@@ -131,7 +131,7 @@ $$
131131
```
132132
attn_out = attn_out.transpose(1, 2).reshape(b, m, d)
133133
```
134-
Then compute the final output representation (by applying the linear layer we called \(W_O\) above) and return the result.
134+
Then compute the final output representation (by applying the linear layer we called $$W_O$$ above) and return the result.
135135

136136
**Sanity check steps 2 and 3.**
137137
Once again create a MHA layer for testing and apply it to an input tensor of the same shape as before. Assuming you don't get any crashes here, the output should be of the same shape as the input. If it crashes or your output has the wrong shape, insert `print` statements along the way, or use an editor with step-by-step debugging, to check the shapes at each step.

0 commit comments

Comments
 (0)