You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Discussion : What is the functionality of pygrain.ShardOption, like pygrain.ShardByJaxProcess?
From what I understand, for each process, it loads a different batch. At the end of the process, we need to use jax.make_array_from_process_local_data to combine the batches into a single global batch with global sharding (across multiple devices in multiple processes), especially if the model is also sharded across global_devices.
Correct me if I'm wrong, but with this approach, we still need to load the entire dataset for each process. This is where I’m confused— There is another way to handle this.
On the other hand, since we load all the data for each process, we can shard the array across global_devices as long as the batch is consistent across processes.
Even when I want to load different datasets for each process (for example, when I have multiple files like data_#number.json), I think we don’t need to use pygrain.ShardOption. Instead, we can load the batch and create a global batch using jax.make_array_from_process_local_data.
I feel like pygrain.ShardOption is inspired by torch.DistributedSampler, but in the case of Torch, it makes sense because the API for sharding the array is not really a well known one.
is there any scenario where using pygrain.shardoption in dataloader is the only way?
The text was updated successfully, but these errors were encountered:
Discussion : What is the functionality of
pygrain.ShardOption
, likepygrain.ShardByJaxProcess
?From what I understand, for each process, it loads a different batch. At the end of the process, we need to use
jax.make_array_from_process_local_data
to combine the batches into a single global batch with global sharding (across multiple devices in multiple processes), especially if the model is also sharded across global_devices.Correct me if I'm wrong, but with this approach, we still need to load the entire dataset for each process. This is where I’m confused— There is another way to handle this.
On the other hand, since we load all the data for each process, we can shard the array across global_devices as long as the batch is consistent across processes.
Even when I want to load different datasets for each process (for example, when I have multiple files like
data_#number.json
), I think we don’t need to usepygrain.ShardOption
. Instead, we can load the batch and create a global batch usingjax.make_array_from_process_local_data
.I feel like
pygrain.ShardOption
is inspired bytorch.DistributedSampler
, but in the case of Torch, it makes sense because the API for sharding the array is not really a well known one.is there any scenario where using pygrain.shardoption in dataloader is the only way?
The text was updated successfully, but these errors were encountered: