Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anisha ckpt2hf1 #1106

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 57 additions & 50 deletions MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
@@ -24,8 +24,8 @@

python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml
base_output_directory=path/to/saving/intermediate_MaxText_files
load_parameters_path=/path/to/MaxText/checkpoint run_name=<your run name> model_name=<llama2 or mistral>
hardware=gpu
load_parameters_path=/path/to/MaxText/checkpoint scan_layers=false run_name=<your run name>
model_name=<llama2 or mistral> hardware=gpu
hf_model_path=/local/path/to/save/HF/model/to

Note that we are saving the converted HuggingFace model to a local path. You can write to a GCS location by mounting
@@ -77,6 +77,8 @@ def load_hf_model(model_size):
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
elif model_size == "mixtral-8x7b":
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", device_map="auto")
elif model_size == "llama3.1-8b":
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
else:
raise NotImplementedError
return model
@@ -121,120 +123,124 @@ def convert_state_to_hf(training_state, model_size):

hf_model_params = {}

converted_dtype = torch.bfloat16 if model_size[:6] == "llama3" else torch.float16

# Port the embedding weights
hf_model_params["model.embed_tokens.weight"] = torch.tensor(
np.asarray(training_state.params["params"]["token_embedder"]["embedding"]), dtype=torch.float16
np.asarray(training_state.params["params"]["token_embedder"]["embedding"]), dtype=converted_dtype
)


for layer_int in tqdm(range(base_num_decoder_layers), desc="Porting parameters layerwise"):
print(f"Converting weights for layer {layer_int}")

# Attention layers
intermediate_query = reverse_scale(
training_state.params["params"]["decoder"][f"layers_{layer_int}"]["self_attention"]["query"]["kernel"][:, :, :],
head_dim,
)
if model_size[:8] != "llama3.1":
intermediate_query = unpermute_from_match_maxtext_rope(intermediate_query)
hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = torch.tensor(
np.asarray(
unpermute_from_match_maxtext_rope(
reverse_scale(
training_state.params["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"][
:, layer_int, :, :
],
head_dim,
)
)
.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim)
.T
),
dtype=torch.float16,
np.asarray(intermediate_query.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim).T),
dtype=converted_dtype,
)
hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"].view(base_num_query_heads * head_dim, base_num_query_heads * head_dim).T.view(base_num_query_heads, head_dim // 2, 2, base_num_query_heads * head_dim).transpose(1, 2).reshape(-1, base_num_query_heads * head_dim)


intermediate_key = training_state.params["params"]["decoder"][f"layers_{layer_int}"]["self_attention"]["key"]["kernel"][
:, :, :
]
if model_size[:8] != "llama3.1":
intermediate_key = unpermute_from_match_maxtext_rope(intermediate_key)
hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = torch.tensor(
np.asarray(
unpermute_from_match_maxtext_rope(
training_state.params["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"][:, layer_int, :, :]
)
.reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim)
.T
),
dtype=torch.float16,
np.asarray(intermediate_key.reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T),
dtype=converted_dtype,
)

hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"].view(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T.reshape(base_num_kv_heads, head_dim // 2, 2, base_num_query_heads * head_dim).transpose(1, 2).reshape(-1 ,base_num_query_heads * head_dim)



hf_model_params[f"model.layers.{layer_int}.self_attn.v_proj.weight"] = torch.tensor(
np.asarray(
training_state.params["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"][:, layer_int, :, :]
training_state.params["params"]["decoder"][f"layers_{layer_int}"]["self_attention"]["value"]["kernel"][:, :, :]
.reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim)
.T
),
dtype=torch.float16,
dtype=converted_dtype,
)
hf_model_params[f"model.layers.{layer_int}.self_attn.o_proj.weight"] = torch.tensor(
np.asarray(
training_state.params["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"][:, layer_int, :, :]
training_state.params["params"]["decoder"][f"layers_{layer_int}"]["self_attention"]["out"]["kernel"][:, :, :]
.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim)
.T
),
dtype=torch.float16,
dtype=converted_dtype,
)

# MLP Layers
if num_experts is None:
hf_model_params[f"model.layers.{layer_int}.mlp.gate_proj.weight"] = torch.tensor(
np.asarray(training_state.params["params"]["decoder"]["layers"]["mlp"]["wi_0"]["kernel"][:, layer_int, :].T),
dtype=torch.float16,
np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["mlp"]["wi_0"]["kernel"][:, :].T),
dtype=converted_dtype,
)
hf_model_params[f"model.layers.{layer_int}.mlp.up_proj.weight"] = torch.tensor(
np.asarray(training_state.params["params"]["decoder"]["layers"]["mlp"]["wi_1"]["kernel"][:, layer_int, :].T),
dtype=torch.float16,
np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["mlp"]["wi_1"]["kernel"][:, :].T),
dtype=converted_dtype,
)
hf_model_params[f"model.layers.{layer_int}.mlp.down_proj.weight"] = torch.tensor(
np.asarray(training_state.params["params"]["decoder"]["layers"]["mlp"]["wo"]["kernel"][:, layer_int, :].T),
dtype=torch.float16,
np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["mlp"]["wo"]["kernel"][:, :].T),
dtype=converted_dtype,
)
else:
hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.gate.weight"] = torch.tensor(
np.asarray(
training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["gate"]["kernel"][:, layer_int, :].T
training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["gate"]["kernel"][:, :].T
),
dtype=torch.float16,
dtype=converted_dtype,
)
for k in range(num_experts):
hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w1.weight"] = torch.tensor(
np.asarray(training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["wi_0"][k, layer_int, :, :].T),
dtype=torch.float16,
np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["wi_0"][k, :, :].T),
dtype=converted_dtype,
)
hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w2.weight"] = torch.tensor(
np.asarray(training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["wo"][k, layer_int, :, :].T),
dtype=torch.float16,
np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["wo"][k, :, :].T),
dtype=converted_dtype,
)
hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w3.weight"] = torch.tensor(
np.asarray(training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["wi_1"][k, layer_int, :, :].T),
dtype=torch.float16,
np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["wi_1"][k, :, :].T),
dtype=converted_dtype,
)

# Pre/post attention layer norm
hf_model_params[f"model.layers.{layer_int}.input_layernorm.weight"] = torch.tensor(
np.asarray(
training_state.params["params"]["decoder"]["layers"]["pre_self_attention_layer_norm"]["scale"][
:, layer_int
training_state.params["params"]["decoder"][f"layers_{layer_int}"]["pre_self_attention_layer_norm"]["scale"][
:
].reshape(base_num_query_heads * head_dim)
),
dtype=torch.float16,
dtype=converted_dtype,
)
hf_model_params[f"model.layers.{layer_int}.post_attention_layernorm.weight"] = torch.tensor(
np.asarray(
training_state.params["params"]["decoder"]["layers"]["post_self_attention_layer_norm"]["scale"][
:, layer_int
training_state.params["params"]["decoder"][f"layers_{layer_int}"]["post_self_attention_layer_norm"]["scale"][
:
].reshape(base_num_query_heads * head_dim)
),
dtype=torch.float16,
dtype=converted_dtype,
)

# LM head and layernorm
hf_model_params["lm_head.weight"] = torch.tensor(
np.asarray(training_state.params["params"]["decoder"]["logits_dense"]["kernel"].T), dtype=torch.float16
np.asarray(training_state.params["params"]["decoder"]["logits_dense"]["kernel"].T), dtype=converted_dtype
)
hf_model_params["model.norm.weight"] = torch.tensor(
np.asarray(
training_state.params["params"]["decoder"]["decoder_norm"]["scale"].reshape(base_num_query_heads * head_dim)
),
dtype=torch.float16,
dtype=converted_dtype,
)

return hf_model_params
@@ -253,6 +259,7 @@ def convert_orbax_hf(hf_model_path, config):

def main(argv: Sequence[str]):
pyconfig.initialize(argv[:-1])
# Assuming the last argument is the path to save the converted checkpoint in HuggingFace format
hf_model_path = argv[-1].split("=")[1]
print(f"Will save converted HuggingFace checkpoint to path = {hf_model_path}")

156 changes: 156 additions & 0 deletions MaxText/tests/orbax_to_hf_logit_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# This forward_pass_logit_checker.py file compares the logits generated by MaxText implementation for some input prompts
# with the golden logits for those input prompts for a particular model. This forward_pass_logit_checker.py is generic that
# it can work with different models and expects an input file called golden_data_<model_name>.jsonl to be present
# under MaxText/test_assets
# For e.g., MaxText/test_assets/golden_data_llama2-7b.jsonl
# The golden jsonl file is a simple jsonlines file with each line is in the format of a dictionary containing the following
# required keys:
# 1. prompt: A string representing the prompt, for e.g., "I love to",
# 2. tokens: token ids after tokenizing the prompt,
# 3. logits: golden logits meaning the ideal logits generated by the model in question when fed with the prompt in #1
# There can be multiple such test cases in the jsonl file, each test case is a new line in the jsonl file
# This forward_pass_logit_checker.py runs the forward pass with the input tokens and asserts that the logits generated by the
# MaxText implementation of the same model matches the golden logits closely
# Users could use a script similar to MaxText/scratch_code/golden_llama2-7b_export.ipynb to create this jsonl file

"""Check if the logits generated by a model's MaxText orbax to HF ckpt matches golden logits for the same inputs"""
import argparse
import sys
import os

current_dir = os.path.dirname(os.path.abspath(__file__))
maxtext_parent_dir = os.path.dirname(current_dir)
sys.path.append(maxtext_parent_dir)

import max_logging

max_logging.log(f"Added parent directory = {maxtext_parent_dir}")

import common_types
import jax
import jax.numpy as jnp
import numpy as np
import pyconfig
import jsonlines
import max_utils
from layers import models
from layers import quantizations

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

hf_model_name_map = {"llama2-7b": "meta-llama/Llama-2-7b-hf",
"llama3.1-8b": "meta-llama/Meta-Llama-3.1-8B"
}


def get_data(golden_data, golden_data_index, config):
"""Get the golden data for the test indexed at golden_data_index"""

max_logging.log(f"Comparing forward pass for golden data index = {golden_data_index} ")
max_logging.log(f"config.global_batch_size_to_train_on={config.global_batch_size_to_train_on}")
s = (config.global_batch_size_to_train_on, config.max_target_length)
ids = np.asarray(golden_data[golden_data_index]["tokens"], dtype=np.int32)

logits = np.asarray(golden_data[golden_data_index]["logits"], dtype=np.float32)
max_logging.log(f" prompt=\"{golden_data[golden_data_index]['prompt']}\" raw ids={ids}, logits.shape = {logits.shape}")

decoder_segment_ids = jax.numpy.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR
decoder_positions = jnp.stack(
[jnp.arange(config.max_target_length, dtype=jnp.int32) for _ in range(config.global_batch_size_to_train_on)]
)

ids = jnp.stack([ids for _ in range(config.global_batch_size_to_train_on)])
max_logging.log(f"ids={ids}, decoder_segment_ids = {decoder_segment_ids}, decoder_positions= {decoder_positions}")

return ids, decoder_segment_ids, decoder_positions, logits


def main(config, test_args):
"""Test the Whole Model of model_name"""

input_golden_data_path = "MaxText/test_assets/golden_data_" + config.model_name + ".jsonl"
with jsonlines.open(input_golden_data_path, "r") as f:
golden_data = list(f)

model_name_hf = hf_model_name_map[config.model_name]
tokenizer = AutoTokenizer.from_pretrained(model_name_hf)

for golden_data_index in range(len(golden_data)):
ids, decoder_segment_ids, decoder_positions, golden_logits = get_data(golden_data, golden_data_index, config)

maxtext_model = AutoModelForCausalLM.from_pretrained(test_args.maxtext_hf_model_path)

full_train_logits = (
maxtext_model(torch.tensor(ids.tolist(), requires_grad=False), output_hidden_states=True).logits.detach().numpy()
)

full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)
max_logging.log(f"{golden_logits[0]=}")
max_logging.log(f"{full_train_logits[0, 0, :]=}")
token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0]
max_logging.log(
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}"
)

model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1)
golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1)

max_logging.log(f"{golden_probabilities[0]=}")
max_logging.log(f"{model_probabilities[0]=}")

kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1)
max_logging.log(f"KL divergence = {kl_div}, max KL divergence = {jax.numpy.max(kl_div)}")

if test_args.max_kl_div is not None:
max_logging.log("Checking KL Divergence between train distribution and golden distribution")
assert jax.numpy.all(
kl_div < test_args.max_kl_div
), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}"
else:
max_logging.log("Checking Numerical Differences between train logits and golden logits")
assert jax.numpy.allclose(
full_train_logits[0, :token_size, :],
golden_logits[:token_size, :],
rtol=float(test_args.rtol),
atol=float(test_args.atol),
equal_nan=False,
), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}."


if __name__ == "__main__":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

parser = argparse.ArgumentParser()
parser.add_argument("--atol", type=float, required=False, default=0.1)
parser.add_argument("--rtol", type=float, required=False, default=0.1)
parser.add_argument("--token_size", type=int, required=False)
parser.add_argument("--max_kl_div", type=float, required=False, default=None)
parser.add_argument("--maxtext-hf-model-path", type=str, required=True, default=None)
test_args, _ = parser.parse_known_args()

# Remove args defined in this test file to avoid error from pyconfig
model_args = sys.argv
to_remove_args = ["--atol", "--rtol", "--token_size", "--max_kl_div", "--maxtext-hf-model-path"]
for arg in to_remove_args:
model_args = [s for s in model_args if not s.startswith(arg)]

pyconfig.initialize(model_args)
cfg = pyconfig.config
main(cfg, test_args)
7 changes: 7 additions & 0 deletions end_to_end/tpu/llama2/7b/test_llama2_7b.sh
Original file line number Diff line number Diff line change
@@ -72,3 +72,10 @@ python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${NEW_CK

# We also test whether the forward pass logits match the golden logits for Llama2-7b
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false

# We convert an Orbax checkpoint generated by MaxText back to Huggingface fornat
export ORBAX_TO_HF_CHECKPOINT_CONVERSION=orbax2hf_checkpoint_conversion_${idx}
python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${NEW_CKPT_PATH} scan_layers=false run_name=${ORBAX_TO_HF_CHECKPOINT_CONVERSION} model_name='llama2-7b' hf_model_path=/tmp/llama2-7b-ckpt/${ORBAX_TO_HF_CHECKPOINT_CONVERSION}

# We test that the logits generated by a forward pass from the new huggingface format checkpoint match the golden logits for Llama2-7b
python3 MaxText/tests/orbax_to_hf_logit_checker.py MaxText/configs/base.yml run_name=orbax_to_hf_logit_checker per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false --maxtext-hf-model-path=/tmp/llama2-7b-ckpt/${ORBAX_TO_HF_CHECKPOINT_CONVERSION}