Skip to content

Commit

Permalink
improved wording
Browse files Browse the repository at this point in the history
  • Loading branch information
simveit committed Oct 12, 2024
1 parent 7616f1b commit a846e1d
Showing 1 changed file with 35 additions and 27 deletions.
62 changes: 35 additions & 27 deletions _posts/2024-10-12-mlp-jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@ tags:

This blog post shows how to perform a natural language processing using a simple MLP and JAX. Throughout this blog post I will assume that you are familiar with my last blog post so if you didn't read that already go and check that out here.
# Going further
So in the last blog post we had a simple bigram model where we trained a basic neural network without a hidden layer to perform the following task. Given a word or in our case a character predict the next character.
Now we will take it a step further. So now what we will do is to do the following. We take multiple characters in order to predict the next character. We will train the model on the same data as the bigram model and compare the performance.
A further difference besides that we take multiple characters now is that we will use a lot a larger neural network. For the bigram model we simply had one matrix 27 by 27 dimensions where 27 was the size of our vocabulary and what we will do now is to find an embedding for every character then put this through a hidden layer with a non-linearity and then finally output everything through a linear layer.
But lines of code say more than words, so let's go to the implementation part now.
In the last blog post, we explored a simple bigram model where we trained a basic neural network without a hidden layer to perform the following task: given a word—or, in our case, a character—predict the next character.

Now, we will take it a step further by using multiple characters to predict the next character. We will train the model on the same data as the bigram model and compare their performances.

Another key difference is that we will use a much larger neural network. For the bigram model, we simply had one matrix with dimensions 27 by 27, where 27 was the size of our vocabulary. Now, we will find an embedding for every character, pass this through a hidden layer with a non-linear activation function, and finally output everything through a linear layer.

But lines of code say more than words, so let’s move on to the implementation part now.
# Implementation
## Data Loading and Preprocessing
So the loading of the data will stay the same as well as encoding of a word. So we won't repeat that. And if you didn't read the last blog post, I suggest you to do that at this stage.
The first difference we encounter comes in the way how we build our data set.
The data loading process remains the same, as does the encoding of a word. We won’t repeat that here. If you haven’t read the last blog post, I suggest you do so at this stage.

The first difference is in how we build our dataset.
```python
def get_dataset(encoded_words: List[List[int]], block_size: int) -> Tuple[Array, Array]:
"""
Expand All @@ -34,22 +38,21 @@ def get_dataset(encoded_words: List[List[int]], block_size: int) -> Tuple[Array,
context = context[1:] + [token]
return jnp.array(X), jnp.array(y)
```
As you can see what we do here is first our function takes an additional argument which is called block size that essentially is the number of tokens we use to predict the next token. So what we do here is first we initialize the context that is what we give the neural network to predict the next token and initialize that with just zeros which stands for our special EOS token which comes at the beginning and at the end of each word. So this is also called padding. Then we basically just go through over every token in our current world, append the current context to the training data, append the current token to our targets and update the context accordingly.
To get a sense of what that means, let's look at a very simple example and the tokenization process for this example.
Let's assume we want to get the training and target data for the word "emma".
This would work as follows:
As you can see, our function now takes an additional argument called block_size, which is the number of tokens we use to predict the next token. We initialize the context, which is what we provide to the neural network to predict the next token, with zeros representing our special <eos> token used at the beginning and end of each word. This is also known as padding. We then iterate through every token in the current word, append the current context to the training data, append the current token to our targets, and update the context accordingly.

To illustrate this, let’s look at a simple example of the tokenization process for the word “emma”:
```
<eos><eos><eos> -> e
<eos><eos>e -> m
<eos><eos>em -> m
<eos><eos>emm -> a
<eos><eos>emma -> <eos>
```
above you see on the left side the corresponding the training points and the targets which correspond to each training point
We then proceed to divide our training set randomly into train, dev and test set.
For the code of that see in the repo I will link at a later point. What I did here was that I fixed a random seed to ensure deterministic behavior.
Above, you can see on the left side the corresponding training points and on the right side the targets for each training point.

