Skip to content

Commit

Permalink
Adjusted name of image folder
Browse files Browse the repository at this point in the history
  • Loading branch information
simveit committed Mar 29, 2024
1 parent a96f6d3 commit 66750cf
Show file tree
Hide file tree
Showing 7 changed files with 5 additions and 31 deletions.
28 changes: 1 addition & 27 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1 @@
# Minimal Mistakes remote theme starter

Click [**Use this template**](https://github.com/mmistakes/mm-github-pages-starter/generate) button above for the quickest method of getting started with the [Minimal Mistakes Jekyll theme](https://github.com/mmistakes/minimal-mistakes).

Contains basic configuration to get you a site with:

- Sample posts.
- Sample top navigation.
- Sample author sidebar with social links.
- Sample footer links.
- Paginated home page.
- Archive pages for posts grouped by year, category, and tag.
- Sample about page.
- Sample 404 page.
- Site wide search.

Replace sample content with your own and [configure as necessary](https://mmistakes.github.io/minimal-mistakes/docs/configuration/).

---

## Troubleshooting

If you have a question about using Jekyll, start a discussion on the [Jekyll Forum](https://talk.jekyllrb.com/) or [StackOverflow](https://stackoverflow.com/questions/tagged/jekyll). Other resources:

- [Ruby 101](https://jekyllrb.com/docs/ruby-101/)
- [Setting up a Jekyll site with GitHub Pages](https://jekyllrb.com/docs/github-pages/)
- [Configuring GitHub Metadata](https://github.com/jekyll/github-metadata/blob/master/docs/configuration.md#configuration) to work properly when developing locally and avoid `No GitHub API authentication could be found. Some fields may be missing or have incorrect data.` warnings.
# Simon's blog
8 changes: 4 additions & 4 deletions _posts/2024-29-03-multi-chip-performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ In this blog post we will explain how to efficiently use Google's TPU. TPUs are

In this tutorial we will take a simple layerwise matrix multiplication of activations with weights as our running example. The workload may be visualized like this:

![Layer-wise Matrix Multiplication](/assets/images/LayerwiseMatMul.png)
![Layer-wise Matrix Multiplication](/assets/multi_chip_processing/LayerwiseMatMul.png)

In the above diagram the activations have shape `B*E x E` and the weights have shape `E x E`.
The question is now how we can distribute this workload in an efficent way onto the different TPU chips.

For the activations it's pretty obvious how we can distribute them onto different chips: Just put each batch onto one chip and then run the calculation for each batch independently, that is we multiply each batch with the weights matrix.
This can be visualized as follows:
![Layer-wise Matrix Multiplication](/assets/images/LayerwiseMatMulSharded.png)
![Layer-wise Matrix Multiplication](/assets/multi_chip_processing/LayerwiseMatMulSharded.png)
The different colors should visualize the fact that the activations are distributed batchwise over the different chips and the weights are copied onto all chips.
In JAX we can accomplish distribution onto different chips as follows:
```
Expand Down Expand Up @@ -94,7 +94,7 @@ We see that we want to use the partition `p=PartitionSpec('axis', None)` for the
So far so good but this still doesn't leverage the full power of having multiple chips. What if the weight matrices are very large- So large that we can't distribute all of them onto each chip?

It turns out we can do the following:
![Layer-wise Matrix Multiplication](/assets/images/LayerwiseMatMulFullShard.png)
![Layer-wise Matrix Multiplication](/assets/multi_chip_processing/LayerwiseMatMulFullShard.png)
What we see is that initially we distribute the weights also over all available chips.
But for the calculation we need the weight for the current layer to be on all chips. How can this be archieved?
It turns out the algorithm is quiete simple:
Expand Down Expand Up @@ -142,7 +142,7 @@ print(f"Average time for forward pass: {average_time:.2f} ms")
```
For the above setting we archieved an average time of $39.82 ms$ on Googles `TPU-v4-8`.
Let's look at the trace viewer to get more insight about how jax compiled the matmul function:
![Profiler](/assets/images/fdsp.png)
![Profiler](/assets/multi_chip_processing/fdsp.png)
We see that JAX does exactly what we described above! Only the first all gather is performed for a "long" time. Afterwards the gathering process gets fused with the matrix multiplication which gives a huge speedup if we compare it to the naive approach that we would just apply all gathering after each matrix multiplication and at the same time it gives us the benefit that we can safe lots of memory by sharding most of the weights over all chips.

I hope this post was insightful and you liked it.
Expand Down
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes

0 comments on commit 66750cf

Please sign in to comment.