Skip to content

Commit

Permalink
[JAX] Flax with compute dtype inferred from input dtype. (#1485)
Browse files Browse the repository at this point in the history
flax module with compute dtype inferred from the inputs

Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng authored Feb 18, 2025
1 parent eb9857d commit 6673f16
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 180 deletions.
9 changes: 3 additions & 6 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __call__(self, x, mask, disable_dropout=False):
self_attn_mask_type="padding",
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

Expand All @@ -72,17 +71,15 @@ def __call__(self, x, mask, disable_dropout=False):
features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
dtype=jnp.bfloat16,
)(x)

x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16,
)(x)

x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x


Expand All @@ -91,7 +88,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):

def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits

Expand Down Expand Up @@ -136,7 +133,7 @@ def eval_step(state, inputs, masks, labels, var_collect):

def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits

Expand Down
11 changes: 5 additions & 6 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,16 @@ def __call__(self, x, mask, disable_dropout=False):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

x = x.reshape(x.shape[0], -1)

x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)

x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)

x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x


Expand All @@ -70,7 +69,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):

def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits

Expand Down Expand Up @@ -115,7 +114,7 @@ def eval_step(state, inputs, masks, labels, var_collect):

def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits

Expand Down
5 changes: 1 addition & 4 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __call__(self, x, mask, disable_dropout=False):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

Expand All @@ -67,17 +66,15 @@ def __call__(self, x, mask, disable_dropout=False):
features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
dtype=jnp.bfloat16,
)(x)

x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16,
)(x)

x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x


Expand Down
12 changes: 6 additions & 6 deletions examples/jax/encoder/test_single_gpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,16 @@ def __call__(self, x, mask, disable_dropout=False):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

x = x.reshape(x.shape[0], -1)

x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)

x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)

x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x


Expand All @@ -66,7 +65,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):

def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits

Expand Down Expand Up @@ -112,7 +111,7 @@ def eval_step(state, inputs, masks, labels, var_collect):

def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits

Expand Down Expand Up @@ -217,6 +216,7 @@ def train_and_evaluate(args):

with te.fp8_autocast(enabled=args.use_fp8):
encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
var_collect = encoder.init(init_rngs, inputs, masks)
Expand Down
10 changes: 7 additions & 3 deletions examples/jax/mnist/test_single_gpu_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __call__(self, x, disable_dropout=False):
nn_Dense = te_flax.DenseGeneral
else:
nn_Dense = nn.Dense
# dtype is used for param init in TE but computation in Linen.nn
dtype = jnp.float32 if self.use_te else jnp.bfloat16

x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
x = nn.relu(x)
Expand All @@ -44,11 +46,13 @@ def __call__(self, x, disable_dropout=False):
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = nn_Dense(features=128, dtype=jnp.bfloat16)(x)
assert x.dtype == jnp.bfloat16
x = nn_Dense(features=128, dtype=dtype)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout)
x = nn_Dense(features=16, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=10, dtype=jnp.bfloat16)(x)
x = nn_Dense(features=16, dtype=dtype)(x)
x = nn_Dense(features=10, dtype=dtype)(x)
assert x.dtype == jnp.bfloat16
return x


Expand Down
2 changes: 0 additions & 2 deletions tests/jax/test_distributed_layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def _test_layernorm_mlp(
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
dtype=dtype,
use_bias=use_bias,
)
params_single = ln_mlp_single.init(init_rngs, x)
Expand All @@ -289,7 +288,6 @@ def _test_layernorm_mlp(
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
dtype=dtype,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
Expand Down
8 changes: 4 additions & 4 deletions tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ def test_forward(
"""Test only the forward"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)
ref_layer_cls = partial(self.reference_layer, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)

ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
Expand All @@ -288,8 +288,8 @@ def test_backward(
"""Test forward and backward through value_and_grad()"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)
ref_layer_cls = partial(self.reference_layer, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)

ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
Expand Down
Loading

0 comments on commit 6673f16

Please sign in to comment.