Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable to assign different parameters dtype during training #1037

Merged
merged 2 commits into from
Mar 7, 2025

Conversation

jialingt
Copy link
Contributor

@jialingt jialingt commented Mar 5, 2025

This PR enables train_dtype to accept both jnp.dype and also PerParamFn[jnp.dtype]:

  1. jnp.dtype, where both float inputs and model parameters will be cast to this dtype.

  2. ConfigOr[PerParamFn[jnp.dtype]], allowing different dtypes to be applied to different parameters during training.

@jialingt jialingt requested review from ruomingp, markblee and a team as code owners March 5, 2025 22:07
@jialingt jialingt changed the title Enable assign different train_dtype for parameters during training Enable assign different parameters dtype during training Mar 5, 2025
@jialingt jialingt changed the title Enable assign different parameters dtype during training Enable to assign different parameters dtype during training Mar 5, 2025
def __call__(self, params: Union[Nested[Tensor], Nested[TensorSpec]]) -> Nested[T]:
"""This protocol requires a callable that accepts either a nested Tensor or
a nested TensorSpec as input and returns a processed value for each parameter.
Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like some docstring format got lost from copy/paste?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you are right! Good catch.
Added all spaces back.

@jialingt jialingt added this pull request to the merge queue Mar 7, 2025
Merged via the queue into apple:main with commit b1e7b37 Mar 7, 2025
6 checks passed
@jialingt jialingt deleted the per_param_train_dtype branch March 7, 2025 04:56
@ds-hwang
Copy link
Contributor

ds-hwang commented Mar 7, 2025

Hi, after this PR, CI is broken.
ImportError: cannot import name 'PerParamFn' from 'axlearn.common.utils'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants