Skip to content

Commit

Permalink
Added new blogpost
Browse files Browse the repository at this point in the history
  • Loading branch information
simveit committed Oct 12, 2024
1 parent 3422591 commit 27d4046
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 0 deletions.
125 changes: 125 additions & 0 deletions _posts/2024-10-12-mlp-jax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
---
title: "Bigram Language Model with JAX"
categories:
- LLM
tags:
- JAX, TPU
---
![Picture](/assets/bigram/image.png_posts/2024-10-05-bigram-jax.md)

# Introduction

This blog post shows how to perform a natural language processing using a simple MLP and JAX. Throughout this blog post I will assume that you are familiar with my last blog post so if you didn't read that already go and check that out here.
# Going further
So in the last blog post we had a simple bigram model where we trained a basic neural network without a hidden layer to perform the following task. Given a word or in our case a character predict the next character.
Now we will take it a step further. So now what we will do is to do the following. We take multiple characters in order to predict the next character. We will train the model on the same data as the bigram model and compare the performance.
A further difference besides that we take multiple characters now is that we will use a lot a larger neural network. For the bigram model we simply had one matrix 27 by 27 dimensions where 27 was the size of our vocabulary and what we will do now is to find an embedding for every character then put this through a hidden layer with a non-linearity and then finally output everything through a linear layer.
But lines of code say more than words, so let's go to the implementation part now.
# Implementation
## Data Loading and Preprocessing
So the loading of the data will stay the same as well as encoding of a word. So we won't repeat that. And if you didn't read the last blog post, I suggest you to do that at this stage.
The first difference we encounter comes in the way how we build our data set.
```python
def get_dataset(encoded_words: List[List[int]], block_size: int) -> Tuple[Array, Array]:
"""
Take block size letters to predict the next letter.
"""
X = []
y = []
for word in encoded_words:
context = [0] * block_size
for token in word[1:]:
X.append(context)
y.append(token)
context = context[1:] + [token]
return jnp.array(X), jnp.array(y)
```
As you can see what we do here is first our function takes an additional argument which is called block size that essentially is the number of tokens we use to predict the next token. So what we do here is first we initialize the context that is what we give the neural network to predict the next token and initialize that with just zeros which stands for our special EOS token which comes at the beginning and at the end of each word. So this is also called padding. Then we basically just go through over every token in our current world, append the current context to the training data, append the current token to our targets and update the context accordingly.
To get a sense of what that means, let's look at a very simple example and the tokenization process for this example.
Let's assume we want to get the training and target data for the word "emma".
This would work as follows:
```
<eos><eos><eos> -> e
<eos><eos>e -> m
<eos><eos>em -> m
<eos><eos>emm -> a
<eos><eos>emma -> <eos>
```
above you see on the left side the corresponding the training points and the targets which correspond to each training point
We then proceed to divide our training set randomly into train, dev and test set.
For the code of that see in the repo I will link at a later point. What I did here was that I fixed a random seed to ensure deterministic behavior.
## Model
So now we come to the interesting part which is the implementation of our model. Our model has the following parameters:
```python
class MLPParams(NamedTuple):
embedding: Array
W1: Array
b1: Array
W2: Array
b2: Array
```
What each of these means becomes clearer if we look at the forward function.
```python
def forward(params: MLPParams, X: Array) -> Array:
embedded = params.embedding[X] # (batch_size, block_size, embed_size)
embedded = embedded.reshape(X.shape[0], -1) # (batch_size, block_size * embed_size)
hidden = jnp.tanh(embedded.dot(params.W1) + params.b1) # (batch_size, hidden_size)
output = hidden.dot(params.W2) + params.b2 # (batch_size, vocab_size)
return output
```
So what we do here is at the first stage we do what is called embedding the tokens, which is really just a fancy term for looking their corresponding vector up in a lookup table. So for each token, we have an embed size dimensional vector and we will look that up in the corresponding matrix, which is called embedding here, during training that matrix will learn a meaningful representation of each token.
The next step is then to do the following. So if X is of the dimension batch size times block size, which is the case here, as you can see in the data processing part, then the lookup will be a 3D array with dimensions batch size, block size, embed size. So we need to reshape that into an ordinary matrix and can use familiar syntax for NumPy users for that.
We then multiply the resulting matrix by another matrix W1 at what is called so-called bias b1 and wrap that into a non-linearity which is simply the tangens hyperbolicus.
Note that here we follow the implementation of the original paper of [Bengio et al.](https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf) That's why we use the tanh non-linearity. Nowadays probably it would be more canonical to use a ReLU non-linearity at this point.
The last step is then to simply output the logits, which is done by multiplying the output of the hidden layer with the second matrix W2 and adding again a bias term. This results in our final logits.
Our loss function will be again the familiar cross entropy loss which I described in the first blog post. What comes after is then also usual training which you can look up in the repo I will link at the end of the post.
The only thing which is maybe interesting is that during training we used what is so called mini-batch training. So in each epoch we just sampled a small batch of the training data, used that batch to compute the gradients and used then this resulting gradient to update our weights. that has the advantage that the training becomes more efficient and of course takes less memory but has the disadvantage that our training loss will not decrease as smoothly as if we would use the whole data but as we will see that is really sufficient here and also a common practice in general.
Let's look at our loss curves now, so here I plotted training as well as validation loss. The training loss doesn't really matter if our model is not able to generalize, so we should always evaluate on the holdout validation set.
![Picture](/assets/mlp_nlp/losses.png)
That's nice because we see a huge improvement in performance compared to our Bi-gram model. Also we see what we want, that the training and the validation loss is roughly of the same order, which means we don't overfit. The bumpy behavior in the training loss is also easily explainable by the fact that we use mini-batch as updates, so we don't compute the full gradient and so a little bit of noise in that is expected.
Let's now move on to the sample part. The sampling process largely is the same as last time, but we need to take into account that we need to put three encoded words into our neural network because that's how we trained it.
```python
def sample(params: MLPParams, key: Array, vocab: List[str]) -> str:
"""
1) Start with <eos>
2) Index into the weights matrix W for the current character
3) Sample the next character from the distribution
4) Append the sampled character to the sampled word
5) Repeat steps 3-5 until <eos> is sampled
6) Return the sampled word
"""
current_chars = jnp.array([vocab.index("<eos>"), vocab.index("<eos>"), vocab.index("<eos>")])[
None, :
]
sampled_word = ["<eos>", "<eos>", "<eos>"]
while True:
key, subkey = jax.random.split(key)
logits = forward(params, current_chars)
sampled_char = random.categorical(subkey, logits=logits)[0]
current_chars = jnp.concatenate(
[current_chars[:, 1:], jnp.array([sampled_char])[None, :]], axis=1
)
sampled_word.append(vocab[sampled_char])
if sampled_char == vocab.index("<eos>"):
break
return "".join(sampled_word)[len("<eos><eos><eos>") : -len("<eos>")]
```
We can then look at what kinds of words the neural network is able to produce after training.
```
karmelmanie
zaaqa
tri
caurelle
raamia
carlyn
mavin
artha
jaamini
tina
```
This looks much much better than the results from the BIKE RAM model if you compare them. They sound really name-like and maybe some of them are not real real names but for example there is Tina, there is Karlin, there is Corel and all of these sound at least to me very name-ish.
So that's it for now with the blog post. As before, this is heavily influenced by the lecture of [Andrej Karpathy](https://www.youtube.com/watch?v=TCH_1BHY58I&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ&index=3).
The experiments were performed on a TPU provided by [TRC research program](https://sites.research.google/trc/about/).
If you have questions or suggestions, please let me know. If you are interested in the code, you can find it at [my gitub](https://github.com/simveit/mlp_language_modelling_jax/tree/main).
If you have any questions, please let me know and I will try to answer them as best as I can.

Binary file modified assets/.DS_Store
Binary file not shown.
Binary file added assets/mlp_nlp/losses.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 27d4046

Please sign in to comment.