Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 3.1 405B on two Trillium pods instructions #25

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Instructions for training Llama 3 405B on Trillium TPU
# Instructions for training Llama 3.1 405B on Trillium TPU (1 pod)

This user guide provides a concise overview of the essential steps required to
run Hugging Face (HF) Llama 3 405B training on Trillium TPUs.

Note: the current docker supports Single Pod v6e. The multipod solution will be available in an upcoming update soon.
run Hugging Face (HF) Llama 3.1 405B training on Trillium TPUs. Specifically,
the instructions and docker image referenced here is optimized for a single
Trillium pod.

## Environment Setup

Expand Down
122 changes: 122 additions & 0 deletions training/trillium/Llama3-405B-PyTorch/XPK/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Instructions for training Llama 3.1 405B on Trillium TPU on multipod using XPK

The instructions and referenced docker image are optimized for training Llama 3.1
405B on two Trillium pods.

NOTE: the docker image contains a fork of `torch_xla`. We're working on
upstreaming the necessary dependencies. In the meantime, you may use this docker
image to study and reproduce the performance.

## Environment Setup
---
### 1. [Optional but suggested] Create virtual env
```bash
sudo apt-get update && sudo apt install python3.10-venv
python3.10 -m venv myenv
source myenv/bin/activate
```
---
### 2. Clone XPK repository and install XPK package
```bash
pushd ./
git clone https://github.com/google/xpk.git
cd xpk
pip install .
popd
```

---
### 3. Update and export environment variables
Modify environment variables in `env.sh` targetting your gcloud resource and the experiment model config. Source the script.
```
source env.sh
```

---
### 4. [Optional, skip if using existing XPK cluster] Create the XPK clusters
Please follow the corresponding XPK user guide to crea the XPK cluster first. If the cluster is already created, skip to Step 4.
```bash
NETWORK_NAME=${CLUSTER_NAME}-mtu9k
NETWORK_FW_NAME=${NETWORK_NAME}-fw

# Use a custom network for better performance as well as avoid the default network to be overloaded.
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"

python3 xpk.py cluster create --cluster $CLUSTER_NAME --cluster-cpu-machine-type=n1-standard-8 --num-slices=$NUM_SLICES --tpu-type=$TPU_TYPE --zone=$ZONE --project=$PROJECT --on-demand --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" --create-vertex-tensorboard --gke-version=1.31.1-gke.1678000
```
Note that if the `gke-version` is not available anymore, pick one available from the error message from the terminal output.

---
### 5. Launch the Llama 3.1 training workload to XPK cluster.
```
bash benchmark.sh
```

Below is part of the sample output from

```
...
[XPK] Waiting for `Upload Docker Image`, for 7 seconds
sqpu-2024-11-01-01-15-40: digest: sha256:3fe8b828bc6f96b1c74220d90273147ee188601781330d3592bbffc4fa0897af size: 4951
[XPK] Task: `Upload Docker Image` terminated with code `0`
[XPK] Task: `Creating Workload` is implemented by `kubectl apply -f /tmp/tmpc65ikqh3`, streaming output live.
[XPK] Waiting for `Creating Workload`, for 0 seconds
jobset.jobset.x-k8s.io/piz-xpk-v6e-256 created
[XPK] Task: `Creating Workload` terminated with code `0`
[XPK] Task: `GKE Dashboard List` is implemented by `gcloud monitoring dashboards list --project=tpu-prod-env-automated --filter="displayName:'GKE - TPU Monitoring Dashboard'" --format="value(name)" --verbosity=error`, hiding output unless there is an error.
[XPK] No dashboard with displayName:'GKE - TPU Monitoring Dashboard' found in the project:tpu-prod-env-automated.
[XPK] Follow https://github.com/google/cloud-tpu-monitoring-debugging to deploy monitoring dashboard to view statistics and outlier mode of GKE metrics.
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/us-east5/bodaborg-v6e-256/default/piz-xpk-v6e-256/details?project=tpu-prod-env-automated
[XPK] Exiting XPK cleanly
```

This will point you to a workload link `https://console.cloud.google.com/kubernetes/service/...`. Follow the workload link and check the log. If the training works correctly, we shall see below info from the log explorer:

```
...
INFO {'train_runtime': 3240.2466, 'train_samples_per_second': 1.58, 'train_steps_per_second': 0.003, 'train_loss': 117.80369873046875, 'epoch': 0.37}
INFO ***** train metrics *****
INFO epoch = 0.3704
INFO total_flos = 94629384960GF
INFO train_loss = 117.8037
INFO train_runtime = 0:54:00.24
INFO train_samples = 13983
INFO train_samples_per_second = 1.58
INFO train_steps_per_second = 0.003
...
EXIT_CODE=0
XPK End: Thu Oct 31 02:03:01 UTC 2024
```

---
### 6. [Optional] Metric processing

You can use the profile
```
# this is the place we place the profile processing script
export PROFILE_SCRIPT_PATH=../../../../utils/

# download the profile from gcp bucket to local
gsutil cp -r $PROFILE_LOG_DIR ./

# locate the profile output ending with ".pb".
# Name it xplane.pb file, and process it
PYTHONPATH==$PROFILE_SCRIPT_PATH:$PYTHONPATH python3 $PROFILE_SCRIPT_PATH/profile_convert.py xplane.pb
```

You will see output like that tells the average step time in second:
```
Parsing xplane.pb
Plane ID: 2, Name: /device:TPU:0
Line ID: 2, Name: XLA Modules
Event Metadata Name: SyncTensorsGraph.157979.161292(17070309993204983656), ID: 33756, Duration: 82.676126708172 s
Event Metadata Name: SyncTensorsGraph.157979.161292(17070309993204983656), ID: 33756, Duration: 79.991382263094 s
Event Metadata Name: SyncTensorsGraph.157979.161292(17070309993204983656), ID: 33756, Duration: 92.256847100156 s
Event Metadata Name: SyncTensorsGraph.157979.161292(17070309993204983656), ID: 33756, Duration: 86.394679781422 s
Event Metadata Name: SyncTensorsGraph.157979.161292(17070309993204983656), ID: 33756, Duration: 79.542469470578 s
Event Metadata Name: SyncTensorsGraph.157979.161292(17070309993204983656), ID: 33756, Duration: 48.444764038344 s
Got 6 iterations
81.3338
```
15 changes: 15 additions & 0 deletions training/trillium/Llama3-405B-PyTorch/XPK/benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

source env.sh

python3 xpk/xpk.py workload create \
--cluster ${CLUSTER_NAME} \
--base-docker-image=${BASE_DOCKER_IMAGE} \
--workload=${WORKLOAD_NAME} \
--tpu-type=${TPU_TYPE} \
--num-slices=${NUM_SLICE} \
--on-demand \
--zone=$ZONE \
--project=$PROJECT \
--enable-debug-logs \
--command="bash /app/train.sh"
34 changes: 34 additions & 0 deletions training/trillium/Llama3-405B-PyTorch/XPK/config_405b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 16384,
"initializer_range": 0.02,
"intermediate_size": 53248,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 128,
"num_hidden_layers": 126,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.42.3",
"use_cache": false,
"vocab_size": 128256
}
17 changes: 17 additions & 0 deletions training/trillium/Llama3-405B-PyTorch/XPK/env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

# Environment variables associated with XPK on GCP.
export ZONE=...
export PROJECT=...
export TPU_TYPE=v6e-256
export NUM_SLICE=2
export CLUSTER_NAME=xpk-$USER-... # use existing CLUSTER if you have

# Environment variables associated with training config.
export BATCH_PER_DEVICE=1
export SEQUENCE_LENGTH=8192
export MAX_STEP=50
export WORKLOAD_NAME=${USER}-xpk-${TPU_TYPE}-... # Your workload name. Need to update for different run.
export BASE_DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-llama@sha256:d3a4c09cd13dab2af8129e8438b0acf3f8b5a2370b94b69e2e3aac16530e3664
export PROFILE_LOG_DIR=... # GSC bucket to store profile in form of gs://...
export HF_TOKEN=... # Add your own Hugging face token to download model
50 changes: 50 additions & 0 deletions training/trillium/Llama3-405B-PyTorch/XPK/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/bin/bash

# XPK will create a new docker and copy env.sh file under /app/
source /app/env.sh

# Calculate the global batch size
# Extract the number after '-' in TPU_TYPE
TPU_NUM=$(echo "$TPU_TYPE" | grep -oP '(?<=-)\d+')
# Calculate GLOBAL_BATCH_SIZE
GLOBAL_BATCH_SIZE=$(( TPU_NUM * BATCH_PER_DEVICE * NUM_SLICE ))
export GLOBAL_BATCH_SIZE
echo "GLOBAL_BATCH_SIZE=$GLOBAL_BATCH_SIZE"

# Note --per_device_train_batch_size is the global batch size since we overwrite the dataloader in the HF trainer.
cd /workspace/ && \
export PJRT_DEVICE=TPU && \
export XLA_USE_SPMD=1 && \
export ENABLE_PJRT_COMPATIBILITY=true && \
export XLA_IR_DEBUG=1 && \
export XLA_HLO_DEBUG=1 && \
export PROFILE_EPOCH=0 && \
export PROFILE_STEP=3 && \
export PROFILE_DURATION_MS=450000 && \
export PROFILE_LOGDIR=${PROFILE_LOG_DIR} && \
export XLA_PERSISTENT_CACHE_PATH=/app/xla_cache/ && \
export TPU_LIBRARY_PATH=/workspace/_libtpu.so && \
export NUM_SLICE=${NUM_SLICE} && \

export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_spmd_rng_bit_generator_unsafe=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_use_enhanced_launch_barrier=true --xla_tpu_enable_all_experimental_scheduler_features=true --xla_tpu_enable_scheduler_memory_pressure_tracking=true --xla_tpu_host_transfer_overlap_limit=2 --xla_tpu_aggressive_opt_barrier_removal=ENABLED --xla_lhs_prioritize_async_depth_over_stall=ENABLED --xla_tpu_enable_ag_backward_pipelining=true --xla_should_allow_loop_variant_parameter_in_chain=ENABLED --xla_should_add_loop_invariant_op_in_chain=ENABLED --xla_max_concurrent_host_send_recv=100 --xla_tpu_scheduler_percent_shared_memory_limit=100 --xla_latency_hiding_scheduler_rerun=2 --megascale_graph_hang_threshold=30m --megascale_graph_within_launch_hang_threshold=30m --megascale_grpc_enable_xor_tracer=false --megascale_grpc_premap_memory_bytes=68719476736 --megascale_grpc_use_chaotic_good=true --megascale_grpc_use_event_engine_allocator=true --grpc_enable_tcp_recv_zerocopy=false --grpc_enable_rpc_receive_coalescing=true"

huggingface-cli login --token=${HF_TOKEN} && \
python3 transformers/examples/pytorch/language-modeling/run_clm.py \
--dataset_name=wikitext \
--dataset_config_name=wikitext-103-raw-v1 \
--per_device_train_batch_size=${GLOBAL_BATCH_SIZE} \
--do_train \
--output_dir=test-clm \
--overwrite_output_dir \
--config_name=/app/config_405b.json \
--cache_dir=cache \
--tokenizer_name=meta-llama/Meta-Llama-3.1-405B \
--block_size=${SEQUENCE_LENGTH} \
--optim=adafactor \
--save_strategy=no \
--logging_strategy=no \
--torch_dtype=bfloat16 \
--dataloader_drop_last=yes \
--flash_attention \
--spmd_2d_sharding=4 \
--max_steps=${MAX_STEP}