diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index f02cc562b5..228105d553 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -56,7 +56,6 @@ def __call__(self, x, mask, disable_dropout=False): self_attn_mask_type="padding", enable_relative_embedding=False, enable_sequence_parallel=self.enable_seq_paral, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) @@ -72,17 +71,15 @@ def __call__(self, x, mask, disable_dropout=False): features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), bias_axes=(NAMED_TP_AXIS,), - dtype=jnp.bfloat16, )(x) x = te_flax.DenseGeneral( features=256, kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), bias_axes=(NAMED_BROADCAST_AXIS,), - dtype=jnp.bfloat16, )(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x @@ -91,7 +88,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -136,7 +133,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index eb4a1d0afb..0dab636718 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -51,17 +51,16 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x @@ -70,7 +69,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -115,7 +114,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 91186a15c4..6522ed896a 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -57,7 +57,6 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) @@ -67,17 +66,15 @@ def __call__(self, x, mask, disable_dropout=False): features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), bias_axes=(NAMED_TP_AXIS,), - dtype=jnp.bfloat16, )(x) x = te_flax.DenseGeneral( features=256, kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), bias_axes=(NAMED_BROADCAST_AXIS,), - dtype=jnp.bfloat16, )(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index dd1997fe6f..cfbd30b767 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -46,17 +46,16 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x @@ -66,7 +65,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -112,7 +111,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -217,6 +216,7 @@ def train_and_evaluate(args): with te.fp8_autocast(enabled=args.use_fp8): encoder = Net(num_embed) + # We use nn.Embed, thus inputs need to be in int inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) var_collect = encoder.init(init_rngs, inputs, masks) diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 54ecadeee8..9d8f51cc16 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -36,6 +36,8 @@ def __call__(self, x, disable_dropout=False): nn_Dense = te_flax.DenseGeneral else: nn_Dense = nn.Dense + # dtype is used for param init in TE but computation in Linen.nn + dtype = jnp.float32 if self.use_te else jnp.bfloat16 x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x) x = nn.relu(x) @@ -44,11 +46,13 @@ def __call__(self, x, disable_dropout=False): x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) - x = nn_Dense(features=128, dtype=jnp.bfloat16)(x) + assert x.dtype == jnp.bfloat16 + x = nn_Dense(features=128, dtype=dtype)(x) x = nn.relu(x) x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout) - x = nn_Dense(features=16, dtype=jnp.bfloat16)(x) - x = nn.Dense(features=10, dtype=jnp.bfloat16)(x) + x = nn_Dense(features=16, dtype=dtype)(x) + x = nn_Dense(features=10, dtype=dtype)(x) + assert x.dtype == jnp.bfloat16 return x diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 87a5145c65..77b299e5bf 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -271,7 +271,6 @@ def _test_layernorm_mlp( transpose_batch_sequence=False, # input: [batch, seqlen, hidden] intermediate_dim=INTERMEDIATE, activations=activation_type, - dtype=dtype, use_bias=use_bias, ) params_single = ln_mlp_single.init(init_rngs, x) @@ -289,7 +288,6 @@ def _test_layernorm_mlp( transpose_batch_sequence=False, intermediate_dim=INTERMEDIATE, activations=activation_type, - dtype=dtype, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index a67335236d..ed15913f38 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -265,8 +265,8 @@ def test_forward( """Test only the forward""" inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) - ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs) - layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs) + ref_layer_cls = partial(self.reference_layer, **self.attrs) + layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs) ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks) test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) @@ -288,8 +288,8 @@ def test_backward( """Test forward and backward through value_and_grad()""" inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) - ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs) - layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs) + ref_layer_cls = partial(self.reference_layer, **self.attrs) + layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs) ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks) test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 554def2c3f..dba7cb64fc 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -110,7 +110,7 @@ class DotProductAttention(nn.Module): Args: dropout_rate: dropout rate - dtype: the dtype of the computation (default: float32) + dtype: the data type used to allocate the initial parameters (default: float32). float32_logits: bool, if True then compute logits in float32 to avoid numerical issues with bfloat16. """ @@ -195,6 +195,7 @@ def __call__( attn_weights = attn_weights * multiplier attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) + attn_weights = attn_weights.astype(value.dtype) # Take the linear combination of `value`. if self.transpose_batch_sequence: @@ -209,7 +210,7 @@ class DenseGeneral(nn.Module): Attributes: features: tuple with numbers of output features. axis: tuple with axes to apply the transformation on. - dtype: the dtype of the computation (default: float32). + dtype: the data type used to allocate the initial parameters (default: float32). kernel_init: initializer function for the weight matrix. use_bias: whether to add a bias to the output (default: False). bias_init: initializer function for the bias vector. @@ -226,7 +227,9 @@ class DenseGeneral(nn.Module): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ) super().__post_init__() @nn.compact @@ -239,6 +242,7 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ + input_dtype = inputs.dtype features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) @@ -248,23 +252,24 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features)) kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes ) - kernel = jnp.asarray(kernel, self.dtype) + kernel = jnp.asarray(kernel, input_dtype) kernel = jnp.reshape(kernel, kernel_shape) if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, self.features, jnp.float32, axes=self.bias_axes + "bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes ) - bias = bias.astype(self.dtype) + bias = bias.astype(input_dtype) else: bias = None contract_ind = tuple(range(0, len(axis))) y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) + y = y.astype(input_dtype) if bias is not None: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) @@ -281,7 +286,7 @@ class MlpBlock(nn.Module): kernel_init: Kernel function, passed to the dense layers. deterministic: Whether the dropout layers should be deterministic. intermediate_dropout_rate: Dropout rate used after the intermediate layers. - dtype: Type for the dense layer. + dtype: the data type used to allocate the initial parameters (default: float32). """ transpose_batch_sequence: bool @@ -296,7 +301,9 @@ class MlpBlock(nn.Module): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ) super().__post_init__() @nn.compact @@ -358,6 +365,9 @@ def __call__(self, inputs, deterministic: bool = False): bias_axes="embed", name="wo", )(x) + assert ( + output.dtype == inputs.dtype + ), f"input.dtype={input.dtype}, output.dtype={output.dtype}" return output @@ -429,7 +439,7 @@ class MultiHeadAttention(nn.Module): should be divisible by the number of heads. num_gqa_groups: number of kv attention heads head_dim: dimension of each head. - dtype: the dtype of the computation. + dtype: the data type used to allocate the initial parameters (default: float32). dropout_rate: dropout rate kernel_init: initializer for the kernel of the Dense layers. float32_logits: bool, if True then compute logits in float32 to avoid @@ -453,7 +463,9 @@ class MultiHeadAttention(nn.Module): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", dtype=self.dtype + ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads super().__post_init__() @@ -738,6 +750,9 @@ def qkv_init(key, shape, dtype): dtype=self.dtype, name="out", )(x) + assert ( + inputs_q.dtype == inputs_kv.dtype == out.dtype + ), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}" return out @@ -763,13 +778,13 @@ def __post_init__(self): def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Applies layer normalization on the input.""" - x = jnp.asarray(x, jnp.float32) + input_dtype = x.dtype features = x.shape[-1] scale = nn_partitioning.param_with_axes( - "scale", self.scale_init, (features,), jnp.float32, axes=("embed",) + "scale", self.scale_init, (features,), self.dtype, axes=("embed",) ) - scale = jnp.asarray(scale, self.dtype) + scale = jnp.asarray(scale, input_dtype) if self.layernorm_type == "layernorm": mean = jnp.mean(x, axis=-1, keepdims=True) @@ -777,9 +792,9 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: y = (x - mean) * lax.rsqrt(var + self.epsilon) bias = nn_partitioning.param_with_axes( - "ln_bias", self.bias_init, (features,), jnp.float32, axes=("embed",) + "ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",) ) - bias = jnp.asarray(bias, self.dtype) + bias = jnp.asarray(bias, input_dtype) if not self.zero_centered_gamma: z = y * scale + bias @@ -792,7 +807,8 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: y = x * lax.rsqrt(mean2 + self.epsilon) z = y * scale - return jnp.asarray(z, self.dtype) + assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}" + return z class RelativePositionBiases(nn.Module): @@ -805,7 +821,7 @@ class RelativePositionBiases(nn.Module): distance bucket. num_heads: Number of heads in the attention layer. Each head will get a different relative position weighting. - dtype: Type of arrays through this module. + dtype: the data type used to allocate the initial parameters (default: float32). embedding_init: initializer for relative embedding table. """ @@ -1087,6 +1103,7 @@ def __call__(self, inputs, encoder_mask=None, deterministic=False): dtype=self.dtype, name="output_layernorm", )(y) + assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}" return y @@ -1293,6 +1310,7 @@ def __call__( name="output_layernorm", )(z) + assert z.dtype == inputs.dtype, f"output_dtype={z.dtype}, input_dtype={inputs.dtype}" return z diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 23bc8d3602..d814c2d4df 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -57,19 +57,15 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga def _create_layernorm_parameters( - layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype, weight_dtype + layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, input_dtype, dtype ): - scale = nn_partitioning.param_with_axes( - "scale", scale_init, shape, weight_dtype, axes=scale_axes - ) - scale = scale.astype(dtype) + scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes) + scale = scale.astype(input_dtype) layernorm_type = canonicalize_layernorm_type(layernorm_type) if layernorm_type == "layernorm": - bias = nn_partitioning.param_with_axes( - "ln_bias", bias_init, shape, weight_dtype, axes=bias_axes - ) - bias = bias.astype(dtype) + bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes) + bias = bias.astype(input_dtype) else: assert layernorm_type == "rmsnorm" bias = None @@ -158,15 +154,15 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp heads = inputs.shape[1] q_seqlen = inputs.shape[2] k_seqlen = inputs.shape[3] - dtype = inputs.dtype + input_dtype = inputs.dtype logits = inputs if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available( - self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype + self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype ): if bias is not None: - logits = logits + bias.astype(dtype) + logits = logits + bias.astype(input_dtype) mask_ = mask if self.softmax_type is not SoftmaxType.SCALED_MASKED: @@ -178,25 +174,27 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp if mask is not None: attention_bias = lax.select( mask > 0, - jnp.full(mask.shape, -1e10).astype(dtype), - jnp.full(mask.shape, 0.0).astype(dtype), + jnp.full(mask.shape, -1e10), + jnp.full(mask.shape, 0.0), ) + attention_bias = attention_bias.astype(input_dtype) if bias is not None: attention_bias = _combine_biases(attention_bias, bias) if attention_bias is not None: - logits = logits + attention_bias.astype(dtype) + logits = logits + attention_bias.astype(input_dtype) # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED # and kernel is unavailable, then try on pure scaled softmax custom calls. if is_softmax_kernel_available( - SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype + SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype ): outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED) else: outputs = jax_nn.softmax(logits * self.scale_factor) + assert input_dtype == outputs.dtype return outputs @@ -261,9 +259,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -278,7 +274,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = ("embed",) dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = False def __post_init__(self): @@ -303,7 +298,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: outputs : jax.numpy.ndarray Output tensors. """ - x = x.astype(self.dtype) + input_dtype = x.dtype features = x.shape[-1] scale, ln_bias = _create_layernorm_parameters( @@ -313,10 +308,10 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: self.scale_axes, self.bias_init, self.bias_axes, + input_dtype, self.dtype, - self.weight_dtype, ) - return layernorm( + out = layernorm( x, scale, ln_bias, @@ -324,6 +319,8 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: zero_centered_gamma=self.zero_centered_gamma, epsilon=self.epsilon, ) + assert out.dtype == input_dtype + return out class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-methods @@ -408,9 +405,7 @@ class DenseGeneral(TransformerEngineBase): Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -428,13 +423,12 @@ class DenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = False def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype + 1.0, "fan_in", "truncated_normal", dtype=self.dtype ) super().__post_init__() @@ -454,24 +448,25 @@ def __call__(self, inputs: Array) -> Array: Output tensors. """ + input_dtype = inputs.dtype features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) - inputs = jnp.asarray(inputs, self.dtype) axis = _normalize_axes(axis, inputs.ndim) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) - kernel = kernel.astype(self.dtype) + if not FP8Helper.is_fp8_enabled(): + kernel = kernel.astype(input_dtype) if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes + "bias", self.bias_init, features, self.dtype, axes=self.bias_axes ) - bias = bias.astype(self.dtype) + bias = bias.astype(input_dtype) else: bias = None @@ -500,11 +495,11 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - self.weight_dtype, + self.dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) - lora_a_kernel = lora_a_kernel.astype(self.dtype) + lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) @@ -512,10 +507,10 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - self.weight_dtype, + self.dtype, axes=lora_b_kernel_axes, ) - lora_b_kernel = lora_b_kernel.astype(self.dtype) + lora_b_kernel = lora_b_kernel.astype(input_dtype) y += _apply_low_rank_adaptation( inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha @@ -524,6 +519,8 @@ def __call__(self, inputs: Array) -> Array: if bias is not None: bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape y += jnp.reshape(bias, bias_shape) + + assert y.dtype == input_dtype return y @@ -606,9 +603,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -638,7 +633,6 @@ class LayerNormDenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None @@ -650,7 +644,7 @@ def __post_init__(self): 1.0, "fan_in", "truncated_normal", - dtype=self.weight_dtype, + dtype=self.dtype, ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init, @@ -677,6 +671,7 @@ def __call__(self, inputs: Array) -> Array: If :attr:`return_layernorm_output=False`, then this would be None. """ + input_dtype = inputs.dtype ln_output = None fuse_layernorm = ( @@ -684,7 +679,6 @@ def __call__(self, inputs: Array) -> Array: and not self.return_layernorm_output and self.enable_layernorm ) - inputs = inputs.astype(self.dtype) if self.enable_layernorm: inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) @@ -699,8 +693,8 @@ def __call__(self, inputs: Array) -> Array: self.scale_axes, self.ln_bias_init, self.ln_bias_axes, + input_dtype, self.dtype, - self.weight_dtype, ) if not fuse_layernorm: @@ -730,9 +724,10 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) - kernel = kernel.astype(self.dtype) + if not FP8Helper.is_fp8_enabled(): + kernel = kernel.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) @@ -775,11 +770,11 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - self.weight_dtype, + self.dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) - lora_a_kernel = lora_a_kernel.astype(self.dtype) + lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) @@ -787,10 +782,10 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - self.weight_dtype, + self.dtype, axes=lora_b_kernel_axes, ) - lora_b_kernel = lora_b_kernel.astype(self.dtype) + lora_b_kernel = lora_b_kernel.astype(input_dtype) z += _apply_low_rank_adaptation( y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha @@ -799,9 +794,9 @@ def __call__(self, inputs: Array) -> Array: bias = None if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes + "bias", self.bias_init, features, self.dtype, axes=self.bias_axes ) - bias = bias.astype(self.dtype) + bias = bias.astype(input_dtype) if bias is not None: bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape @@ -810,6 +805,7 @@ def __call__(self, inputs: Array) -> Array: if self.depth_scaling is not None: z = z / self.depth_scaling + assert z.dtype == input_dtype return z, ln_output # dense_output, layer_norm_output @@ -915,9 +911,7 @@ class LayerNormMLP(TransformerEngineBase): Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -950,7 +944,6 @@ class LayerNormMLP(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None @@ -959,7 +952,7 @@ class LayerNormMLP(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype + 1.0, "fan_in", "truncated_normal", dtype=self.dtype ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init, @@ -988,6 +981,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: If :attr:`return_layernorm_output=False`, then this would be None. """ + input_dtype = inputs.dtype ln_output = None fuse_layernorm = ( @@ -996,8 +990,6 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: and self.enable_layernorm ) - inputs = inputs.astype(self.dtype) - gated_act_pool = [ ("gelu", "linear"), ("silu", "linear"), @@ -1035,8 +1027,8 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: self.scale_axes, self.ln_bias_init, self.ln_bias_axes, + input_dtype, self.dtype, - self.weight_dtype, ) if not fuse_layernorm: @@ -1083,11 +1075,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, kernel_1_each_shape, - self.weight_dtype, + self.dtype, axes=self.kernel_axes_1, ) kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) - kernel_1 = kernel_1.astype(self.dtype) + if not FP8Helper.is_fp8_enabled(): + kernel_1 = kernel_1.astype(input_dtype) hidden_size = inputs.shape[-1] hidden_size_tuple = _canonicalize_tuple(hidden_size) kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple @@ -1096,11 +1089,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_kernel", self.kernel_init, kernel_2_param_shape, - self.weight_dtype, + self.dtype, axes=self.kernel_axes_2, ) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) - kernel_2 = kernel_2.astype(self.dtype) + if not FP8Helper.is_fp8_enabled(): + kernel_2 = kernel_2.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) ffn1_ckpt_name = "ffn1" @@ -1115,20 +1109,20 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wi_bias", self.bias_init, bias_1_shape, - self.weight_dtype, + self.dtype, axes=self.bias_axes_1, ) - bias_1 = bias_1.astype(self.dtype) + bias_1 = bias_1.astype(input_dtype) bias_2_shape = (hidden_size,) bias_2 = nn_partitioning.param_with_axes( "wo_bias", self.bias_init, bias_2_shape, - self.weight_dtype, + self.dtype, axes=self.bias_axes_2, ) - bias_2 = bias_2.astype(self.dtype) + bias_2 = bias_2.astype(input_dtype) else: bias_1 = None bias_2 = None @@ -1195,11 +1189,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, wi_lora_a_kernel_init_each_shape, - self.weight_dtype, + self.dtype, axes=wi_lora_a_kernel_axes, ) wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) - wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype) + wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype) wi_lora_b_kernel_shape = ( num_activations, @@ -1211,10 +1205,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wi_lora_b_kernel", nn.initializers.zeros, wi_lora_b_kernel_shape, - self.weight_dtype, + self.dtype, axes=wi_lora_b_kernel_axes, ) - wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) + wi_lora_b_kernel = wi_lora_b_kernel.astype(input_dtype) x += _apply_low_rank_adaptation( y, @@ -1231,11 +1225,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wi_bias", self.bias_init, intermediate_dim, - self.weight_dtype, + self.dtype, axes=self.bias_axes_1, ) bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape - bias_1 = bias_1.astype(self.dtype) + bias_1 = bias_1.astype(input_dtype) x += jnp.reshape(bias_1, bias_1_shape) x = checkpoint_name(x, ffn1_ckpt_name) @@ -1250,7 +1244,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): z = functools.reduce(operator.mul, activations) # Remove act axis z = jnp.reshape(z, (*z.shape[:-2], -1)) - z = z.astype(self.dtype) + z = z.astype(input_dtype) z = nn.Dropout( rate=self.intermediate_dropout_rate, @@ -1259,7 +1253,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): )(z, deterministic=deterministic) z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes) - z = z.astype(self.dtype) + z = z.astype(input_dtype) # DenseGeneral 2 out = type_safe_dot_general( @@ -1273,10 +1267,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_a_kernel", self.kernel_init, wo_lora_a_kernel_shape, - self.weight_dtype, + self.dtype, axes=wo_lora_a_kernel_axes, ) - wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) + wo_lora_a_kernel = wo_lora_a_kernel.astype(input_dtype) wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size) wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape) @@ -1284,10 +1278,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_b_kernel", nn.initializers.zeros, wo_lora_b_kernel_shape, - self.weight_dtype, + self.dtype, axes=wo_lora_b_kernel_axes, ) - wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) + wo_lora_b_kernel = wo_lora_b_kernel.astype(input_dtype) out += _apply_low_rank_adaptation( z, @@ -1304,12 +1298,13 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_bias", self.bias_init, (hidden_size,), - self.weight_dtype, + self.dtype, axes=self.bias_axes_2, ) - bias_2 = bias_2.astype(self.dtype) + bias_2 = bias_2.astype(input_dtype) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out = checkpoint_name(out, ffn2_ckpt_name) + assert out.dtype == input_dtype return out, ln_output # Output, layner_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 100557404b..69fb74ba31 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -115,7 +115,6 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_bias_type: Optional[AttnBiasType] = None dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 float32_logits: bool = False scale_factor: Optional[float] = None transpose_batch_sequence: bool = True @@ -143,6 +142,8 @@ def __call__( assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match." assert query.shape[-1] == key.shape[-1], "q, k head_dim must match." + input_dtype = query.dtype + if self.scale_factor is None: scale_factor = 1.0 / sqrt(query.shape[-1]) else: @@ -150,8 +151,8 @@ def __call__( del self.scale_factor if self.float32_logits: - query = query.astype(self.dtype) - key = key.astype(self.dtype) + query = query.astype(jnp.float32) + key = key.astype(jnp.float32) h_q, h_kv = query.shape[-2], key.shape[-2] # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv. # Therefore, we have to maintain two code paths. @@ -234,7 +235,7 @@ def convert_to_softmax_type(attn_mask_type, mask): attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)( attn_weights, mask, bias - ).astype(self.dtype) + ).astype(input_dtype) if is_gqa: attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) @@ -244,9 +245,12 @@ def convert_to_softmax_type(attn_mask_type, mask): dropout_shape = list(attn_weights.shape) # TODO(rewang): add attention dropout broadcast dimension arguments for users keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) - multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype) + multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype) attn_weights = attn_weights * multiplier + assert ( + attn_weights.dtype == input_dtype + ), f"output={attn_weights.dtype}, input={input_dtype}" if self.transpose_batch_sequence: if is_gqa: return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape) @@ -254,6 +258,7 @@ def convert_to_softmax_type(attn_mask_type, mask): if is_gqa: return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape) + return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value) @@ -262,7 +267,6 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_bias_type: Optional[AttnBiasType] = None dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD scale_factor: Optional[float] = None transpose_batch_sequence: bool = False @@ -372,6 +376,7 @@ def __call__( if self.transpose_batch_sequence: x = x.transpose([1, 0, 2, 3]) + assert x.dtype == query.dtype return x @@ -492,9 +497,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. """ head_dim: int @@ -504,7 +507,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods attn_mask_type: AttnMaskType = "causal" attn_bias_type: AttnBiasType = None dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 dropout_rng_name: str = "dropout" float32_logits: bool = False qkv_layout: str = "bshd_bshd_bshd" @@ -552,6 +554,7 @@ def __call__( outputs: jax.numpy.ndarray Output tensors. """ + input_dtype = query.dtype if mask is not None: if sequence_descriptor is not None: @@ -642,7 +645,6 @@ def __call__( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, dtype=self.dtype, - weight_dtype=self.weight_dtype, float32_logits=self.float32_logits, scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, @@ -662,7 +664,6 @@ def __call__( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, dtype=self.dtype, - weight_dtype=self.weight_dtype, scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, qkv_layout=qkv_layout, @@ -679,7 +680,7 @@ def __call__( dropout_rng=dropout_rng, deterministic=deterministic, ) - + assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}" return x @@ -720,10 +721,10 @@ def alternate_impl(): sin, cos = generate_sin_cos(time_scales) x1, x2 = jnp.split(x, 2, axis=-1) - part_1 = (x1 * cos - x2 * sin).astype(x.dtype) - part_2 = (x2 * cos + x1 * sin).astype(x.dtype) + part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype) + part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype) - output = jnp.concatenate([part_1, part_2], axis=-1) + output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype) return output def consecutive_impl(): @@ -928,8 +929,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. fuse_qkv_params: bool, default = True If set to True, this module exposes a single fused parameter for query-key-value for self-attention and key-value for @@ -975,7 +974,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 fuse_qkv_params: bool = True transpose_batch_sequence: bool = True enable_sequence_parallel: bool = False @@ -1026,7 +1024,7 @@ def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "normal", dtype=self.weight_dtype + 1.0, "fan_in", "normal", dtype=self.dtype ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads @@ -1071,6 +1069,11 @@ def __call__( Output tensors. """ + assert ( + inputs_q.dtype == inputs_kv.dtype + ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}" + input_dtype = inputs_q.dtype + def query_init(*args): depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0) @@ -1154,7 +1157,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): dot_input_axes=inputs_logical_axes_no_sp, name="qkv", dtype=self.dtype, - weight_dtype=self.weight_dtype, )(inputs_q) qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj") qkv_layout = QKVLayout.BS3HD @@ -1178,7 +1180,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, - weight_dtype=self.weight_dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, @@ -1203,7 +1204,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, name="kv", dtype=self.dtype, - weight_dtype=self.weight_dtype, )(inputs_kv) kv_proj = checkpoint_name(kv_proj, "combined_kv_proj") qkv_layout = QKVLayout.BSHD_BS2HD @@ -1221,7 +1221,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, - weight_dtype=self.weight_dtype, ) query, ln_out = LayerNormDenseGeneral( enable_layernorm=self.input_layernorm, @@ -1242,7 +1241,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, - weight_dtype=self.weight_dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, @@ -1253,9 +1251,11 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): assert ln_out is not None inputs_kv = ln_out + query = query.astype(input_dtype) key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv) - key = key.astype(self.dtype) + key = key.astype(input_dtype) value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv) + value = value.astype(input_dtype) query = checkpoint_name(query, "query_proj") key = checkpoint_name(key, "key_proj") value = checkpoint_name(value, "value_proj") @@ -1380,7 +1380,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): attn_bias_type=self.attn_bias_type, attention_dropout=self.attention_dropout, dtype=self.dtype, - weight_dtype=self.weight_dtype, dropout_rng_name=self.dropout_rng_name, float32_logits=self.float32_logits, qkv_layout=qkv_layout.name, @@ -1406,11 +1405,13 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, - weight_dtype=self.weight_dtype, name="out", )(x) out = checkpoint_name(out, "out_proj") + assert ( + inputs_q.dtype == out.dtype + ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}" return out, ln_out @@ -1435,9 +1436,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. """ num_buckets: int @@ -1446,7 +1445,6 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho embedding_init: Callable[..., Array] = nn.linear.default_embed_init embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets") dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 @nn.compact def __call__(self, q_seqlen, k_seqlen, bidirectional=True): @@ -1499,7 +1497,7 @@ def __call__(self, q_seqlen, k_seqlen, bidirectional=True): "rel_embedding", self.embedding_init, (self.num_attention_heads, self.num_buckets), - self.weight_dtype, + self.dtype, axes=self.embedding_axes, ) @@ -1672,9 +1670,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used for computation. - weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type of the module parameters. + The data type used to allocate the initial parameters. drop_path: float, default = 0.0 When > 0.0, applies stochastic depth per sample in the main path of the residual block. @@ -1727,7 +1723,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 - weight_dtype: DType = jnp.float32 drop_path: float = 0.0 fuse_qkv_params: bool = True transpose_batch_sequence: bool = False @@ -1739,11 +1734,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods def __post_init__(self): if self.mha_kernel_init is None: self.mha_kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "normal", dtype=self.weight_dtype + 1.0, "fan_in", "normal", dtype=self.dtype ) if self.mlp_kernel_init is None: self.mlp_kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype + 1.0, "fan_in", "truncated_normal", dtype=self.dtype ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads @@ -1793,9 +1788,7 @@ def __call__( outputs: jax.numpy.ndarray Output tensors. """ - - inputs = inputs.astype(self.dtype) - + input_dtype = inputs.dtype assert ( self.layer_type in TransformerLayerType ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}." @@ -1833,8 +1826,9 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, - weight_dtype=self.weight_dtype, - embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"), + embedding_init=nn.initializers.variance_scaling( + 1.0, "fan_avg", "uniform", dtype=self.dtype + ), name="relpos_bias", ) else: @@ -1867,7 +1861,6 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): x, ln_out = MultiHeadAttention( num_attention_heads=self.num_attention_heads, dtype=self.dtype, - weight_dtype=self.weight_dtype, head_dim=head_dim, num_gqa_groups=self.num_gqa_groups, transpose_batch_sequence=self.transpose_batch_sequence, @@ -1946,7 +1939,6 @@ def hidden_dropout(x, deterministic): y, ln_out = MultiHeadAttention( num_attention_heads=self.num_attention_heads, dtype=self.dtype, - weight_dtype=self.weight_dtype, head_dim=head_dim, num_gqa_groups=self.num_gqa_groups, transpose_batch_sequence=self.transpose_batch_sequence, @@ -2012,7 +2004,6 @@ def hidden_dropout(x, deterministic): intermediate_dropout_rate=self.intermediate_dropout, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, dtype=self.dtype, - weight_dtype=self.weight_dtype, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), kernel_init=self.mlp_kernel_init, @@ -2062,8 +2053,7 @@ def hidden_dropout(x, deterministic): bias_axes=(W_NO_SHARD_AXES,), transpose_batch_sequence=self.transpose_batch_sequence, dtype=self.dtype, - weight_dtype=self.weight_dtype, name="output_layernorm", )(z) - + assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}" return z