Skip to content

Commit

Permalink
Small change in wording
Browse files Browse the repository at this point in the history
  • Loading branch information
simveit committed Mar 29, 2024
1 parent fd5301b commit dd42a71
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion _posts/2024-03-29-multi-chip-performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ average_time = simple_timeit(matmul, ACTIVATION, WEIGHTS, task="matmul")
print(f"Average time for forward pass: {average_time:.2f} ms")
```
For the above setting we archieved an average time of $39.82 ms$ on Googles `TPU-v4-8`.
For the above setting we archieved an average time of $39.82 ms$ on Googles `TPU-v4-8` (that is a TPU with 8/2=4 chips).
Let's look at the trace viewer to get more insight about how jax compiled the matmul function:
![Profiler](/assets/multi_chip_processing/fdsp.png)
We see that JAX does exactly what we described above! Only the first all gather is performed for a "long" time. Afterwards the gathering process gets fused with the matrix multiplication which gives a huge speedup if we compare it to the naive approach that we would just apply all gathering after each matrix multiplication and at the same time it gives us the benefit that we can safe lots of memory by sharding most of the weights over all chips.
Expand Down

0 comments on commit dd42a71

Please sign in to comment.