Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discussion : What is the functionality of pygrain.ShardOption, like pygrain.ShardByJaxProcess? #154

Open
carlesoctav opened this issue Jan 22, 2025 · 0 comments

Comments

@carlesoctav
Copy link

carlesoctav commented Jan 22, 2025

Discussion : What is the functionality of pygrain.ShardOption, like pygrain.ShardByJaxProcess?

From what I understand, for each process, it loads a different batch. At the end of the process, we need to use jax.make_array_from_process_local_data to combine the batches into a single global batch with global sharding (across multiple devices in multiple processes), especially if the model is also sharded across global_devices.

Correct me if I'm wrong, but with this approach, we still need to load the entire dataset for each process. This is where I’m confused— There is another way to handle this.

On the other hand, since we load all the data for each process, we can shard the array across global_devices as long as the batch is consistent across processes.

Even when I want to load different datasets for each process (for example, when I have multiple files like data_#number.json), I think we don’t need to use pygrain.ShardOption. Instead, we can load the batch and create a global batch using jax.make_array_from_process_local_data.

I feel like pygrain.ShardOption is inspired by torch.DistributedSampler, but in the case of Torch, it makes sense because the API for sharding the array is not really a well known one.

is there any scenario where using pygrain.shardoption in dataloader is the only way?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant