Jay Mody Blog
https://jaykmody.com
A blog about things and stuff.enSpeculative Sampling
https://jaykmody.com/blog/speculative-sampling/
<p>This post provides an overview, implementation, and time complexity analysis of DeepMind's paper <a href="https://arxiv.org/abs/2302.01318">Accelerating Large Language Model Decoding with Speculative Sampling</a>.</p>
<p>Code for this blog post can be found at <a href="https://github.com/jaymody/speculative-sampling">github.com/jaymody/speculative-samlping</a>.</p>
<p><strong>EDIT (Apr 13th, 2023):</strong> Updated code and time complexity to avoid the extra forward pass of the draft model (credits to <a href="https://github.com/jaymody/speculative-sampling/issues/1">KexinFeng</a>).</p>
<p></p><div class="table-of-contents"><ul></ul></div><p></p>
<h1 id="autoregressive-sampling" tabindex="-1">Autoregressive Sampling</h1>
<p>The standard way of generating text from a language model is with <strong>autoregressive sampling</strong>, here's the algorithm as defined in the paper:</p>
<figure><img src="https://i.imgur.com/YrLebkI.png" alt="" /></figure>
<p>In code:</p>
<pre><code class="language-python">def autoregressive_sampling(x, model, N):
n = len(x)
T = len(x) + N
while n < T:
x = np.append(x, sample(model(x)[-1]))
n += 1
return x
</code></pre>
<p>Where:</p>
<ul>
<li><code>x</code> is a list of integers representing the token ids of the input text</li>
<li><code>model</code> is a language model (like GPT-2) that accepts as input a list of token ids of length <code>seq_len</code> and outputs a matrix of probabilities of shape <code>[seq_len, vocab_size]</code>.</li>
<li><code>N</code> is the number of tokens we want to decode.</li>
</ul>
<p>The time complexity of this algorithm is \(O(N \cdot t_{\text{model}})\):</p>
<ul>
<li>\(N\): The number of iterations of our while loop, which is just the number of tokens to decode \(N\).</li>
<li>\(t_{\text{model}}\): The time complexity of each iteration in the loop, which is just the time taken for a single forward pass of our model \(t_{\text{model}}\).</li>
</ul>
<h1 id="speculative-sampling" tabindex="-1">Speculative Sampling</h1>
<p>In <strong>speculative sampling</strong>, we have two models:</p>
<ol>
<li>A smaller, faster <strong>draft model</strong> (e.g. DeepMind's 7B Chinchilla model)</li>
<li>A larger, slower <strong>target model</strong> (e.g. DeepMind's 70B Chinchilla model)</li>
</ol>
<p>The idea is that the draft model <em>speculates</em> what the output is \(K\) steps into the future, while the target model determines how many of those tokens we should <em>accept</em>. Here's an outline of the algorithm:</p>
<ol>
<li>The draft model decodes \(K\) tokens in the regular autoregressive fashion.</li>
<li>We get the probability outputs of the target and draft model on the new predicted sequence.</li>
<li>We compare the target and draft model probabilities to determine how many of the \(K\) tokens we want to keep based on some <strong>rejection criteria</strong>. If a token is rejected, we <strong>resample</strong> it using a combination of the two distributions and don't accept any more tokens.</li>
<li>If all \(K\) tokens are accepted, we can sample an additional final token from the target model probability output.</li>
</ol>
<p>As such, instead of decoding a single token at each iteration, speculative sampling decodes between 1 to \(K + 1\) tokens per iteration. If no tokens are accepted, we resample guaranteeing at least 1 token is decoded. If all \(K\) tokens are accepted, then we can also sample a final token from the target models probability distribution, giving us a total of \(K + 1\) tokens decoded.</p>
<p>For example, consider the common idiom "The apple doesn't fall far from the tree". Given just the first part of the phrase, "The apple doesn't fall", in speculative sampling with \(K=4\):</p>
<ol>
<li>The draft model speculates the output to be "far from the tree" (4 tokens)</li>
<li>The target model looks at those tokens, and decides to accept them all, and also sample a final token (i.e. maybe it samples a period ".").</li>
</ol>
<p>As such, in a single iteration, we were able to decode 5 tokens instead of just a single token. However, this may not always be the case, consider instead the input "Not all heroes":</p>
<ol>
<li>The draft model speculates the output to be "wear capes and hats" (4 tokens)</li>
<li>The target model looks at those tokens, but decides to only accepts the first two "wear capes" and discard the rest.</li>
</ol>
<p>In this case, only 2 tokens were accepted.</p>
<p>As long as the draft model is sufficiently faster than the target model <strong>while also</strong> maintaining a high enough <strong>acceptance rate</strong>, then speculative sampling should yield a speedup.</p>
<p>The intuition behind speculative sampling is that certain strings of tokens (common phrases, pronouns, punctuation, etc ...) are fairly easy to predict, so a smaller, less powerful, but faster draft model should be able to quickly predict these instead of having our slower target model doing all the work.</p>
<p>Another important property of speculative sampling is that it is <strong>mathematically equivalent</strong> to sampling from the target model, due to the way the rejection criteria and resampling method are designed. The <a href="https://arxiv.org/pdf/2302.01318.pdf#page=10">proof for this is shown in the paper (Theorem 1)</a>.</p>
<p>Finally, speculative sampling requires no changes to the model's architecture, training, or anything like that. It can be used with existing models alongside other inference techniques such as quantization, hardware acceleration, flash attention, etc ... It can also be used with top-p/top-k/temperature.</p>
<p>Here's the full algorithm as defined in the paper:</p>
<figure><img src="https://i.imgur.com/rhR3U46.png" alt="" /></figure>
<p>In code (<a href="https://github.com/jaymody/speculative-sampling">full implementation here</a>):</p>
<pre><code class="language-python">def max_fn(x):
x_max = np.where(x > 0, x, 0)
return x_max / np.sum(x_max)
def speculative_sampling(x, draft_model, target_model, N, K):
# NOTE: paper indexes arrays starting from 1, python indexes from 0, so
# we have to add an extra -1 term when indexing using n, T, or t
n = len(x)
T = len(x) + N
while n < T:
# Step 1: auto-regressive decode K tokens from draft model and get final p
x_draft = x
for _ in range(K):
p = draft_model(x_draft)
x_draft = np.append(x_draft, sample(p[-1]))
# Step 2: target model forward passes on x_draft
q = target_model(x_draft)
# Step 3: append draft tokens based on rejection criterion and resample
# a token on rejection
all_accepted = True
for _ in range(K):
i = n - 1
j = x_draft[i + 1]
if np.random.random() < min(1, q[i][j] / p[i][j]): # accepted
x = np.append(x, j)
n += 1
else: # rejected
x = np.append(x, sample(max_fn(q[i] - p[i]))) # resample
n += 1
all_accepted = False
break
# Step 4: if all draft tokens were accepted, sample a final token
if all_accepted:
x = np.append(x, sample(q[-1]))
n += 1
# just keeping my sanity
assert n == len(x), f"{n} {len(x)}"
return x
</code></pre>
<p>The time complexity for this algorithm is \(O(\frac{N}{r(K + 1)} \cdot (t_{\text{draft}}K + t_{\text{target}}))\).</p>
<ul>
<li>\(\frac{N}{r(K+1)}\): The number of iterations in our while loop. This works out to the number of tokens we want to decode \(N\) divided by the average number of tokens that get decoded per iteration \(r(K + 1)\). The paper doesn't directly report the average number of tokens that get decoded per iteration, instead they provide the acceptance rate \(r\) (which is the average number of tokens decoded per iteration divided by \(K + 1\))<sup class="footnote-ref"><a href="https://jaykmody.com/blog/speculative-sampling/#fn1" id="fnref1">[1]</a></sup>. As such, we can recover the average number of tokens decoded simply by multiplying \(r\) by \(K + 1\).</li>
<li>\(t_{\text{draft}}K + t_{\text{target}}\): The time complexity for each iteration in the loop. The \(t_{\text{target}}\) term is for the single forward pass of the target model in step 2, and \(t_{\text{draft}}K\) is for the \(K\) forward passes of the draft model in step 1.</li>
</ul>
<h1 id="speedup-results" tabindex="-1">Speedup Results</h1>
<p>The paper reports the following speedups for their 70B Chinchilla model (using a specially trained 7B Chinchilla as the draft model):</p>
<figure><img src="https://i.imgur.com/3ZcmZfr.png" alt="" /></figure>
<p>You can see that there was no performance degradation and the decoding process is 2 times faster as compared to autoregressive decoding.</p>
<p>Let's compare these empirical speedup numbers to theoretical speedup numbers, which we can calculate using our time complexity equations:</p>
<p>\[
\begin{align}
\text{speedup} & = \frac{\text{time complexity of autoregressive}}{\text{time complexity of speculative}} \\
& = \frac{N\cdot t_{\text{target}}}{\frac{N}{r(K + 1)} \cdot (t_{\text{draft}}K + t_{\text{target}})}
& \\
& = \frac{r(K + 1) \cdot t_{\text{target}}}{t_{\text{draft}}K + t_{\text{target}}}
\end{align}
\]</p>
<p>Using the values provided in the paper:</p>
<ul>
<li>\(K = 4\)</li>
<li>\(t_{\text{draft}} = 1.8\text{ms}\)</li>
<li>\(t_{\text{target}} = 14.1\text{ms}\)</li>
<li>\(r = 0.8\) for HumanEval and \(r = 0.62\) for XSum (see figure 1 in the paper)</li>
</ul>
<p>For HumanEval we get a theoretical speedup of <strong>2.65</strong>, while the paper reports an empirical speedup of <strong>2.46</strong>.</p>
<p>For XSum we get a theoretical speedup of <strong>2.05</strong>, while the paper reports an empirical speedup of <strong>1.92</strong>.</p>
<p>We can reproduce these results by <a href="https://github.com/jaymody/speculative-sampling">running our implementation with GPT-2 1.5B as our target model and GPT-2 124M as our draft model</a>:</p>
<pre><code class="language-python">python main.py \
--prompt "Alan Turing theorized that computers would one day become" \
--n_tokens_to_generate 40 \
--draft_model_size "124M" \
--target_model_size "1558M" \
--K 4 \
--temperature 0 \
--seed 123
</code></pre>
<p>Which gives a speedup of <strong>2.23</strong>:</p>
<pre><code class="language-text">Time = 60.64s
Text = Alan Turing theorized that computers would one day become so powerful that they would be able to think like humans.
In the 1950s, he proposed a way to build a computer that could think like a human. He called it the "T
Speculative Decode
------------------
Time = 27.15s
Text = Alan Turing theorized that computers would one day become so powerful that they would be able to think like humans.
In the 1950s, he proposed a way to build a computer that could think like a human. He called it the "T
</code></pre>
<p>Note, the output is the exact same for both methods due to the use of <code>temperature = 0</code>, which corresponds to <strong>greedy sampling</strong> (always taking the token with the highest probability). If a non-zero temperature were used, this would not be the case. Although speculative sampling is mathematically the same as sampling from the target model directly, the results of autoregressive and speculative sampling will be different due to randomness. Speculative sampling giving a different result than autoregressive sampling is akin to running autoregressive sampling but with a different seed. When <code>temperature = 0</code> however, a 100% of the probability is assigned to a single token, so sampling from the distribution becomes deterministic, hence why the outputs are the same. If we instead used <code>temperature = 0.5</code>, we'd get different outputs:</p>
<pre><code>Autoregressive Decode
---------------------
Time = 49.06s
Text = Alan Turing theorized that computers would one day become self-aware. This is known as the "Turing Test" and it is a test that has been used to determine if a computer is intelligent.
The Turing Test is based on the
Speculative Decode
------------------
Time = 31.60s
Text = Alan Turing theorized that computers would one day become so powerful that they would be able to simulate the behavior of human minds. The Turing Test is a test that asks a computer to recognize whether a given piece of text is a human or a computer generated
</code></pre>
<hr class="footnotes-sep" />
<section class="footnotes">
<ol class="footnotes-list">
<li id="fn1" class="footnote-item"><p>The wording from the paper for \(r\) is a bit misleading. The paper states that \(r\) is "the average number of tokens <strong>accepted</strong> divided by \(K + 1\)". This gives the impression they are reporting the rate at which <strong>just</strong> the draft tokens are accepted (i.e. don't include the resampled and final sampled tokens). In actuality, \(r\) is "the average number of tokens <strong>decoded</strong> divided by \(K + 1\)" meaning we also include the resampled and final token. This would make sense since otherwise, they would have to divided \(r\) by \(K\) and not \(K + 1\) when reporting \(r\). I confirmed this with the authors of the paper. <a href="https://jaykmody.com/blog/speculative-sampling/#fnref1" class="footnote-backref">↩︎</a></p>
</li>
</ol>
</section>
Wed, 08 Feb 2023 00:00:00 +0000Jay Modyhttps://jaykmody.com/blog/speculative-sampling/GPT in 60 Lines of NumPy
https://jaykmody.com/blog/gpt-from-scratch/
<p>In this post, we'll implement a GPT from scratch in just <a href="https://github.com/jaymody/picoGPT/blob/29e78cc52b58ed2c1c483ffea2eb46ff6bdec785/gpt2_pico.py#L3-L58">60 lines of <code>numpy</code></a>. We'll then load the trained GPT-2 model weights released by OpenAI into our implementation and generate some text.</p>
<p><strong>Note:</strong></p>
<ul>
<li>This post assumes familiarity with Python, NumPy, and some basic experience with neural networks.</li>
<li>This implementation is for educational purposes, so it's missing lots of features/improvements on purpose to keep it as simple as possible while remaining complete.</li>
<li>All the code for this blog post can be found at <a href="https://github.com/jaymody/picoGPT">github.com/jaymody/picoGPT</a>.</li>
<li><a href="https://news.ycombinator.com/item?id=34726115">Hacker news thread</a></li>
<li><a href="https://jiqihumanr.github.io/2023/04/13/gpt-from-scratch/">Chinese translation</a></li>
<li><a href="https://mlwizardry.netlify.app/nlp/gpt-from-scratch/">Japanese translation</a></li>
</ul>
<p><strong>EDIT (Feb 9th, 2023):</strong> Added a "What's Next" section and updated the intro with some notes.<br />
<strong>EDIT (Feb 28th, 2023):</strong> Added some additional sections to "What's Next".</p>
<h2 id="table-of-contents" tabindex="-1">Table of Contents</h2>
<hr />
<p></p><div class="table-of-contents"><ul><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#table-of-contents">Table of Contents</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#what-is-a-gpt%3F">What is a GPT?</a><ul><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#input-%2F-output">Input / Output</a><ul><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#input">Input</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#output">Output</a></li></ul></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#generating-text">Generating Text</a><ul><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#autoregressive">Autoregressive</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#sampling">Sampling</a></li></ul></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#training">Training</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#prompting">Prompting</a></li></ul></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#setup">Setup</a><ul><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#encoder">Encoder</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#hyperparameters">Hyperparameters</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#parameters">Parameters</a></li></ul></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#basic-layers">Basic Layers</a><ul><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#gelu">GELU</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#softmax">Softmax</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#layer-normalization">Layer Normalization</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#linear">Linear</a></li></ul></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#gpt-architecture">GPT Architecture</a><ul><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#embeddings">Embeddings</a><ul><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#token-embeddings">Token Embeddings</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#positional-embeddings">Positional Embeddings</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#combined">Combined</a></li></ul></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#decoder-stack">Decoder Stack</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#projection-to-vocab">Projection to Vocab</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#decoder-block">Decoder Block</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#position-wise-feed-forward-network">Position-wise Feed Forward Network</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#multi-head-causal-self-attention">Multi-Head Causal Self Attention</a><ul><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#attention">Attention</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#self">Self</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#causal">Causal</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#multi-head">Multi-Head</a></li></ul></li></ul></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#putting-it-all-together">Putting it All Together</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#what-next%3F">What Next?</a><ul><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#gpu%2Ftpu-support">GPU/TPU Support</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#backpropagation">Backpropagation</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#batching">Batching</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#inference-optimization">Inference Optimization</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#training-1">Training</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#evaluation">Evaluation</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#architecture-improvements">Architecture Improvements</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#stopping-generation">Stopping Generation</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#fine-tuning">Fine-tuning</a><ul><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#classification-fine-tuning">Classification Fine-tuning</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#generative-fine-tuning">Generative Fine-tuning</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#instruction-fine-tuning">Instruction Fine-tuning</a></li><li><a href="https://jaykmody.com/blog/gpt-from-scratch/#parameter-efficient-fine-tuning">Parameter Efficient Fine-tuning</a></li></ul></li></ul></li></ul></div><p></p>
<h2 id="what-is-a-gpt%3F" tabindex="-1">What is a GPT?</h2>
<hr />
<p>GPT stands for <strong>Generative Pre-trained Transformer</strong>. It's a type of neural network architecture based on the <a href="https://arxiv.org/pdf/1706.03762.pdf"><strong>Transformer</strong></a>. <a href="https://jalammar.github.io/how-gpt3-works-visualizations-animations/">Jay Alammar's How GPT3 Works</a> is an excellent introduction to GPTs at a high level, but here's the tl;dr:</p>
<ul>
<li><strong>Generative</strong>: A GPT <em>generates</em> text.</li>
<li><strong>Pre-trained</strong>: A GPT is <em>trained</em> on lots of text from books, the internet, etc ...</li>
<li><strong>Transformer</strong>: A GPT is a decoder-only <em>transformer</em> neural network.</li>
</ul>
<p>Large Language Models (LLMs) like <a href="https://en.wikipedia.org/wiki/GPT-3">OpenAI's GPT-3</a> are just GPTs under the hood. What makes them special is they happen to be <strong>1)</strong> very big (billions of parameters) and <strong>2)</strong> trained on lots of data (hundreds of gigabytes of text).</p>
<p>Fundamentally, a GPT <strong>generates text</strong> given a <strong>prompt</strong>. Even with this very simple API (input = text, output = text), a well-trained GPT can do some pretty awesome stuff like <a href="https://machinelearningknowledge.ai/ezoimgfmt/b2611031.smushcdn.com/2611031/wp-content/uploads/2022/12/ChatGPT-Demo-of-Drafting-an-Email.png?lossy=0&strip=1&webp=1&ezimgfmt=ng:webp/ngcb1">write your emails</a>, <a href="https://machinelearningknowledge.ai/ezoimgfmt/b2611031.smushcdn.com/2611031/wp-content/uploads/2022/12/ChatGPT-Example-Book-Summarization.png?lossy=0&strip=1&webp=1&ezimgfmt=ng:webp/ngcb1">summarize a book</a>, <a href="https://khrisdigital.com/wp-content/uploads/2022/12/image-1.png">give you instagram caption ideas</a>, <a href="https://machinelearningknowledge.ai/ezoimgfmt/b2611031.smushcdn.com/2611031/wp-content/uploads/2022/12/ChatGPT-Examples-Explaining-Black-Holes.png?lossy=0&strip=1&webp=1&ezimgfmt=ng:webp/ngcb1">explain black holes to a 5 year old</a>, <a href="https://machinelearningknowledge.ai/ezoimgfmt/b2611031.smushcdn.com/2611031/wp-content/uploads/2022/12/ChatGPT-Demo-of-Writing-SQL-Queries.png?lossy=0&strip=1&webp=1&ezimgfmt=ng:webp/ngcb1">code in SQL</a>, and <a href="https://machinelearningknowledge.ai/ezoimgfmt/b2611031.smushcdn.com/2611031/wp-content/uploads/2022/12/Chat-GPT-Example-Writing-a-Will.png?lossy=0&strip=1&webp=1&ezimgfmt=ng:webp/ngcb1">even write your will</a>.</p>
<p>So that's a high-level overview of GPTs and their capabilities. Let's dig into some more specifics.</p>
<h3 id="input-%2F-output" tabindex="-1">Input / Output</h3>
<p>The function signature for a GPT looks roughly like this:</p>
<pre><code class="language-python">def gpt(inputs: list[int]) -> list[list[float]]:
# inputs has shape [n_seq]
# output has shape [n_seq, n_vocab]
output = # beep boop neural network magic
return output
</code></pre>
<h4 id="input" tabindex="-1">Input</h4>
<p>The input is some text represented by a <strong>sequence of integers</strong> that map to <strong>tokens</strong> in the text:</p>
<pre><code class="language-python"># integers represent tokens in our text, for example:
# text = "not all heroes wear capes":
# tokens = "not" "all" "heroes" "wear" "capes"
inputs = [1, 0, 2, 4, 6]
</code></pre>
<p>Tokens are sub-pieces of the text, which are produced using some kind of <strong>tokenizer</strong>. We can map tokens to integers using a <strong>vocabulary</strong>:</p>
<pre><code class="language-python"># the index of a token in the vocab represents the integer id for that token
# i.e. the integer id for "heroes" would be 2, since vocab[2] = "heroes"
vocab = ["all", "not", "heroes", "the", "wear", ".", "capes"]
# a pretend tokenizer that tokenizes on whitespace
tokenizer = WhitespaceTokenizer(vocab)
# the encode() method converts a str -> list[int]
ids = tokenizer.encode("not all heroes wear") # ids = [1, 0, 2, 4]
# we can see what the actual tokens are via our vocab mapping
tokens = [tokenizer.vocab[i] for i in ids] # tokens = ["not", "all", "heroes", "wear"]
# the decode() method converts back a list[int] -> str
text = tokenizer.decode(ids) # text = "not all heroes wear"
</code></pre>
<p>In short:</p>
<ul>
<li>We have a string.</li>
<li>We use a tokenizer to break it down into smaller pieces called tokens.</li>
<li>We use a vocabulary to map those tokens to integers.</li>
</ul>
<p>In practice, we use more advanced methods of tokenization than simply splitting by whitespace, such as <a href="https://huggingface.co/course/chapter6/5?fw=pt">Byte-Pair Encoding</a> or <a href="https://huggingface.co/course/chapter6/6?fw=pt">WordPiece</a>, but the principle is the same:</p>
<ol>
<li>There is a <code>vocab</code> that maps string tokens to integer indices</li>
<li>There is an <code>encode</code> method that converts <code>str -> list[int]</code></li>
<li>There is a <code>decode</code> method that converts <code>list[int] -> str</code><sup class="footnote-ref"><a href="https://jaykmody.com/blog/gpt-from-scratch/#fn1" id="fnref1">[1]</a></sup></li>
</ol>
<h4 id="output" tabindex="-1">Output</h4>
<p>The output is a <strong>2D array</strong>, where <code>output[i][j]</code> is the model's <strong>predicted probability</strong> that the token at <code>vocab[j]</code> is the next token <code>inputs[i+1]</code>. For example:</p>
<pre><code class="language-python">vocab = ["all", "not", "heroes", "the", "wear", ".", "capes"]
inputs = [1, 0, 2, 4] # "not" "all" "heroes" "wear"
output = gpt(inputs)
# ["all", "not", "heroes", "the", "wear", ".", "capes"]
# output[0] = [0.75 0.1 0.0 0.15 0.0 0.0 0.0 ]
# given just "not", the model predicts the word "all" with the highest probability
# ["all", "not", "heroes", "the", "wear", ".", "capes"]
# output[1] = [0.0 0.0 0.8 0.1 0.0 0.0 0.1 ]
# given the sequence ["not", "all"], the model predicts the word "heroes" with the highest probability
# ["all", "not", "heroes", "the", "wear", ".", "capes"]
# output[-1] = [0.0 0.0 0.0 0.1 0.0 0.05 0.85 ]
# given the whole sequence ["not", "all", "heroes", "wear"], the model predicts the word "capes" with the highest probability
</code></pre>
<p>To get a <strong>next token prediction</strong> for the whole sequence, we simply take the token with the highest probability in <code>output[-1]</code>:</p>
<pre><code class="language-python">vocab = ["all", "not", "heroes", "the", "wear", ".", "capes"]
inputs = [1, 0, 2, 4] # "not" "all" "heroes" "wear"
output = gpt(inputs)
next_token_id = np.argmax(output[-1]) # next_token_id = 6
next_token = vocab[next_token_id] # next_token = "capes"
</code></pre>
<p>Taking the token with the highest probability as our prediction is known as <a href="https://docs.cohere.ai/docs/controlling-generation-with-top-k-top-p#1-pick-the-top-token-greedy-decoding"><strong>greedy decoding</strong></a> or <strong>greedy sampling</strong>.</p>
<p>The task of predicting the next logical word in a sequence is called <strong>language modeling</strong>. As such, we can call a GPT a <strong>language model</strong>.</p>
<p>Generating a single word is cool and all, but what about entire sentences, paragraphs, etc ...?</p>
<h3 id="generating-text" tabindex="-1">Generating Text</h3>
<h4 id="autoregressive" tabindex="-1">Autoregressive</h4>
<p>We can generate full sentences by iteratively getting the next token prediction from our model. At each iteration, we append the predicted token back into the input:</p>
<pre><code class="language-python">def generate(inputs, n_tokens_to_generate):
for _ in range(n_tokens_to_generate): # auto-regressive decode loop
output = gpt(inputs) # model forward pass
next_id = np.argmax(output[-1]) # greedy sampling
inputs.append(int(next_id)) # append prediction to input
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
input_ids = [1, 0] # "not" "all"
output_ids = generate(input_ids, 3) # output_ids = [2, 4, 6]
output_tokens = [vocab[i] for i in output_ids] # "heroes" "wear" "capes"
</code></pre>
<p>This process of predicting a future value (regression), and adding it back into the input (auto), is why you might see a GPT described as <strong>autoregressive</strong>.</p>
<h4 id="sampling" tabindex="-1">Sampling</h4>
<p>We can introduce some <strong>stochasticity</strong> (randomness) to our generations by sampling from the probability distribution instead of being greedy:</p>
<pre><code class="language-python">inputs = [1, 0, 2, 4] # "not" "all" "heroes" "wear"
output = gpt(inputs)
np.random.choice(np.arange(vocab_size), p=output[-1]) # capes
np.random.choice(np.arange(vocab_size), p=output[-1]) # hats
np.random.choice(np.arange(vocab_size), p=output[-1]) # capes
np.random.choice(np.arange(vocab_size), p=output[-1]) # capes
np.random.choice(np.arange(vocab_size), p=output[-1]) # pants
</code></pre>
<p>This allows us to generate different sentences given the same input. When combined with techniques like <a href="https://docs.cohere.ai/docs/controlling-generation-with-top-k-top-p#2-pick-from-amongst-the-top-tokens-top-k"><strong>top-k</strong></a>, <a href="https://docs.cohere.ai/docs/controlling-generation-with-top-k-top-p#3-pick-from-amongst-the-top-tokens-whose-probabilities-add-up-to-15-top-p"><strong>top-p</strong></a>, and <a href="https://docs.cohere.ai/docs/temperature"><strong>temperature</strong></a>, which modify the distribution prior to sampling, the quality of our outputs is greatly increased. These techniques also introduce some hyperparameters that we can play around with to get different generation behaviors (for example, increasing temperature makes our model take more risks and thus be more "creative").</p>
<h3 id="training" tabindex="-1">Training</h3>
<p>We train a GPT like any other neural network, using <a href="https://arxiv.org/pdf/1609.04747.pdf"><strong>gradient descent</strong></a> with respect to some <strong>loss function</strong>. In the case of a GPT, we take the <strong><a href="https://www.youtube.com/watch?v=ErfnhcEV1O8">cross entropy loss</a> over the language modeling task</strong>:</p>
<pre><code class="language-python">def lm_loss(inputs: list[int], params) -> float:
# the labels y are just the input shifted 1 to the left
#
# inputs = [not, all, heros, wear, capes]
# x = [not, all, heroes, wear]
# y = [all, heroes, wear, capes]
#
# of course, we don't have a label for inputs[-1], so we exclude it from x
#
# as such, for N inputs, we have N - 1 langauge modeling example pairs
x, y = inputs[:-1], inputs[1:] # both have shape [num_tokens_in_seq - 1]
# forward pass
# all the predicted next token probability distributions at each position
output = gpt(x, params) # has shape [num_tokens_in_seq - 1, num_tokens_in_vocab]
# cross entropy loss
# we take the average over all N-1 examples
loss = np.mean(-np.log(output[np.arange(len(output)), y]))
return loss
def train(texts: list[list[str]], params) -> float:
for text in texts:
inputs = tokenizer.encode(text)
loss = lm_loss(inputs, params)
gradients = compute_gradients_via_backpropagation(loss, params)
params = gradient_descent_update_step(gradients, params)
return params
</code></pre>
<p>This is a heavily simplified training setup, but it illustrates the point. Notice the addition of <code>params</code> to our <code>gpt</code> function signature (we left this out in the previous sections for simplicity). During each iteration of the training loop:</p>
<ol>
<li>We compute the language modeling loss for the given input text example</li>
<li>The loss determines our gradients, which we compute via backpropagation</li>
<li>We use the gradients to update our model parameters such that the loss is minimized (gradient descent)</li>
</ol>
<p>Notice, we don't use explicitly labelled data. Instead, we are able to produce the input/label pairs from just the raw text itself. This is referred to as <strong><a href="https://en.wikipedia.org/wiki/Self-supervised_learning">self-supervised learning</a></strong>.</p>
<p>Self-supervision enables us to massively scale training data. Just get our hands on as much raw text as possible and throw it at the model. For example, GPT-3 was trained on <strong>300 billion tokens</strong> of text from the internet and books:</p>
<figure><img src="https://miro.medium.com/max/1400/1*Sc3Gi73hepgrOLnx8bXFBA.png" alt="" /><figcaption>Table 2.2 from GPT-3 paper</figcaption></figure>
<p>Of course, you need a sufficiently large model to be able to learn from all this data, which is why GPT-3 has <strong>175 billion parameters</strong> and probably cost between <a href="https://twitter.com/eturner303/status/1266264358771757057">$1m-10m in compute cost to train</a>.<sup class="footnote-ref"><a href="https://jaykmody.com/blog/gpt-from-scratch/#fn2" id="fnref2">[2]</a></sup></p>
<p>This self-supervised training step is called <strong>pre-training</strong>, since we can reuse the "pre-trained" models weights to further train the model on downstream tasks, such as classifying if a tweet is toxic or not. Pre-trained models are also sometimes called <strong>foundation models</strong>.</p>
<p>Training the model on downstream tasks is called <strong>fine-tuning</strong>, since the model weights have already been pre-trained to understand language, it's just being fine-tuned to the specific task at hand.</p>
<p>The "pre-training on a general task + fine-tuning on a specific task" strategy is called <a href="https://en.wikipedia.org/wiki/Transfer_learning">transfer learning</a>.</p>
<h3 id="prompting" tabindex="-1">Prompting</h3>
<p>In principle, the original <a href="https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf">GPT</a> paper was only about the benefits of pre-training a transformer model for transfer learning. The paper showed that pre-training a 117M GPT achieved state-of-the-art performance on various <strong>NLP</strong> (natural language processing) tasks when fine-tuned on labelled datasets.</p>
<p>It wasn't until the <a href="https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf">GPT-2</a> and <a href="https://arxiv.org/abs/2005.14165">GPT-3</a> papers that we realized a GPT model pre-trained on enough data with enough parameters was capable of performing any arbitrary task <strong>by itself</strong>, no fine-tuning needed. Just prompt the model, perform autoregressive language modeling, and voila, the model magically gives us an appropriate response. This is referred to as <strong>in-context learning</strong>, because the model is using just the context of the prompt to perform the task. In-context learning can be zero shot, one shot, or few shot:</p>
<figure><img src="https://i.imgur.com/VKZXC0K.png" alt="" /><figcaption>Figure 2.1 from the GPT-3 Paper</figcaption></figure>
<p>Generating text given a prompt is also sometimes referred to as <strong>conditional generation</strong>, since our model is generating some output <em>conditioned</em> on some input.</p>
<p>GPTs are not limited to NLP tasks. You can condition the model on anything you want. For example, you can turn a GPT into a <strong>chatbot</strong> (i.e. <a href="https://openai.com/blog/chatgpt/">ChatGPT</a>) by conditioning it on the conversation history. You can also further condition the chatbot to behave a certain way by prepending the prompt with some kind of description (i.e. "You are a chatbot. Be polite, speak in full sentences, don't say harmful things, etc ..."). Conditioning the model like this can even give your <a href="https://imgur.com/a/AbDFcgk">chatbot a persona</a>. This is often referred to as a <strong>system prompt</strong>. However, this is not robust, you can still <a href="https://twitter.com/zswitten/status/1598380220943593472">"jailbreak" the model and make it misbehave</a>.</p>
<p>With that out of the way, let's finally get to the actual implementation.</p>
<h2 id="setup" tabindex="-1">Setup</h2>
<hr />
<p>Clone the repository for this tutorial:</p>
<pre><code class="language-bash">git clone https://github.com/jaymody/picoGPT
cd picoGPT
</code></pre>
<p>Then let's install our dependencies:</p>
<pre><code class="language-bash">pip install -r requirements.txt
</code></pre>
<p>Note: This code was tested with <code>Python 3.9.10</code>.</p>
<p>A quick breakdown of each of the files:</p>
<ul>
<li><strong><code>encoder.py</code></strong> contains the code for OpenAI's BPE Tokenizer, taken straight from their <a href="https://github.com/openai/gpt-2/blob/master/src/encoder.py">gpt-2 repo</a>.</li>
<li><strong><code>utils.py</code></strong> contains the code to download and load the GPT-2 model weights, tokenizer, and hyperparameters.</li>
<li><strong><code>gpt2.py</code></strong> contains the actual GPT model and generation code, which we can run as a python script.</li>
<li><strong><code>gpt2_pico.py</code></strong> is the same as <code>gpt2.py</code>, but in even fewer lines of code. Why? Because why not.</li>
</ul>
<p>We'll be reimplementing <code>gpt2.py</code> from scratch, so let's delete it and recreate it as an empty file:</p>
<pre><code class="language-bash">rm gpt2.py
touch gpt2.py
</code></pre>
<p>As a starting point, paste the following code into <code>gpt2.py</code>:</p>
<pre><code class="language-python">import numpy as np
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
pass # TODO: implement this
def generate(inputs, params, n_head, n_tokens_to_generate):
from tqdm import tqdm
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
logits = gpt2(inputs, **params, n_head=n_head) # model forward pass
next_id = np.argmax(logits[-1]) # greedy sampling
inputs.append(int(next_id)) # append prediction to input
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
from utils import load_encoder_hparams_and_params
# load encoder, hparams, and params from the released open-ai gpt-2 files
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
# encode the input string using the BPE tokenizer
input_ids = encoder.encode(prompt)
# make sure we are not surpassing the max sequence length of our model
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
# generate output ids
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
# decode the ids back into a string
output_text = encoder.decode(output_ids)
return output_text
if __name__ == "__main__":
import fire
fire.Fire(main)
</code></pre>
<p>Breaking down each of the 4 sections:</p>
<ol>
<li>The <code>gpt2</code> function is the actual GPT code we'll be implementing. You'll notice that the function signature includes some extra stuff in addition to <code>inputs</code>:
<ul>
<li><code>wte</code>, <code>wpe</code>, <code>blocks</code>, and <code>ln_f</code> are the parameters of our model.</li>
<li><code>n_head</code> is a hyperparameter that is needed during the forward pass.</li>
</ul>
</li>
<li>The <code>generate</code> function is the autoregressive decoding algorithm we saw earlier. We use greedy sampling for simplicity. <a href="https://www.google.com/search?q=tqdm"><code>tqdm</code></a> is a progress bar to help us visualize the decoding process as it generates tokens one at a time.</li>
<li>The <code>main</code> function handles:
<ol>
<li>Loading the tokenizer (<code>encoder</code>), model weights (<code>params</code>), and hyperparameters (<code>hparams</code>)</li>
<li>Encoding the input prompt into token IDs using the tokenizer</li>
<li>Calling the generate function</li>
<li>Decoding the output IDs into a string</li>
</ol>
</li>
<li><a href="https://github.com/google/python-fire"><code>fire.Fire(main)</code></a> just turns our file into a CLI application, so we can eventually run our code with: <code>python gpt2.py "some prompt here"</code></li>
</ol>
<p>Let's take a closer look at <code>encoder</code>, <code>hparams</code>, and <code>params</code>, in a notebook, or an interactive python session, run:</p>
<pre><code class="language-python">from utils import load_encoder_hparams_and_params
encoder, hparams, params = load_encoder_hparams_and_params("124M", "models")
</code></pre>
<p>This will <a href="https://github.com/jaymody/picoGPT/blob/a750c145ba4d09d5764806a6c78c71ffaff88e64/utils.py#L13-L40">download the necessary model and tokenizer files</a> to <code>models/124M</code> and <a href="https://github.com/jaymody/picoGPT/blob/a750c145ba4d09d5764806a6c78c71ffaff88e64/utils.py#L68-L82">load <code>encoder</code>, <code>hparams</code>, and <code>params</code></a> into our code.</p>
<h3 id="encoder" tabindex="-1">Encoder</h3>
<p><code>encoder</code> is the BPE tokenizer used by GPT-2:</p>
<pre><code class="language-python">>>> ids = encoder.encode("Not all heroes wear capes.")
>>> ids
[3673, 477, 10281, 5806, 1451, 274, 13]
>>> encoder.decode(ids)
"Not all heroes wear capes."
</code></pre>
<p>Using the vocabulary of the tokenizer (stored in <code>encoder.decoder</code>), we can take a peek at what the actual tokens look like:</p>
<pre><code class="language-python">>>> [encoder.decoder[i] for i in ids]
['Not', 'Ġall', 'Ġheroes', 'Ġwear', 'Ġcap', 'es', '.']
</code></pre>
<p>Notice, sometimes our tokens are words (e.g. <code>Not</code>), sometimes they are words but with a space in front of them (e.g. <code>Ġall</code>, the <a href="https://github.com/karpathy/minGPT/blob/37baab71b9abea1b76ab957409a1cc2fbfba8a26/mingpt/bpe.py#L22-L33"><code>Ġ</code> represents a space</a>), sometimes there are part of a word (e.g. capes is split into <code>Ġcap</code> and <code>es</code>), and sometimes they are punctuation (e.g. <code>.</code>).</p>
<p>One nice thing about BPE is that it can encode any arbitrary string. If it encounters something that is not present in the vocabulary, it just breaks it down into substrings it does understand:</p>
<pre><code class="language-python">>>> [encoder.decoder[i] for i in encoder.encode("zjqfl")]
['z', 'j', 'q', 'fl']
</code></pre>
<p>We can also check the size of the vocabulary:</p>
<pre><code class="language-python">>>> len(encoder.decoder)
50257
</code></pre>
<p>The vocabulary, as well as the byte-pair merges which determines how strings are broken down, is obtained by <em>training</em> the tokenizer. When we load the tokenizer, we're loading the already trained vocab and byte-pair merges from some files, which were downloaded alongside the model files when we ran <code>load_encoder_hparams_and_params</code>. See <code>models/124M/encoder.json</code> (the vocabulary) and <code>models/124M/vocab.bpe</code> (byte-pair merges).</p>
<h3 id="hyperparameters" tabindex="-1">Hyperparameters</h3>
<p><code>hparams</code> is a dictionary that contains the hyper-parameters of our model:</p>
<pre><code class="language-python">>>> hparams
{
"n_vocab": 50257, # number of tokens in our vocabulary
"n_ctx": 1024, # maximum possible sequence length of the input
"n_embd": 768, # embedding dimension (determines the "width" of the network)
"n_head": 12, # number of attention heads (n_embd must be divisible by n_head)
"n_layer": 12 # number of layers (determines the "depth" of the network)
}
</code></pre>
<p>We'll use these symbols in our code's comments to show the underlying shape of things. We'll also use <code>n_seq</code> to denote the length of our input sequence (i.e. <code>n_seq = len(inputs)</code>).</p>
<h3 id="parameters" tabindex="-1">Parameters</h3>
<p><code>params</code> is a nested json dictionary that hold the trained weights of our model. The leaf nodes of the json are NumPy arrays. If we print <code>params</code>, replacing the arrays with their shapes, we get:</p>
<pre><code class="language-python">>>> import numpy as np
>>> def shape_tree(d):
>>> if isinstance(d, np.ndarray):
>>> return list(d.shape)
>>> elif isinstance(d, list):
>>> return [shape_tree(v) for v in d]
>>> elif isinstance(d, dict):
>>> return {k: shape_tree(v) for k, v in d.items()}
>>> else:
>>> ValueError("uh oh")
>>>
>>> print(shape_tree(params))
{
"wpe": [1024, 768],
"wte": [50257, 768],
"ln_f": {"b": [768], "g": [768]},
"blocks": [
{
"attn": {
"c_attn": {"b": [2304], "w": [768, 2304]},
"c_proj": {"b": [768], "w": [768, 768]},
},
"ln_1": {"b": [768], "g": [768]},
"ln_2": {"b": [768], "g": [768]},
"mlp": {
"c_fc": {"b": [3072], "w": [768, 3072]},
"c_proj": {"b": [768], "w": [3072, 768]},
},
},
... # repeat for n_layers
]
}
</code></pre>
<p>These are loaded from the original OpenAI tensorflow checkpoint:</p>
<pre><code class="language-python">>>> import tensorflow as tf
>>> tf_ckpt_path = tf.train.latest_checkpoint("models/124M")
>>> for name, _ in tf.train.list_variables(tf_ckpt_path):
>>> arr = tf.train.load_variable(tf_ckpt_path, name).squeeze()
>>> print(f"{name}: {arr.shape}")
model/h0/attn/c_attn/b: (2304,)
model/h0/attn/c_attn/w: (768, 2304)
model/h0/attn/c_proj/b: (768,)
model/h0/attn/c_proj/w: (768, 768)
model/h0/ln_1/b: (768,)
model/h0/ln_1/g: (768,)
model/h0/ln_2/b: (768,)
model/h0/ln_2/g: (768,)
model/h0/mlp/c_fc/b: (3072,)
model/h0/mlp/c_fc/w: (768, 3072)
model/h0/mlp/c_proj/b: (768,)
model/h0/mlp/c_proj/w: (3072, 768)
model/h1/attn/c_attn/b: (2304,)
model/h1/attn/c_attn/w: (768, 2304)
...
model/h9/mlp/c_proj/b: (768,)
model/h9/mlp/c_proj/w: (3072, 768)
model/ln_f/b: (768,)
model/ln_f/g: (768,)
model/wpe: (1024, 768)
model/wte: (50257, 768)
</code></pre>
<p>The <a href="https://github.com/jaymody/picoGPT/blob/29e78cc52b58ed2c1c483ffea2eb46ff6bdec785/utils.py#L43-L65">following code</a> converts the above tensorflow variables into our <code>params</code> dictionary.</p>
<p>For reference, here's the shapes of <code>params</code> but with the numbers replaced by the <code>hparams</code> they represent:</p>
<pre><code class="language-python">{
"wpe": [n_ctx, n_embd],
"wte": [n_vocab, n_embd],
"ln_f": {"b": [n_embd], "g": [n_embd]},
"blocks": [
{
"attn": {
"c_attn": {"b": [3*n_embd], "w": [n_embd, 3*n_embd]},
"c_proj": {"b": [n_embd], "w": [n_embd, n_embd]},
},
"ln_1": {"b": [n_embd], "g": [n_embd]},
"ln_2": {"b": [n_embd], "g": [n_embd]},
"mlp": {
"c_fc": {"b": [4*n_embd], "w": [n_embd, 4*n_embd]},
"c_proj": {"b": [n_embd], "w": [4*n_embd, n_embd]},
},
},
... # repeat for n_layers
]
}
</code></pre>
<p>You'll probably want to come back to reference this dictionary to check the shape of the weights as we implement our GPT. We'll match the variable names in our code with the keys of this dictionary for consistency.</p>
<h2 id="basic-layers" tabindex="-1">Basic Layers</h2>
<hr />
<p>Last thing before we get into the actual GPT architecture itself, let's implement some of the more basic neural network layers that are non-specific to GPTs.</p>
<h3 id="gelu" tabindex="-1">GELU</h3>
<p>The non-linearity (<strong>activation function</strong>) of choice for GPT-2 is <a href="https://arxiv.org/pdf/1606.08415.pdf">GELU (Gaussian Error Linear Units)</a>, an alternative for ReLU:</p>
<figure><img src="https://miro.medium.com/max/491/1*kwHcbpKUNLda8tvCiwudqQ.png" alt="" /><figcaption>Figure 1 from the GELU paper</figcaption></figure>
<p>It is approximated by the following function:</p>
<pre><code class="language-python">def gelu(x):
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))
</code></pre>
<p>Like ReLU, GELU operates element-wise on the input:</p>
<pre><code class="language-python">>>> gelu(np.array([[1, 2], [-2, 0.5]]))
array([[ 0.84119, 1.9546 ],
[-0.0454 , 0.34571]])
</code></pre>
<h3 id="softmax" tabindex="-1">Softmax</h3>
<p>Good ole <a href="https://en.wikipedia.org/wiki/Softmax_function">softmax</a>:</p>
<p>\[
\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}}
\]</p>
<pre><code class="language-python">def softmax(x):
exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
</code></pre>
<p>We use the <a href="https://jaykmody.com/blog/stable-softmax/"><code>max(x)</code> trick for numerical stability</a>.</p>
<p>Softmax is used to a convert set of real numbers (between \(-\infty\) and \(\infty\)) to probabilities (between 0 and 1, with the numbers all summing to 1). We apply <code>softmax</code> over the last axis of the input.</p>
<pre><code class="language-python">>>> x = softmax(np.array([[2, 100], [-5, 0]]))
>>> x
array([[0.00034, 0.99966],
[0.26894, 0.73106]])
>>> x.sum(axis=-1)
array([1., 1.])
</code></pre>
<h3 id="layer-normalization" tabindex="-1">Layer Normalization</h3>
<p><a href="https://arxiv.org/pdf/1607.06450.pdf">Layer normalization</a> standardizes values to have a mean of 0 and a variance of 1:</p>
<p>\[
\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2}} + \beta
\]where \(\mu\) is the mean of \(x\), \(\sigma^2\) is the variance of \(x\), and \(\gamma\) and \(\beta\) are learnable parameters.</p>
<pre><code class="language-python">def layer_norm(x, g, b, eps: float = 1e-5):
mean = np.mean(x, axis=-1, keepdims=True)
variance = np.var(x, axis=-1, keepdims=True)
x = (x - mean) / np.sqrt(variance + eps) # normalize x to have mean=0 and var=1 over last axis
return g * x + b # scale and offset with gamma/beta params
</code></pre>
<p>Layer normalization ensures that the inputs for each layer are always within a consistent range, which is supposed to speed up and stabilize the training process. Like <a href="https://arxiv.org/pdf/1502.03167.pdf">Batch Normalization</a>, the normalized output is then scaled and offset with two learnable vectors gamma and beta. The small epsilon term in the denominator is used to avoid a division by zero error.</p>
<p>Layer norm is used instead of batch norm in the transformer for <a href="https://stats.stackexchange.com/questions/474440/why-do-transformers-use-layer-norm-instead-of-batch-norm">various reasons</a>. The differences between various normalization techniques is outlined <a href="https://tungmphung.com/deep-learning-normalization-methods/">in this excellent blog post</a>.</p>
<p>We apply layer normalization over the last axis of the input.</p>
<pre><code class="language-python">>>> x = np.array([[2, 2, 3], [-5, 0, 1]])
>>> x = layer_norm(x, g=np.ones(x.shape[-1]), b=np.zeros(x.shape[-1]))
>>> x
array([[-0.70709, -0.70709, 1.41418],
[-1.397 , 0.508 , 0.889 ]])
>>> x.var(axis=-1)
array([0.99996, 1. ]) # floating point shenanigans
>>> x.mean(axis=-1)
array([-0., -0.])
</code></pre>
<h3 id="linear" tabindex="-1">Linear</h3>
<p>Your standard matrix multiplication + bias:</p>
<pre><code class="language-python">def linear(x, w, b): # [m, in], [in, out], [out] -> [m, out]
return x @ w + b
</code></pre>
<p>Linear layers are often referred to as <strong>projections</strong> (since they are projecting from one vector space to another vector space).</p>
<pre><code class="language-python">>>> x = np.random.normal(size=(64, 784)) # input dim = 784, batch/sequence dim = 64
>>> w = np.random.normal(size=(784, 10)) # output dim = 10
>>> b = np.random.normal(size=(10,))
>>> x.shape # shape before linear projection
(64, 784)
>>> linear(x, w, b).shape # shape after linear projection
(64, 10)
</code></pre>
<h2 id="gpt-architecture" tabindex="-1">GPT Architecture</h2>
<hr />
<p>The GPT architecture follows that of the <a href="https://arxiv.org/pdf/1706.03762.pdf">transformer</a>:</p>
<figure><img src="https://machinelearningmastery.com/wp-content/uploads/2021/08/attention_research_1.png" alt="" /><figcaption>Figure 1 from Attention is All You Need</figcaption></figure>
<p>But uses only the decoder stack (the right part of the diagram):</p>
<figure><img src="https://i.imgur.com/c4Z6PG8.png" alt="" /><figcaption>GPT Architecture</figcaption></figure>
<p>Note, the middle "cross-attention" layer is also removed since we got rid of the encoder.</p>
<p>At a high level, the GPT architecture has three sections:</p>
<ul>
<li>Text + positional <strong>embeddings</strong></li>
<li>A transformer <strong>decoder stack</strong></li>
<li>A <strong>projection to vocab</strong> step</li>
</ul>
<p>In code, it looks like this:</p>
<pre><code class="language-python">def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
# token + positional embeddings
x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]
# forward pass through n_layer transformer blocks
for block in blocks:
x = transformer_block(x, **block, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
# projection to vocab
x = layer_norm(x, **ln_f) # [n_seq, n_embd] -> [n_seq, n_embd]
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]
</code></pre>
<p>Let's break down each of these three sections into more detail.</p>
<h3 id="embeddings" tabindex="-1">Embeddings</h3>
<h4 id="token-embeddings" tabindex="-1">Token Embeddings</h4>
<p>Token IDs by themselves are not very good representations for a neural network. For one, the relative magnitudes of the token IDs falsely communicate information (for example, if <code>Apple = 5</code> and <code>Table = 10</code> in our vocab, then we are implying that <code>2 * Table = Apple</code>). Secondly, a single number is not a lot of <em>dimensionality</em> for a neural network to work with.</p>
<p>To address these limitations, we'll take advantage of <a href="https://jaykmody.com/blog/attention-intuition/#word-vectors-and-similarity">word vectors</a>, specifically via a learned embedding matrix:</p>
<pre><code class="language-python">wte[inputs] # [n_seq] -> [n_seq, n_embd]
</code></pre>
<p>Recall, <code>wte</code> is a <code>[n_vocab, n_embd]</code> matrix. It acts as a lookup table, where the \(i\)th row in the matrix corresponds to the learned vector for the \(i\)th token in our vocabulary. <code>wte[inputs]</code> uses <a href="https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing">integer array indexing</a> to retrieve the vectors corresponding to each token in our input.</p>
<p>Like any other parameter in our network, <code>wte</code> is learned. That is, it is randomly initialized at the start of training and then updated via gradient descent.</p>
<h4 id="positional-embeddings" tabindex="-1">Positional Embeddings</h4>
<p>One quirk of the transformer architecture is that it doesn't take into account position. That is, if we randomly shuffled our input and then accordingly unshuffled the output, the output would be the same as if we never shuffled the input in the first place (the ordering of inputs doesn't have any effect on the output).</p>
<p>Of course, the ordering of words is a crucial part of language (duh), so we need some way to encode positional information into our inputs. For this, we can just use another learned embedding matrix:</p>
<pre><code class="language-python">wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]
</code></pre>
<p>Recall, <code>wpe</code> is a <code>[n_ctx, n_embd]</code> matrix. The \(i\)th row of the matrix contains a vector that encodes information about the \(i\)th position in the input. Similar to <code>wte</code>, this matrix is learned during gradient descent.</p>
<p>Notice, this restricts our model to a maximum sequence length of <code>n_ctx</code>.<sup class="footnote-ref"><a href="https://jaykmody.com/blog/gpt-from-scratch/#fn3" id="fnref3">[3]</a></sup> That is, <code>len(inputs) <= n_ctx</code> must hold.</p>
<h4 id="combined" tabindex="-1">Combined</h4>
<p>We can add our token and positional embeddings to get a combined embedding that encodes both token and positional information.</p>
<pre><code class="language-python"># token + positional embeddings
x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]
# x[i] represents the word embedding for the ith word + the positional
# embedding for the ith position
</code></pre>
<h3 id="decoder-stack" tabindex="-1">Decoder Stack</h3>
<p>This is where all the magic happens and the "deep" in deep learning comes in. We pass our embedding through a stack of <code>n_layer</code> transformer decoder blocks.</p>
<pre><code class="language-python"># forward pass through n_layer transformer blocks
for block in blocks:
x = transformer_block(x, **block, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
</code></pre>
<p>Stacking more layers is what allows us to control how <em>deep</em> our network is. GPT-3 for example, has a <a href="https://preview.redd.it/n9fgba8b0qr01.png?auto=webp&s=e86d2d3447c777d3222016e81a0adfaec1a95592">whopping 96 layers</a>. On the other hand, choosing a larger <code>n_embd</code> value allows us to control how <em>wide</em> our network is (for example, GPT-3 uses an embedding size of 12288).</p>
<h3 id="projection-to-vocab" tabindex="-1">Projection to Vocab</h3>
<p>In our final step, we project the output of the final transformer block to a probability distribution over our vocab:</p>
<pre><code class="language-python"># projection to vocab
x = layer_norm(x, **ln_f) # [n_seq, n_embd] -> [n_seq, n_embd]
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]
</code></pre>
<p>Couple things to note here:</p>
<ol>
<li>We first pass <code>x</code> through a <strong>final layer normalization</strong> layer before doing the projection to vocab. This is specific to the GPT-2 architecture (this is not present in the original GPT and Transformer papers).</li>
<li>We are <strong>reusing the embedding matrix</strong> <code>wte</code> for the projection. Other GPT implementations may choose to use a separate learned weight matrix for the projection, however sharing the embedding matrix has a couple of advantages:
<ul>
<li>You save some parameters (although at GPT-3 scale, this is negligible).</li>
<li>Since the matrix is both responsible for mapping both <em>to</em> words and <em>from</em> words, in theory, it <em>may</em> learn a richer representation compared to having two separate matrixes.</li>
</ul>
</li>
<li>We <strong>don't apply <code>softmax</code></strong> at the end, so our outputs will be <a href="https://developers.google.com/machine-learning/glossary/#logits">logits</a> instead of probabilities between 0 and 1. This is done for several reasons:
<ul>
<li><code>softmax</code> is <a href="https://en.wikipedia.org/wiki/Monotonic_function">monotonic</a>, so for greedy sampling <code>np.argmax(logits)</code> is equivalent to <code>np.argmax(softmax(logits))</code> making <code>softmax</code> redundant</li>
<li><code>softmax</code> is irreversible, meaning we can always go from <code>logits</code> to <code>probabilities</code> by applying <code>softmax</code>, but we can't go back to <code>logits</code> from <code>probabilities</code>, so for maximum flexibility, we output the <code>logits</code></li>
<li>Numerically stability (for example, to compute cross entropy loss, taking <a href="https://jaykmody.com/blog/stable-softmax/#cross-entropy-and-log-softmax"><code>log(softmax(logits))</code> is numerically unstable compared to <code>log_softmax(logits)</code></a></li>
</ul>
</li>
</ol>
<p>The projection to vocab step is also sometimes called the <strong>language modeling head</strong>. What does "head" mean? Once your GPT is pre-trained, you can swap out the language modeling head with some other kind of projection, like a <strong>classification head</strong> for fine-tuning the model on some classification task. So your model can have multiple heads, kind of like a <a href="https://en.wikipedia.org/wiki/Lernaean_Hydra">hydra</a>.</p>
<p>So that's the GPT architecture at a high level, let's actually dig a bit deeper into what the decoder blocks are doing.</p>
<h3 id="decoder-block" tabindex="-1">Decoder Block</h3>
<p>The transformer decoder block consists of two sublayers:</p>
<ol>
<li>Multi-head causal self attention</li>
<li>Position-wise feed forward neural network</li>
</ol>
<pre><code class="language-python">def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
# multi-head causal self attention
x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
# position-wise feed forward network
x = x + ffn(layer_norm(x, **ln_2), **mlp) # [n_seq, n_embd] -> [n_seq, n_embd]
return x
</code></pre>
<p>Each sublayer utilizes layer normalization on their inputs as well as a residual connection (i.e. add the input of the sublayer to the output of the sublayer).</p>
<p>Some things to note:</p>
<ol>
<li><strong>Multi-head causal self attention</strong> is what facilitates the communication between the inputs. Nowhere else in the network does the model allow inputs to "see" each other. The embeddings, position-wise feed forward network, layer norms, and projection to vocab all operate on our inputs position-wise. Modeling relationships between inputs is tasked solely to attention.</li>
<li>The <strong>Position-wise feed forward neural network</strong> is just a regular 2 layer fully connected neural network. This just adds a bunch of learnable parameters for our model to work with to facilitate learning.</li>
<li>In the original transformer paper, layer norm is placed on the output <code>layer_norm(x + sublayer(x))</code> while we place layer norm on the input <code>x + sublayer(layer_norm(x))</code> to match GPT-2. This is referred to as <strong>pre-norm</strong> and has been shown to be <a href="https://arxiv.org/pdf/2002.04745.pdf">important in improving the performance of the transformer</a>.</li>
<li><strong>Residual connections</strong> (popularized by <a href="https://arxiv.org/pdf/1512.03385.pdf">ResNet</a>) serve a couple of different purposes:
<ol>
<li>Makes it easier to optimize neural networks that are deep (i.e. networks that have lots of layers). The idea here is that we are providing "shortcuts" for the gradients to flow back through the network, making it easier to optimize the earlier layers in the network.</li>
<li>Without residual connections, deeper models see a degradation in performance when adding more layers (possibly because it's hard for the gradients to flow all the way back through a deep network without losing information). Residual connections seem to give a bit of an accuracy boost for deeper networks.</li>
<li>Can help with the <a href="https://programmathically.com/understanding-the-exploding-and-vanishing-gradients-problem/">vanishing/exploding gradients problem</a>.</li>
</ol>
</li>
</ol>
<p>Let's dig a little deeper into the 2 sublayers.</p>
<h3 id="position-wise-feed-forward-network" tabindex="-1">Position-wise Feed Forward Network</h3>
<p>This is just a simple multi-layer perceptron with 2 layers:</p>
<pre><code class="language-python">def ffn(x, c_fc, c_proj): # [n_seq, n_embd] -> [n_seq, n_embd]
# project up
a = gelu(linear(x, **c_fc)) # [n_seq, n_embd] -> [n_seq, 4*n_embd]
# project back down
x = linear(a, **c_proj) # [n_seq, 4*n_embd] -> [n_seq, n_embd]
return x
</code></pre>
<p>Nothing super fancy here, we just project from <code>n_embd</code> up to a higher dimension <code>4*n_embd</code> and then back down to <code>n_embd</code><sup class="footnote-ref"><a href="https://jaykmody.com/blog/gpt-from-scratch/#fn4" id="fnref4">[4]</a></sup>.</p>
<p>Recall, from our <code>params</code> dictionary, that our <code>mlp</code> params look like this:</p>
<pre><code class="language-python">"mlp": {
"c_fc": {"b": [4*n_embd], "w": [n_embd, 4*n_embd]},
"c_proj": {"b": [n_embd], "w": [4*n_embd, n_embd]},
}
</code></pre>
<h3 id="multi-head-causal-self-attention" tabindex="-1">Multi-Head Causal Self Attention</h3>
<p>This layer is probably the most difficult part of the transformer to understand. So let's work our way up to "Multi-Head Causal Self Attention" by breaking each word down into its own section:</p>
<ol>
<li>Attention</li>
<li>Self</li>
<li>Causal</li>
<li>Multi-Head</li>
</ol>
<h4 id="attention" tabindex="-1">Attention</h4>
<p>I have another <a href="https://jaykmody.com/blog/attention-intuition/">blog post</a> on this topic, where we derive the scaled dot product equation proposed in the <a href="https://arxiv.org/pdf/1706.03762.pdf">original transformer paper</a> from the ground up:<br />
\[\text{attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V\]As such, I'm going to skip an explanation for attention in this post. You can also reference <a href="https://lilianweng.github.io/posts/2018-06-24-attention/">Lilian Weng's Attention? Attention!</a> and <a href="https://jalammar.github.io/visualizing-neural-machine-translation-mechanics-of-seq2seq-models-with-attention/">Jay Alammar's The Illustrated Transformer</a> which are also great explanations for attention.</p>
<p>We'll just adapt our attention implementation from my blog post:</p>
<pre><code class="language-python">def attention(q, k, v): # [n_q, d_k], [n_k, d_k], [n_k, d_v] -> [n_q, d_v]
return softmax(q @ k.T / np.sqrt(q.shape[-1])) @ v
</code></pre>
<h4 id="self" tabindex="-1">Self</h4>
<p>When <code>q</code>, <code>k</code>, and <code>v</code> all come from the same source, we are performing <a href="https://lilianweng.github.io/posts/2018-06-24-attention/#self-attention">self-attention</a> (i.e. letting our input sequence attend to itself):</p>
<pre><code class="language-python">def self_attention(x): # [n_seq, n_embd] -> [n_seq, n_embd]
return attention(q=x, k=x, v=x)
</code></pre>
<p>For example, if our input is <code>"Jay went to the store, he bought 10 apples."</code>, we would be letting the word "he" attend to all the other words, including "Jay", meaning the model can learn to recognize that "he" is referring to "Jay".</p>
<p>We can enhance self attention by introducing projections for <code>q</code>, <code>k</code>, <code>v</code> and the attention output:</p>
<pre><code class="language-python">def self_attention(x, w_k, w_q, w_v, w_proj): # [n_seq, n_embd] -> [n_seq, n_embd]
# qkv projections
q = x @ w_q # [n_seq, n_embd] @ [n_embd, n_embd] -> [n_seq, n_embd]
k = x @ w_k # [n_seq, n_embd] @ [n_embd, n_embd] -> [n_seq, n_embd]
v = x @ w_v # [n_seq, n_embd] @ [n_embd, n_embd] -> [n_seq, n_embd]
# perform self attention
x = attention(q, k, v) # [n_seq, n_embd] -> [n_seq, n_embd]
# out projection
x = x @ w_proj # [n_seq, n_embd] @ [n_embd, n_embd] -> [n_seq, n_embd]
return x
</code></pre>
<p>This enables our model to learn a mapping for <code>q</code>, <code>k</code>, and <code>v</code> that best helps attention distinguish relationships between inputs.</p>
<p>We can reduce the number of matrix multiplication from 4 to just 2 if we combine <code>w_q</code>, <code>w_k</code> and <code>w_v</code> into a single matrix <code>w_fc</code>, perform the projection, and then split the result:</p>
<pre><code class="language-python">def self_attention(x, w_fc, w_proj): # [n_seq, n_embd] -> [n_seq, n_embd]
# qkv projections
x = x @ w_fc # [n_seq, n_embd] @ [n_embd, 3*n_embd] -> [n_seq, 3*n_embd]
# split into qkv
q, k, v = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> 3 of [n_seq, n_embd]
# perform self attention
x = attention(q, k, v) # [n_seq, n_embd] -> [n_seq, n_embd]
# out projection
x = x @ w_proj # [n_seq, n_embd] @ [n_embd, n_embd] = [n_seq, n_embd]
return x
</code></pre>
<p>This is a bit more efficient as modern accelerators (GPUs) can take better advantage of one large matrix multiplication rather than 3 separate small ones happening sequentially.</p>
<p>Finally, we add bias vectors to match the implementation of GPT-2, use our <code>linear</code> function, and rename our parameters to match our <code>params</code> dictionary:</p>
<pre><code class="language-python">def self_attention(x, c_attn, c_proj): # [n_seq, n_embd] -> [n_seq, n_embd]
# qkv projections
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
# split into qkv
q, k, v = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> 3 of [n_seq, n_embd]
# perform self attention
x = attention(q, k, v) # [n_seq, n_embd] -> [n_seq, n_embd]
# out projection
x = linear(x, **c_proj) # [n_seq, n_embd] @ [n_embd, n_embd] = [n_seq, n_embd]
return x
</code></pre>
<p>Recall, from our <code>params</code> dictionary, our <code>attn</code> params look like this:</p>
<pre><code class="language-python">"attn": {
"c_attn": {"b": [3*n_embd], "w": [n_embd, 3*n_embd]},
"c_proj": {"b": [n_embd], "w": [n_embd, n_embd]},
},
</code></pre>
<h4 id="causal" tabindex="-1">Causal</h4>
<p>There is a bit of an issue with our current self-attention setup, our inputs can see into the future! For example, if our input is <code>["not", "all", "heroes", "wear", "capes"]</code>, during self attention we are allowing "wear" to see "capes". This means our output probabilities for "wear" will be biased since the model already knows the correct answer is "capes". This is no good since our model will just learn that the correct answer for input \(i\) can be taken from input \(i+1\).</p>
<p>To prevent this, we need to somehow modify our attention matrix to <em>hide</em> or <strong>mask</strong> our inputs from being able to see into the future. For example, let's pretend our attention matrix looks like this:</p>
<pre><code> not all heroes wear capes
not 0.116 0.159 0.055 0.226 0.443
all 0.180 0.397 0.142 0.106 0.175
heroes 0.156 0.453 0.028 0.129 0.234
wear 0.499 0.055 0.133 0.017 0.295
capes 0.089 0.290 0.240 0.228 0.153
</code></pre>
<p>Each row corresponds to a query and the columns to a key. In this case, looking at the row for "wear", you can see that it is attending to "capes" in the last column with a weight of 0.295. To prevent this, we want to set that entry to <code>0.0</code>:</p>
<pre><code> not all heroes wear capes
not 0.116 0.159 0.055 0.226 0.443
all 0.180 0.397 0.142 0.106 0.175
heroes 0.156 0.453 0.028 0.129 0.234
wear 0.499 0.055 0.133 0.017 0.
capes 0.089 0.290 0.240 0.228 0.153
</code></pre>
<p>In general, to prevent all the queries in our input from looking into the future, we set all positions \(i, j\) where \(j > i\) to <code>0</code>:</p>
<pre><code> not all heroes wear capes
not 0.116 0. 0. 0. 0.
all 0.180 0.397 0. 0. 0.
heroes 0.156 0.453 0.028 0. 0.
wear 0.499 0.055 0.133 0.017 0.
capes 0.089 0.290 0.240 0.228 0.153
</code></pre>
<p>We call this <strong>masking</strong>. One issue with our above masking approach is our rows no longer sum to 1 (since we are setting them to 0 after the <code>softmax</code> has been applied). To make sure our rows still sum to 1, we need to modify our attention matrix before the <code>softmax</code> is applied.</p>
<p>This can be achieved by setting entries that are to be masked to \(-\infty\) prior to the <code>softmax</code><sup class="footnote-ref"><a href="https://jaykmody.com/blog/gpt-from-scratch/#fn5" id="fnref5">[5]</a></sup>:</p>
<pre><code class="language-python">def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
</code></pre>
<p>where <code>mask</code> is the matrix (for <code>n_seq=5</code>):</p>
<pre><code>0 -1e10 -1e10 -1e10 -1e10
0 0 -1e10 -1e10 -1e10
0 0 0 -1e10 -1e10
0 0 0 0 -1e10
0 0 0 0 0
</code></pre>
<p>We use <code>-1e10</code> instead of <code>-np.inf</code> as <code>-np.inf</code> can cause <code>nans</code>.</p>
<p>Adding <code>mask</code> to our attention matrix instead of just explicitly setting the values to <code>-1e10</code> works because practically, any number plus <code>-inf</code> is just <code>-inf</code>.</p>
<p>We can compute the <code>mask</code> matrix in NumPy with <code>(1 - np.tri(n_seq)) * -1e10</code>.</p>
<p>Putting it all together, we get:</p>
<pre><code class="language-python">def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v
def causal_self_attention(x, c_attn, c_proj): # [n_seq, n_embd] -> [n_seq, n_embd]
# qkv projections
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
# split into qkv
q, k, v = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> 3 of [n_seq, n_embd]
# causal mask to hide future inputs from being attended to
causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10 # [n_seq, n_seq]
# perform causal self attention
x = attention(q, k, v, causal_mask) # [n_seq, n_embd] -> [n_seq, n_embd]
# out projection
x = linear(x, **c_proj) # [n_seq, n_embd] @ [n_embd, n_embd] = [n_seq, n_embd]
return x
</code></pre>
<h4 id="multi-head" tabindex="-1">Multi-Head</h4>
<p>We can further improve our implementation by performing <code>n_head</code> separate attention computations, splitting our queries, keys, and values into <strong>heads</strong>:</p>
<pre><code class="language-python">def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
# qkv projection
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
# split into qkv
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
# split into heads
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [3, n_head, n_seq, n_embd/n_head]
# causal mask to hide future inputs from being attended to
causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10 # [n_seq, n_seq]
# perform attention over each head
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [3, n_head, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
# merge heads
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
# out projection
x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd]
return x
</code></pre>
<p>There are three steps added here:</p>
<ol>
<li>Split <code>q, k, v</code> into <code>n_head</code> heads:</li>
</ol>
<pre><code class="language-python"># split into heads
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head]
</code></pre>
<ol start="2">
<li>Compute attention for each head:</li>
</ol>
<pre><code class="language-python"># perform attention over each head
out_heads = [attention(q, k, v) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
</code></pre>
<ol start="3">
<li>Merge the outputs of each head:</li>
</ol>
<pre><code class="language-python"># merge heads
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
</code></pre>
<p>Notice, this reduces the dimension from <code>n_embd</code> to <code>n_embd/n_head</code> for each attention computation. This is a tradeoff. For reduced dimensionality, our model gets additional <em>subspaces</em> to work when modeling relationships via attention. For example, maybe one attention head is responsible for connecting pronouns to the person the pronoun is referencing. Maybe another might be responsible for grouping sentences by periods. Another could simply be identifying which words are entities, and which are not. Although, it's probably just another neural network black box.</p>
<p>The code we wrote performs the attention computations over each head sequentially in a loop (one at a time), which is not very efficient. In practice, you'd want to do these in parallel. For simplicity, we'll just leave this sequential.</p>
<p>With that, we're finally done our GPT implementation! Now, all that's left to do is put it all together and run our code.</p>
<h2 id="putting-it-all-together" tabindex="-1">Putting it All Together</h2>
<hr />
<p>Putting everything together, we get <a href="https://github.com/jaymody/picoGPT/blob/main/gpt2.py">gpt2.py</a>, which in its entirety is a mere 120 lines of code (<a href="https://github.com/jaymody/picoGPT/blob/a750c145ba4d09d5764806a6c78c71ffaff88e64/gpt2_pico.py#L3-L58">60 lines if you remove comments and whitespace</a>).</p>
<p>We can test our implementation with:</p>
<pre><code class="language-bash">python gpt2.py \
"Alan Turing theorized that computers would one day become" \
--n_tokens_to_generate 8
</code></pre>
<p>which gives the output:</p>
<pre><code class="language-text">the most powerful machines on the planet.
</code></pre>
<p>It works!!!</p>
<p>We can test that our implementation gives identical results to <a href="https://github.com/openai/gpt-2">OpenAI's official GPT-2 repo</a> using the following <a href="https://gist.github.com/jaymody/9054ca64eeea7fad1b58a185696bb518">Dockerfile</a> (Note: this won't work on M1 Macbooks because of tensorflow shenanigans and also warning, it downloads all 4 GPT-2 model sizes, which is a lot of GBs of stuff to download):</p>
<pre><code class="language-bash">docker build -t "openai-gpt-2" "https://gist.githubusercontent.com/jaymody/9054ca64eeea7fad1b58a185696bb518/raw/Dockerfile"
docker run -dt "openai-gpt-2" --name "openai-gpt-2-app"
docker exec -it "openai-gpt-2-app" /bin/bash -c 'python3 src/interactive_conditional_samples.py --length 8 --model_type 124M --top_k 1'
# paste "Alan Turing theorized that computers would one day become" when prompted
</code></pre>
<p>which should give an identical result:</p>
<pre><code class="language-text">the most powerful machines on the planet.
</code></pre>
<h2 id="what-next%3F" tabindex="-1">What Next?</h2>
<hr />
<p>This implementation is cool and all, but it's missing a ton of bells and whistles:</p>
<h3 id="gpu%2Ftpu-support" tabindex="-1">GPU/TPU Support</h3>
<p>Replace NumPy with <a href="https://github.com/google/jax">JAX</a>:</p>
<pre><code class="language-python">import jax.numpy as np
</code></pre>
<p>That's it. You can now use the code with GPUs and even <a href="https://cloud.google.com/tpu/docs/system-architecture-tpu-vm">TPUs</a>! Just make sure you <a href="https://github.com/google/jax#installation">install JAX correctly</a>.</p>
<h3 id="backpropagation" tabindex="-1">Backpropagation</h3>
<p>Again, if we replace NumPy with <a href="https://github.com/google/jax">JAX</a>:</p>
<pre><code class="language-python">import jax.numpy as np
</code></pre>
<p>Then computing the gradients is as easy as:</p>
<pre><code class="language-python">def lm_loss(params, inputs, n_head) -> float:
x, y = inputs[:-1], inputs[1:]
logits = gpt2(x, **params, n_head=n_head)
loss = np.mean(-log_softmax(logits)[y])
return loss
grads = jax.grad(lm_loss)(params, inputs, n_head)
</code></pre>
<h3 id="batching" tabindex="-1">Batching</h3>
<p>Once again, if we replace NumPy with <a href="https://github.com/google/jax">JAX</a><sup class="footnote-ref"><a href="https://jaykmody.com/blog/gpt-from-scratch/#fn6" id="fnref6">[6]</a></sup>:</p>
<pre><code class="language-python">import jax.numpy as np
</code></pre>
<p>Then, making our <code>gpt2</code> function batched is as easy as:</p>
<pre><code class="language-python">gpt2_batched = jax.vmap(gpt2, in_axes=[0, None, None, None, None, None])
gpt2_batched(batched_inputs) # [batch, seq_len] -> [batch, seq_len, vocab]
</code></pre>
<h3 id="inference-optimization" tabindex="-1">Inference Optimization</h3>
<p>Our implementation is quite inefficient. The quickest and most impactful optimization you can make (outside of GPU + batching support) would be to implement a <a href="https://kipp.ly/blog/transformer-inference-arithmetic/#kv-cache">kv cache</a>. Also, we implemented our attention head computations sequentially, when we should really be doing it in parallel<sup class="footnote-ref"><a href="https://jaykmody.com/blog/gpt-from-scratch/#fn7" id="fnref7">[7]</a></sup>.</p>
<p>There's many many more inference optimizations. I recommend <a href="https://lilianweng.github.io/posts/2023-01-10-inference-optimization/">Lillian Weng's Large Transformer Model Inference Optimization</a> and <a href="https://kipp.ly/blog/transformer-inference-arithmetic/">Kipply's Transformer Inference Arithmetic</a> as a starting point.</p>
<h3 id="training-1" tabindex="-1">Training</h3>
<p>Training a GPT is pretty standard for a neural network (gradient descent w.r.t a loss function). Of course, you also need to use the standard bag of tricks when training a GPT (i.e. use the Adam optimizer, find the optimal learning rate, regularization via dropout and/or weight decay, use a learning rate scheduler, use the correct weight initialization, batching, etc ...).</p>
<p>The real secret sauce to training a good GPT model is the ability to <strong>scale the data and the model</strong>, which is where the real challenge is.</p>
<p>For scaling data, you'll want a corpus of text that is big, high quality, and diverse.</p>
<ul>
<li>Big means billions of tokens (terabytes of data). For example, check out <a href="https://pile.eleuther.ai/">The Pile</a>, which is an open source pre-training dataset for large language models.</li>
<li>High quality means you want to filter out duplicate examples, unformatted text, incoherent text, garbage text, etc ...</li>
<li>Diverse means varying sequence lengths, about lots of different topics, from different sources, with differing perspectives, etc ... Of course, if there are any biases in the data, it will reflect in the model, so you need to be careful of that as well.</li>
</ul>
<p>Scaling the model to billions of parameters involves a cr*p ton of engineering (and money lol). Training frameworks can get <a href="https://github.com/NVIDIA/Megatron-LM">absurdly long and complex</a>. A good place to start would be <a href="https://lilianweng.github.io/posts/2021-09-25-train-large/">Lillian Weng's How to Train Really Large Models on Many GPUs</a>. On the topic there's also the <a href="https://arxiv.org/pdf/1909.08053.pdf">NVIDIA's Megatron Framework</a>, <a href="https://arxiv.org/pdf/2204.06514.pdf">Cohere's Training Framework</a>, <a href="https://arxiv.org/pdf/2204.02311.pdf">Google's PALM</a>, the open source <a href="https://github.com/kingoflolz/mesh-transformer-jax">mesh-transformer-jax</a> (used to train EleutherAI's open source models), and <a href="https://arxiv.org/pdf/2203.15556.pdf">many</a> <a href="https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/">many</a> <a href="https://arxiv.org/pdf/2005.14165.pdf">more</a>.</p>
<h3 id="evaluation" tabindex="-1">Evaluation</h3>
<p>Oh boy, how does one even evaluate LLMs? Honestly, it's really hard problem. <a href="https://arxiv.org/abs/2211.09110">HELM</a> is pretty comprehensive and a good place to start, but you should always be skeptical of <a href="https://en.wikipedia.org/wiki/Goodhart%27s_law">benchmarks and evaluation metrics</a>.</p>
<h3 id="architecture-improvements" tabindex="-1">Architecture Improvements</h3>
<p>I recommend taking a look at <a href="https://github.com/lucidrains/x-transformers">Phil Wang's X-Transformer's</a>. It has the latest and greatest research on the transformer architecture. <a href="https://arxiv.org/pdf/2102.11972.pdf">This paper</a> is also a pretty good summary (see Table 1). Facebook's recent <a href="https://arxiv.org/pdf/2302.13971.pdf">LLaMA paper</a> is also probably a good reference for standard architecture improvements (as of February 2023).</p>
<h3 id="stopping-generation" tabindex="-1">Stopping Generation</h3>
<p>Our current implementation requires us to specify the exact number of tokens we'd like to generate ahead of time. This is not a very good approach as our generations end up being too long, too short, or cutoff mid-sentence.</p>
<p>To resolve this, we can introduce a special <strong>end of sentence (EOS) token</strong>. During pre-training, we append the EOS token to the end of our input (i.e. <code>tokens = ["not", "all", "heroes", "wear", "capes", ".", "<|EOS|>"]</code>). During generation, we simply stop whenever we encounter the EOS token (or if we hit some maximum sequence length):</p>
<pre><code class="language-python">def generate(inputs, eos_id, max_seq_len):
prompt_len = len(inputs)
while inputs[-1] != eos_id and len(inputs) < max_seq_len:
output = gpt(inputs)
next_id = np.argmax(output[-1])
inputs.append(int(next_id))
return inputs[prompt_len:]
</code></pre>
<p>GPT-2 was not pre-trained with an EOS token, so we can't use this approach in our code, but most LLMs nowadays use an EOS token.</p>
<h3 id="fine-tuning" tabindex="-1">Fine-tuning</h3>
<p>We briefly touched on fine-tuning in the training section. Recall, fine-tuning is when we re-use the pre-trained weights to train the model on some downstream task. We call this process transfer-learning.</p>
<p>In theory, we could use zero-shot or few-shot prompting to get the model to complete our task, however, if you have access to a labelled dataset, fine-tuning a GPT is going to yield better results (results that can scale given additional data and higher quality data).</p>
<p>There are a couple different topics related to fine-tuning, I've broken them down below:</p>
<h4 id="classification-fine-tuning" tabindex="-1">Classification Fine-tuning</h4>
<p>In classification fine-tuning, we give the model some text and we ask it to predict which class it belongs to. For example, consider the <a href="https://huggingface.co/datasets/imdb">IMDB dataset</a>, which contains movie reviews that rate the movie as either good, or bad:</p>
<pre><code class="language-text">--- Example 1 ---
Text: I wouldn't rent this one even on dollar rental night.
Label: Bad
--- Example 2 ---
Text: I don't know why I like this movie so well, but I never get tired of watching it.
Label: Good
--- Example 3 ---
...
</code></pre>
<p>To fine-tune our model, we replace the language modeling head with a classification head, which we apply to the last token output:</p>
<pre><code class="language-python">def gpt2(inputs, wte, wpe, blocks, ln_f, cls_head, n_head):
x = wte[inputs] + wpe[range(len(inputs))]
for block in blocks:
x = transformer_block(x, **block, n_head=n_head)
x = layer_norm(x, **ln_f)
# project to n_classes
# [n_embd] @ [n_embd, n_classes] -> [n_classes]
return x[-1] @ cls_head
</code></pre>
<p>We only use the last token output <code>x[-1]</code> because we only need to produce a single probability distribution for the entire input instead of <code>n_seq</code> distributions as in the case of language modeling. We take the last token in particular (instead of say the first token or a combination of all the tokens) because the last token is the only token that is allowed to attend to the entire sequence and thus has information about the input text as a whole.</p>
<p>As per usual, we optimize w.r.t. the cross entropy loss:</p>
<pre><code class="language-python">def singe_example_loss_fn(inputs: list[int], label: int, params) -> float:
logits = gpt(inputs, **params)
probs = softmax(logits)
loss = -np.log(probs[label]) # cross entropy loss
return loss
</code></pre>
<h4 id="generative-fine-tuning" tabindex="-1">Generative Fine-tuning</h4>
<p>Some tasks can't be neatly categorized into classes. For example, consider the task of summarization. We can fine-tune these types of task by simply performing language modeling on the input concatenated with the label. For example, here's what a single summarization training sample might look like:</p>
<pre><code class="language-text">--- Article ---
This is an article I would like to summarize.
--- Summary ---
This is the summary.
</code></pre>
<p>We train the model as we do during pre-training (optimize w.r.t language modeling loss).</p>
<p>At predict time, we feed the model the everything up to <code>--- Summary ---</code> and then perform auto-regressive language modeling to generate the summary.</p>
<p>The choice of the delimiters <code>--- Article ---</code> and <code>--- Summary ---</code> are arbitrary. How you choose to format the text is up to you, as long as it is consistent between training and inference.</p>
<p>Notice, we can also formulate classification tasks as generative tasks (for example with IMDB):</p>
<pre><code class="language-text">--- Text ---
I wouldn't rent this one even on dollar rental night.
--- Label ---
Bad
</code></pre>
<p>However, this will probably perform worse than doing classification fine-tuning directly (loss includes language modeling on the entire sequence, not just the final prediction, so the loss specific to the prediction will get diluted)</p>
<h4 id="instruction-fine-tuning" tabindex="-1">Instruction Fine-tuning</h4>
<p>Most state-of-the-art large language models these days also undergo an additional <strong>instruction fine-tuning</strong> step after being pre-trained. In this step, the model is fine-tuned (generative) on thousands of instruction prompt + completion pairs that were <strong>human labeled</strong>. Instruction fine-tuning can also be referred to as <strong>supervised fine-tuning</strong>, since the data is human labelled (i.e. <strong>supervised</strong>).</p>
<p>So what's the benefit of instruction fine-tuning? While predicting the next word in a wikipedia article makes the model is good at continuing sentences, it doesn't make it particularly good at following instructions, or having a conversation, or summarizing a document (all the things we would like a GPT to do). Fine-tuning them on human labelled instruction + completion pairs is a way to teach the model how it can be more useful, and make them easier to interact with. This call this <strong>AI alignment</strong>, as we are aligning the model to do and behave as we want it to. Alignment is an active area of research, and includes more than just following instructions (bias, safety, intent, etc ...).</p>
<p>What does this instruction data look like exactly? Google's <a href="https://arxiv.org/pdf/2109.01652.pdf">FLAN</a> models were trained on various academic NLP datasets (which are already human labelled):</p>
<figure><img src="https://i.imgur.com/9W2bwJF.png" alt="" /><figcaption>Figure 3 from FLAN paper</figcaption></figure>
<p>OpenAI's <a href="https://arxiv.org/pdf/2203.02155.pdf">InstructGPT</a> on the other hand was trained on prompts collected from their own API. They then paid workers to write completions for those prompts. Here's a breakdown of the data:</p>
<figure><img src="https://i.imgur.com/FaRRbCa.png" alt="" /><figcaption>Table 1 and 2 from InstructGPT paper</figcaption></figure>
<h4 id="parameter-efficient-fine-tuning" tabindex="-1">Parameter Efficient Fine-tuning</h4>
<p>When we talk about fine-tuning in the above sections, it is assumed that we are updating all of the model parameters. While this yields the best performance, it is costly both in terms of compute (need to back propagate over the entire model) and in terms of storage (for each fine-tuned model, you need to store a completely new copy of the parameters). For instruction fine-tuning, this is fine, we want maximum performance, but if you then wanted to fine-tune 100 different models for various downstream tasks, then you'd have a problem.</p>
<p>The most simple approach to this problem is to <strong>only update the head</strong> and <strong>freeze</strong> (i.e. make untrainable) the rest of the model. This would speed up training and greatly reduce the number of new parameters, however it would not perform nearly as well as a full fine-tune (we are lacking the <em>deep</em> in deep learning). We could instead <strong>selectively freeze</strong> specific layers (i.e. freeze all layers except the last 4, or freeze every other layer, or freeze all parameters except multi-head attention parameters), which would help restore some of the depth. This will perform a lot better, but we become a lot less parameter efficient and reduce our training speed ups.</p>
<p>Instead, we can utilize <strong>parameter-efficient fine-tuning</strong> (PEFT) methods. PEFT is active area of research, and there are <a href="https://aclanthology.org/2021.emnlp-main.243.pdf">lots</a> <a href="https://arxiv.org/pdf/2110.07602.pdf">of</a> <a href="https://arxiv.org/pdf/2101.00190.pdf">different</a> <a href="https://arxiv.org/pdf/2103.10385.pdf">methods</a> <a href="https://arxiv.org/pdf/2106.09685.pdf">to</a> <a href="https://arxiv.org/pdf/1902.00751.pdf">choose</a> <a href="https://arxiv.org/abs/2205.05638">from</a>.</p>
<p>As an example, take the <a href="https://arxiv.org/pdf/1902.00751.pdf">Adapters paper</a>. In this approach, we add an additional "adapter" layer after the FFN and MHA layers in the transformer block. The adapter layer is just a simple 2 layer fully connected neural network, where the input and output dimensions are <code>n_embd</code>, and the hidden dimension is smaller than <code>n_embd</code>:</p>
<figure><img src="https://miro.medium.com/max/633/0*Z2FMWTCmdkgevHr-.png" alt="" /><figcaption>Figure 2 from the Adapters paper</figcaption></figure>
<p>The size of the hidden dimension is a hyper-parameter that we can set, enabling us to tradeoff parameters for performance. For a BERT model, the paper showed that using this approach can reduce the number of trained parameters to 2% while only sustaining a small hit in performance (<1%) when compared to a full fine-tune.</p>
<hr class="footnotes-sep" />
<section class="footnotes">
<ol class="footnotes-list">
<li id="fn1" class="footnote-item"><p>For certain applications, the tokenizer doesn't require a <code>decode</code> method. For example, if you want to classify if a movie review is saying the movie was good or bad, you only need to be able to <code>encode</code> the text and do a forward pass of the model, there is no need for <code>decode</code>. For generating text however, <code>decode</code> is a requirement. <a href="https://jaykmody.com/blog/gpt-from-scratch/#fnref1" class="footnote-backref">↩︎</a></p>
</li>
<li id="fn2" class="footnote-item"><p>Although, with the <a href="https://arxiv.org/pdf/2210.11416.pdf">InstructGPT</a> and <a href="https://arxiv.org/pdf/2203.15556.pdf">Chinchilla</a> papers, we've realized that we don't actually need to train models that big. An optimally trained and instruction fine-tuned GPT at 1.3B parameters can outperform GPT-3 at 175B parameters. <a href="https://jaykmody.com/blog/gpt-from-scratch/#fnref2" class="footnote-backref">↩︎</a></p>
</li>
<li id="fn3" class="footnote-item"><p>The original transformer paper used a <a href="https://nlp.seas.harvard.edu/2018/04/03/attention.html#positional-encoding">calculated positional embedding</a> which they found performed just as well as learned positional embeddings, but has the distinct advantage that you can input any arbitrarily long sequence (you are not restricted by a maximum sequence length). However, in practice, your model is only going to be as the good sequence lengths that it was trained on. You can't just train a GPT on sequences that are 1024 long and then expect it to perform well at 16k tokens long. Recently however, there has been some success with relative positional embeddings, such as <a href="https://arxiv.org/pdf/2108.12409.pdf">Alibi</a> and <a href="https://arxiv.org/pdf/2104.09864v4.pdf">RoPE</a>. <a href="https://jaykmody.com/blog/gpt-from-scratch/#fnref3" class="footnote-backref">↩︎</a></p>
</li>
<li id="fn4" class="footnote-item"><p>Different GPT models may choose a different hidden width that is not <code>4*n_embd</code>, however this is the common practice for GPT models. Also, we give the multi-head attention layer a lot of <em>attention</em> (pun intended) for driving the success of the transformer, but at the scale of GPT-3, <a href="https://twitter.com/stephenroller/status/1579993017234382849">80% of the model parameters are contained in the feed forward layer</a>. Just something to think about. <a href="https://jaykmody.com/blog/gpt-from-scratch/#fnref4" class="footnote-backref">↩︎</a></p>
</li>
<li id="fn5" class="footnote-item"><p>If you're not convinced, stare at the softmax equation and convince yourself this is true (maybe even pull out a pen and paper):<br />
\[
\text{softmax}(\vec{x})_i=\frac{e^{x_i}}{\sum_je^{x_j}}
\] <a href="https://jaykmody.com/blog/gpt-from-scratch/#fnref5" class="footnote-backref">↩︎</a></p>
</li>
<li id="fn6" class="footnote-item"><p>I love JAX ❤️. <a href="https://jaykmody.com/blog/gpt-from-scratch/#fnref6" class="footnote-backref">↩︎</a></p>
</li>
<li id="fn7" class="footnote-item"><p>Using JAX, this is as simple as <code>heads = jax.vmap(attention, in_axes=(0, 0, 0, None))(q, k, v, causal_mask)</code>. <a href="https://jaykmody.com/blog/gpt-from-scratch/#fnref7" class="footnote-backref">↩︎</a></p>
</li>
</ol>
</section>
Mon, 30 Jan 2023 00:00:00 +0000Jay Modyhttps://jaykmody.com/blog/gpt-from-scratch/Numerically Stable Softmax and Cross Entropy
https://jaykmody.com/blog/stable-softmax/
<p>In this post, we'll take a look at softmax and cross entropy loss, two very common mathematical functions used in deep learning. We'll see that naive implementations are numerically unstable, and then we'll derive implementations that are numerically stable.</p>
<h2 id="symbols" tabindex="-1">Symbols</h2>
<hr />
<ul>
<li>\(x\): Input vector of dimensionality \(d\).</li>
<li>\(y\): Correct class, an integer on the range \(y \in [1\ldots K]\).</li>
<li>\(\hat{y}\): Raw outputs (i.e. logits) of our neural network, vector of dimensionality \(K\).</li>
<li>We use \(\log\) to denote the natural logarithm.</li>
</ul>
<h2 id="softmax" tabindex="-1">Softmax</h2>
<hr />
<p>The softmax function is defined as:<br />
\[
\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}}
\]<br />
The softmax function converts a vector of real numbers (\(x\)) to a vector of probabilities (such that \(\sum_i \text{softmax}(x)_i = 1\) and \(0 \leq \text{softmax}(x)_i \leq 1\)). This is useful for converting the raw final output of a neural network (often referred to as <strong>logits</strong>) into probabilities.</p>
<p>In code:</p>
<pre><code class="language-python">def softmax(x):
# assumes x is a vector
return np.exp(x) / np.sum(np.exp(x))
x = np.array([1.2, 2, -4, 0.0]) # might represent raw output logits of a neural network
softmax(x)
# outputs: [0.28310553, 0.63006295, 0.00156177, 0.08526975]
</code></pre>
<p>For very large inputs, we start seeing some numerical instability:</p>
<pre><code class="language-python">x = np.array([1.2, 2000, -4000, 0.0])
softmax(x)
# outputs: [0., nan, 0., 0.]
</code></pre>
<p>Why? Because floating point numbers aren't magic, they have limits:</p>
<pre><code class="language-python">np.finfo(np.float64).max
# 1.7976931348623157e+308, largest positive number
np.finfo(np.float64).tiny
# 2.2250738585072014e-308, smallest positive number at full precision
np.finfo(np.float64).smallest_subnormal
# 5e-324, smallest positive number
</code></pre>
<p>When we go beyond these limits, we start seeing funky behavior:</p>
<pre><code class="language-python">np.finfo(np.float64).max * 2
# inf, overflow error
np.inf - np.inf
# nan, not a number error
np.finfo(np.float64).smallest_subnormal / 2
# 0.0, underflow error
</code></pre>
<p>Looking back at our softmax example that resulted in <code>[0., nan, 0., 0.]</code>, we can see that the overflow of <code>np.exp(2000) = np.inf</code> is causing the <code>nan</code>, since we end up with <code>np.inf / np.inf = nan</code>.</p>
<p>If we want to avoid <code>nans</code>, we need to avoid <code>infs</code>.</p>
<p>To avoid <code>infs</code>, we need to avoid overflows.</p>
<p>To avoid overflows, we need to prevent our numbers from growing too large.</p>
<p>Underflows on the other hand don't seem quite as detrimental. Worst case scenario, we get the result <code>0</code> and lose all precision (i.e. <code>np.exp(-4000) = 0)</code>. While this is not ideal, this is a lot better than running into <code>inf</code> and <code>nan</code>.</p>
<p>Given the relative stability of floating point underflows vs overflows, how can we fix softmax?</p>
<p>Let's revisit our softmax equation and apply some tricks:<br />
\[
\begin{align}
\text{softmax}(x)_i
&= \frac{e^{x_i}}{\sum_j e^{x_j}} \\
&= 1\cdot \frac{e^{x_i}}{\sum_j e^{x_j}} \\
&= \frac{C}{C}\frac{e^{x_i}}{\sum_j e^{x_j}} \\
&= \frac{Ce^{x_i}}{\sum_j Ce^{x_j}} \\
&= \frac{e^{x_i + \log C}}{\sum_j e^{x_j + \log C}} \\
\end{align}
\]<br />
Here, we're taking advantage of the rule \(a\cdot b^x = b^{x + \log_b a}\). As a result, we are given the ability to offset our inputs by any constant of our choosing. For example, if we set that constant to \(\log C = -\max(x)\):<br />
\[
\text{softmax}(x)_i = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}
\]</p>
<p>We get a numerically stable version of softmax:</p>
<ul>
<li>All exponentiated values will be between 0 and 1 (\(0 \leq e^{x_i - \max(x)} \leq 1\)) since the value in the exponent is always negative (\(x_i - \max(x) \leq 0\))
<ul>
<li>This prevents overflow errors (but we are still prone to underflows)</li>
</ul>
</li>
<li>At least one of the exponentiated values is 1 in the case when \(x_i = \max(x)\): \(e^{ \max(x)- \max(x)} = e^0 = 1\)
<ul>
<li>i.e. at least one value is guaranteed not to underflow</li>
<li>Thus, our denominator will always be \(>= 1\), preventing division by zero errors</li>
<li>We have at least one non-zero numerator, so softmax can't result in a zero vector</li>
</ul>
</li>
</ul>
<p>In code:</p>
<pre><code class="language-python">def softmax(x):
# assumes x is a vector
x = x - np.max(x)
return np.exp(x) / np.sum(np.exp(x))
x = np.array([1.2, 2, -4, 0])
softmax(x)
# outputs: [0.28310553, 0.63006295, 0.00156177, 0.08526975]
# works for large numbers!!!
x = np.array([1.2, 2, -4, 0]) * 1000
softmax(x)
# outputs: [0., 1., 0., 0.]
</code></pre>
<h2 id="cross-entropy-and-log-softmax" tabindex="-1">Cross Entropy and Log Softmax</h2>
<hr />
<p>The cross entropy between two probability distributions is defined as.<br />
\[
H(p, q) = -\sum_i p_i\log(q_i)
\]<br />
where \(p\) and \(q\) are our probability distributions represented as probability vectors (that is \(p_i\) and \(q_i\) are the probabilities of event \(i\) occurring for \(p\) and \(q\) respectively). This <a href="https://www.youtube.com/watch?v=ErfnhcEV1O8">video has a great explanation for cross entropy</a>.</p>
<p>Roughly speaking, cross entropy measures the similarity of two probability distributions. In the context of neural networks, it's common to use cross entropy as a loss function for classification problems where:</p>
<ul>
<li>\(q\) is our predicted probabilities vector (i.e. the softmax of our raw network outputs, also called <strong>logits</strong>, denoted as \(\hat{y}\)), that is \(q = \text{softmax}(\hat{y})\)</li>
<li>\(p\) is a one-hot encoded vector of our label, that is a probability vector that assigns 100% probability to the position \(y\) (our label for the correct class): \(p_i = \begin{cases} 1 & i = y \\ 0 & i \neq y \end{cases}\)</li>
</ul>
<p>In this setup, cross entropy simplifies to:<br />
\[
\begin{align}
H(p, q)
&= -\sum_i p_i\log(q_i) \\
&= -p_y\cdot\log(q_y) -\sum_{i \neq y} p_i\log(q_i) \\
&= -1\cdot\log(q_y) -\sum_{i \neq y} 0\cdot\log(q_i) \\
&= -\log(q_y) - 0 \sum_{i \neq y} \log(q_i) \\
&= -\log(q_y) \\
&= -\log(\text{softmax}(\hat{y})_y)
\end{align}
\]</p>
<p>In code:</p>
<pre><code class="language-python">def cross_entropy(y_hat, y_true):
# assume y_hat is a vector and y_true is an integer
return -np.log(softmax(y_hat)[y_true])
cross_entropy(
y_hat=np.random.normal(size=(10)),
y_true=3,
)
# 2.580982279204241
</code></pre>
<p>For large numbers in <code>y_hat</code>, we start seeing <code>inf</code>:</p>
<pre><code class="language-python">cross_entropy(
y_hat = np.array([-1000, 1000]),
y_true = 0,
)
# inf
</code></pre>
<p>The problem is that <code>softmax([-1000, 1000]) = [0, 1]</code>, and since <code>y_true = 0</code>, we get <code>-log(0) = inf</code>. So we need some way to avoid taking the log of zero. To prevent this, we can rearrange our equation for <code>log(softmax(x))</code>:<br />
\[
\begin{align}
\log(\text{softmax}(x)_i)
& = \log(\frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}) \\
&= \log(e^{x_i - \max(x)}) - \log(\sum_j e^{x_j - \max(x)}) \\
&= (x_i - \max(x))\log(e) - \log(\sum_j e^{x_j - \max(x)}) \\
&= (x_i - \max(x))\cdot 1 - \log(\sum_j e^{x_j - \max(x)}) \\
&= x_i - \max(x) - \log(\sum_j e^{x_j - \max(x)}) \\
\end{align}
\]<br />
This new equation guarantees that the sum inside the log will always be \(\geq 1\), so we no longer need to worry about <code>log(0)</code> errors.</p>
<p>In code:</p>
<pre><code class="language-python">def log_softmax(x):
# assumes x is a vector
x_max = np.max(x)
return x - x_max - np.log(np.sum(np.exp(x - x_max)))
def cross_entropy(y_hat, y_true):
return -log_softmax(y_hat)[y_true]
cross_entropy(
y_hat=np.random.normal(size=(10)),
y_true=3,
)
# 2.580982279204241
# works for large inputs!!!!
cross_entropy(
y_hat = np.array([-1000, 1000]),
y_true = 0,
)
# 2000.0
</code></pre>
Thu, 15 Dec 2022 00:00:00 +0000Jay Modyhttps://jaykmody.com/blog/stable-softmax/An Intuition for Attention
https://jaykmody.com/blog/attention-intuition/
<p>ChatGPT and other large language models use a special type of neural network called the transformer. The transformer defining feature is the <em>attention</em> mechanism. Attention is defined by the equation:</p>
<p>\[\text{attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V\]</p>
<p>Attention can come in different forms, but this version of attention (known as scaled dot product attention) was first proposed in the original <a href="https://arxiv.org/pdf/1706.03762.pdf">transformer paper</a>. In this post, we'll build an intuition for the above equation by deriving it from the ground up.</p>
<p>To start, let's take a look at the problem attention aims to solve, the key-value lookup.</p>
<h2 id="key-value-lookups" tabindex="-1">Key-Value Lookups</h2>
<hr />
<p>A key-value (kv) lookup involves three components:</p>
<ol>
<li>A list of \(n_k\) <strong>keys</strong></li>
<li>A list of \(n_k\) <strong>values</strong> (that map 1-to-1 with the keys, forming key-value pairs)</li>
<li>A <strong>query</strong>, for which we want to <em>match</em> with the keys and get some value based on the match</li>
</ol>
<p>You're probably familiar with this concept as a dictionary or hash map:</p>
<pre><code class="language-python">>>> d = {
>>> "apple": 10,
>>> "banana": 5,
>>> "chair": 2,
>>> }
>>> d.keys()
['apple', 'banana', 'chair']
>>> d.values()
[10, 5, 2]
>>> query = "apple"
>>> d[query]
10
</code></pre>
<p>Dictionaries let us perform lookups based on an <em>exact</em> string match.</p>
<p>What if instead we wanted to do a lookup based on the <em>meaning</em> of a word?</p>
<h2 id="key-value-lookups-based-on-meaning" tabindex="-1">Key-Value Lookups based on Meaning</h2>
<hr />
<p>Say we wanted to look up the word "fruit" in our previous example, how do we choose which key is the best match?</p>
<p>It's obviously not "chair", but both "apple" and "banana" seem like a good match. It's hard to choose one or the other, fruit feels more like a combination of apple and banana rather than a strict match for either.</p>
<p>So, let's not choose. Instead, we'll do exactly that, take a combination of apple and banana. For example, say we assign a 60% meaning based match for apple, a 40% match for banana, and 0% match for chair. We compute our final output value as the <strong>weighted sum</strong> of the values with the percentages:</p>
<pre><code class="language-python">>>> query = "fruit"
>>> d = {"apple": 10, "banana": 5, "chair": 2}
>>> 0.6 * d["apple"] + 0.4 * d["banana"] + 0.0 * d["chair"]
8
</code></pre>
<p>In a sense, we are determining how much <strong>attention</strong> our query should be paying to each key-value pair based on <em>meaning</em>. The amount of "attention" is represented as a decimal percentage, called an <strong>attention score</strong>. Mathematically, we can define our output as a simple weighted sum:<br />
\[
\sum_{i} \alpha_iv_i
\]where \(\alpha_i\) is our attention score for the \(i\)th kv pair and \(v_i\) is the \(i\)th value. Remember, the attention scores are decimal percentages, that is they must be between 0 and 1 inclusive (\(0 \leq \alpha_i \leq 1\)) and their sum must be 1 (\(\sum_i a_i = 1\)).</p>
<p>Okay, but where did we get these attention scores from? In our example, I just kind of chose them based on what I <em>felt</em>. While I think I did a pretty good job, this approach doesn't seem sustainable (unless you can find a way to make a copy of me inside your computer).</p>
<p>Instead, let's take a look at how <strong>word vectors</strong> can help solve our problem of determining attention scores.</p>
<h2 id="word-vectors-and-similarity" tabindex="-1">Word Vectors and Similarity</h2>
<hr />
<p>Imagine we represent a word with a vector of numbers. Ideally, the values in the vector should in some way capture the <em>meaning</em> of the word it represents. For example, imagine we have the following word vectors (visualized in 2D space):</p>
<figure><img src="https://i.imgur.com/VDnSf7P.png" alt="" /></figure>
<p>You can see that words that are <em>similar</em> are clustered together. Fruits are clustered at the top right, vegetables are clustered at the top left, and furniture is clustered at the bottom. In fact, you can even see that the vegetable and fruit clusters are closer to each other than they are to the furniture cluster, since they are more closely related things.</p>
<p>You can even imagine doing arithmetic on word vectors. For example, given the words "king", "queen", "man", and "woman" and their respective vector representations \(\boldsymbol{v}_{\text{king}}, \boldsymbol{v}_{\text{queen}}, \boldsymbol{v}_{\text{man}}, \boldsymbol{v}_{\text{women}}\), we can imagine that:<br />
\[\boldsymbol{v}_{\text{queen}} - \boldsymbol{v}_{\text{woman}} + \boldsymbol{v}_{\text{man}} \sim \boldsymbol{v}_{\text{king}}\]That is, the vector for "queen" minus "woman" plus "man" should result in a vector that is <em>similar</em> to the vector for "king".</p>
<p>But what does it exactly mean for two vectors to be <em>similar</em>? In the fruits/vegetables example, we were using distance as a measure of similarity (in particular, <a href="https://en.wikipedia.org/wiki/Euclidean_distance">euclidean distance</a>).</p>
<p>There are also <a href="https://towardsdatascience.com/9-distance-measures-in-data-science-918109d069fa">other ways to measure similarity between two vectors</a>, each with its own advantages and disadvantages. Possibly the simplest measure of similarity between two vectors is their dot product:<br />
\[\boldsymbol{v} \cdot \boldsymbol{w} = \sum_{i}v_i w_i\]<a href="https://www.youtube.com/watch?v=LyGKycYT2v0">3blue1brown has a great video on the intuition behind dot product</a>, but for our purposes all we need to know is:</p>
<ul>
<li>If two vectors are pointing in the same direction, the dot product will be > 0 (i.e. similar)</li>
<li>If they are pointing in opposing directions, the dot product will be < 0 (i.e. dissimilar)</li>
<li>If they are exactly perpendicular, the dot product will be 0 (i.e. neutral)</li>
</ul>
<p>Using this information, we can define a simple heuristic to determine the similarity between two word vectors: The greater the dot product, the more similar two words are in <em>meaning</em>.<sup class="footnote-ref"><a href="https://jaykmody.com/blog/attention-intuition/#fn1" id="fnref1">[1]</a></sup></p>
<p>Okay cool, but where do these word vectors actually come from? In the context of neural networks, they usually come from some kind of learned embedding or latent representation. That is, initially the word vectors are just random numbers, but as the neural network is trained, their values are adjusted to become better and better representations for words. How does a neural network learn these better representations? That is beyond the scope of this blog post, you'll have to take an intro to deep learning course for that. For now, we just need to accept that word vectors exist, and that they somehow are able to capture the meaning of words.</p>
<h2 id="attention-scores-using-the-dot-product" tabindex="-1">Attention Scores using the Dot Product</h2>
<hr />
<p>Let's return to our example of fruits, but this time around using word vectors to represent our words. That is \(\boldsymbol{q} = \boldsymbol{v}_{\text{fruit}}\) and \(\boldsymbol{k} = [\boldsymbol{v}_{\text{apple}} \ \boldsymbol{v}_{\text{banana}} \ \boldsymbol{v}_{\text{chair}}]\), such that \(\boldsymbol{v} \in \mathbb{R}^{d_k}\) (that is each vector has the same dimensionality of \(d_k\), which is a value we choose when training a neural network).</p>
<p>Using our new dot product similarity measure, we can compute the similarity between the query and the \(i\)th key as:<br />
\[
x_i = \boldsymbol{q} \cdot \boldsymbol{k}_i
\]<br />
Generalizing this further, we can compute the dot product for all \(n_k\) keys with:<br />
\[
\boldsymbol{x} = \boldsymbol{q}{K}^T
\]where \(\boldsymbol{x}\) is our vector of dot products \(\boldsymbol{x} = [x_1, x_2, \ldots, x_{n_k - 1}, x_{n_k}]\) and \(K\) is a row-wise matrix of our key vectors (i.e. our key vectors stacked on-top of each-other to form a \(n_k\) by \(d_k\) matrix such that \(k_i\) is the \(i\)th row of \(K\)). If you're having trouble understanding this, see the following footnote <sup class="footnote-ref"><a href="https://jaykmody.com/blog/attention-intuition/#fn2" id="fnref2">[2]</a></sup>.</p>
<p>Recall that our attention scores need to be decimal percentages (between 0 and 1 and sum to 1). Our dot product values however can be any real number (i.e. between \(-\infty\) and \(\infty\)). To transform our dot product values to decimal percentages, we'll use the <a href="https://en.wikipedia.org/wiki/Softmax_function">softmax function</a>:<br />
\[
\text{softmax}(\boldsymbol{x})_i = \frac{e^{x_i}}{\sum_j e^{x_j}}
\]</p>
<pre><code class="language-python">>>> import numpy as np
>>> def softmax(x):
>>> # assumes x is a vector
>>> return np.exp(x) / np.sum(np.exp(x))
>>>
>>> softmax(np.array([4.0, -1.0, 2.1]))
[0.8648, 0.0058, 0.1294]
</code></pre>
<p>Notice:</p>
<ul>
<li>✅ Each number is between 0 and 1</li>
<li>✅ The numbers sum to 1</li>
<li>✅ The larger valued inputs get more "weight"</li>
<li>✅ The sorted order is preserved (i.e. the 4.0 is still the largest after softmax, and -1.0 is still the lowest), this is because softmax is a <a href="https://en.wikipedia.org/wiki/Monotonic_function">monotonic</a> function</li>
</ul>
<p>This satisfies all the desired properties of an attention scores. Thus, we can compute the attention score for the \(i\)th key-value pair with:<br />
\[
\alpha_i = \text{softmax}(\boldsymbol{x})_i = \text{softmax}(\boldsymbol{q}K^T)_i
\]Plugging this into our weighted sum we get:<br />
\[
\begin{align}
\sum_{i}\alpha_iv_i
= & \sum_i \text{softmax}(\boldsymbol{x})_iv_i\\
= & \sum_i \text{softmax}(\boldsymbol{q}K^T)_iv_i\\
= &\ \text{softmax}(\boldsymbol{q}K^T)\boldsymbol{v}
\end{align}
\]<br />
Note: In the last step, we pack our values into a vector \(\boldsymbol{v} = [v_1, v_2, ..., v_{n_k -1}, v_{n_k}]\), which allows us to get rid of the summation notation in favor of a dot product.</p>
<p>And that's it, we have a full working definition for attention:<br />
\[
\text{attention}(\boldsymbol{q}, K, \boldsymbol{v}) = \text{softmax}(\boldsymbol{q}K^T)\boldsymbol{v}
\]In code:</p>
<pre><code class="language-python">import numpy as np
def get_word_vector(word, d_k=8):
"""Hypothetical mapping that returns a word vector of size
d_k for the given word. For demonstrative purposes, we initialize
this vector randomly, but in practice this would come from a learned
embedding or some kind of latent representation."""
return np.random.normal(size=(d_k,))
def softmax(x):
# assumes x is a vector
return np.exp(x) / np.sum(np.exp(x))
def attention(q, K, v):
# assumes q is a vector of shape (d_k)
# assumes K is a matrix of shape (n_k, d_k)
# assumes v is a vector of shape (n_k)
return softmax(q @ K.T) @ v
def kv_lookup(query, keys, values):
return attention(
q = get_word_vector(query),
K = np.array([get_word_vector(key) for key in keys]),
v = values,
)
# returns some float number
print(kv_lookup("fruit", ["apple", "banana", "chair"], [10, 5, 2]))
</code></pre>
<h2 id="scaled-dot-product-attention" tabindex="-1">Scaled Dot Product Attention</h2>
<hr />
<p>In principle, the attention equation we derived in the last section is complete. However, we'll need to make a couple of changes to match the version in <a href="https://arxiv.org/pdf/1706.03762.pdf">Attention is All You Need</a>.</p>
<h4 id="values-as-vectors" tabindex="-1">Values as Vectors</h4>
<p>Currently, our values in the key-value pairs are just numbers. However, we could also instead replace them with vectors of some size \(d_v\). For example, with \(d_v = 4\), you might have:</p>
<pre><code class="language-python">d = {
"apple": [0.9, 0.2, -0.5, 1.0]
"banana": [1.2, 2.0, 0.1, 0.2]
"chair": [-1.2, -2.0, 1.0, -0.2]
}
</code></pre>
<p>When we compute our output via a weighted sum, we'd be doing a weighted sum over vectors instead of numbers (i.e. scalar-vector multiplication instead of scalar-scalar multiplication). This is desirable because vectors let us hold/convey more information than just a single number.</p>
<p>To adjust for this change in our equation, instead of multiply our attention scores by a vector \(v\) we multiply it by the row-wise matrix of our value vectors \(V\) (similar to how we stacked our keys to form \(K\)):<br />
\[
\text{attention}(\boldsymbol{q}, K, V) = \text{softmax}(\boldsymbol{q}K^T)V
\]Of course, our output is no longer a scalar, instead it would be a vector of dimensionality \(d_v\).</p>
<h4 id="scaling" tabindex="-1">Scaling</h4>
<p>The dot product between our query and keys can get really large in magnitude if \(d_k\) is large. This makes the output of softmax more <em>extreme</em>. For example, <code>softmax([3, 2, 1]) = [0.665, 0.244, 0.090]</code>, but with larger values <code>softmax([30, 20, 10]) = [9.99954600e-01, 4.53978686e-05, 2.06106005e-09]</code>. When training a neural network, this would mean the gradients would become really small which is undesirable. As a solution, we scale our pre-softmax scores by \(\frac{1}{\sqrt{d_k}}\):</p>
<p>\[
\text{attention}(\boldsymbol{q}, K, V) = \text{softmax}(\frac{\boldsymbol{q}K^T}{\sqrt{d_k}})V
\]</p>
<h4 id="multiple-queries" tabindex="-1">Multiple Queries</h4>
<p>In practice, we often want to perform multiple lookups for \(n_q\) different queries rather than just a single query. Of course, we could always do this one at a time, plugging each query individually into the above equation. However, if we stack of query vectors row-wise as a matrix \(Q\) (in the same way we did for \(K\) and \(V\)), we can compute our output as an \(n_q\) by \(d_v\) matrix where row \(i\) is the output vector for the attention on the \(i\)th query:<br />
\[
\text{attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V
\]that is, \(\text{attention}(Q, K, V)_i = \text{attention}(q_i, K, V)\).</p>
<p>This makes computation faster than if we ran attention for each query sequentially (say, in a for loop) since we can parallelize calculations (particularly when using a GPU).</p>
<p>Note, our input to softmax becomes a matrix instead of a vector. When we write softmax here, we mean that we are taking the softmax along each row of the matrix independently, as if we were doing things sequentially.</p>
<h4 id="result" tabindex="-1">Result</h4>
<p>With that, we have our final equation for scaled dot product attention as it's written in the original paper:<br />
\[
\text{attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V
\]In code:</p>
<pre><code class="language-python">import numpy as np
def softmax(x):
# assumes x is a matrix and we want to take the softmax along each row
# (which is achieved using axis=-1 and keepdims=True)
return np.exp(x) / np.sum(np.exp(x), axis=-1, keepdims=True)
def attention(Q, K, V):
# assumes Q is a matrix of shape (n_q, d_k)
# assumes K is a matrix of shape (n_k, d_k)
# assumes v is a matrix of shape (n_k, d_v)
# output is a matrix of shape (n_q, d_v)
d_k = K.shape[-1]
return softmax(Q @ K.T / np.sqrt(d_k)) @ V
</code></pre>
<hr class="footnotes-sep" />
<section class="footnotes">
<ol class="footnotes-list">
<li id="fn1" class="footnote-item"><p>You'll note that the magnitude of the vectors have an influence on the output of dot product. For example, given 3 vectors, \(a=[1, 1, 1]\), \(b=[1000, 0, 0]\), and \(c=[2, 2, 2]\), our dot product heuristic would tell us that becuase \(a \cdot b > a \cdot c\) that \(a\) is more similar to \(c\) than \(a\) is to \(b\). This doesn't seem right, since \(b\) and \(a\) are pointing in the exact same direction, while \(c\) and \(a\) are not. <a href="https://en.wikipedia.org/wiki/Cosine_similarity">Cosine similarity</a> accounts for this normalizing the vectors to unit vectors before taking the dot product, essentially ignoring the magnitudes and only caring about the direction. So why don't we take the cosine similarity? In a deep learning setting, the magnitude of a vector might actually contain information we care about (and we shouldn't get rid of it). Also, if we regularize our networks properly, outlier examples like the above should not occur. <a href="https://jaykmody.com/blog/attention-intuition/#fnref1" class="footnote-backref">↩︎</a></p>
</li>
<li id="fn2" class="footnote-item"><p>Basically, instead of computing each dot product separately:<br />
\[
\begin{align}
x_1 = & \ \boldsymbol{q} \cdot \boldsymbol{k}_1 = [2, 1, 3] \cdot [-1, 2, -1] = -3\\
x_2 = & \ \boldsymbol{q} \cdot \boldsymbol{k}_2 = [2, 1, 3] \cdot [1.5, 0, -1] = 0\\
x_3 = & \ \boldsymbol{q} \cdot \boldsymbol{k}_3 = [2, 1, 3] \cdot [4, -2, -1] = 3
\end{align}
\]<br />
You compute it all at once:<br />
\[
\begin{align}
\boldsymbol{x} & = \boldsymbol{q}{K}^T \\
& = \begin{bmatrix}2 & 1 & 3\end{bmatrix}\begin{bmatrix}-1 & 2 & -1\\1.5 & 0 & -1\\4 & -2 & -1\end{bmatrix}^T\\
& = \begin{bmatrix}2 & 1 & 3\end{bmatrix}\begin{bmatrix}-1 & 1.5 & 4\\2 & 0 & -2\\-1 & -1 & -1\end{bmatrix}\\
& = [-3, 0, 3]\\
& = [x_1, x_2, x_3]
\end{align}
\] <a href="https://jaykmody.com/blog/attention-intuition/#fnref2" class="footnote-backref">↩︎</a></p>
</li>
</ol>
</section>
Sat, 22 Oct 2022 00:00:00 +0000Jay Modyhttps://jaykmody.com/blog/attention-intuition/Computing Distance Matrices with NumPy
https://jaykmody.com/blog/distance-matrices-with-numpy/
<h2 id="background" tabindex="-1">Background</h2>
<p>A <a href="https://en.wikipedia.org/wiki/Distance_matrix#:~:text=In%20mathematics%2C%20computer%20science%20and,may%20not%20be%20a%20metric.">distance matrix</a> is a square matrix that captures the pairwise distances between a set of vectors. More formally:</p>
<blockquote>
<p>Given a set of vectors \(v_1, v_2, ... v_n\) and it's distance matrix \(\text{dist}\), the element \(\text{dist}_{ij}\) in the matrix would represent the distance between \(v_i\) and \(v_j\). Notice, this means the matrix is symmetric since \(\text{dist}_{ij} = \text{dist}_{ji}\), and the dimensionality (size) of the matrix is \((n, n)\).</p>
</blockquote>
<p>The above definition, however, doesn't define what <em>distance</em> means. There are <a href="https://numerics.mathdotnet.com/Distance.html">many ways to define and compute the distance between two vectors</a>, but usually, when speaking of the distance between vectors, we are referring to their <em>euclidean distance</em>. Euclidean distance is our intuitive notion of what distance is (i.e. shortest line between two points on a map). Mathematically, we can define euclidean distance between two vectors \(u, v\) as,</p>
<p>\[|| u - v ||_2 = \sqrt{\sum_{k=1}^d (u_k - v_k)^2}\]</p>
<p>where \(d\) is the dimensionality (size) of the vectors.</p>
<p>By itself, distance matrixes are already highly useful in all kinds of applications, from math, to computer science, to graph theory, to bio-informatics. Let's explore one particular application for distance matrices, machine learning.</p>
<h2 id="motivating-example%3A-k-nearest-neighbors" tabindex="-1">Motivating Example: k-Nearest Neighbors</h2>
<p><a href="https://cs231n.github.io/classification/#k---nearest-neighbor-classifier">k-Nearest Neighbour</a> (kNN) is a machine learning classification algorithm that utilizes distance matrices under the hood. The idea is simple, we can predict the class of any given data point by looking at the classes of the \(k\) nearest neighboring labelled data points. Whichever class is most common within the neighbors is the class we predict for the data point.</p>
<p>How do you determine which labelled points are the "nearest"? Well, if we represent each data point as a vector, we can compute their euclidean distance.</p>
<p>Let's say instead of just predicting for a single point, you want to predict for multiple points. More formally, you are given \(n\) labelled data points (train data), and \(m\) unlabelled data points (test data, for which you would like to classify). The data points are represented as vectors, of dimensionality \(d\). In order to implement the kNN classifier, you'll need to compute the distances between all labelled-unlabelled pairs. These distances can be stored in an \((m, n)\) matrix \(\text{dist}\), where \(\text{dist}_{ij}\) represents the distance between the ith unlabelled point and the jth labelled point. If we represent our labelled data points by the \((n, d)\) matrix \(Y\), and our unlabelled data points by the \((m, d)\) matrix \(X\), the distance matrix can be formulated as:</p>
<p>\[\text{dist}_{ij} = \sqrt{\sum_{k=1}^d (X_{ik} - Y_{jk})^2}\]</p>
<p>This distance computation is really the meat of the algorithm, and what I'll be focusing on for this post. Let's implement it.</p>
<p><strong>Note:</strong> I use the term distance matrix here even though the matrix is no longer square (since we are computing the distances between two sets of vectors and not just one).</p>
<h2 id="three-loop" tabindex="-1">Three Loop</h2>
<p>Most simple way to compute our distance matrix is to just loop over all the pairs and elements:</p>
<pre><code class="language-python">X # test data (m, d)
X_train # train data (n, d)
m = X.shape[0]
n = X_train.shape[0]
d = X.shape[1]
dists = np.zeros((num_test, num_train)) # distance matrix (m, n)
for i in range(m):
for j in range(n):
val = 0
for k in range(d):
val += (X[i][k] - X_train[j][k]) ** 2
dists[i][j] = np.sqrt(val)
</code></pre>
<p>While this works, it's quite inefficient and doesn't take advantage of numpy's efficient vectorized operations. Let's change that.</p>
<h2 id="two-loops" tabindex="-1">Two Loops</h2>
<pre><code class="language-python">for i in range(m):
for j in range(n):
# element-wise subtract, element-wise square, take the sum and sqrt
dists[i][j] = np.sqrt(np.sum((X[i] - X_train[j]) ** 2))
</code></pre>
<p>That wasn't too bad, we even made it easier to read if you're asking me, but we can do better.</p>
<h2 id="one-loop" tabindex="-1">One Loop</h2>
<pre><code class="language-python">for i in range(m):
dists[i, :] = np.sqrt(np.sum((X[i] - X_train) ** 2, axis=1))
</code></pre>
<p>What the hell is going on here?! Ok let's break it down.</p>
<p>Firstly, shouldn't <code>X[i] - X_train</code> result in an error? <code>X[i]</code> has shape \((d)\) while <code>X_train</code> has shape \((n, d)\). Element-wise operations only work if both parties have the same shape, so what's happening here?</p>
<p>Numpy is automatically <a href="https://numpy.org/doc/stable/user/basics.broadcasting.html">broadcasting</a> <code>X[i]</code> to match the shape of <code>X_train</code>. You can think of this as stacking <code>X[i]</code> \(n\) times to produce an \((n, d)\) matrix where each row is just a copy of <code>X[i]</code>. This way, when performing the subtraction, each row of <code>X_train</code> is being subtracted by <code>X[i]</code> (or the other way around, it doesn't matter since we'll be taking the square of the result). If you wanted, you can create the "stacked" matrix yourself in numpy using <code>np.tile</code>, but it would be <a href="https://gist.github.com/jaymody/9d7dec07300f817ddd40b74b1d648a34">slower then if you let numpy handle it with broadcasting</a>. So now we have an \((n, d)\) matrix where each row is <code>X[i] - X_train[j]</code>, sick.</p>
<p>The next step is easy, we perform an element-wise square. Then, we need to take the sum of each row, so we use <code>np.sum</code> with the argument <code>axis=1</code> which tells numpy to sum across the first axis (ie the rows). Without the axis argument, <code>np.sum</code> will take the sum of every element in the matrix and output a single scalar value. The result of the <code>np.sum</code> with <code>axis=1</code> gives us a vector of size \(n\).</p>
<p>Finally, we take the element-wise square root of this vector and store it in \(dists[i]\).</p>
<p>So here's a better annotated version of the code that's much easier to understand:</p>
<pre><code class="language-python">for i in range(m):
# X[i] gets broadcasted (d) -> (n, d)
# (each row is a copy of X[i])
diffs = X[i] - X_train
# element wise square
squared = diffs ** 2
# take the sum of each row (n, d) -> (n)
sums = np.sum(squared, axis=1)
# take the element-wise square root and store them in dists
dists[i, :] = np.sqrt(sums)
</code></pre>
<h2 id="no-loops%3F!" tabindex="-1">No Loops?!</h2>
<p>We can do even better and only use vector/matrix operations, no loops needed. How you ask? Let's take a closer look at our equation:</p>
<p>\[\text{dist}_{ij} = \sqrt{\sum_{k=1}^d (x_{ik} - y_{jk})^2}\]</p>
<p>What happens if we expand out the expression in the sum?</p>
<p>\[
\text{dist}_{ij} = \sqrt{\sum_{k=1}^d x^2_{ik} - 2x_{ik}y_{jk} + y^2_{jk}}\\
\]</p>
<p>Interesting, let's distribute the sum:</p>
<p>\[
\text{dist}_{ij} = \sqrt{\sum_{k=1}^d x^2_{ik} - 2 \sum_{k=1}^d x_{ik}y_{jk} + \sum_{k=1}^dy^2_{jk}}\\
\]</p>
<p>You'll notice that each of these sums are just dot products, so let's replace the ugly notation and get a much cleaner expression:</p>
<p>\[
\text{dist}_{ij} = \sqrt{x_i \cdot x_i - 2x_i \cdot y_j + y_j \cdot y_j}\\
\]</p>
<p>Notice, for all combinations of \(i, j\), the middle term is unique, but the left and right terms are repeated. Imagine fixing either \(i\) or \(j\) and iterate the other variable, you'll see that \(x_i \cdot x_i\) shows up \(j\) times and \(y_j \cdot y_j\) shows up \(i\) times. So, our challenge is to figure out how to compute all possible \(x_i \cdot x_i\), \(x_i \cdot y_j\), and \(y_j \cdot y_j\), and then add them together in the right way. All of this without loops. Let's try it:</p>
<pre><code class="language-python"># this has the same affect as taking the dot product of each row with itself
x2 = np.sum(X**2, axis=1) # shape of (m)
y2 = np.sum(X_train**2, axis=1) # shape of (n)
# we can compute all x_i * y_j and store it in a matrix at xy[i][j] by
# taking the matrix multiplication between X and X_train transpose
# if you're stuggling to understand this, draw out the matrices and
# do the matrix multiplication by hand
# (m, d) x (d, n) -> (m, n)
xy = np.matmul(X, X_train.T)
# each row in xy needs to be added with x2[i]
# each column of xy needs to be added with y2[j]
# to get everything to play well, we'll need to reshape
# x2 from (m) -> (m, 1), numpy will handle the rest of the broadcasting for us
# see: https://numpy.org/doc/stable/user/basics.broadcasting.html
x2 = x2.reshape(-1, 1)
dists = np.sqrt(x2 - 2*xy + y2) # (m, 1) repeat columnwise + (m, n) + (n) repeat rowwise -> (m, n)
</code></pre>
<h2 id="-1-loops%3F!!%3F!-%F0%9F%A4%94" tabindex="-1">-1 Loops?!!?! 🤔</h2>
<pre><code class="language-python">from sklearn.neighbors import KNeighborsClassifier
</code></pre>
<h2 id="speed-comparison" tabindex="-1">Speed Comparison</h2>
<p>To test the speed of each implementation, we can run it against a small subset of the cifar-10 dataset as seen in the <a href="https://github.com/jaymody/cs231n/blob/master/assignment1/knn.ipynb">cs231n assignment 1 knn notebook</a>:</p>
<pre><code class="language-python">Two loop version took 39.707250 seconds
One loop version took 28.705156 seconds
No loop version took 0.218127 seconds
</code></pre>
<p>Clearly, we can see the no loop version is the winner, beating out both the two loop and one loop implementations by orders of magnitudes. Notice, I didn't include the three loop implementation because that would have taken hours to run! On just <code>10</code> training and <code>10</code> test examples, the three loop implementation took <code>0.5</code> seconds. For reference, the above time profiles are for <code>5000</code> training and <code>500</code> test examples, yikes! +1 for vector operations!</p>
Sun, 04 Apr 2021 00:00:00 +0000Jay Modyhttps://jaykmody.com/blog/distance-matrices-with-numpy/