From 326c5fa630bf3294a4554c092702752ccf367ec0 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Thu, 9 Jan 2025 13:10:50 -0800 Subject: [PATCH 1/3] Add a CPU check and CI (#15) Similar to https://github.com/AI-Hypercomputer/torchprime/pull/13, we also add a CPU github action. This action will run `pytest` on the repo. Currently there is only one test, which is the Llama test in torch_xla_models. In order to run the test today, we need a HF_TOKEN. I created a personal read only token and https://github.com/AI-Hypercomputer/torchprime/issues/14 tracks avoiding the need for HF_TOKEN, after which I'll need to remember to invalidate the token. --- torchprime/experimental/torchax_models/run.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/torchprime/experimental/torchax_models/run.py b/torchprime/experimental/torchax_models/run.py index dba9d1b..365b1d6 100644 --- a/torchprime/experimental/torchax_models/run.py +++ b/torchprime/experimental/torchax_models/run.py @@ -3,6 +3,7 @@ import custom_mesh import jax +from jax import numpy as np import numpy as np import splash_attn import torch @@ -142,17 +143,27 @@ def make_weight_shard(weight_meta, slice_index): def create_sharded_weights(model, mesh, sharding_map): - res = {} - for name, weight_meta in model.state_dict().items(): - sharding_spec = sharding_map.get(_process_sharding_name(name)) - if sharding_spec is None: - print("Skipping weight:", name) - continue - sharding = NamedSharding(mesh, P(*sharding_spec)) - res[name] = jax.make_array_from_callback( - weight_meta.shape, sharding, - functools.partial(make_weight_shard, weight_meta)) - return res + name_to_sharding = { + name: NamedSharding(mesh, P(*sharding_map.get(_process_sharding_name(name)))) + for name in model.state_dict().keys() + if _process_sharding_name(name) in sharding_map + } + + kaiming = jax.nn.initializers.he_uniform(dtype=jnp.bfloat16) + key = jax.random.PRNGKey(0) + + @functools.partial( + jax.jit, + out_shardings=name_to_sharding, + ) + def create_weights(): + res = {} + for name, weight_meta in model.state_dict().items(): + res[name] = kaiming(key, weight_meta.shape, interop.jax_view(weight_meta.dtype)) + return res + + weights = create_weights() + return interop.torch_view(weights) def sharded_device_put(tensor, sharding): From 140a845da1e893dc17e39613f7ed9b406e8a8035 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 10 Jan 2025 20:15:36 +0000 Subject: [PATCH 2/3] different init --- torchprime/experimental/torchax_models/run.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torchprime/experimental/torchax_models/run.py b/torchprime/experimental/torchax_models/run.py index 365b1d6..eefc698 100644 --- a/torchprime/experimental/torchax_models/run.py +++ b/torchprime/experimental/torchax_models/run.py @@ -3,7 +3,7 @@ import custom_mesh import jax -from jax import numpy as np +from jax import numpy as jnp import numpy as np import splash_attn import torch @@ -150,19 +150,26 @@ def create_sharded_weights(model, mesh, sharding_map): } kaiming = jax.nn.initializers.he_uniform(dtype=jnp.bfloat16) + key = jax.random.PRNGKey(0) + key = jax.device_put(key, NamedSharding(mesh, P())) # replicate @functools.partial( jax.jit, out_shardings=name_to_sharding, ) - def create_weights(): + def create_weights(rng): res = {} for name, weight_meta in model.state_dict().items(): - res[name] = kaiming(key, weight_meta.shape, interop.jax_view(weight_meta.dtype)) + rng, subkey = jax.random.split(rng) + if len(weight_meta.shape) < 2: + res[name] = jax.random.normal(subkey, weight_meta.shape, + interop.jax_view(weight_meta.dtype)) + else: + res[name] = kaiming(subkey, weight_meta.shape, interop.jax_view(weight_meta.dtype)) return res - weights = create_weights() + weights = create_weights(key) return interop.torch_view(weights) From 35d7e49797398d55adffaac42bcab2265476b5e7 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 10 Jan 2025 21:54:21 +0000 Subject: [PATCH 3/3] skip keys not needed --- torchprime/experimental/torchax_models/run.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchprime/experimental/torchax_models/run.py b/torchprime/experimental/torchax_models/run.py index eefc698..990fc5e 100644 --- a/torchprime/experimental/torchax_models/run.py +++ b/torchprime/experimental/torchax_models/run.py @@ -161,6 +161,8 @@ def create_sharded_weights(model, mesh, sharding_map): def create_weights(rng): res = {} for name, weight_meta in model.state_dict().items(): + if _process_sharding_name(name) not in sharding_map: + continue rng, subkey = jax.random.split(rng) if len(weight_meta.shape) < 2: res[name] = jax.random.normal(subkey, weight_meta.shape,