diff --git a/.github/workflows/cpu_test.yml b/.github/workflows/cpu_test.yml new file mode 100644 index 0000000..cc015c6 --- /dev/null +++ b/.github/workflows/cpu_test.yml @@ -0,0 +1,29 @@ +name: CPU tests + +on: + push: + branches: + - main + pull_request: + schedule: + - cron: "0 8 * * *" + +jobs: + pytest: + runs-on: ubuntu-22.04 + strategy: + matrix: + python-version: ["3.10", "3.11"] + container: + image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_${{ matrix.python-version }}_tpuvm + steps: + - uses: actions/checkout@v4 + - name: Install dev dependencies + run: | + python -m pip install --upgrade pip + pip install -e '.[dev]' + - name: Run PyTest + run: | + # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token. + export HF_TOKEN=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 + pytest diff --git a/torchprime/experimental/torchax_models/run.py b/torchprime/experimental/torchax_models/run.py index fc58bcf..5cb73b2 100644 --- a/torchprime/experimental/torchax_models/run.py +++ b/torchprime/experimental/torchax_models/run.py @@ -3,6 +3,7 @@ import custom_mesh import jax +from jax import numpy as np import numpy as np import splash_attn import torch @@ -135,20 +136,28 @@ def register_attention(fn): def create_sharded_weights(model, mesh, sharding_map): - res = {} - for name, weight_meta in model.state_dict().items(): - sharding_spec = sharding_map.get(_process_sharding_name(name)) - if sharding_spec is None: - print("Skipping weight:", name) - continue - sharding = NamedSharding(mesh, P(*sharding_spec)) - with jax.default_device(jax.devices("cpu")[0]): - weight_torch = torch.randn(weight_meta.shape, dtype=weight_meta.dtype) - weight_jax = torch_xla2.default_env().to_xla(weight_torch).jax() - callback = (lambda weight_jax: lambda a: weight_jax[a])(weight_jax) - res[name] = jax.make_array_from_callback(weight_jax.shape, sharding, - callback) - return res + + name_to_sharding = { + name: NamedSharding(mesh, P(*sharding_map.get(_process_sharding_name(name)))) + for name in model.state_dict().keys() + if _process_sharding_name(name) in sharding_map + } + + kaiming = jax.nn.initializers.he_uniform(dtype=jnp.bfloat16) + key = jax.random.PRNGKey(0) + + @functools.partial( + jax.jit, + out_shardings=name_to_sharding, + ) + def create_weights(): + res = {} + for name, weight_meta in model.state_dict().items(): + res[name] = kaiming(key, weight_meta.shape, interop.jax_view(weight_meta.dtype)) + return res + + weights = create_weights() + return interop.torch_view(weights) def sharded_device_put(tensor, sharding):