generated from mmistakes/mm-github-pages-starter
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
125 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
--- | ||
title: "Bigram Language Model with JAX" | ||
categories: | ||
- LLM | ||
tags: | ||
- JAX, TPU | ||
--- | ||
![Picture](/assets/bigram/image.png_posts/2024-10-05-bigram-jax.md) | ||
|
||
# Introduction | ||
|
||
This blog post shows how to perform a natural language processing using a simple MLP and JAX. Throughout this blog post I will assume that you are familiar with my last blog post so if you didn't read that already go and check that out here. | ||
# Going further | ||
So in the last blog post we had a simple bigram model where we trained a basic neural network without a hidden layer to perform the following task. Given a word or in our case a character predict the next character. | ||
Now we will take it a step further. So now what we will do is to do the following. We take multiple characters in order to predict the next character. We will train the model on the same data as the bigram model and compare the performance. | ||
A further difference besides that we take multiple characters now is that we will use a lot a larger neural network. For the bigram model we simply had one matrix 27 by 27 dimensions where 27 was the size of our vocabulary and what we will do now is to find an embedding for every character then put this through a hidden layer with a non-linearity and then finally output everything through a linear layer. | ||
But lines of code say more than words, so let's go to the implementation part now. | ||
# Implementation | ||
## Data Loading and Preprocessing | ||
So the loading of the data will stay the same as well as encoding of a word. So we won't repeat that. And if you didn't read the last blog post, I suggest you to do that at this stage. | ||
The first difference we encounter comes in the way how we build our data set. | ||
```python | ||
def get_dataset(encoded_words: List[List[int]], block_size: int) -> Tuple[Array, Array]: | ||
""" | ||
Take block size letters to predict the next letter. | ||
""" | ||
X = [] | ||
y = [] | ||
for word in encoded_words: | ||
context = [0] * block_size | ||
for token in word[1:]: | ||
X.append(context) | ||
y.append(token) | ||
context = context[1:] + [token] | ||
return jnp.array(X), jnp.array(y) | ||
``` | ||
As you can see what we do here is first our function takes an additional argument which is called block size that essentially is the number of tokens we use to predict the next token. So what we do here is first we initialize the context that is what we give the neural network to predict the next token and initialize that with just zeros which stands for our special EOS token which comes at the beginning and at the end of each word. So this is also called padding. Then we basically just go through over every token in our current world, append the current context to the training data, append the current token to our targets and update the context accordingly. | ||
To get a sense of what that means, let's look at a very simple example and the tokenization process for this example. | ||
Let's assume we want to get the training and target data for the word "emma". | ||
This would work as follows: | ||
``` | ||
<eos><eos><eos> -> e | ||
<eos><eos>e -> m | ||
<eos><eos>em -> m | ||
<eos><eos>emm -> a | ||
<eos><eos>emma -> <eos> | ||
``` | ||
above you see on the left side the corresponding the training points and the targets which correspond to each training point | ||
We then proceed to divide our training set randomly into train, dev and test set. | ||
For the code of that see in the repo I will link at a later point. What I did here was that I fixed a random seed to ensure deterministic behavior. | ||
## Model | ||
So now we come to the interesting part which is the implementation of our model. Our model has the following parameters: | ||
```python | ||
class MLPParams(NamedTuple): | ||
embedding: Array | ||
W1: Array | ||
b1: Array | ||
W2: Array | ||
b2: Array | ||
``` | ||
What each of these means becomes clearer if we look at the forward function. | ||
```python | ||
def forward(params: MLPParams, X: Array) -> Array: | ||
embedded = params.embedding[X] # (batch_size, block_size, embed_size) | ||
embedded = embedded.reshape(X.shape[0], -1) # (batch_size, block_size * embed_size) | ||
hidden = jnp.tanh(embedded.dot(params.W1) + params.b1) # (batch_size, hidden_size) | ||
output = hidden.dot(params.W2) + params.b2 # (batch_size, vocab_size) | ||
return output | ||
``` | ||
So what we do here is at the first stage we do what is called embedding the tokens, which is really just a fancy term for looking their corresponding vector up in a lookup table. So for each token, we have an embed size dimensional vector and we will look that up in the corresponding matrix, which is called embedding here, during training that matrix will learn a meaningful representation of each token. | ||
The next step is then to do the following. So if X is of the dimension batch size times block size, which is the case here, as you can see in the data processing part, then the lookup will be a 3D array with dimensions batch size, block size, embed size. So we need to reshape that into an ordinary matrix and can use familiar syntax for NumPy users for that. | ||
We then multiply the resulting matrix by another matrix W1 at what is called so-called bias b1 and wrap that into a non-linearity which is simply the tangens hyperbolicus. | ||
Note that here we follow the implementation of the original paper of [Bengio et al.](https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf) That's why we use the tanh non-linearity. Nowadays probably it would be more canonical to use a ReLU non-linearity at this point. | ||
The last step is then to simply output the logits, which is done by multiplying the output of the hidden layer with the second matrix W2 and adding again a bias term. This results in our final logits. | ||
Our loss function will be again the familiar cross entropy loss which I described in the first blog post. What comes after is then also usual training which you can look up in the repo I will link at the end of the post. | ||
The only thing which is maybe interesting is that during training we used what is so called mini-batch training. So in each epoch we just sampled a small batch of the training data, used that batch to compute the gradients and used then this resulting gradient to update our weights. that has the advantage that the training becomes more efficient and of course takes less memory but has the disadvantage that our training loss will not decrease as smoothly as if we would use the whole data but as we will see that is really sufficient here and also a common practice in general. | ||
Let's look at our loss curves now, so here I plotted training as well as validation loss. The training loss doesn't really matter if our model is not able to generalize, so we should always evaluate on the holdout validation set. | ||
![Picture](/assets/mlp_nlp/losses.png) | ||
That's nice because we see a huge improvement in performance compared to our Bi-gram model. Also we see what we want, that the training and the validation loss is roughly of the same order, which means we don't overfit. The bumpy behavior in the training loss is also easily explainable by the fact that we use mini-batch as updates, so we don't compute the full gradient and so a little bit of noise in that is expected. | ||
Let's now move on to the sample part. The sampling process largely is the same as last time, but we need to take into account that we need to put three encoded words into our neural network because that's how we trained it. | ||
```python | ||
def sample(params: MLPParams, key: Array, vocab: List[str]) -> str: | ||
""" | ||
1) Start with <eos> | ||
2) Index into the weights matrix W for the current character | ||
3) Sample the next character from the distribution | ||
4) Append the sampled character to the sampled word | ||
5) Repeat steps 3-5 until <eos> is sampled | ||
6) Return the sampled word | ||
""" | ||
current_chars = jnp.array([vocab.index("<eos>"), vocab.index("<eos>"), vocab.index("<eos>")])[ | ||
None, : | ||
] | ||
sampled_word = ["<eos>", "<eos>", "<eos>"] | ||
while True: | ||
key, subkey = jax.random.split(key) | ||
logits = forward(params, current_chars) | ||
sampled_char = random.categorical(subkey, logits=logits)[0] | ||
current_chars = jnp.concatenate( | ||
[current_chars[:, 1:], jnp.array([sampled_char])[None, :]], axis=1 | ||
) | ||
sampled_word.append(vocab[sampled_char]) | ||
if sampled_char == vocab.index("<eos>"): | ||
break | ||
return "".join(sampled_word)[len("<eos><eos><eos>") : -len("<eos>")] | ||
``` | ||
We can then look at what kinds of words the neural network is able to produce after training. | ||
``` | ||
karmelmanie | ||
zaaqa | ||
tri | ||
caurelle | ||
raamia | ||
carlyn | ||
mavin | ||
artha | ||
jaamini | ||
tina | ||
``` | ||
This looks much much better than the results from the BIKE RAM model if you compare them. They sound really name-like and maybe some of them are not real real names but for example there is Tina, there is Karlin, there is Corel and all of these sound at least to me very name-ish. | ||
So that's it for now with the blog post. As before, this is heavily influenced by the lecture of [Andrej Karpathy](https://www.youtube.com/watch?v=TCH_1BHY58I&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ&index=3). | ||
The experiments were performed on a TPU provided by [TRC research program](https://sites.research.google/trc/about/). | ||
If you have questions or suggestions, please let me know. If you are interested in the code, you can find it at [my gitub](https://github.com/simveit/mlp_language_modelling_jax/tree/main). | ||
If you have any questions, please let me know and I will try to answer them as best as I can. | ||
|
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.