Skip to content

Commit

Permalink
Simplifies logical dispatch by direct global array construction. (#1040)
Browse files Browse the repository at this point in the history
* Simplifies logical dispatch by direct global array construction.

* Address comments.

* Fixes non-compliant testing inputs.
  • Loading branch information
markblee authored Mar 8, 2025
1 parent 771926e commit 2673462
Show file tree
Hide file tree
Showing 15 changed files with 588 additions and 125 deletions.
1 change: 1 addition & 0 deletions axlearn/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,7 @@ def __init__(self, cfg):

@property
def config(self: C) -> Config[C]:
# TODO(markblee): Consider supporting copy-on-write behavior.
return copy.deepcopy(self._config)

def __repr__(self):
Expand Down
22 changes: 18 additions & 4 deletions axlearn/common/evaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from absl import logging
from jax import numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec

from axlearn.common import input_base, struct, summary_writer, utils
from axlearn.common.base_model import BaseModel
Expand Down Expand Up @@ -188,11 +189,11 @@ def _pjit(self, fn: Callable) -> Callable:
in_shardings=(
self._model_param_partition_specs, # model_params.
None, # replicated_inputs (e.g., prng_key).
utils.input_partition_spec(), # per_example_inputs.
self._input_partition_spec(), # per_example_inputs.
),
out_shardings=dict(
replicated=None,
per_example=utils.input_partition_spec(),
per_example=self._input_partition_spec(),
),
)

Expand Down Expand Up @@ -240,6 +241,14 @@ def _call_model(
is_training=False,
)

def _input_partition_spec(self) -> PartitionSpec:
module = self.parent
while module is not None and not isinstance(module, SpmdEvaler):
module = module.parent
if module is not None and hasattr(module.input, "partition_spec"):
return module.input.partition_spec
return utils.input_partition_spec()

