From 9f17a0413872c5b810df2b629689fce585f119e6 Mon Sep 17 00:00:00 2001 From: A9isha Date: Thu, 19 Dec 2024 07:45:36 +0000 Subject: [PATCH 1/3] fix orbax to hf ckpt converter --- MaxText/llama_mistral_mixtral_orbax_to_hf.py | 65 ++++---- MaxText/tests/orbax_to_hf_logit_checker.py | 154 +++++++++++++++++++ end_to_end/tpu/llama2/7b/test_llama2_7b.sh | 7 + 3 files changed, 192 insertions(+), 34 deletions(-) create mode 100644 MaxText/tests/orbax_to_hf_logit_checker.py diff --git a/MaxText/llama_mistral_mixtral_orbax_to_hf.py b/MaxText/llama_mistral_mixtral_orbax_to_hf.py index ec126bd4a..6d4ed92c1 100644 --- a/MaxText/llama_mistral_mixtral_orbax_to_hf.py +++ b/MaxText/llama_mistral_mixtral_orbax_to_hf.py @@ -24,8 +24,8 @@ python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml base_output_directory=path/to/saving/intermediate_MaxText_files - load_parameters_path=/path/to/MaxText/checkpoint run_name= model_name= - hardware=gpu + load_parameters_path=/path/to/MaxText/checkpoint scan_layers=false run_name= + model_name= hardware=gpu hf_model_path=/local/path/to/save/HF/model/to Note that we are saving the converted HuggingFace model to a local path. You can write to a GCS location by mounting @@ -77,6 +77,8 @@ def load_hf_model(model_size): model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") elif model_size == "mixtral-8x7b": model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", device_map="auto") + elif model_size == "llama3.1-8b": + model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B") else: raise NotImplementedError return model @@ -130,35 +132,29 @@ def convert_state_to_hf(training_state, model_size): print(f"Converting weights for layer {layer_int}") # Attention layers + intermediate_query = reverse_scale( + training_state.params["params"]["decoder"][f"layers_{layer_int}"]["self_attention"]["query"]["kernel"][:, :, :], + head_dim, + ) + if model_size[:8] != "llama3.1": + intermediate_query = unpermute_from_match_maxtext_rope(intermediate_query) hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = torch.tensor( - np.asarray( - unpermute_from_match_maxtext_rope( - reverse_scale( - training_state.params["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"][ - :, layer_int, :, : - ], - head_dim, - ) - ) - .reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim) - .T - ), + np.asarray(intermediate_query.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim).T), dtype=torch.float16, ) + intermediate_key = training_state.params["params"]["decoder"][f"layers_{layer_int}"]["self_attention"]["key"]["kernel"][ + :, :, : + ] + if model_size[:8] != "llama3.1": + intermediate_key = unpermute_from_match_maxtext_rope(intermediate_key) hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = torch.tensor( - np.asarray( - unpermute_from_match_maxtext_rope( - training_state.params["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"][:, layer_int, :, :] - ) - .reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim) - .T - ), + np.asarray(intermediate_key.reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T), dtype=torch.float16, ) hf_model_params[f"model.layers.{layer_int}.self_attn.v_proj.weight"] = torch.tensor( np.asarray( - training_state.params["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"][:, layer_int, :, :] + training_state.params["params"]["decoder"][f"layers_{layer_int}"]["self_attention"]["value"]["kernel"][:, :, :] .reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim) .T ), @@ -166,7 +162,7 @@ def convert_state_to_hf(training_state, model_size): ) hf_model_params[f"model.layers.{layer_int}.self_attn.o_proj.weight"] = torch.tensor( np.asarray( - training_state.params["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"][:, layer_int, :, :] + training_state.params["params"]["decoder"][f"layers_{layer_int}"]["self_attention"]["out"]["kernel"][:, :, :] .reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim) .T ), @@ -176,51 +172,51 @@ def convert_state_to_hf(training_state, model_size): # MLP Layers if num_experts is None: hf_model_params[f"model.layers.{layer_int}.mlp.gate_proj.weight"] = torch.tensor( - np.asarray(training_state.params["params"]["decoder"]["layers"]["mlp"]["wi_0"]["kernel"][:, layer_int, :].T), + np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["mlp"]["wi_0"]["kernel"][:, :].T), dtype=torch.float16, ) hf_model_params[f"model.layers.{layer_int}.mlp.up_proj.weight"] = torch.tensor( - np.asarray(training_state.params["params"]["decoder"]["layers"]["mlp"]["wi_1"]["kernel"][:, layer_int, :].T), + np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["mlp"]["wi_1"]["kernel"][:, :].T), dtype=torch.float16, ) hf_model_params[f"model.layers.{layer_int}.mlp.down_proj.weight"] = torch.tensor( - np.asarray(training_state.params["params"]["decoder"]["layers"]["mlp"]["wo"]["kernel"][:, layer_int, :].T), + np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["mlp"]["wo"]["kernel"][:, :].T), dtype=torch.float16, ) else: hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.gate.weight"] = torch.tensor( np.asarray( - training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["gate"]["kernel"][:, layer_int, :].T + training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["gate"]["kernel"][:, :].T ), dtype=torch.float16, ) for k in range(num_experts): hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w1.weight"] = torch.tensor( - np.asarray(training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["wi_0"][k, layer_int, :, :].T), + np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["wi_0"][k, :, :].T), dtype=torch.float16, ) hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w2.weight"] = torch.tensor( - np.asarray(training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["wo"][k, layer_int, :, :].T), + np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["wo"][k, :, :].T), dtype=torch.float16, ) hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w3.weight"] = torch.tensor( - np.asarray(training_state.params["params"]["decoder"]["layers"]["MoeBlock_0"]["wi_1"][k, layer_int, :, :].T), + np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["wi_1"][k, :, :].T), dtype=torch.float16, ) # Pre/post attention layer norm hf_model_params[f"model.layers.{layer_int}.input_layernorm.weight"] = torch.tensor( np.asarray( - training_state.params["params"]["decoder"]["layers"]["pre_self_attention_layer_norm"]["scale"][ - :, layer_int + training_state.params["params"]["decoder"][f"layers_{layer_int}"]["pre_self_attention_layer_norm"]["scale"][ + : ].reshape(base_num_query_heads * head_dim) ), dtype=torch.float16, ) hf_model_params[f"model.layers.{layer_int}.post_attention_layernorm.weight"] = torch.tensor( np.asarray( - training_state.params["params"]["decoder"]["layers"]["post_self_attention_layer_norm"]["scale"][ - :, layer_int + training_state.params["params"]["decoder"][f"layers_{layer_int}"]["post_self_attention_layer_norm"]["scale"][ + : ].reshape(base_num_query_heads * head_dim) ), dtype=torch.float16, @@ -253,6 +249,7 @@ def convert_orbax_hf(hf_model_path, config): def main(argv: Sequence[str]): pyconfig.initialize(argv[:-1]) + # Assuming the last argument is the path to save the converted checkpoint in HuggingFace format hf_model_path = argv[-1].split("=")[1] print(f"Will save converted HuggingFace checkpoint to path = {hf_model_path}") diff --git a/MaxText/tests/orbax_to_hf_logit_checker.py b/MaxText/tests/orbax_to_hf_logit_checker.py new file mode 100644 index 000000000..9fd0cff5d --- /dev/null +++ b/MaxText/tests/orbax_to_hf_logit_checker.py @@ -0,0 +1,154 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This forward_pass_logit_checker.py file compares the logits generated by MaxText implementation for some input prompts +# with the golden logits for those input prompts for a particular model. This forward_pass_logit_checker.py is generic that +# it can work with different models and expects an input file called golden_data_.jsonl to be present +# under MaxText/test_assets +# For e.g., MaxText/test_assets/golden_data_llama2-7b.jsonl +# The golden jsonl file is a simple jsonlines file with each line is in the format of a dictionary containing the following +# required keys: +# 1. prompt: A string representing the prompt, for e.g., "I love to", +# 2. tokens: token ids after tokenizing the prompt, +# 3. logits: golden logits meaning the ideal logits generated by the model in question when fed with the prompt in #1 +# There can be multiple such test cases in the jsonl file, each test case is a new line in the jsonl file +# This forward_pass_logit_checker.py runs the forward pass with the input tokens and asserts that the logits generated by the +# MaxText implementation of the same model matches the golden logits closely +# Users could use a script similar to MaxText/scratch_code/golden_llama2-7b_export.ipynb to create this jsonl file + +"""Check if the logits generated by a model's MaxText orbax to HF ckpt matches golden logits for the same inputs""" +import argparse +import sys +import os + +current_dir = os.path.dirname(os.path.abspath(__file__)) +maxtext_parent_dir = os.path.dirname(current_dir) +sys.path.append(maxtext_parent_dir) + +import max_logging + +max_logging.log(f"Added parent directory = {maxtext_parent_dir}") + +import common_types +import jax +import jax.numpy as jnp +import numpy as np +import pyconfig +import jsonlines +import max_utils +from layers import models +from layers import quantizations + +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + +hf_model_name_map = {"llama2-7b": "meta-llama/Llama-2-7b-hf"} + + +def get_data(golden_data, golden_data_index, config): + """Get the golden data for the test indexed at golden_data_index""" + + max_logging.log(f"Comparing forward pass for golden data index = {golden_data_index} ") + max_logging.log(f"config.global_batch_size_to_train_on={config.global_batch_size_to_train_on}") + s = (config.global_batch_size_to_train_on, config.max_target_length) + ids = np.asarray(golden_data[golden_data_index]["tokens"], dtype=np.int32) + + logits = np.asarray(golden_data[golden_data_index]["logits"], dtype=np.float32) + max_logging.log(f" prompt=\"{golden_data[golden_data_index]['prompt']}\" raw ids={ids}, logits.shape = {logits.shape}") + + decoder_segment_ids = jax.numpy.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + decoder_positions = jnp.stack( + [jnp.arange(config.max_target_length, dtype=jnp.int32) for _ in range(config.global_batch_size_to_train_on)] + ) + + ids = jnp.stack([ids for _ in range(config.global_batch_size_to_train_on)]) + max_logging.log(f"ids={ids}, decoder_segment_ids = {decoder_segment_ids}, decoder_positions= {decoder_positions}") + + return ids, decoder_segment_ids, decoder_positions, logits + + +def main(config, test_args): + """Test the Whole Model of model_name""" + + input_golden_data_path = "MaxText/test_assets/golden_data_" + config.model_name + ".jsonl" + with jsonlines.open(input_golden_data_path, "r") as f: + golden_data = list(f) + + model_name_hf = hf_model_name_map[config.model_name] + tokenizer = AutoTokenizer.from_pretrained(model_name_hf) + + for golden_data_index in range(len(golden_data)): + ids, decoder_segment_ids, decoder_positions, golden_logits = get_data(golden_data, golden_data_index, config) + + maxtext_model = AutoModelForCausalLM.from_pretrained(test_args.maxtext_hf_model_path) + + full_train_logits = ( + maxtext_model(torch.tensor(ids.tolist(), requires_grad=False), output_hidden_states=True).logits.detach().numpy() + ) + + full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits) + max_logging.log(f"{golden_logits[0]=}") + max_logging.log(f"{full_train_logits[0, 0, :]=}") + token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0] + max_logging.log( + f"Max Numerical Difference {np.max(np.subtract(full_train_logits[0, :token_size, :], golden_logits[:token_size, :]))}" + ) + + model_probabilities = jax.nn.softmax(full_train_logits[0, :token_size, :], axis=-1) + golden_probabilities = jax.nn.softmax(golden_logits[:token_size, :], axis=-1) + + max_logging.log(f"{golden_probabilities[0]=}") + max_logging.log(f"{model_probabilities[0]=}") + + kl_div = jax.numpy.sum(jax.scipy.special.kl_div(golden_probabilities, model_probabilities), axis=-1) + max_logging.log(f"KL divergence = {kl_div}, max KL divergence = {jax.numpy.max(kl_div)}") + + if test_args.max_kl_div is not None: + max_logging.log("Checking KL Divergence between train distribution and golden distribution") + assert jax.numpy.all( + kl_div < test_args.max_kl_div + ), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}" + else: + max_logging.log("Checking Numerical Differences between train logits and golden logits") + assert jax.numpy.allclose( + full_train_logits[0, :token_size, :], + golden_logits[:token_size, :], + rtol=float(test_args.rtol), + atol=float(test_args.atol), + equal_nan=False, + ), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}." + + +if __name__ == "__main__": + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + parser = argparse.ArgumentParser() + parser.add_argument("--atol", type=float, required=False, default=0.1) + parser.add_argument("--rtol", type=float, required=False, default=0.1) + parser.add_argument("--token_size", type=int, required=False) + parser.add_argument("--max_kl_div", type=float, required=False, default=None) + parser.add_argument("--maxtext-hf-model-path", type=str, required=True, default=None) + test_args, _ = parser.parse_known_args() + + # Remove args defined in this test file to avoid error from pyconfig + model_args = sys.argv + to_remove_args = ["--atol", "--rtol", "--token_size", "--max_kl_div", "--maxtext-hf-model-path"] + for arg in to_remove_args: + model_args = [s for s in model_args if not s.startswith(arg)] + + pyconfig.initialize(model_args) + cfg = pyconfig.config + main(cfg, test_args) diff --git a/end_to_end/tpu/llama2/7b/test_llama2_7b.sh b/end_to_end/tpu/llama2/7b/test_llama2_7b.sh index de361f99f..c4d1982af 100644 --- a/end_to_end/tpu/llama2/7b/test_llama2_7b.sh +++ b/end_to_end/tpu/llama2/7b/test_llama2_7b.sh @@ -72,3 +72,10 @@ python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${NEW_CK # We also test whether the forward pass logits match the golden logits for Llama2-7b python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false + +# We convert an Orbax checkpoint generated by MaxText back to Huggingface fornat +export ORBAX_TO_HF_CHECKPOINT_CONVERSION=orbax2hf_checkpoint_conversion_${idx} +python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${NEW_CKPT_PATH} scan_layers=false run_name=${ORBAX_TO_HF_CHECKPOINT_CONVERSION} model_name='llama2-7b' hf_model_path=/tmp/llama2-7b-ckpt/${ORBAX_TO_HF_CHECKPOINT_CONVERSION} + +# We test that the logits generated by a forward pass from the new huggingface format checkpoint match the golden logits for Llama2-7b +python3 MaxText/tests/orbax_to_hf_logit_checker.py MaxText/configs/base.yml run_name=orbax_to_hf_logit_checker per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false --maxtext-hf-model-path=/tmp/llama2-7b-ckpt/${ORBAX_TO_HF_CHECKPOINT_CONVERSION} \ No newline at end of file From c2b0f911604073248797ccc89c47ae69069ad1d6 Mon Sep 17 00:00:00 2001 From: A9isha Date: Thu, 19 Dec 2024 21:25:44 +0000 Subject: [PATCH 2/3] change dtype --- MaxText/llama_mistral_mixtral_orbax_to_hf.py | 35 +++++++++++--------- MaxText/tests/orbax_to_hf_logit_checker.py | 4 ++- end_to_end/tpu/llama2/7b/test_llama2_7b.sh | 2 +- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/MaxText/llama_mistral_mixtral_orbax_to_hf.py b/MaxText/llama_mistral_mixtral_orbax_to_hf.py index 6d4ed92c1..1912c7680 100644 --- a/MaxText/llama_mistral_mixtral_orbax_to_hf.py +++ b/MaxText/llama_mistral_mixtral_orbax_to_hf.py @@ -123,11 +123,14 @@ def convert_state_to_hf(training_state, model_size): hf_model_params = {} + converted_dtype = torch.bfloat16 if model_size[:6] == "llama3" else torch.float16 + # Port the embedding weights hf_model_params["model.embed_tokens.weight"] = torch.tensor( - np.asarray(training_state.params["params"]["token_embedder"]["embedding"]), dtype=torch.float16 + np.asarray(training_state.params["params"]["token_embedder"]["embedding"]), dtype=converted_dtype ) + for layer_int in tqdm(range(base_num_decoder_layers), desc="Porting parameters layerwise"): print(f"Converting weights for layer {layer_int}") @@ -140,7 +143,7 @@ def convert_state_to_hf(training_state, model_size): intermediate_query = unpermute_from_match_maxtext_rope(intermediate_query) hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = torch.tensor( np.asarray(intermediate_query.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim).T), - dtype=torch.float16, + dtype=converted_dtype, ) intermediate_key = training_state.params["params"]["decoder"][f"layers_{layer_int}"]["self_attention"]["key"]["kernel"][ @@ -150,7 +153,7 @@ def convert_state_to_hf(training_state, model_size): intermediate_key = unpermute_from_match_maxtext_rope(intermediate_key) hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = torch.tensor( np.asarray(intermediate_key.reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T), - dtype=torch.float16, + dtype=converted_dtype, ) hf_model_params[f"model.layers.{layer_int}.self_attn.v_proj.weight"] = torch.tensor( np.asarray( @@ -158,7 +161,7 @@ def convert_state_to_hf(training_state, model_size): .reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim) .T ), - dtype=torch.float16, + dtype=converted_dtype, ) hf_model_params[f"model.layers.{layer_int}.self_attn.o_proj.weight"] = torch.tensor( np.asarray( @@ -166,42 +169,42 @@ def convert_state_to_hf(training_state, model_size): .reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim) .T ), - dtype=torch.float16, + dtype=converted_dtype, ) # MLP Layers if num_experts is None: hf_model_params[f"model.layers.{layer_int}.mlp.gate_proj.weight"] = torch.tensor( np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["mlp"]["wi_0"]["kernel"][:, :].T), - dtype=torch.float16, + dtype=converted_dtype, ) hf_model_params[f"model.layers.{layer_int}.mlp.up_proj.weight"] = torch.tensor( np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["mlp"]["wi_1"]["kernel"][:, :].T), - dtype=torch.float16, + dtype=converted_dtype, ) hf_model_params[f"model.layers.{layer_int}.mlp.down_proj.weight"] = torch.tensor( np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["mlp"]["wo"]["kernel"][:, :].T), - dtype=torch.float16, + dtype=converted_dtype, ) else: hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.gate.weight"] = torch.tensor( np.asarray( training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["gate"]["kernel"][:, :].T ), - dtype=torch.float16, + dtype=converted_dtype, ) for k in range(num_experts): hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w1.weight"] = torch.tensor( np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["wi_0"][k, :, :].T), - dtype=torch.float16, + dtype=converted_dtype, ) hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w2.weight"] = torch.tensor( np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["wo"][k, :, :].T), - dtype=torch.float16, + dtype=converted_dtype, ) hf_model_params[f"model.layers.{layer_int}.block_sparse_moe.experts.{k}.w3.weight"] = torch.tensor( np.asarray(training_state.params["params"]["decoder"][f"layers_{layer_int}"]["MoeBlock_0"]["wi_1"][k, :, :].T), - dtype=torch.float16, + dtype=converted_dtype, ) # Pre/post attention layer norm @@ -211,7 +214,7 @@ def convert_state_to_hf(training_state, model_size): : ].reshape(base_num_query_heads * head_dim) ), - dtype=torch.float16, + dtype=converted_dtype, ) hf_model_params[f"model.layers.{layer_int}.post_attention_layernorm.weight"] = torch.tensor( np.asarray( @@ -219,18 +222,18 @@ def convert_state_to_hf(training_state, model_size): : ].reshape(base_num_query_heads * head_dim) ), - dtype=torch.float16, + dtype=converted_dtype, ) # LM head and layernorm hf_model_params["lm_head.weight"] = torch.tensor( - np.asarray(training_state.params["params"]["decoder"]["logits_dense"]["kernel"].T), dtype=torch.float16 + np.asarray(training_state.params["params"]["decoder"]["logits_dense"]["kernel"].T), dtype=converted_dtype ) hf_model_params["model.norm.weight"] = torch.tensor( np.asarray( training_state.params["params"]["decoder"]["decoder_norm"]["scale"].reshape(base_num_query_heads * head_dim) ), - dtype=torch.float16, + dtype=converted_dtype, ) return hf_model_params diff --git a/MaxText/tests/orbax_to_hf_logit_checker.py b/MaxText/tests/orbax_to_hf_logit_checker.py index 9fd0cff5d..c820a3f13 100644 --- a/MaxText/tests/orbax_to_hf_logit_checker.py +++ b/MaxText/tests/orbax_to_hf_logit_checker.py @@ -54,7 +54,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM import torch -hf_model_name_map = {"llama2-7b": "meta-llama/Llama-2-7b-hf"} +hf_model_name_map = {"llama2-7b": "meta-llama/Llama-2-7b-hf", + "llama3.1-8b": "meta-llama/Meta-Llama-3.1-8B" + } def get_data(golden_data, golden_data_index, config): diff --git a/end_to_end/tpu/llama2/7b/test_llama2_7b.sh b/end_to_end/tpu/llama2/7b/test_llama2_7b.sh index c4d1982af..31c7481eb 100644 --- a/end_to_end/tpu/llama2/7b/test_llama2_7b.sh +++ b/end_to_end/tpu/llama2/7b/test_llama2_7b.sh @@ -78,4 +78,4 @@ export ORBAX_TO_HF_CHECKPOINT_CONVERSION=orbax2hf_checkpoint_conversion_${idx} python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${NEW_CKPT_PATH} scan_layers=false run_name=${ORBAX_TO_HF_CHECKPOINT_CONVERSION} model_name='llama2-7b' hf_model_path=/tmp/llama2-7b-ckpt/${ORBAX_TO_HF_CHECKPOINT_CONVERSION} # We test that the logits generated by a forward pass from the new huggingface format checkpoint match the golden logits for Llama2-7b -python3 MaxText/tests/orbax_to_hf_logit_checker.py MaxText/configs/base.yml run_name=orbax_to_hf_logit_checker per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false --maxtext-hf-model-path=/tmp/llama2-7b-ckpt/${ORBAX_TO_HF_CHECKPOINT_CONVERSION} \ No newline at end of file +python3 MaxText/tests/orbax_to_hf_logit_checker.py MaxText/configs/base.yml run_name=orbax_to_hf_logit_checker per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false --maxtext-hf-model-path=/tmp/llama2-7b-ckpt/${ORBAX_TO_HF_CHECKPOINT_CONVERSION} From fc32a58ee6f0278480659a172e67a2963b599565 Mon Sep 17 00:00:00 2001 From: A9isha Date: Thu, 19 Dec 2024 21:28:56 +0000 Subject: [PATCH 3/3] fix Q K --- MaxText/llama_mistral_mixtral_orbax_to_hf.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/MaxText/llama_mistral_mixtral_orbax_to_hf.py b/MaxText/llama_mistral_mixtral_orbax_to_hf.py index 1912c7680..671dd0b13 100644 --- a/MaxText/llama_mistral_mixtral_orbax_to_hf.py +++ b/MaxText/llama_mistral_mixtral_orbax_to_hf.py @@ -145,6 +145,8 @@ def convert_state_to_hf(training_state, model_size): np.asarray(intermediate_query.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim).T), dtype=converted_dtype, ) + hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"] = hf_model_params[f"model.layers.{layer_int}.self_attn.q_proj.weight"].view(base_num_query_heads * head_dim, base_num_query_heads * head_dim).T.view(base_num_query_heads, head_dim // 2, 2, base_num_query_heads * head_dim).transpose(1, 2).reshape(-1, base_num_query_heads * head_dim) + intermediate_key = training_state.params["params"]["decoder"][f"layers_{layer_int}"]["self_attention"]["key"]["kernel"][ :, :, : @@ -155,6 +157,11 @@ def convert_state_to_hf(training_state, model_size): np.asarray(intermediate_key.reshape(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T), dtype=converted_dtype, ) + + hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"] = hf_model_params[f"model.layers.{layer_int}.self_attn.k_proj.weight"].view(base_num_query_heads * head_dim, base_num_kv_heads * head_dim).T.reshape(base_num_kv_heads, head_dim // 2, 2, base_num_query_heads * head_dim).transpose(1, 2).reshape(-1 ,base_num_query_heads * head_dim) + + + hf_model_params[f"model.layers.{layer_int}.self_attn.v_proj.weight"] = torch.tensor( np.asarray( training_state.params["params"]["decoder"][f"layers_{layer_int}"]["self_attention"]["value"]["kernel"][:, :, :]