From 1bfa18a654a51fd85da04fd183dfb590a922173c Mon Sep 17 00:00:00 2001 From: diaoshizhe <654745845@qq.com> Date: Sat, 1 Apr 2023 10:39:25 +0800 Subject: [PATCH 1/3] only main process require input --- examples/chatbot.py | 11 +++++++++-- scripts/run_chatbot.sh | 6 +++--- src/lmflow/models/hf_decoder_model.py | 12 ++++++++---- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/examples/chatbot.py b/examples/chatbot.py index 7d835acdf..20b205978 100644 --- a/examples/chatbot.py +++ b/examples/chatbot.py @@ -14,7 +14,7 @@ from lmflow.pipeline.auto_pipeline import AutoPipeline from lmflow.models.auto_model import AutoModel from lmflow.args import ModelArguments, DatasetArguments, AutoArguments - +import torch.distributed as dist logging.disable(logging.ERROR) warnings.filterwarnings("ignore") @@ -80,7 +80,14 @@ def main(): end_string = "\n\n" while True: - input_text = input("User >>> ") + if dist.get_rank() == 0: + input_text = input("User >>> ") + dist.broadcast_object_list([input_text]) + else: + recev_object = [None] * 1 + dist.broadcast_object_list([recev_object]) + input_text = recev_object[0] + if not input_text: print("exit...") break diff --git a/scripts/run_chatbot.sh b/scripts/run_chatbot.sh index 3c6ee4f98..c856e84af 100755 --- a/scripts/run_chatbot.sh +++ b/scripts/run_chatbot.sh @@ -1,6 +1,6 @@ #!/bin/bash -model=gpt2 +model=circulus/llama-13b lora_args="" if [ $# -ge 1 ]; then model=$1 @@ -9,8 +9,8 @@ if [ $# -ge 2 ]; then lora_args="--lora_model_path $2" fi -CUDA_VISIBLE_DEVICES=0 \ - deepspeed examples/chatbot.py \ +CUDA_VISIBLE_DEVICES=2,3 \ + deepspeed --master_port=10000 examples/chatbot.py \ --deepspeed configs/ds_config_chatbot.json \ --model_name_or_path ${model} \ ${lora_args} diff --git a/src/lmflow/models/hf_decoder_model.py b/src/lmflow/models/hf_decoder_model.py index 8a68443fd..032e53621 100644 --- a/src/lmflow/models/hf_decoder_model.py +++ b/src/lmflow/models/hf_decoder_model.py @@ -48,6 +48,8 @@ from lmflow.models.decoder_model import DecoderModel from lmflow.models.interfaces.tunable import Tunable +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + class HFDecoderModel(DecoderModel, Tunable): r""" @@ -179,7 +181,6 @@ def __init__( self.tune_strategy = tune_strategy elif tune_strategy == 'none': - dschf = HfDeepSpeedConfig(ds_config) self.backend_model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) peft_model_id = model_args.lora_model_path @@ -188,9 +189,12 @@ def __init__( self.backend_model, peft_model_id ) - deepspeed.init_distributed() - self.ds_engine = deepspeed.initialize(model=self.backend_model, config_params=ds_config)[0] - self.ds_engine.module.eval() + self.ds_engine = deepspeed.init_inference( + self.backend_model, + mp_size=2, + dtype=torch.half, + injection_policy={LlamaDecoderLayer: ('self_attn.o_proj', 'mlp.down_proj')} + ) elif tune_strategy == 'adapter': raise NotImplementedError('adapter tune strategy not implemented') From fa77d0cbf702f828ec95adeac790efa84bc1b0db Mon Sep 17 00:00:00 2001 From: diaoshizhe <654745845@qq.com> Date: Sat, 1 Apr 2023 11:05:41 +0800 Subject: [PATCH 2/3] exit by entering extit --- examples/chatbot.py | 10 ++++++---- scripts/run_chatbot.sh | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/chatbot.py b/examples/chatbot.py index 20b205978..063293d5d 100644 --- a/examples/chatbot.py +++ b/examples/chatbot.py @@ -70,7 +70,8 @@ def main(): f"#############################################################################\n" "\n" ) - print(guide_message, end="") + if dist.get_rank() == 0: + print(guide_message, end="") # context = ( # "You are a helpful assistant who follows the given instructions" @@ -85,10 +86,10 @@ def main(): dist.broadcast_object_list([input_text]) else: recev_object = [None] * 1 - dist.broadcast_object_list([recev_object]) + dist.broadcast_object_list(recev_object) input_text = recev_object[0] - if not input_text: + if input_text == "exit": print("exit...") break @@ -115,7 +116,8 @@ def main(): index = response.index(end_string) response = response[:index + 1] - print("Bot: " + response, end="") + if dist.get_rank() == 0: + print("Bot: " + response, end="") context += response context = context[-model.get_max_length():] # Memory of the bot diff --git a/scripts/run_chatbot.sh b/scripts/run_chatbot.sh index c856e84af..7445f1a40 100755 --- a/scripts/run_chatbot.sh +++ b/scripts/run_chatbot.sh @@ -1,6 +1,6 @@ #!/bin/bash -model=circulus/llama-13b +model=aleksickx/llama-7b-hf lora_args="" if [ $# -ge 1 ]; then model=$1 From 809a67716934210307eba0c93ff649d26823cb22 Mon Sep 17 00:00:00 2001 From: diaoshizhe <654745845@qq.com> Date: Sat, 1 Apr 2023 13:37:47 +0800 Subject: [PATCH 3/3] update --- scripts/run_chatbot.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/run_chatbot.sh b/scripts/run_chatbot.sh index 7445f1a40..3c6ee4f98 100755 --- a/scripts/run_chatbot.sh +++ b/scripts/run_chatbot.sh @@ -1,6 +1,6 @@ #!/bin/bash -model=aleksickx/llama-7b-hf +model=gpt2 lora_args="" if [ $# -ge 1 ]; then model=$1 @@ -9,8 +9,8 @@ if [ $# -ge 2 ]; then lora_args="--lora_model_path $2" fi -CUDA_VISIBLE_DEVICES=2,3 \ - deepspeed --master_port=10000 examples/chatbot.py \ +CUDA_VISIBLE_DEVICES=0 \ + deepspeed examples/chatbot.py \ --deepspeed configs/ds_config_chatbot.json \ --model_name_or_path ${model} \ ${lora_args}