Skip to content

Commit

Permalink
Small addidtion
Browse files Browse the repository at this point in the history
  • Loading branch information
simveit committed Mar 29, 2024
1 parent b39d597 commit c08c17d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion _posts/2024-03-29-multi-chip-performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ What we see is that initially we distribute the weights also over all available
But for the calculation we need the weight for the current layer to be on all chips. How can this be archieved?
It turns out the algorithm is quiete simple:
Let `L_i, A_i, W_i` be i-th layer, activation and weight.
While calculating `L_{i+1}`, i.e. multiplying `A_i` with `W_i` we have `W_i` ungathered (i.e. distributed over all chips). At the same time we ungather `W_{i+1}`. If this process is faster than the matrix multiplication we only need to keep 2 weights unsharded instead of `N_layer` weights!
While calculating `L_{i+1}`, i.e. multiplying `A_i` with `W_i` we have `W_i` ungathered (i.e. distributed over all chips). At the same time we ungather `W_{i+1}`. If this process is faster than the matrix multiplication we only need to keep 2 weights unsharded instead of `N_layer` weights while don't decrease performance!
Let's see how we can implement that in JAX:
```
import jax
Expand Down

0 comments on commit c08c17d

Please sign in to comment.