Skip to content

Commit

Permalink
Improvements and fixes to gradient accumulation
Browse files Browse the repository at this point in the history
- Fix to with_minibatch_steps decorator to generate correct primal outputs shapes.
- Improved with_minibatch_steps to take a minibatch_partitioner that contraints the input batch to the same PartitionSpec as Input Partitioner.
  • Loading branch information
apoorvtintin committed Feb 14, 2025
1 parent 31e8da0 commit 9b0f9a3
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 21 deletions.
72 changes: 56 additions & 16 deletions axlearn/common/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import jax
import numpy as np
from jax import numpy as jnp
from jax.sharding import PartitionSpec

from axlearn.common import utils
from axlearn.common.config import ConfigOr, maybe_instantiate
from axlearn.common.config import ConfigOr, config_for_function, maybe_instantiate
from axlearn.common.input_base import InputPartitionFn, partition_by_path_rank
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.update_transformation import ForwardFn, ForwardOutputs
from axlearn.common.utils import Nested, Tensor, input_partition_spec, with_sharding_constraint
from axlearn.common.utils import Nested, Tensor


def _compute_minibatch_size(input_batch: Nested[Tensor], *, steps: int) -> int:
Expand Down Expand Up @@ -57,39 +59,38 @@ def _make_scan_minibatch_inputs(
param_noise_key: Tensor,
minibatch_size: int,
minibatch_index: int,
minibatch_partitioner: Optional[InputPartitionFn],
) -> tuple[Nested[Tensor], Tensor, Tensor]:
"""Creates minibatch inputs from inputs.
This is a utility function that is only meant to be called from
within a scan function body and is meant to slice the inputs
into `minibatch_size` sized slices to run the ForwardFn on.
Note that this only preserves the input sharding if the `input_partition_spec`
returns the correct partition spec to shard the input slices with.
Args:
inputs: Same pytree as ForwardFn inputs.
forward_key: The `forward_key` from the ForwardFn inputs
param_noise_key: The `param_noise_key` from the ForwardFn inputs
minibatch_size: Size of the minibatch.
minibatch_index: Current scan minibatch index.
minibatch_partitioner: If not None, applies additional sharding constraints
on each minibatch created.
Returns:
A tuple of minibatch inputs which of the same structure as `inputs`
and new (carry) forward_key and param_noise_key.
"""
minibatch_input = with_sharding_constraint(
jax.tree.map(
lambda x: jax.lax.dynamic_slice_in_dim(
x,
start_index=minibatch_index * minibatch_size,
slice_size=minibatch_size,
axis=0,
),
inputs["input_batch"],
minibatch_input = jax.tree.map(
lambda x: jax.lax.dynamic_slice_in_dim(
x,
start_index=minibatch_index * minibatch_size,
slice_size=minibatch_size,
axis=0,
),
input_partition_spec(),
inputs["input_batch"],
)

minibatch_input = minibatch_partitioner(minibatch_input)
next_forward_key, forward_key = jax.random.split(forward_key)
next_param_noise_key, param_noise_key = jax.random.split(param_noise_key)

Expand All @@ -106,6 +107,7 @@ def with_minibatch_steps(
steps: int,
metric_accumulator: ConfigOr[MetricAccumulator],
grad_dtype: Optional[jnp.dtype] = None,
minibatch_partitioner: Optional[ConfigOr[InputPartitionFn]] = None,
) -> Callable[[ForwardFn], ForwardFn]:
"""Decorate a ForwardFn to accumulate gradients over minibatch steps.
Expand Down Expand Up @@ -134,16 +136,37 @@ def with_minibatch_steps(
TODO(cemkoc): Investigate the slight difference in loss curves when decorated.
A minibatch_partitioner is used to partition minibatch inputs to the original_func.
Note that if minibatch_partitioner is None, the default minibatch partitioner is used which
partitions the microbatch along (("data", "expert", "fsdp"), "seq"). Otherwise the
minibatch_partitioner passed in is used.
Args:
steps: Number of gradient accumulation steps.
metric_accumulator: A `MetricAccumulator` to accumulate minibatch summaries from the
forward output.
grad_dtype: Optional dtype to cast the grads back to after accumulating in fp32.
minibatch_partitioner: If not None, contains config for a partitioner that applies
additional sharding constraints on each minibatch created.
Returns:
Decorated ForwardFn.
"""

# Default partitioner for minibatches.
if not minibatch_partitioner:
minibatch_partitioner = (
config_for_function(partition_by_path_rank).set(
path_rank_to_partition={
# Note: the batch axes are different here than in
# `cfg.batch_axis_names`,
# as we partition sequence dim over `seq`.
(None, 1): PartitionSpec(("data", "expert", "fsdp")),
(None, 2): PartitionSpec(("data", "expert", "fsdp"), "seq"),
}
),
)

def decorator(fn: ForwardFn) -> ForwardFn:
# We define a positional arg only version of the original function
# that is passed because jax.value_and_grad does not accept
Expand Down Expand Up @@ -171,13 +194,29 @@ def fwd_helper(
and second is the accumulated grads (if `compute_grad` is True)
otherwise None.
"""
partitioner = maybe_instantiate(minibatch_partitioner)
minibatch_size = _compute_minibatch_size(inputs["input_batch"], steps=steps)

# Create a sample microbatch for the carry buffer creation below
(
sample_minibatch_inputs,
_,
_,
) = _make_scan_minibatch_inputs(
inputs,
forward_key=inputs["forward_key"],
param_noise_key=inputs["param_noise_key"],
minibatch_size=minibatch_size,
minibatch_index=0,
minibatch_partitioner=partitioner,
)

# Carry initialization for the lax.scan procedure. Since we are passing a
# `MetricAccumulator` into carry and carry input/output shapes must match
# we need initialize the `MetricAccumulator` summary with the right PyTree
# structure.
_, primal_output_shape = jax.eval_shape(
original_func_positional_args, model_params, inputs
original_func_positional_args, model_params, sample_minibatch_inputs
)
init_primal_out = jax.tree.map(jnp.zeros_like, primal_output_shape)
init_accumulator = maybe_instantiate(metric_accumulator)
Expand Down Expand Up @@ -213,6 +252,7 @@ def scan_body(
param_noise_key=param_noise_key,
minibatch_size=minibatch_size,
minibatch_index=minibatch_index,
minibatch_partitioner=partitioner,
)
minibatch_args = (model_params, minibatch_inputs)

Expand Down
8 changes: 7 additions & 1 deletion axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""Defines trainer config modifiers, which will be used in model definitions."""

from typing import Dict, Sequence, Union
from typing import Dict, Optional, Sequence, Union

from axlearn.common import config
from axlearn.common.base_layer import RematSpec
Expand All @@ -16,6 +16,7 @@
maybe_instantiate,
)
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.input_base import InputPartitionFn
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec
Expand All @@ -29,18 +30,22 @@ class Config(ConfigModifier.Config):
"""Configure GradientAccumulationModifier.
Attributes:
grad_acc_steps: The number of steps to accumulate the gradients from mini-batches.
grad_acc_steps: The number of steps to accumulate the gradients from mini-batches.
metric_accumulator: The metric accumulator to export the metrics.
minibatch_partitioner: Constraints the minibatch to a PartitionSpec.
"""

grad_acc_steps: Required[int] = REQUIRED
metric_accumulator: MetricAccumulator.Config = MetricAccumulator.default_config()
minibatch_partitioner: Optional[ConfigOr[InputPartitionFn]] = None

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._grad_acc_steps = cfg.grad_acc_steps
self._metric_accumulator = cfg.metric_accumulator
self._minibatch_partitioner = cfg.minibatch_partitioner

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the forward_fn_transformation to accumulate gradients for grad_acc_steps steps.
Expand All @@ -63,6 +68,7 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
).set(
steps=self._grad_acc_steps,
metric_accumulator=self._metric_accumulator,
minibatch_partitioner=self._minibatch_partitioner,
)
return cfg

Expand Down
11 changes: 7 additions & 4 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,18 +326,18 @@ def get_trainer_kwargs(
elif model_size == "3B":
trainer_kwargs = dict(
model_kwargs=dict(
num_layers=28,
num_layers=1,
hidden_dim=3072,
num_heads=24,
num_heads=8,
num_kv_heads=num_kv_heads,
ffn_dim=8192,
ffn_dim=512,
rope_theta=rope_theta,
shared_lm_head=True,
flash_attention=flash_attention,
),
learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
train_batch_size=16,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
Expand All @@ -352,6 +352,9 @@ def get_trainer_kwargs(
),
*trn2_config.module_modifications,
*trn2_config.partition_spec_modifications,
GradientAccumulationModifier.default_config().set(
grad_acc_steps=4,
),
],
),
),
Expand Down

0 comments on commit 9b0f9a3

Please sign in to comment.