Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to handle dropout rng_key in model #6

Open
YunxiTang opened this issue Nov 10, 2024 · 3 comments
Open

How to handle dropout rng_key in model #6

YunxiTang opened this issue Nov 10, 2024 · 3 comments

Comments

@YunxiTang
Copy link

YunxiTang commented Nov 10, 2024

Hi,
Thanks for this nice work. Is there any efficient way to use different dropout rng key for each device in data parallelism if a model has a dropout layer?
Thanks!

@young-geng
Copy link
Owner

I think in most cases you won't need to manually use different dropout keys. Instead, you could just have one dropout RNG key as replicated and pass it to the layer, as if you are doing it on the single device. Since all arrays in JAX are global, this makes sure devices holding different parts of a layer sample different dropout masks.

@YunxiTang
Copy link
Author

In data parallelism, we often replicate a model on multiple devices, and split a global batch into several sub-batches for model training. If the model has dropout layer, I think it will be better to use different dropout rng key for each device?

@young-geng
Copy link
Owner

In JAX, shardings specifies how tensors are computed and stored across device, but not what the result of computation should be. Therefore, the nice thing about JAX is that the correctness of computation result is independent of the sharding used. In this case, imagine you are running a large batch on a single device, and you only need one RNG key. That RNG key will generate different masks for different examples in the batch. Now once you specify a data parallel sharding, the results would be the same as running on a single device, except stored and computed on multiple devices.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants