Skip to content

Commit

Permalink
closes #2, propagate typechecking through vectorizing transformations (
Browse files Browse the repository at this point in the history
…#4)

Rewrite train.Model to use `Array['layers', TransformerLayer]`
  • Loading branch information
zhangir-azerbayev authored Jul 17, 2024
1 parent f89f6e4 commit 3d8aba2
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 30 deletions.
30 changes: 29 additions & 1 deletion shardlib/shardtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@ def center_channels(x: f32[b'batch/d channels']) -> f32[b'batch/d channels']:
from contextvars import ContextVar
from enum import IntEnum
from typing import Any, Union
from typing import get_args, get_origin
from typeguard import check_type_internal, typechecked
import jax
import jax.numpy as jnp
from types import GenericAlias
from typeguard import TypeCheckError, TypeCheckerCallable
import dataclasses
from dataclasses import dataclass
from dataclasses import dataclass, make_dataclass
from typeguard import checker_lookup_functions


Expand Down Expand Up @@ -286,6 +287,33 @@ def unflatten(_aux, fields):
_PYTREE_DATACLASSES.add(cls)
return cls

class Array:
"""If `cls` is an array type or a `pytree_dataclass` of array types,
`Array[axes, cls]` will extend `cls` with leading axes `axes`.
For example, `Array['layers', f32['batch d_model']] returns f32['layers batch d_model`]`.
"""
def __class_getitem__(cls, x):
axes, input_cls = x
if isinstance(axes, str):
axes = axes.encode('utf-8')
elif isinstance(axes, bytes):
pass
else:
raise ValueError(f"input axes to {cls} must be Union[bytes, str]")

if dataclasses.is_dataclass(input_cls):
extended_fields = []
for fld in dataclasses.fields(input_cls):
extended_type = Array[axes, fld.type]
extended_fields.append((fld.name, extended_type))

extended_cls = make_dataclass(input_cls.__name__, extended_fields, bases=(input_cls,))
pytree_dataclass(extended_cls)
return extended_cls
else:
number_type, shape = get_origin(input_cls), get_args(input_cls)
extended_shape = (axes + b' ' + shape[0],)
return GenericAlias(number_type, extended_shape)

def make_partition_specs(cls):
"""Instantiates a pytree dataclass with a PartitionSpec at array type."""
Expand Down
63 changes: 34 additions & 29 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax.numpy as jnp
import math
from input_loader import FlatTokensParams, HuggingFaceDataParams, TokenBatch, TokenBatchParams, get_loader
from shardlib.shardtypes import bf16, bool_, f32, pytree_dataclass, u32, make_shardings
from shardlib.shardtypes import bf16, bool_, f32, pytree_dataclass, u32, make_shardings, Array
import shardlib.shardops as shardops
P = PartitionSpec
import einops
Expand All @@ -47,19 +47,24 @@ class Hparams:
d_ff: int
rope_max_timescale: int

@pytree_dataclass
class TransformerLayer:
ln1: f32['d_model/t/d']
ln2: f32['d_model/t/d']
w_q: f32['d_model/d n_q_per_kv n_kv/t d_head']
w_kv: f32['2 d_model/d n_kv/t d_head']
w_o: f32['d_model/d n_q_per_kv n_kv/t d_head']
w_gate: f32['d_model/d d_ff/t']
w_up: f32['d_model/d d_ff/t']
w_down: f32['d_model/d d_ff/t']

Transformer = Array['layers', TransformerLayer]

@pytree_dataclass
class Model:
embed: f32['vocab/t d_model/d']
unembed: f32['vocab/t d_model/d']
ln1: f32['layers d_model/t/d']
ln2: f32['layers d_model/t/d']
w_q: f32['layers d_model/d n_q_per_kv n_kv/t d_head']
w_kv: f32['layers 2 d_model/d n_kv/t d_head']
w_o: f32['layers d_model/d n_q_per_kv n_kv/t d_head']
w_gate: f32['layers d_model/d d_ff/t']
w_up: f32['layers d_model/d d_ff/t']
w_down: f32['layers d_model/d d_ff/t']
transformer: Transformer
final_layer_norm: f32['d_model/d/t']