def _dispatch_global_batch(self, input_batch: NestedTensor) -> NestedTensor:
module = self.parent
while module is not None and not isinstance(module, SpmdEvaler):
Expand Down Expand Up @@ -592,7 +601,9 @@ def __init__(
if cfg.eval_dtype is not None:
utils.validate_float_dtype(cfg.eval_dtype)

self._add_child("input", maybe_set_config(cfg.input, is_training=False))
self.input: input_base.Input = self._add_child( # pytype: disable=annotation-type-mismatch
"input", maybe_set_config(cfg.input, is_training=False)
)
self._add_child(
"metric_calculator",
cfg.metric_calculator.set(eval_dtype=cfg.eval_dtype),
Expand Down Expand Up @@ -691,7 +702,10 @@ def eval_step(

with jax.profiler.StepTraceAnnotation(cfg.name, step_num=step):
with jax.profiler.TraceAnnotation(f"{cfg.name}.forward"):
global_input_batch = utils.host_to_global_device_array(input_batch)
global_input_batch = utils.host_to_global_device_array(
input_batch,
partition=self.input.partition_spec,
)
forward_outputs = self.metric_calculator.forward(
global_input_batch,
model_params=model_params,
Expand Down
5 changes: 3 additions & 2 deletions axlearn/common/evaler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
OutputRecordWriter,
TfExampleRecordSink,
)
from axlearn.common.input_base import Input
from axlearn.common.layers import Linear
from axlearn.common.metrics import WeightedScalar
from axlearn.common.module import REQUIRED, Module, OutputCollection, Required
Expand All @@ -63,11 +64,11 @@
]


class DummyInput(Module):
class DummyInput(Input):
"""A dummy input."""

@config_class
class Config(Module.Config):
class Config(Input.Config):
"""Configures DummyInput."""

is_training: Required[bool] = REQUIRED
Expand Down
4 changes: 3 additions & 1 deletion axlearn/common/host_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def _is_supported(*, platform: str, mesh_shape: MeshShape) -> bool:
def _ordered_devices(mesh_shape: MeshShape, process_shape: MeshShape) -> np.ndarray:
"""Returns devices of shape `mesh_shape` with consistent host ordering.
process_shape indicates how the hosts should be laid out.
`process_shape` indicates how the hosts should be laid out. For example, if `mesh_shape` is
(4,4) and `process_shape` is (2,2), the top-left quadrant will be assigned device IDs from
process 0, the top-right quadrant from process 1, etc.
"""
assert len(mesh_shape) == len(process_shape), "ndim should match"

Expand Down
83 changes: 63 additions & 20 deletions axlearn/common/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@

"""Base Input interface."""

import math
import re
from typing import Iterable, Iterator, NamedTuple, Optional, Protocol, Sequence, Union
from typing import Iterable, Iterator, NamedTuple, Optional, Protocol, Union

import jax
from absl import logging
from jax._src.mesh import thread_resources
from jax.sharding import PartitionSpec

from axlearn.common.config import ConfigOr, config_class, maybe_instantiate
from axlearn.common.input_dispatch import InputDispatcher
from axlearn.common.config import ConfigOr, config_class, maybe_instantiate, maybe_set_config
from axlearn.common.input_dispatch import BaseInputDispatcher, InputDispatcher
from axlearn.common.module import Module
from axlearn.common.utils import (
Nested,
Tensor,
as_numpy_array,
dispatch_input_batch,
input_partition_spec,
tree_paths,
with_sharding_constraint,
)
Expand Down Expand Up @@ -48,7 +50,7 @@ class PathAndRank(NamedTuple):
def partition_by_path_rank(
path_rank_to_partition: dict[PathAndRank, PartitionSpec],
) -> InputPartitionFn:
"""Partitions the keys in the input batch by Tensor path and rank (ndim).
"""Partitions the paths in the input batch by regex and rank (ndim).
If not within a mesh, the partition fn is a no-op.
Expand Down Expand Up @@ -139,7 +141,9 @@ def train_step(global_physical_batch):
...
for per_feed_physical_batch in input.batches(input_iter):
global_physical_batch = host_to_global_device_array(per_feed_physical_batch)
global_physical_batch = host_to_global_device_array(
per_feed_physical_batch, partition=input.partition_spec
)
... = pjit(train_step)(global_physical_batch)
```
"""
Expand All @@ -149,21 +153,34 @@ class Config(Module.Config):
"""Configures Input.
Attributes:
partition_spec: If not None, configures the partition specs for the input batch used in
`host_to_global_device_array` and `jit`. Note that these specs may be different from
those constrained by `input_partitioner`, as they depend on the host-local shapes of
each input feed. For example, it is common to first form global batches from
uniformly batch-sharded host-local arrays by only configuring the batch axes of
`partition_spec`, and then further partition the batches within `jit` via
`input_partitioner`.
If None, defaults to `input_partition_spec()`.
input_dispatcher: If not None, creates an InputDispatcher and uses it for dispatching
per-feed batches to global batches.
input_partitioner: If not None, applies additional sharding constraints on each input
batch during `dispatch_global_batch`.
"""

partition_spec: Optional[PartitionSpec] = None
input_dispatcher: Optional[InputDispatcher.Config] = None
input_partitioner: Optional[ConfigOr[InputPartitionFn]] = None

def __init__(self, cfg: Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)
cfg = self.config
self._partition_spec = cfg.partition_spec or input_partition_spec()
if cfg.input_dispatcher is not None:
self.input_dispatcher: InputDispatcher = ( # pytype: disable=annotation-type-mismatch
self._add_child("input_dispatcher", cfg.input_dispatcher)
self.input_dispatcher: BaseInputDispatcher = (
self._add_child( # pytype: disable=annotation-type-mismatch
"input_dispatcher",
maybe_set_config(cfg.input_dispatcher, partition_spec=cfg.partition_spec),
)
)
self._input_partitioner: Optional[InputPartitionFn] = maybe_instantiate(
cfg.input_partitioner
Expand Down Expand Up @@ -224,12 +241,7 @@ def check_per_feed_batch(x: Tensor):
input_batch = self.input_dispatcher.logical_to_physical_batch(input_batch)
yield input_batch

def dispatch_global_batch(
self,
global_physical_batch: Nested[Tensor],
*,
batch_axis_names: Union[str, Sequence[str]] = "data",
) -> Nested[Tensor]:
def dispatch_global_batch(self, global_physical_batch: Nested[Tensor]) -> Nested[Tensor]:
"""Converts a global physical batch to a global logical batch.
The leaves of the output logical batch are partitioned across `batch_axis_names` along the
Expand All @@ -240,22 +252,39 @@ def dispatch_global_batch(
constraining `batch_axis_names`.
"""

def constrain_batch_axis(batch):
return jax.tree.map(
lambda x: with_sharding_constraint(x, PartitionSpec(batch_axis_names)),
batch,
def constrain_batch_axis(path: str, value: Tensor):
mesh = thread_resources.env.physical_mesh
batch_partitions = math.prod(
mesh.shape[axis] for axis in jax.tree.leaves(self._partition_spec[0])
)
# Warn if an invalid constraint is applied, since by default this can silently be
# ignored, potentially leading to unexpected OOMs.
if value.shape[0] % batch_partitions != 0:
logging.warning(
"Attempting to constrain path=%s (with batch dim %d) over %d partitions (%s).",
path,
value.shape[0],
batch_partitions,
self._partition_spec,
)
return with_sharding_constraint(value, self._partition_spec)

if "input_dispatcher" in self.children:
global_logical_batch = self.input_dispatcher.physical_to_logical_batch(
constrain_batch_axis(global_physical_batch)
jax.tree.map(
constrain_batch_axis,
tree_paths(global_physical_batch),
global_physical_batch,
)
)
else:
global_logical_batch = dispatch_input_batch(
global_physical_batch, batch_axis_names=batch_axis_names
global_physical_batch, batch_axis_names=self._partition_spec[0]
)

global_logical_batch = constrain_batch_axis(global_logical_batch)
global_logical_batch = jax.tree.map(
constrain_batch_axis, tree_paths(global_logical_batch), global_logical_batch
)

# Further constrain based on user-configured partitioning rules.
if self._input_partitioner is not None:
Expand All @@ -269,3 +298,17 @@ def element_spec(self) -> Nested[jax.ShapeDtypeStruct]:
This is used e.g. for AOT compilation and is not strictly required for training.
"""
raise NotImplementedError(type(self))

@property
def partition_spec(self) -> PartitionSpec:
"""Returns the input partition spec for `host_to_global_device_array` and for `jit`.
Depending on the dispatch implementation, it may be possible to directly form the global
logical batch from feed logical batches via `host_to_global_device_array`. In these cases,
we can use an input partition spec that follows `cfg.partition_spec`.
In all other cases we default to `input_partition_spec()`.
"""
if "input_dispatcher" in self.children:
return self.input_dispatcher.partition_spec
return input_partition_spec()
5 changes: 4 additions & 1 deletion axlearn/common/input_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def test_dispatch_global_batch(self):
"target_labels": jnp.ones((batch_size, seq_len), dtype=jnp.int32),
"target_num_bytes": jnp.ones((batch_size,), dtype=jnp.int32),
}
input_cfg = Input.default_config().set(name="test")
input_cfg = Input.default_config().set(
name="test",
partition_spec=PartitionSpec("data"),
)

with jax.sharding.Mesh(
np.array(jax.devices()).reshape(4, 2)[..., None],
Expand Down
Loading

0 comments on commit 2673462

Please sign in to comment.