diff --git a/_posts/2024-03-29-multi-chip-performance.md b/_posts/2024-03-29-multi-chip-performance.md index 8cfd52d..6dace31 100644 --- a/_posts/2024-03-29-multi-chip-performance.md +++ b/_posts/2024-03-29-multi-chip-performance.md @@ -99,7 +99,7 @@ What we see is that initially we distribute the weights also over all available But for the calculation we need the weight for the current layer to be on all chips. How can this be archieved? It turns out the algorithm is quiete simple: Let `L_i, A_i, W_i` be i-th layer, activation and weight. -While calculating `L_{i+1}`, i.e. multiplying `A_i` with `W_i` we have `W_i` ungathered (i.e. distributed over all chips). At the same time we ungather `W_{i+1}`. If this process is faster than the matrix multiplication we only need to keep 2 weights unsharded instead of `N_layer` weights! +While calculating `L_{i+1}`, i.e. multiplying `A_i` with `W_i` we have `W_i` ungathered (i.e. distributed over all chips). At the same time we ungather `W_{i+1}`. If this process is faster than the matrix multiplication we only need to keep 2 weights unsharded instead of `N_layer` weights while don't decrease performance! Let's see how we can implement that in JAX: ``` import jax