@staticmethod
Expand Down Expand Up @@ -105,14 +110,16 @@ def init(h: Hparams, rng: PRNGKey) -> 'Model':
arrays = Model(
embed=embed,
unembed=unembed,
ln1=ln1,
ln2=ln2,
w_q=w_q,
w_kv=w_kv,
w_o=w_o,
w_gate=w_gate,
w_up=w_up,
w_down=w_down,
transformer=Transformer(
ln1=ln1,
ln2=ln2,
w_q=w_q,
w_kv=w_kv,
w_o=w_o,
w_gate=w_gate,
w_up=w_up,
w_down=w_down,
),
final_layer_norm=final_layer_norm,
)
shardings = make_shardings(Model)
Expand All @@ -138,19 +145,17 @@ def forward_pass(self, h: Hparams, ids: u32[b'B/d L'], is_seq_start: bool_[b'B/d
##### Transformer blocks.
@explicit_activation_checkpointing
@typechecked
def loop_body(x: bf16[b'B/d L M/t'], layer_weights: Any) -> Tuple[bf16[b'B/d L M/t'], Tuple[()]]:
w_q, w_kv, w_o, w_gate, w_up, w_down, ln1, ln2 = layer_weights

def loop_body(x: bf16[b'B/d L M/t'], layer_weights: TransformerLayer) -> Tuple[bf16[b'B/d L M/t'], Tuple[()]]:
# Pre-attention RMSNorm
ln1 = shardops.all_gather('M/t/d -> M', jnp.float32(ln1))
ln1 = shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln1))
gx = shardops.all_gather('B/d L M/t -> B/d L M', x)
nx = jnp.bfloat16(rms_norm(gx) * ln1)

# Attention, using Grouped Query Attention and RoPE position embeddings.
w_q = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(w_q))
w_q = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_q))
q = save_for_backward(shardops.einsum_unreduced('B/d L M, M Q K/t D -> B/d L Q K/t D', nx, w_q))
q = rope_table.apply('L D -> 1 L 1 1 D', q)
w_kv = shardops.all_gather('2 M/d K/t D -> 2 M K/t D', jnp.bfloat16(w_kv))
w_kv = shardops.all_gather('2 M/d K/t D -> 2 M K/t D', jnp.bfloat16(layer_weights.w_kv))
k, v = shardops.einsum_unreduced('B/d L M, k_v M K/t D -> k_v B/d L K/t D', nx, w_kv)
k = save_for_backward(k)
v = save_for_backward(v)
Expand All @@ -161,29 +166,29 @@ def loop_body(x: bf16[b'B/d L M/t'], layer_weights: Any) -> Tuple[bf16[b'B/d L M
probs = jnp.bfloat16(jax.nn.softmax(logits, axis=2))
attn_out = shardops.einsum_unreduced(
'B/d Qlen Klen Q K/t, B/d Klen K/t D -> B/d Qlen Q K/t D', probs, v)
w_o = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(w_o))
w_o = shardops.all_gather('M/d Q K/t D -> M Q K/t D', jnp.bfloat16(layer_weights.w_o))
attn_out = shardops.einsum_unreduced('B/d Qlen Q K/t D, M Q K/t D -> B/d Qlen M', attn_out, w_o)
attn_out = shardops.psum_scatter('B/d Qlen M -> B/d Qlen M/t', attn_out)
x = save_for_backward(x + attn_out)

# Pre-FFN RMSNorm
ln2 = save_for_backward(shardops.all_gather('M/t/d -> M', jnp.float32(ln2)))
ln2 = save_for_backward(shardops.all_gather('M/t/d -> M', jnp.float32(layer_weights.ln2)))
gx = shardops.all_gather('B/d L M/t -> B/d L M', x)
nx = jnp.bfloat16(rms_norm(gx) * ln2)

# FFN, using SwiGLU
w_gate = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(w_gate))
w_gate = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_gate))
gate_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_gate))
w_up = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(w_up))
w_up = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_up))
up_proj = save_for_backward(shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_up))
y = jax.nn.swish(gate_proj) * up_proj
w_down = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(w_down))
w_down = shardops.all_gather('M/d F/t -> M F/t', jnp.bfloat16(layer_weights.w_down))
ffn_out = shardops.einsum_unreduced('B/d L F/t, M F/t -> B/d L M', y, w_down)
ffn_out = shardops.psum_scatter('B/d L M -> B/d L M/t', ffn_out)

return jnp.bfloat16(x + ffn_out), ()

x, () = jax.lax.scan(loop_body, jnp.bfloat16(x), (self.w_q, self.w_kv, self.w_o, self.w_gate, self.w_up, self.w_down, self.ln1, self.ln2))
x, () = jax.lax.scan(loop_body, jnp.bfloat16(x), self.transformer)

##### Final layernorm and output projection.
x = shardops.all_gather('B/d L M/t -> B/d L M', x)
Expand Down

0 comments on commit 3d8aba2

Please sign in to comment.