Skip to content

Commit

Permalink
Add a CPU check and CI (#15)
Browse files Browse the repository at this point in the history
Similar to #13, we
also add a CPU github action. This action will run `pytest` on the repo.

Currently there is only one test, which is the Llama test in
torch_xla_models.

In order to run the test today, we need a HF_TOKEN. I created a personal
read only token and
#14 tracks avoiding
the need for HF_TOKEN, after which I'll need to remember to invalidate
the token.
  • Loading branch information
tengyifei authored and qihqi committed Jan 13, 2025
1 parent 04cd0a9 commit 326c5fa
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions torchprime/experimental/torchax_models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import custom_mesh
import jax
from jax import numpy as np
import numpy as np
import splash_attn
import torch
Expand Down Expand Up @@ -142,17 +143,27 @@ def make_weight_shard(weight_meta, slice_index):


def create_sharded_weights(model, mesh, sharding_map):
res = {}
for name, weight_meta in model.state_dict().items():
sharding_spec = sharding_map.get(_process_sharding_name(name))
if sharding_spec is None:
print("Skipping weight:", name)
continue
sharding = NamedSharding(mesh, P(*sharding_spec))
res[name] = jax.make_array_from_callback(
weight_meta.shape, sharding,
functools.partial(make_weight_shard, weight_meta))
return res
name_to_sharding = {
name: NamedSharding(mesh, P(*sharding_map.get(_process_sharding_name(name))))
for name in model.state_dict().keys()
if _process_sharding_name(name) in sharding_map
}

kaiming = jax.nn.initializers.he_uniform(dtype=jnp.bfloat16)
key = jax.random.PRNGKey(0)

@functools.partial(
jax.jit,
out_shardings=name_to_sharding,
)
def create_weights():
res = {}
for name, weight_meta in model.state_dict().items():
res[name] = kaiming(key, weight_meta.shape, interop.jax_view(weight_meta.dtype))
return res

weights = create_weights()
return interop.torch_view(weights)


def sharded_device_put(tensor, sharding):
Expand Down

0 comments on commit 326c5fa

Please sign in to comment.