diff --git a/shardlib/shardtypes.py b/shardlib/shardtypes.py index cd9ad7e..757681f 100644 --- a/shardlib/shardtypes.py +++ b/shardlib/shardtypes.py @@ -47,13 +47,14 @@ def center_channels(x: f32[b'batch/d channels']) -> f32[b'batch/d channels']: from contextvars import ContextVar from enum import IntEnum from typing import Any, Union +from typing import get_args, get_origin from typeguard import check_type_internal, typechecked import jax import jax.numpy as jnp from types import GenericAlias from typeguard import TypeCheckError, TypeCheckerCallable import dataclasses -from dataclasses import dataclass +from dataclasses import dataclass, make_dataclass from typeguard import checker_lookup_functions @@ -286,6 +287,33 @@ def unflatten(_aux, fields): _PYTREE_DATACLASSES.add(cls) return cls +class Array: + """If `cls` is an array type or a `pytree_dataclass` of array types, + `Array[axes, cls]` will extend `cls` with leading axes `axes`. + For example, `Array['layers', f32['batch d_model']] returns f32['layers batch d_model`]`. + """ + def __class_getitem__(cls, x): + axes, input_cls = x + if isinstance(axes, str): + axes = axes.encode('utf-8') + elif isinstance(axes, bytes): + pass + else: + raise ValueError(f"input axes to {cls} must be Union[bytes, str]") + + if dataclasses.is_dataclass(input_cls): + extended_fields = [] + for fld in dataclasses.fields(input_cls): + extended_type = Array[axes, fld.type] + extended_fields.append((fld.name, extended_type)) + + extended_cls = make_dataclass(input_cls.__name__, extended_fields, bases=(input_cls,)) + pytree_dataclass(extended_cls) + return extended_cls + else: + number_type, shape = get_origin(input_cls), get_args(input_cls) + extended_shape = (axes + b' ' + shape[0],) + return GenericAlias(number_type, extended_shape) def make_partition_specs(cls): """Instantiates a pytree dataclass with a PartitionSpec at array type.""" diff --git a/train.py b/train.py index cc0a27b..efb8ade 100644 --- a/train.py +++ b/train.py @@ -21,7 +21,7 @@ import jax.numpy as jnp import math from input_loader import FlatTokensParams, HuggingFaceDataParams, TokenBatch, TokenBatchParams, get_loader -from shardlib.shardtypes import bf16, bool_, f32, pytree_dataclass, u32, make_shardings +from shardlib.shardtypes import bf16, bool_, f32, pytree_dataclass, u32, make_shardings, Array import shardlib.shardops as shardops P = PartitionSpec import einops @@ -47,19 +47,24 @@ class Hparams: d_ff: int rope_max_timescale: int +@pytree_dataclass +class TransformerLayer: + ln1: f32['d_model/t/d'] + ln2: f32['d_model/t/d'] + w_q: f32['d_model/d n_q_per_kv n_kv/t d_head'] + w_kv: f32['2 d_model/d n_kv/t d_head'] + w_o: f32['d_model/d n_q_per_kv n_kv/t d_head'] + w_gate: f32['d_model/d d_ff/t'] + w_up: f32['d_model/d d_ff/t'] + w_down: f32['d_model/d d_ff/t'] + +Transformer = Array['layers', TransformerLayer] @pytree_dataclass class Model: embed: f32['vocab/t d_model/d'] unembed: f32['vocab/t d_model/d'] - ln1: f32['layers d_model/t/d'] - ln2: f32['layers d_model/t/d'] - w_q: f32['layers d_model/d n_q_per_kv n_kv/t d_head'] - w_kv: f32['layers 2 d_model/d n_kv/t d_head'] - w_o: f32['layers d_model/d n_q_per_kv n_kv/t d_head'] - w_gate: f32['layers d_model/d d_ff/t'] - w_up: f32['layers d_model/d d_ff/t'] - w_down: f32['layers d_model/d d_ff/t'] + transformer: Transformer final_layer_norm: f32['d_model/d/t'] @staticmethod @@ -105,14 +110,16 @@ def init(h: Hparams, rng: PRNGKey) -> 'Model': arrays = Model( embed=embed, unembed=unembed, - ln1=ln1, - ln2=ln2, - w_q=w_q, - w_kv=w_kv, - w_o=w_o, - w_gate=w_gate, - w_up=w_up, - w_down=w_down, + transformer=Transformer( + ln1=ln1, + ln2=ln2, + w_q=w_q, + w_kv=w_kv, + w_o=w_o, + w_gate=w_gate, + w_up=w_up, + w_down=w_down, + ), final_layer_norm=final_layer_norm, ) shardings = make_shardings(Model) @@ -138,19 +145,17 @@ def forward_pass(self, h: Hparams, ids: u32[b'B/d L'], is_seq_start: bool_[b'B/d ##### Transformer blocks. @explicit_activation_checkpointing @typechecked - def loop_body(x: bf16[b'B/d L M/t'], layer_weights: Any) -> Tuple[bf16[b'B/d L M/t'], Tuple[()]]: - w_q, w_kv, w_o, w_gate, w_up, w_down, ln1, ln2 = layer_weights - + def loop_body(x: bf16[b'B/d L M/t'], layer_weights: TransformerLayer) -> Tuple[bf16[b'B/d L M/t'], Tuple[()]]: # Pre-attention RMSNorm - ln1 = shardops.all_gather('M/t/d -> M', jnp.float32(ln1)) + ln1 = shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln1)) gx = shardops.all_gather('B/d L M/t -> B/d L M', x) nx = jnp.bfloat16(rms_norm(gx) * ln1) # Attention, using Grouped Query Attention and RoPE position embeddings. - w_q = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(w_q)) + w_q = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_q)) q = save_for_backward(shardops.einsum_unreduced('B/d L M, M Q K/t D -> B/d L Q K/t D', nx, w_q)) q = rope_table.apply('L D -> 1 L 1 1 D', q) - w_kv = shardops.all_gather('2 M/d K/t D -> 2 M K/t D', jnp.bfloat16(w_kv)) + w_kv = shardops.all_gather('2 M/d K/t D -> 2 M K/t D', jnp.bfloat16(layer_weights.w_kv)) k, v = shardops.einsum_unreduced('B/d L M, k_v M K/t D -> k_v B/d L K/t D', nx, w_kv) k = save_for_backward(k) v = save_for_backward(v) @@ -161,29 +166,29 @@ def loop_body(x: bf16[b'B/d L M/t'], layer_weights: Any) -> Tuple[bf16[b'B/d L M probs = jnp.bfloat16(jax.nn.softmax(logits, axis=2)) attn_out = shardops.einsum_unreduced( 'B/d Qlen Klen Q K/t, B/d Klen K/t D -> B/d Qlen Q K/t D', probs, v) - w_o = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(w_o)) + w_o = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_o)) attn_out = shardops.einsum_unreduced('B/d Qlen Q K/t D, M Q K/t D -> B/d Qlen M', attn_out, w_o) attn_out = shardops.psum_scatter('B/d Qlen M -> B/d Qlen M/t', attn_out) x = save_for_backward(x + attn_out) # Pre-FFN RMSNorm - ln2 = save_for_backward(shardops.all_gather('M/t/d -> M', jnp.float32(ln2))) + ln2 = save_for_backward(shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln2))) gx = shardops.all_gather('B/d L M/t -> B/d L M', x) nx = jnp.bfloat16(rms_norm(gx) * ln2) # FFN, using SwiGLU - w_gate = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(w_gate)) + w_gate = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_gate)) gate_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_gate)) - w_up = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(w_up)) + w_up = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_up)) up_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_up)) y = jax.nn.swish(gate_proj) * up_proj - w_down = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(w_down)) + w_down = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_down)) ffn_out = shardops.einsum_unreduced('B/d L F/t, M F/t -> B/d L M', y, w_down) ffn_out = shardops.psum_scatter('B/d L M -> B/d L M/t', ffn_out) return jnp.bfloat16(x + ffn_out), () - x, () = jax.lax.scan(loop_body, jnp.bfloat16(x), (self.w_q, self.w_kv, self.w_o, self.w_gate, self.w_up, self.w_down, self.ln1, self.ln2)) + x, () = jax.lax.scan(loop_body, jnp.bfloat16(x), self.transformer) ##### Final layernorm and output projection. x = shardops.all_gather('B/d L M/t -> B/d L M', x)