diff --git a/README.md b/README.md index c58cc88..88ca58a 100644 --- a/README.md +++ b/README.md @@ -1,27 +1 @@ -# Minimal Mistakes remote theme starter - -Click [**Use this template**](https://github.com/mmistakes/mm-github-pages-starter/generate) button above for the quickest method of getting started with the [Minimal Mistakes Jekyll theme](https://github.com/mmistakes/minimal-mistakes). - -Contains basic configuration to get you a site with: - -- Sample posts. -- Sample top navigation. -- Sample author sidebar with social links. -- Sample footer links. -- Paginated home page. -- Archive pages for posts grouped by year, category, and tag. -- Sample about page. -- Sample 404 page. -- Site wide search. - -Replace sample content with your own and [configure as necessary](https://mmistakes.github.io/minimal-mistakes/docs/configuration/). - ---- - -## Troubleshooting - -If you have a question about using Jekyll, start a discussion on the [Jekyll Forum](https://talk.jekyllrb.com/) or [StackOverflow](https://stackoverflow.com/questions/tagged/jekyll). Other resources: - -- [Ruby 101](https://jekyllrb.com/docs/ruby-101/) -- [Setting up a Jekyll site with GitHub Pages](https://jekyllrb.com/docs/github-pages/) -- [Configuring GitHub Metadata](https://github.com/jekyll/github-metadata/blob/master/docs/configuration.md#configuration) to work properly when developing locally and avoid `No GitHub API authentication could be found. Some fields may be missing or have incorrect data.` warnings. +# Simon's blog diff --git a/_posts/2024-29-03-multi-chip-performance.md b/_posts/2024-29-03-multi-chip-performance.md index fcaa419..83e62ce 100644 --- a/_posts/2024-29-03-multi-chip-performance.md +++ b/_posts/2024-29-03-multi-chip-performance.md @@ -12,14 +12,14 @@ In this blog post we will explain how to efficiently use Google's TPU. TPUs are In this tutorial we will take a simple layerwise matrix multiplication of activations with weights as our running example. The workload may be visualized like this: -![Layer-wise Matrix Multiplication](/assets/images/LayerwiseMatMul.png) +![Layer-wise Matrix Multiplication](/assets/multi_chip_processing/LayerwiseMatMul.png) In the above diagram the activations have shape `B*E x E` and the weights have shape `E x E`. The question is now how we can distribute this workload in an efficent way onto the different TPU chips. For the activations it's pretty obvious how we can distribute them onto different chips: Just put each batch onto one chip and then run the calculation for each batch independently, that is we multiply each batch with the weights matrix. This can be visualized as follows: -![Layer-wise Matrix Multiplication](/assets/images/LayerwiseMatMulSharded.png) +![Layer-wise Matrix Multiplication](/assets/multi_chip_processing/LayerwiseMatMulSharded.png) The different colors should visualize the fact that the activations are distributed batchwise over the different chips and the weights are copied onto all chips. In JAX we can accomplish distribution onto different chips as follows: ``` @@ -94,7 +94,7 @@ We see that we want to use the partition `p=PartitionSpec('axis', None)` for the So far so good but this still doesn't leverage the full power of having multiple chips. What if the weight matrices are very large- So large that we can't distribute all of them onto each chip? It turns out we can do the following: -![Layer-wise Matrix Multiplication](/assets/images/LayerwiseMatMulFullShard.png) +![Layer-wise Matrix Multiplication](/assets/multi_chip_processing/LayerwiseMatMulFullShard.png) What we see is that initially we distribute the weights also over all available chips. But for the calculation we need the weight for the current layer to be on all chips. How can this be archieved? It turns out the algorithm is quiete simple: @@ -142,7 +142,7 @@ print(f"Average time for forward pass: {average_time:.2f} ms") ``` For the above setting we archieved an average time of $39.82 ms$ on Googles `TPU-v4-8`. Let's look at the trace viewer to get more insight about how jax compiled the matmul function: -![Profiler](/assets/images/fdsp.png) +![Profiler](/assets/multi_chip_processing/fdsp.png) We see that JAX does exactly what we described above! Only the first all gather is performed for a "long" time. Afterwards the gathering process gets fused with the matrix multiplication which gives a huge speedup if we compare it to the naive approach that we would just apply all gathering after each matrix multiplication and at the same time it gives us the benefit that we can safe lots of memory by sharding most of the weights over all chips. I hope this post was insightful and you liked it. diff --git a/assets/images/LayerwiseMatMul.png b/assets/multi_chip_processing/LayerwiseMatMul.png similarity index 100% rename from assets/images/LayerwiseMatMul.png rename to assets/multi_chip_processing/LayerwiseMatMul.png diff --git a/assets/images/LayerwiseMatMulFullShard.png b/assets/multi_chip_processing/LayerwiseMatMulFullShard.png similarity index 100% rename from assets/images/LayerwiseMatMulFullShard.png rename to assets/multi_chip_processing/LayerwiseMatMulFullShard.png diff --git a/assets/images/LayerwiseMatMulSharded.png b/assets/multi_chip_processing/LayerwiseMatMulSharded.png similarity index 100% rename from assets/images/LayerwiseMatMulSharded.png rename to assets/multi_chip_processing/LayerwiseMatMulSharded.png diff --git a/assets/images/bio-photo.jpg b/assets/multi_chip_processing/bio-photo.jpg similarity index 100% rename from assets/images/bio-photo.jpg rename to assets/multi_chip_processing/bio-photo.jpg diff --git a/assets/images/fdsp.png b/assets/multi_chip_processing/fdsp.png similarity index 100% rename from assets/images/fdsp.png rename to assets/multi_chip_processing/fdsp.png