We then proceed to divide our training set randomly into train, dev, and test sets. You can find the code for that in the repository, which I will link later. I fixed a random seed to ensure deterministic behavior.
## Model
So now we come to the interesting part which is the implementation of our model. Our model has the following parameters:
Now we come to the interesting part: the implementation of our model. Our model has the following parameters:
```python
class MLPParams(NamedTuple):
embedding: Array
Expand All @@ -58,7 +61,7 @@ class MLPParams(NamedTuple):
W2: Array
b2: Array
```
What each of these means becomes clearer if we look at the forward function.
What each of these means becomes clearer when we look at the forward function.
```python
def forward(params: MLPParams, X: Array) -> Array:
embedded = params.embedding[X] # (batch_size, block_size, embed_size)
Expand All @@ -67,17 +70,22 @@ def forward(params: MLPParams, X: Array) -> Array:
output = hidden.dot(params.W2) + params.b2 # (batch_size, vocab_size)
return output
```
So what we do here is at the first stage we do what is called embedding the tokens, which is really just a fancy term for looking their corresponding vector up in a lookup table. So for each token, we have an embed size dimensional vector and we will look that up in the corresponding matrix, which is called embedding here, during training that matrix will learn a meaningful representation of each token.
The next step is then to do the following. So if X is of the dimension batch size times block size, which is the case here, as you can see in the data processing part, then the lookup will be a 3D array with dimensions batch size, block size, embed size. So we need to reshape that into an ordinary matrix and can use familiar syntax for NumPy users for that.
We then multiply the resulting matrix by another matrix W1 at what is called so-called bias b1 and wrap that into a non-linearity which is simply the tangens hyperbolicus.
Note that here we follow the implementation of the original paper of [Bengio et al.](https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf) That's why we use the tanh non-linearity. Nowadays probably it would be more canonical to use a ReLU non-linearity at this point.
The last step is then to simply output the logits, which is done by multiplying the output of the hidden layer with the second matrix W2 and adding again a bias term. This results in our final logits.
Our loss function will be again the familiar cross entropy loss which I described in the first blog post. What comes after is then also usual training which you can look up in the repo I will link at the end of the post.
The only thing which is maybe interesting is that during training we used what is so called mini-batch training. So in each epoch we just sampled a small batch of the training data, used that batch to compute the gradients and used then this resulting gradient to update our weights. that has the advantage that the training becomes more efficient and of course takes less memory but has the disadvantage that our training loss will not decrease as smoothly as if we would use the whole data but as we will see that is really sufficient here and also a common practice in general.
Let's look at our loss curves now, so here I plotted training as well as validation loss. The training loss doesn't really matter if our model is not able to generalize, so we should always evaluate on the holdout validation set.
At the first stage, we perform what is called embedding the tokens, which is essentially looking up their corresponding vectors in a lookup table. For each token, we have an embed_size-dimensional vector, and we retrieve these vectors from the embedding matrix during training. This matrix learns a meaningful representation of each token.

Next, if X has the dimensions (batch_size, block_size), as shown in the data processing part, the lookup will result in a 3D array with dimensions (batch_size, block_size, embed_size). We need to reshape this into a 2D matrix to use familiar NumPy syntax.
We then multiply the resulting matrix by another matrix W1, add the bias b1, and apply a non-linear activation function, which is the hyperbolic tangent (tanh). Note that we follow the implementation from the original paper by [Bengio et al.](https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf), which is why we use the tanh activation function. Nowadays, it is more common to use a ReLU activation function at this stage.

The final step is to output the logits by multiplying the hidden layer’s output with the second matrix W2 and adding the bias term b2. This results in our final logits.

Our loss function is the familiar cross-entropy loss, which I described in the first blog post. The training process is also standard and can be found in the repository linked at the end of this post.

One interesting aspect is that during training, we used mini-batch training. In each epoch, we sample a small batch of the training data, use that batch to compute the gradients, and then use the resulting gradients to update our weights. This approach makes training more efficient and requires less memory. However, the training loss may not decrease as smoothly as when using the entire dataset. As we will see, this approach is sufficient here and is also a common practice in general.

Let’s look at our loss curves now. Here, I plotted both training and validation loss. The training loss doesn’t matter much if our model cannot generalize, so we should always evaluate on the holdout validation set.
![Picture](/assets/mlp_nlp/losses.png)
That's nice because we see a huge improvement in performance compared to our Bi-gram model. Also we see what we want, that the training and the validation loss is roughly of the same order, which means we don't overfit. The bumpy behavior in the training loss is also easily explainable by the fact that we use mini-batch as updates, so we don't compute the full gradient and so a little bit of noise in that is expected.
Let's now move on to the sample part. The sampling process largely is the same as last time, but we need to take into account that we need to put three encoded words into our neural network because that's how we trained it.
This is encouraging because we see a significant improvement in performance compared to our bigram model. Additionally, we observe that the training and validation loss are roughly of the same order, indicating that we don’t overfit. The bumpy behavior in the training loss is easily explained by the use of mini-batch updates, which introduce some noise since we don’t compute the full gradient.

Now, let’s move on to the sampling part. The sampling process is largely the same as before, but we need to input three encoded characters into our neural network because that’s how we trained it.
```python
def sample(params: MLPParams, key: Array, vocab: List[str]) -> str:
"""
Expand All @@ -104,7 +112,7 @@ def sample(params: MLPParams, key: Array, vocab: List[str]) -> str:
break
return "".join(sampled_word)[len("<eos><eos><eos>") : -len("<eos>")]
```
We can then look at what kinds of words the neural network is able to produce after training.
We can then examine the kinds of words the neural network is able to produce after training:
```
karmelmanie
zaaqa
Expand All @@ -117,7 +125,7 @@ artha
jaamini
tina
```
This looks much much better than the results from the bigram model if you compare them. They sound really name-like and maybe some of them are not real real names but for example there is Tina, there is Karlin, there is Corel and all of these sound at least to me very name-ish.
These results are much better than those from the bigram model. The generated words sound more name-like. While some may not be real names, examples like tine, carlyn, and caurelle certainly sound plausible.
So that's it for now with the blog post. As before, this is heavily influenced by the lecture of [Andrej Karpathy](https://www.youtube.com/watch?v=TCH_1BHY58I&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ&index=3).
The experiments were performed on a TPU provided by [TRC research program](https://sites.research.google/trc/about/).
If you have questions or suggestions, please let me know. If you are interested in the code, you can find it at [my gitub](https://github.com/simveit/mlp_language_modelling_jax/tree/main).
Expand Down

0 comments on commit a846e1d

Please sign in to comment.