-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to handle dropout rng_key in model #6
Comments
I think in most cases you won't need to manually use different dropout keys. Instead, you could just have one dropout RNG key as replicated and pass it to the layer, as if you are doing it on the single device. Since all arrays in JAX are global, this makes sure devices holding different parts of a layer sample different dropout masks. |
In data parallelism, we often replicate a model on multiple devices, and split a global batch into several sub-batches for model training. If the model has dropout layer, I think it will be better to use different dropout rng key for each device? |
In JAX, shardings specifies how tensors are computed and stored across device, but not what the result of computation should be. Therefore, the nice thing about JAX is that the correctness of computation result is independent of the sharding used. In this case, imagine you are running a large batch on a single device, and you only need one RNG key. That RNG key will generate different masks for different examples in the batch. Now once you specify a data parallel sharding, the results would be the same as running on a single device, except stored and computed on multiple devices. |
Hi,
Thanks for this nice work. Is there any efficient way to use different dropout rng key for each device in data parallelism if a model has a dropout layer?
Thanks!
The text was updated successfully, but these errors were encountered: