The sharding utilities in scalax are contained in the scalax.sharding
submodule.
scalax.sharding
provides a set of utilities to help the users managing the accelerator
device mesh and shard the model and computations. The main object for managing
device mesh and sharding is scalax.sharding.MeshShardingHelper
and its sjit
method. This can be composed with various ShardingRule
object to specify the
shardings of computation.
MeshShardingHelper
is the main object for managing device mesh and sharding.
It mangages the device mesh and provides a convenient interface to JAX's JIT
compiler.
axis_dims
: A list or tuple of integers representing the shape of the device mesh. The length of the list or tuple should be the same as the number of dimensions of the device mesh. Each integer represents the number of devices along the corresponding axis. One of the integers can be -1, which means that the size of mesh along that axis is inferred from the number of devices.axis_names
: A list or tuple of strings representing the names of the axes of the device mesh. It should have the same length asaxis_dims
.mesh_axis_splitting
: Default toFalse
. On certain platforms like TPU, devices are arranged in a physical mesh with certain dimensions. This parameter controls whether splitting a phyisical mesh dimension into multiple logical mesh dimensions is allowed. For example, on TPU pod with physical topology 4x4x4, it would be impossible to constract 8x8 logical mesh without splitting the physical mesh dimensions. However, please note that splitting a physical mesh dimension may lead to degraded communication bandwidth between devices.
sjit
is a wrapper around jax.jit
, with additional support for ShardingRule
and sharding annotations in addition to the usual PartitionSpecs
.
Parameters
fun
: The function to be compiled.in_shardings
: A tuple ofShardingRule
,PartitionSpecs
objects orNone
representing the sharding annotations for the input arguments of the function.None
means that the sharding is inferred by XLA.out_shardings
: A tuple ofShardingRule
,PartitionSpecs
objects orNone
representing the sharding annotations for the output of the function.None
means that the sharding is inferred by XLA.static_argnums
: A tuple of integers representing the argument indices that should be treated as static arguments. Note that static arguments are excluded fromin_shardings
andargs_sharding_constraint
.args_sharding_constraint
: A tuple ofShardingRule
,PartitionSpecs
objects orNone
representing the sharding constraints for the arguments after the function begines. This is useful if we want to input data into the function in one sharding, and then reshard it inside the function.None
means that the sharding is inferred by XLA.annotation_shardings
: A dictionary mapping annotation names toShardingRule
orPartitionSpecs
objects. This is used to specify the shardings for annotations specificied usingscalax.sharding.with_sharding_annotation
.
Returns
- A compiled function.
match_sharding_rule
is a method to apply a sharding rule to a pytree under the
current mesh.
Parameters
sharding_rules
: AShardingRule
orPartitionSpec
object to be applied to the pytree. It can also be a pytree ofShardingRule
orPartitionSpec
objects.pytree
: The pytree to be sharded.
Returns
- A pytree of concrete
NamedSharding
objects, which has the same structure as the input pytree.
local_data_to_global_array
is a method to convert host local data to global
jax.Array
. This is useful for data loading. We assume that each host loads a
portion of the data, and the data should be concatenated along a given axis to
form a global batch array.
Parameters
pytree
: The pytree of local data.batch_axis
: The axis along which the data should be concatenated. Default to 0.mesh_axis_subset
: A list of mesh names representing the subset of dimensions the data should be sharded agaist. Default toNone
, which means that the data is sharded against all dimensions.
Returns
- A global
jax.Array
or pytree of globaljax.Array
representing the global batch.
Similar to jax.lax.with_sharding_constraint
, but takes ShardingRule
and
PartitionSpec
objects. This functions enforces a sharding constraint inside
a sjit'ed function.
Parameters
pytree
: The pytree to be sharded.sharding_rule
: AShardingRule
orPartitionSpec
object to be applied to the pytree. It can also be a pytree ofShardingRule
orPartitionSpec
objects.
Returns
- The sharded pytree, with the same structure as the input pytree.
This function provides a named sharding annotation for a pytree. The concrete
sharding for this annotated pytree can be then specified using the annotation_shardings
parameter of sjit
.
Paramters
pytree
: The pytree to be annotated.sharding_name
: The string name of the sharding annotation.
Returns
- The sharded pytree, with the same structure as the input pytree.
ShardingRule
is a base class for sharding rules. It provides a apply
method
to apply the sharding rule to a pytree.
FSDPShardingRule
is a sharding rule for Fully Sharded Data Parallelism (FSDP).
It analyzes the shape of the tensors in the pytree and finds a suitable axis for
FSDP sharding.
fsdp_axis_name
: The name of the FSDP axisfsdp_axis_size
: The size of the FSDP axis.FSDPShardingRule
will find an axis of a tensor which can be divided byfsdp_axis_size
and use that as the FSDP axis. If it is None,FSDPShardingRule
will find the axis that has the largest power of 2 divisor and use that as the FSDP axis.min_fsdp_size
: The minimum size of tensors to be sharded. If the size of a tensor is smaller thanmin_fsdp_size
, it will not be sharded.
Applies FSDP sharding rule to a the pytree.
Parameters
pytree
: The pytree to be sharded.
Returns
- A pytree of
NamedSharding
objects, which has the same structure as the input pytree.
TreePathShardingRule
is a sharding rule for sharding a pytree using the tree
path of its leaves. It is a flexible sharding rule that allows the user to specify a combination of different shardings.
*rules
: regex sharding rules to be applied to the leaves of the pytree. Each sharding rule should be a pair of a regex pattern and aPartitionSpec
object.TreePathShardingRule
will iterate through the rules in order and apply thePartitionSpec
with the matching pattern.strict
: Default toTrue
. IfTrue
,TreePathShardingRule
will raise an error if there is a leaf in the pytree that does not match any of the patterns. IfFalse
,TreePathShardingRule
will simply provide a replicatedPartitionSpec
for the unmatched leaf.
Applies the sharding rule to the pytree.
Parameters
pytree
: The pytree to be sharded.
Returns
- A pytree of
NamedSharding
objects, which has the same structure as the input pytree.
PolicyShardingRule
is a sharding rule for sharding a pytree using a user-defined
policy function.
policy
: The policy function to be used for sharding. The policy function should take the pytree path and the leaf tensor and return aPartitionSpec
object.
Applies the sharding rule to the pytree.
Parameters
pytree
: The pytree to be sharded.
Returns
- A pytree of
NamedSharding
objects, which has the same structure as the input pytree.