From ebbbfb0655e5a5a5c37c2ee548b3a1629056f3e1 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 17 Jan 2025 15:04:49 +0800 Subject: [PATCH 01/52] mcts init --- swift/experimental/sampling/mcts.py | 339 ++++++++++++++++++- swift/experimental/sampling/sampling.py | 6 + swift/experimental/sampling/sampling_args.py | 16 +- swift/plugin/prm.py | 12 +- 4 files changed, 368 insertions(+), 5 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 464090415c..e8bd14e066 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -1 +1,338 @@ -# TODO +from copy import deepcopy +import numpy as np + +from swift.llm import InferRequest +from swift.llm.infer.protocol import UsageInfo +from swift.utils import get_logger + +from .base import Sampler +from .utils import get_reward +from .sampling_args import SamplingArguments + +from typing import Union, List + +logger = get_logger('./output/sampler/mcts.log') + + +SYS_PROMPT = """You are a super intelligent AI, you can solve any math problem step by step. + +REMEMBER: Each step should stop with a 'ки'. Final answer should start with '# Answer'. + +Here is an example: + +user +Janet pays $40/hour for 3 hours per week of clarinet lessons and $28/hour for 5 hours a week of piano lessons. How much more does she spend on piano lessons than clarinet lessons in a year? + +assistant +Step 1: Janet spends 3 hours + 5 hours = <<3+5=8>>8 hours per week on music lessons. ки +Step 2: She spends 40 * 3 = <<40*3=120>>120 on clarinet lessons per week. ки +Step 3: She spends 28 * 5 = <<28*5=140>>140 on piano lessons per week. ки +Step 4: Janet spends 120 + 140 = <<120+140=260>>260 on music lessons per week. ки +Step 5: She spends 260 * 52 = <<260*52=13520>>13520 on music lessons in a year. ки +# Answer 13520 ки + +Now answer the question: +""" + +SEP_TOKEN = "ки" + +system_message = { + "role": "system", + "content": SYS_PROMPT, +} + +def check_terminate(answers: Union[str, List[str]]) -> List[bool]: + if isinstance(answers, str): + answers = [answers] + results = [] + for answer in answers: + results.append("# Answer" in answer) + return results + +class LanguageNode: + + def __init__(self, + step: str = None, + parent: "LanguageNode" = None,): + self.parent = parent + if parent: + self.path = parent.path[:] + [step] + self.answer = parent.answer + step + SEP_TOKEN + self.depth = parent.depth + 1 + else: + self.path = [] + self.answer = "" + self.depth = 0 + + self.active_children = [] + self.children = [] + self.visit_count = 0 + self.process_reward = 0.0 + self.outcome_reward = 0.0 + self.terminated = False + + def is_leaf(self): + return len(self.children) == 0 + + def is_root(self): + return self.parent is None + + def visit(self): + self.visit_count += 1 + + def init_and_update_value(self, value): + self.outcome_reward = (self.outcome_reward * self.visit_count + value) / (self.visit_count + 1) + + def add_child(self, child: "LanguageNode"): + self.children.append(child) + if not child.terminated: + self.active_children.append(child) + + def __lt__(self, other): + return self.outcome_reward < other.outcome_reward + + +class MctsSampler(Sampler): + + def __init__(self, input_args: SamplingArguments): + super().__init__(input_args) + self.usage_info = UsageInfo(0,0,0) + + def _prepare_model_tokenizer(self): + args = self.args + self.infer_kwargs = {} + if args.sampler_engine == 'client': + import os + from swift.llm import InferClient + api_key = os.getenv('DASHSCOPE_API_KEY') + base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1' + self.infer_engine = InferClient(base_url=base_url, api_key=api_key) + self.infer_kwargs['model'] = args.model + else: + _Engine = self.get_infer_engine() + self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs) + + def get_infer_engine(self): + if self.args.sampler_engine == 'pt': + from swift.llm import PtEngine + _Engine = PtEngine + elif self.args.sampler_engine == 'vllm': + from swift.llm import VllmEngine + _Engine = VllmEngine + elif self.args.sampler_engine == 'lmdeploy': + from swift.llm import LmdeployEngine + _Engine = LmdeployEngine + elif self.args.sampler_engine == 'no': + _Engine = None + else: + raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}') + return _Engine + + def _prepare_template(self) -> None: + pass + + def search_single(self, query, ground_truth): + def _UCT(node: LanguageNode): + alpha = _args.process_reward_rate + value = alpha * node.process_reward + (1 - alpha) * node.outcome_reward + if node.is_root(): + return value + + exploitation_score = value + exploration_score = (_args.exploration_rate + * np.sqrt(np.log(node.parent.visit_count) / (node.visit_count + 1))) + + return exploration_score + exploitation_score + + def _select(node: LanguageNode): + while not node.is_leaf(): + node = max(node.active_children, key=lambda x: _UCT(x)) + return node + + def _expand(node: LanguageNode): + prompt_message = { + "role": "user", + "content": query, + } + if node.is_root(): + infer_request = InferRequest([system_message, prompt_message]) + else: + history_message = { + "role": "assistant", + "content": node.answer, + } + infer_request = InferRequest([system_message, prompt_message, history_message]) + expand_request_config = deepcopy(request_config) + n = _args.num_return_sequences - len(node.children) + while n > 0: + expand_request_config.n = n if n <= 4 else 4 + n -= expand_request_config.n + expand_request_config.num_return_sequences = expand_request_config.n + expand_request_config.num_beams = expand_request_config.n + expand_request_config.seed += 1 + responses = self.infer_engine.infer( + [infer_request], + expand_request_config, + **self.infer_kwargs, + ) + for key, value in self.usage_info.__dict__.items(): + update_value = getattr(responses[0].usage, key, None) + value + setattr(self.usage_info, key, update_value) + for choice in responses[0].choices: + output = choice.message.content.rstrip(SEP_TOKEN + '\n') + output = output.split(SEP_TOKEN)[0] + child = LanguageNode(step=output, parent=node) + if check_terminate(child.answer)[0]: + child._terminated = True + orm_infer_requests = [InferRequest([{"role": "assistant", "content": output}])] + orm_score, _orm_mask = get_reward( + self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(orm_infer_requests), + threshold=0.0) + child.init_and_update_value(orm_score[0]) + if child.outcome_reward == 1: + terminate_correct.append(child.answer) + else: + terminate_incorrect.append(child.answer) + node.add_child(child) + if self.prm_model: + prm_infer_requests = [] + for child in node.children: + prm_message = {"role": "assistant", "content": child.answer} + prm_infer_requests.append(InferRequest([prompt_message, prm_message])) + prm_score, _prm_mask = get_reward( + self.prm_model, + prm_infer_requests, + ground_truths=[ground_truth] * len(prm_infer_requests), + threshold=0.0) + for child, score in zip(node.children, prm_score): + child.process_reward = score + + def _rollout(node: LanguageNode): + rollout_iter_index = 0 + prompt_message = { + "role": "user", + "content": query, + } + rollout_request_config = deepcopy(request_config) + rollout_request_config.temperature = 0.0 + rollout_request_config.max_tokens = 500 + rollout_nodes = node.active_children[:] + history_messages = [] + for child in rollout_nodes: + history_message = { + "role": "assistant", + "content": child.answer, + } + history_messages.append(history_message) + while len(rollout_nodes) > 0 and rollout_iter_index < _args.max_rollout_iterations: + infer_requests = [InferRequest([system_message, prompt_message, h]) for h in history_messages] + # Because template will pop out last assistant message, so add an additional one. + responses = self.infer_engine.infer(infer_requests, rollout_request_config, **self.infer_kwargs) + rollout_iter_index += 1 + rollout_node_index = 0 + for index, response in enumerate(responses): + for key, value in self.usage_info.__dict__.items(): + update_value = getattr(response.usage, key, None) + value + setattr(self.usage_info, key, update_value) + output = response.choices[0].message.content.rstrip(SEP_TOKEN + '\n') + output = output.split(SEP_TOKEN)[0] + output += SEP_TOKEN + '\n' + history_messages[rollout_node_index]["content"] += output + end_path = history_messages[rollout_node_index]["content"] + if check_terminate(end_path)[0]: + orm_infer_requests = [InferRequest([history_messages[rollout_node_index]])] + orm_response, _orm_mask = get_reward( + self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(orm_infer_requests), + threshold=0.0) + orm_score = float(orm_response[0].choices[0].message.content) + node.active_children[index].outcome_reward = orm_score + if orm_score == 1: + correct_answers.append(end_path) + else: + incorrect_answers.append(end_path) + rollout_nodes.pop(rollout_node_index) + history_messages.pop(rollout_node_index) + rollout_node_index -= 1 + rollout_node_index += 1 + + def _back_propagate(curr_node: LanguageNode): + while curr_node: + best_child_value = max([child.outcome_reward for child in curr_node.children]) + curr_node.init_and_update_value(best_child_value) + curr_node.visit() + curr_node = curr_node.parent + + def _collect(curr_node: LanguageNode): + if curr_node.is_leaf(): + return [] + results = [] + for child in curr_node.children: + results += _collect(child) + curr_node.children = sorted(curr_node.children) + if curr_node.children[-1].outcome_reward - curr_node.children[0].outcome_reward > 0.6: + results.append({ + "query": query, + "path": curr_node.path, + "good": curr_node.children[-1].path[-1], + "good_score": curr_node.children[-1].outcome_reward, + "bad": curr_node.children[0].path[-1], + "bad_score": curr_node.children[0].outcome_reward, + }) + return results + + _args = self.args + request_config = _args.get_request_config() + request_config.stop = [SEP_TOKEN] + request_config.seed = _args.seed + _root = LanguageNode() + + correct_answers, incorrect_answers, prefer_pair = [], [], [] + terminate_correct, terminate_incorrect = [], [] + too_easy, too_hard = False, False + iter_count = 0 + while (not too_easy and not too_hard + and len(terminate_incorrect) + len(terminate_correct) < _args.num_return_sequences + and iter_count < _args.max_iterations): + logger.info(f"iter_count: {iter_count}" + "." * 10) + logger.info("select" + "=" * 10) + curr_node = _select(_root) + logger.info("expand" + "=" * 10) + _expand(curr_node) + if curr_node.depth > 3: + logger.info("rollout" + "=" * 10) + _rollout(curr_node) + logger.info("back propagate" + "=" * 10) + _back_propagate(curr_node) + if len(correct_answers) + len(incorrect_answers) >= _args.num_return_sequences: + if 4 * len(incorrect_answers) < len(correct_answers): + logger.info("too easy" + "!" * 20) + logger.info(f"correct_answers: {correct_answers}") + logger.info(f"incorrect_answers: {incorrect_answers}") + too_easy = True + elif 4 * len(correct_answers) < len(incorrect_answers): + logger.info("too hard" + "!" * 20) + logger.info(f"correct_answers: {correct_answers}") + logger.info(f"incorrect_answers: {incorrect_answers}") + too_hard = True + iter_count += 1 + if iter_count == _args.max_iterations: + logger.info("too hard" + "!" * 20) + logger.info(f"correct_answers: {correct_answers}") + logger.info(f"incorrect_answers: {incorrect_answers}") + too_hard = True + if not too_easy and not too_hard: + prefer_pair = _collect(_root) + logger.info(f"prefer_pair: {prefer_pair}") + return prefer_pair + + def do_sample(self, data): + if not isinstance(data, list): + data = [data] + prefer_pairs = [] + for item in data: + messages = item['messages'][0] + query = messages[0]['content'] + ground_truth = messages[1]['content'] + prefer_pair = self.search_single(query, ground_truth) + prefer_pairs.append(prefer_pair) + return prefer_pairs \ No newline at end of file diff --git a/swift/experimental/sampling/sampling.py b/swift/experimental/sampling/sampling.py index 312396807b..22629cdc54 100644 --- a/swift/experimental/sampling/sampling.py +++ b/swift/experimental/sampling/sampling.py @@ -28,6 +28,9 @@ def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> No if self.args.sampler_type == 'sample': from swift.experimental.sampling.vanilla_sampler import VanillaSampler self.sampler = VanillaSampler(self.args) + elif self.args.sampler_type == 'mcts': + from swift.experimental.sampling.mcts import MctsSampler + self.sampler = MctsSampler(self.args) def _get_dataset(self): args = self.args @@ -67,3 +70,6 @@ def run(self): def sampling_main(args: Union[List[str], SamplingArguments, None] = None): return SwiftSampling(args).main() + +if __name__ == "__main__": + sampling_main() \ No newline at end of file diff --git a/swift/experimental/sampling/sampling_args.py b/swift/experimental/sampling/sampling_args.py index 839f1ac99e..55e8e25ad7 100644 --- a/swift/experimental/sampling/sampling_args.py +++ b/swift/experimental/sampling/sampling_args.py @@ -20,8 +20,8 @@ class SamplingArguments(BaseArguments): # sampler settings # sample/mcts/dvts/xxx - sampler_type: str = 'sample' - sampler_engine: Literal['pt', 'lmdeploy', 'vllm', 'no'] = 'pt' + sampler_type: Literal['sample', 'mcts'] = 'sample' + sampler_engine: Literal['pt', 'lmdeploy', 'vllm', 'no', 'client'] = 'pt' output_dir: str = 'sample_output' output_file: Optional[str] = None override_exist_file: bool = False @@ -42,6 +42,18 @@ class SamplingArguments(BaseArguments): # Vanilla cache_files: List[str] = dataclasses.field(default_factory=list) + # MCTS + max_rollout_iterations: int = 5 + max_iterations: int = 100 + process_reward_rate: float = 0.0 + exploration_rate: float = 0.5 + + def _init_model_info(self): + if self.sampler_engine != 'client': + return super._init_model_info(self) + self.task_type = 'causal_lm' + return + def __post_init__(self): if self.output_file is None: now = datetime.now() diff --git a/swift/plugin/prm.py b/swift/plugin/prm.py index a2fd255388..d95af1f719 100644 --- a/swift/plugin/prm.py +++ b/swift/plugin/prm.py @@ -5,7 +5,7 @@ import torch from swift.llm import InferRequest -from swift.llm.infer.protocol import ChatCompletionResponse +from swift.llm.infer.protocol import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage class PRM: @@ -91,7 +91,15 @@ def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], except Exception: rewards.append(None) - return rewards + return [ + ChatCompletionResponse( + choices=[ + ChatCompletionResponseChoice( + message=ChatMessage(content=1.0 if r else 0.0, role='assistant'), index=0, finish_reason='') + ], + model=None, + usage=None) for r in rewards + ] prms = {'qwen_max': QwenMaxPRM} From 8871d165ff8b339d9bc0520adce6afbfa4a3cb2f Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 17 Jan 2025 18:00:27 +0800 Subject: [PATCH 02/52] step continue & faster prm --- swift/experimental/sampling/mcts.py | 19 +++++-- swift/plugin/prm.py | 79 ++++++++++++++++++++++++++--- 2 files changed, 87 insertions(+), 11 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index e8bd14e066..27a4315511 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -1,5 +1,6 @@ from copy import deepcopy import numpy as np +import time from swift.llm import InferRequest from swift.llm.infer.protocol import UsageInfo @@ -33,13 +34,19 @@ Now answer the question: """ +NXT_PROMPT = """Please continue. +""" -SEP_TOKEN = "ки" +SEP_TOKEN = "ки\n" system_message = { "role": "system", "content": SYS_PROMPT, } +next_message = { + "role": "user", + "content": NXT_PROMPT, +} def check_terminate(answers: Union[str, List[str]]) -> List[bool]: if isinstance(answers, str): @@ -150,6 +157,7 @@ def _select(node: LanguageNode): return node def _expand(node: LanguageNode): + # s_time = time.time() prompt_message = { "role": "user", "content": query, @@ -161,12 +169,11 @@ def _expand(node: LanguageNode): "role": "assistant", "content": node.answer, } - infer_request = InferRequest([system_message, prompt_message, history_message]) + infer_request = InferRequest([system_message, prompt_message, history_message, next_message]) expand_request_config = deepcopy(request_config) n = _args.num_return_sequences - len(node.children) while n > 0: expand_request_config.n = n if n <= 4 else 4 - n -= expand_request_config.n expand_request_config.num_return_sequences = expand_request_config.n expand_request_config.num_beams = expand_request_config.n expand_request_config.seed += 1 @@ -175,6 +182,7 @@ def _expand(node: LanguageNode): expand_request_config, **self.infer_kwargs, ) + n -= len(responses[0].choices) for key, value in self.usage_info.__dict__.items(): update_value = getattr(responses[0].usage, key, None) + value setattr(self.usage_info, key, update_value) @@ -194,6 +202,8 @@ def _expand(node: LanguageNode): else: terminate_incorrect.append(child.answer) node.add_child(child) + # logger.info(f"expand time: {time.time() - s_time}") + # s_time = time.time() if self.prm_model: prm_infer_requests = [] for child in node.children: @@ -206,6 +216,7 @@ def _expand(node: LanguageNode): threshold=0.0) for child, score in zip(node.children, prm_score): child.process_reward = score + # logger.info(f"prm time: {time.time() - s_time}") def _rollout(node: LanguageNode): rollout_iter_index = 0 @@ -225,7 +236,7 @@ def _rollout(node: LanguageNode): } history_messages.append(history_message) while len(rollout_nodes) > 0 and rollout_iter_index < _args.max_rollout_iterations: - infer_requests = [InferRequest([system_message, prompt_message, h]) for h in history_messages] + infer_requests = [InferRequest([system_message, prompt_message, h, next_message]) for h in history_messages] # Because template will pop out last assistant message, so add an additional one. responses = self.infer_engine.infer(infer_requests, rollout_request_config, **self.infer_kwargs) rollout_iter_index += 1 diff --git a/swift/plugin/prm.py b/swift/plugin/prm.py index d95af1f719..c16ebb280a 100644 --- a/swift/plugin/prm.py +++ b/swift/plugin/prm.py @@ -52,13 +52,15 @@ class QwenMaxPRM(PRM): def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], **kwargs) -> List[ChatCompletionResponse]: rewards = [] - for request, ground_truth in zip(infer_requests, ground_truths): - from openai import OpenAI - client = OpenAI( - api_key=os.getenv('DASHSCOPE_API_KEY'), - base_url='https://dashscope.aliyuncs.com/compatible-mode/v1', - ) + from openai import OpenAI + + client = OpenAI( + api_key=os.getenv('DASHSCOPE_API_KEY'), + base_url='https://dashscope.aliyuncs.com/compatible-mode/v1', + ) + + for request, ground_truth in zip(infer_requests, ground_truths): previous = request.messages[:-1] if previous[0]['role'] == 'system': previous = previous[1:] @@ -102,4 +104,67 @@ def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], ] -prms = {'qwen_max': QwenMaxPRM} +class QwenPlusPRM(PRM): + def __init__(self): + super().__init__() + import os + from swift.llm import InferClient + api_key = os.getenv('DASHSCOPE_API_KEY') + base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1' + self.infer_engine = InferClient(base_url=base_url, api_key=api_key) + self.infer_kwargs = { + 'model': 'qwen-plus', + } + + def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], + **kwargs) -> List[ChatCompletionResponse]: + prm_infer_requests = [] + for request, ground_truth in zip(infer_requests, ground_truths): + previous = request.messages[:-1] + if previous[0]['role'] == 'system': + previous = previous[1:] + + assert request.messages[-1]['role'] == 'assistant' + query = QUERY.replace('#query#', json.dumps(previous)) + query = query.replace('#ground_truth#', ground_truth) + query = query.replace('#response#', request.messages[-1]['content']) + messages = [ + { + 'role': 'system', + 'content': SYSTEM + }, + { + 'role': 'user', + 'content': query + }, + ] + + prm_infer_requests.append(InferRequest(messages=messages)) + + responses = self.infer_engine.infer(prm_infer_requests, **self.infer_kwargs) + rewards = [] + for response in responses: + content = response.choices[0].message.content + if 'Reward:' not in content: + rewards.append(None) + try: + reward = float(content.split('Reward:')[1].strip().replace('*', '')) + rewards.append(reward) + except Exception: + rewards.append(None) + + return [ + ChatCompletionResponse( + choices=[ + ChatCompletionResponseChoice( + message=ChatMessage(content=1.0 if r else 0.0, role='assistant'), index=0, finish_reason='') + ], + model=None, + usage=None) for r in rewards + ] + + +prms = { + 'qwen_max': QwenMaxPRM, + 'qwen_plus': QwenPlusPRM, +} From acb62b455bccd6a386f9366fb997e9b38e9aff05 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 17 Jan 2025 18:26:32 +0800 Subject: [PATCH 03/52] fix --- swift/experimental/sampling/mcts.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 27a4315511..5941819b38 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -1,5 +1,6 @@ from copy import deepcopy import numpy as np +import json import time from swift.llm import InferRequest @@ -252,12 +253,11 @@ def _rollout(node: LanguageNode): end_path = history_messages[rollout_node_index]["content"] if check_terminate(end_path)[0]: orm_infer_requests = [InferRequest([history_messages[rollout_node_index]])] - orm_response, _orm_mask = get_reward( - self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(orm_infer_requests), + orm_score, _orm_mask = get_reward( + self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0) - orm_score = float(orm_response[0].choices[0].message.content) - node.active_children[index].outcome_reward = orm_score - if orm_score == 1: + node.active_children[index].outcome_reward = orm_score[0] + if orm_score[0] == 1: correct_answers.append(end_path) else: incorrect_answers.append(end_path) @@ -281,14 +281,14 @@ def _collect(curr_node: LanguageNode): results += _collect(child) curr_node.children = sorted(curr_node.children) if curr_node.children[-1].outcome_reward - curr_node.children[0].outcome_reward > 0.6: - results.append({ + results.append(json.dumps({ "query": query, "path": curr_node.path, "good": curr_node.children[-1].path[-1], "good_score": curr_node.children[-1].outcome_reward, "bad": curr_node.children[0].path[-1], "bad_score": curr_node.children[0].outcome_reward, - }) + }, ensure_ascii=False) + '\n') return results _args = self.args @@ -339,11 +339,11 @@ def _collect(curr_node: LanguageNode): def do_sample(self, data): if not isinstance(data, list): data = [data] - prefer_pairs = [] + generated = [] for item in data: messages = item['messages'][0] query = messages[0]['content'] ground_truth = messages[1]['content'] - prefer_pair = self.search_single(query, ground_truth) - prefer_pairs.append(prefer_pair) - return prefer_pairs \ No newline at end of file + prefer_pairs = self.search_single(query, ground_truth) + generated += prefer_pairs + return generated \ No newline at end of file From dee44d3f6fd900348baef30f95c375d2ce763a27 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 18 Jan 2025 00:04:17 +0800 Subject: [PATCH 04/52] more parallel --- swift/experimental/sampling/mcts.py | 198 ++++++++++++++++----------- swift/experimental/sampling/utils.py | 7 + 2 files changed, 122 insertions(+), 83 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 5941819b38..b09f82cf7d 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -2,13 +2,14 @@ import numpy as np import json import time +from concurrent.futures import ThreadPoolExecutor, as_completed from swift.llm import InferRequest from swift.llm.infer.protocol import UsageInfo from swift.utils import get_logger from .base import Sampler -from .utils import get_reward +from .utils import get_reward, perform_infer from .sampling_args import SamplingArguments from typing import Union, List @@ -114,7 +115,7 @@ def _prepare_model_tokenizer(self): from swift.llm import InferClient api_key = os.getenv('DASHSCOPE_API_KEY') base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1' - self.infer_engine = InferClient(base_url=base_url, api_key=api_key) + self.infer_engines = [InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences)] self.infer_kwargs['model'] = args.model else: _Engine = self.get_infer_engine() @@ -137,7 +138,30 @@ def get_infer_engine(self): return _Engine def _prepare_template(self) -> None: - pass + # Hack from super() + self._prepare_request_configs() + + def _prepare_request_configs(self): + _args = self.args + request_config = _args.get_request_config() + request_config.stop = [SEP_TOKEN] + request_config.seed = _args.seed + self.expand_request_configs = [] + for i in range(_args.num_return_sequences): + expand_request_config = deepcopy(request_config) + expand_request_config.n = 1 + expand_request_config.num_beams = expand_request_config.n + expand_request_config.seed += i + self.expand_request_configs.append(expand_request_config) + self.rollout_request_config = deepcopy(request_config) + self.rollout_request_config.max_tokens = 500 + self.rollout_request_config.temperature = 0.0 + self.rollout_request_config.n = 1 + + def update_usage_info(self, response): + for key, value in self.usage_info.__dict__.items(): + update_value = getattr(response.usage, key, None) + value + setattr(self.usage_info, key, update_value) def search_single(self, query, ground_truth): def _UCT(node: LanguageNode): @@ -159,10 +183,6 @@ def _select(node: LanguageNode): def _expand(node: LanguageNode): # s_time = time.time() - prompt_message = { - "role": "user", - "content": query, - } if node.is_root(): infer_request = InferRequest([system_message, prompt_message]) else: @@ -171,41 +191,55 @@ def _expand(node: LanguageNode): "content": node.answer, } infer_request = InferRequest([system_message, prompt_message, history_message, next_message]) - expand_request_config = deepcopy(request_config) + + # 为了并行进行 Expand 操作,这里暂时不需要考虑顺序,因为 Prompt 是一样的 n = _args.num_return_sequences - len(node.children) - while n > 0: - expand_request_config.n = n if n <= 4 else 4 - expand_request_config.num_return_sequences = expand_request_config.n - expand_request_config.num_beams = expand_request_config.n - expand_request_config.seed += 1 - responses = self.infer_engine.infer( - [infer_request], - expand_request_config, - **self.infer_kwargs, - ) - n -= len(responses[0].choices) - for key, value in self.usage_info.__dict__.items(): - update_value = getattr(responses[0].usage, key, None) + value - setattr(self.usage_info, key, update_value) - for choice in responses[0].choices: - output = choice.message.content.rstrip(SEP_TOKEN + '\n') - output = output.split(SEP_TOKEN)[0] - child = LanguageNode(step=output, parent=node) - if check_terminate(child.answer)[0]: - child._terminated = True - orm_infer_requests = [InferRequest([{"role": "assistant", "content": output}])] - orm_score, _orm_mask = get_reward( - self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(orm_infer_requests), - threshold=0.0) - child.init_and_update_value(orm_score[0]) - if child.outcome_reward == 1: - terminate_correct.append(child.answer) - else: - terminate_incorrect.append(child.answer) - node.add_child(child) + with ThreadPoolExecutor(max_workers=n) as executor: + futures = {executor.submit(perform_infer, + self.infer_engines[i], + infer_request, + self.expand_request_configs[i], + **self.infer_kwargs): i for i in range(n)} + responses = [] + for future in as_completed(futures): + task_id = futures[future] + try: + responses.append(future.result()) + except Exception as e: + print(f"任务 {task_id} 执行请求时发生错误: {e}") + + # 为了并行获取 Outcome Reward,这里获得的 OR 是顺序返回的,所以可以直接对应 + orm_infer_requests = [] + all_child_terminated = True + for response in responses: + self.update_usage_info(response) + output = response[0].choice[0].message.content.rstrip(SEP_TOKEN + '\n').split(SEP_TOKEN)[0] + orm_infer_requests.append(InferRequest([{"role": "assistant", "content": output}])) + child = LanguageNode(step=output, parent=node) + if check_terminate(child.answer)[0]: + child._terminated = True + else: + all_child_terminated = False + node.add_child(child) + if all_child_terminated: + node._terminated = True + if not node.is_root(): + node.parent.active_children.remove(node) + + orm_score, _orm_mask = get_reward( + self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(orm_infer_requests), + threshold=0.0) + for child in node.children: + child.init_and_update_value(orm_score[0]) + if child.outcome_reward == 1: + terminate_correct.append(child.answer) + else: + terminate_incorrect.append(child.answer) # logger.info(f"expand time: {time.time() - s_time}") # s_time = time.time() + if self.prm_model: + # 为了并行获取 PRM 的 Reward prm_infer_requests = [] for child in node.children: prm_message = {"role": "assistant", "content": child.answer} @@ -221,50 +255,47 @@ def _expand(node: LanguageNode): def _rollout(node: LanguageNode): rollout_iter_index = 0 - prompt_message = { - "role": "user", - "content": query, - } - rollout_request_config = deepcopy(request_config) - rollout_request_config.temperature = 0.0 - rollout_request_config.max_tokens = 500 - rollout_nodes = node.active_children[:] - history_messages = [] - for child in rollout_nodes: - history_message = { - "role": "assistant", - "content": child.answer, + rollout_nodes = {} + for i in range(len(node.active_children)): + rollout_nodes[i] = { + "node": node.active_children[i], + "history_messages": { + "role": "assistant", + "content": node.active_children[i].answer, + }, } - history_messages.append(history_message) - while len(rollout_nodes) > 0 and rollout_iter_index < _args.max_rollout_iterations: - infer_requests = [InferRequest([system_message, prompt_message, h, next_message]) for h in history_messages] - # Because template will pop out last assistant message, so add an additional one. - responses = self.infer_engine.infer(infer_requests, rollout_request_config, **self.infer_kwargs) - rollout_iter_index += 1 - rollout_node_index = 0 - for index, response in enumerate(responses): - for key, value in self.usage_info.__dict__.items(): - update_value = getattr(response.usage, key, None) + value - setattr(self.usage_info, key, update_value) - output = response.choices[0].message.content.rstrip(SEP_TOKEN + '\n') - output = output.split(SEP_TOKEN)[0] - output += SEP_TOKEN + '\n' - history_messages[rollout_node_index]["content"] += output - end_path = history_messages[rollout_node_index]["content"] - if check_terminate(end_path)[0]: - orm_infer_requests = [InferRequest([history_messages[rollout_node_index]])] - orm_score, _orm_mask = get_reward( - self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(infer_requests), - threshold=0.0) - node.active_children[index].outcome_reward = orm_score[0] - if orm_score[0] == 1: - correct_answers.append(end_path) + active_rollout_nodes = list(rollout_nodes.keys()) + while len(active_rollout_nodes) > 0 and rollout_iter_index < _args.max_rollout_iterations: + infer_requests = [InferRequest([system_message, + prompt_message, + rollout_nodes[index]['history_messages'], + next_message]) + for index in active_rollout_nodes] + responses = self.infer_engine.infer(infer_requests, + self.rollout_request_config, + **self.infer_kwargs) + orm_infer_requests = [] + end_paths = [] + for index, response in zip(active_rollout_nodes, responses): + self.update_usage_info(response) + output = response.choices[0].message.content.rstrip(SEP_TOKEN + '\n').split(SEP_TOKEN)[0] + SEP_TOKEN + '\n' + rollout_nodes[index]['history_messages']["content"] += output + end_paths.append(rollout_nodes[index]['history_messages']["content"]) + orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']])) + + orm_score, _orm_mask = get_reward( + self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(infer_requests), + threshold=0.0) + terminated_state = check_terminate(end_paths) + for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state): + if terminated: + node.active_children[index].outcome_reward = score + if score == 1: + correct_answers.append(rollout_nodes[index]['history_messages']["content"]) else: - incorrect_answers.append(end_path) - rollout_nodes.pop(rollout_node_index) - history_messages.pop(rollout_node_index) - rollout_node_index -= 1 - rollout_node_index += 1 + incorrect_answers.append(rollout_nodes[index]['history_messages']["content"]) + rollout_nodes.pop(index) + active_rollout_nodes = list(rollout_nodes.keys()) def _back_propagate(curr_node: LanguageNode): while curr_node: @@ -292,10 +323,11 @@ def _collect(curr_node: LanguageNode): return results _args = self.args - request_config = _args.get_request_config() - request_config.stop = [SEP_TOKEN] - request_config.seed = _args.seed _root = LanguageNode() + prompt_message = { + "role": "user", + "content": query, + } correct_answers, incorrect_answers, prefer_pair = [], [], [] terminate_correct, terminate_incorrect = [], [] diff --git a/swift/experimental/sampling/utils.py b/swift/experimental/sampling/utils.py index c056d9a287..4923f6d30d 100644 --- a/swift/experimental/sampling/utils.py +++ b/swift/experimental/sampling/utils.py @@ -67,3 +67,10 @@ def normalize(arr): return normalized return normalize(arr), _mask + +def perform_infer(infer_engine, infer_request, request_config, **infer_kwargs): + return infer_engine.infer( + [infer_request], + request_config, + **infer_kwargs, + ) \ No newline at end of file From f1fa335537c811ce3267c5c766fd0bbc741cbdef Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 18 Jan 2025 00:04:17 +0800 Subject: [PATCH 05/52] more parallel --- swift/experimental/sampling/mcts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index b09f82cf7d..6e69abccd3 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -237,9 +237,7 @@ def _expand(node: LanguageNode): terminate_incorrect.append(child.answer) # logger.info(f"expand time: {time.time() - s_time}") # s_time = time.time() - if self.prm_model: - # 为了并行获取 PRM 的 Reward prm_infer_requests = [] for child in node.children: prm_message = {"role": "assistant", "content": child.answer} @@ -296,6 +294,7 @@ def _rollout(node: LanguageNode): incorrect_answers.append(rollout_nodes[index]['history_messages']["content"]) rollout_nodes.pop(index) active_rollout_nodes = list(rollout_nodes.keys()) + rollout_iter_index += 1 def _back_propagate(curr_node: LanguageNode): while curr_node: From 68767d5435af30f6b07848086334e9ce0054b6db Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 18 Jan 2025 02:58:34 +0800 Subject: [PATCH 06/52] fix --- swift/experimental/sampling/mcts.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index b09f82cf7d..2d86505e67 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -204,7 +204,7 @@ def _expand(node: LanguageNode): for future in as_completed(futures): task_id = futures[future] try: - responses.append(future.result()) + responses += future.result() except Exception as e: print(f"任务 {task_id} 执行请求时发生错误: {e}") @@ -213,28 +213,29 @@ def _expand(node: LanguageNode): all_child_terminated = True for response in responses: self.update_usage_info(response) - output = response[0].choice[0].message.content.rstrip(SEP_TOKEN + '\n').split(SEP_TOKEN)[0] + output = response.choices[0].message.content.rstrip(SEP_TOKEN + '\n').split(SEP_TOKEN)[0] orm_infer_requests.append(InferRequest([{"role": "assistant", "content": output}])) child = LanguageNode(step=output, parent=node) if check_terminate(child.answer)[0]: - child._terminated = True + child.terminated = True else: all_child_terminated = False node.add_child(child) if all_child_terminated: - node._terminated = True + node.terminated = True if not node.is_root(): node.parent.active_children.remove(node) orm_score, _orm_mask = get_reward( self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(orm_infer_requests), threshold=0.0) - for child in node.children: - child.init_and_update_value(orm_score[0]) - if child.outcome_reward == 1: - terminate_correct.append(child.answer) - else: - terminate_incorrect.append(child.answer) + for child, score in zip(node.children, orm_score): + if child.terminated: + child.init_and_update_value(score) + if child.outcome_reward == 1: + terminate_correct.append(child.answer) + else: + terminate_incorrect.append(child.answer) # logger.info(f"expand time: {time.time() - s_time}") # s_time = time.time() @@ -271,7 +272,7 @@ def _rollout(node: LanguageNode): rollout_nodes[index]['history_messages'], next_message]) for index in active_rollout_nodes] - responses = self.infer_engine.infer(infer_requests, + responses = self.infer_engines[0].infer(infer_requests, self.rollout_request_config, **self.infer_kwargs) orm_infer_requests = [] @@ -296,6 +297,7 @@ def _rollout(node: LanguageNode): incorrect_answers.append(rollout_nodes[index]['history_messages']["content"]) rollout_nodes.pop(index) active_rollout_nodes = list(rollout_nodes.keys()) + rollout_iter_index += 1 def _back_propagate(curr_node: LanguageNode): while curr_node: From 890d2fd0e807f13d5fb4a7ce9d0eb43f3d15bffd Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 18 Jan 2025 16:15:57 +0800 Subject: [PATCH 07/52] catch client error --- swift/experimental/sampling/mcts.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index c7d736ccc6..3ae723d50c 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -373,9 +373,13 @@ def do_sample(self, data): data = [data] generated = [] for item in data: - messages = item['messages'][0] - query = messages[0]['content'] - ground_truth = messages[1]['content'] - prefer_pairs = self.search_single(query, ground_truth) - generated += prefer_pairs + logger.info(f"time: {time.time()}") + try: + messages = item['messages'][0] + query = messages[0]['content'] + ground_truth = messages[1]['content'] + prefer_pairs = self.search_single(query, ground_truth) + generated += prefer_pairs + except Exception as e: + logger.error(f"Error: {e}") return generated \ No newline at end of file From eb923cc57f8f0c6320da907df17ee9a0b3e4f576 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 18 Jan 2025 16:21:39 +0800 Subject: [PATCH 08/52] ctime --- swift/experimental/sampling/mcts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 3ae723d50c..5849cee26b 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -373,7 +373,7 @@ def do_sample(self, data): data = [data] generated = [] for item in data: - logger.info(f"time: {time.time()}") + logger.info(f"time: {time.ctime(time.time())}") try: messages = item['messages'][0] query = messages[0]['content'] From 170721bdaf93dcd940ad30859937271f8bf899f5 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Tue, 21 Jan 2025 14:19:48 +0800 Subject: [PATCH 09/52] sample prompt --- swift/experimental/sampling/mcts.py | 30 ++++------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 5849cee26b..993a148e49 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -19,43 +19,22 @@ SYS_PROMPT = """You are a super intelligent AI, you can solve any math problem step by step. -REMEMBER: Each step should stop with a 'ки'. Final answer should start with '# Answer'. - -Here is an example: - -user -Janet pays $40/hour for 3 hours per week of clarinet lessons and $28/hour for 5 hours a week of piano lessons. How much more does she spend on piano lessons than clarinet lessons in a year? - -assistant -Step 1: Janet spends 3 hours + 5 hours = <<3+5=8>>8 hours per week on music lessons. ки -Step 2: She spends 40 * 3 = <<40*3=120>>120 on clarinet lessons per week. ки -Step 3: She spends 28 * 5 = <<28*5=140>>140 on piano lessons per week. ки -Step 4: Janet spends 120 + 140 = <<120+140=260>>260 on music lessons per week. ки -Step 5: She spends 260 * 52 = <<260*52=13520>>13520 on music lessons in a year. ки -# Answer 13520 ки - Now answer the question: """ -NXT_PROMPT = """Please continue. -""" -SEP_TOKEN = "ки\n" +SEP_TOKEN = "\n\n" system_message = { "role": "system", "content": SYS_PROMPT, } -next_message = { - "role": "user", - "content": NXT_PROMPT, -} def check_terminate(answers: Union[str, List[str]]) -> List[bool]: if isinstance(answers, str): answers = [answers] results = [] for answer in answers: - results.append("# Answer" in answer) + results.append("\\boxed" in answer) return results class LanguageNode: @@ -190,7 +169,7 @@ def _expand(node: LanguageNode): "role": "assistant", "content": node.answer, } - infer_request = InferRequest([system_message, prompt_message, history_message, next_message]) + infer_request = InferRequest([system_message, prompt_message, history_message]) # 为了并行进行 Expand 操作,这里暂时不需要考虑顺序,因为 Prompt 是一样的 n = _args.num_return_sequences - len(node.children) @@ -267,8 +246,7 @@ def _rollout(node: LanguageNode): while len(active_rollout_nodes) > 0 and rollout_iter_index < _args.max_rollout_iterations: infer_requests = [InferRequest([system_message, prompt_message, - rollout_nodes[index]['history_messages'], - next_message]) + rollout_nodes[index]['history_messages']]) for index in active_rollout_nodes] responses = self.infer_engines[0].infer(infer_requests, self.rollout_request_config, From 88e6b402e484bb9348a795cb980ff395abeeb6c2 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Tue, 21 Jan 2025 14:49:58 +0800 Subject: [PATCH 10/52] unique log & expand pruning --- swift/experimental/sampling/mcts.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 993a148e49..ae71afd085 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -14,7 +14,9 @@ from typing import Union, List -logger = get_logger('./output/sampler/mcts.log') + +log_filename = f"./output/sampler/mcts/mcts_{time.strftime('%Y%m%d_%H%M%S')}.log" +logger = get_logger(log_filename) SYS_PROMPT = """You are a super intelligent AI, you can solve any math problem step by step. @@ -151,7 +153,7 @@ def _UCT(node: LanguageNode): exploitation_score = value exploration_score = (_args.exploration_rate - * np.sqrt(np.log(node.parent.visit_count) / (node.visit_count + 1))) + * np.sqrt(np.log(node.parent.visit_count + 1) / (node.visit_count + 1))) return exploration_score + exploitation_score @@ -189,10 +191,14 @@ def _expand(node: LanguageNode): # 为了并行获取 Outcome Reward,这里获得的 OR 是顺序返回的,所以可以直接对应 orm_infer_requests = [] + unique_output = set() # 用于去重 all_child_terminated = True for response in responses: self.update_usage_info(response) output = response.choices[0].message.content.rstrip(SEP_TOKEN + '\n').split(SEP_TOKEN)[0] + if output in unique_output: + continue + unique_output.add(output) orm_infer_requests.append(InferRequest([{"role": "assistant", "content": output}])) child = LanguageNode(step=output, parent=node) if check_terminate(child.answer)[0]: From 288c11a8a31541790144900f96309e6b385ce2e2 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Tue, 21 Jan 2025 15:53:53 +0800 Subject: [PATCH 11/52] logger time --- swift/experimental/sampling/mcts.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index ae71afd085..c2ae15d89f 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -321,15 +321,19 @@ def _collect(curr_node: LanguageNode): and len(terminate_incorrect) + len(terminate_correct) < _args.num_return_sequences and iter_count < _args.max_iterations): logger.info(f"iter_count: {iter_count}" + "." * 10) - logger.info("select" + "=" * 10) + s_time = time.time() curr_node = _select(_root) - logger.info("expand" + "=" * 10) + logger.info("select" + "=" * 10+ f"time: {time.time() - s_time}") + s_time = time.time() _expand(curr_node) + logger.info("expand" + "=" * 10 + f"time: {time.time() - s_time}") if curr_node.depth > 3: - logger.info("rollout" + "=" * 10) + s_time = time.time() _rollout(curr_node) - logger.info("back propagate" + "=" * 10) + logger.info("rollout" + "=" * 10 + f"time: {time.time() - s_time}") + s_time = time.time() _back_propagate(curr_node) + logger.info("back propagate" + "=" * 10 + f"time: {time.time() - s_time}") if len(correct_answers) + len(incorrect_answers) >= _args.num_return_sequences: if 4 * len(incorrect_answers) < len(correct_answers): logger.info("too easy" + "!" * 20) From 068ba81a101093b5d094df933b97e3efe5255df9 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Tue, 21 Jan 2025 16:57:12 +0800 Subject: [PATCH 12/52] rollout with multi-engines --- swift/experimental/sampling/mcts.py | 40 +++++++++++++++++++++++------ swift/plugin/orm.py | 7 ++--- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index c2ae15d89f..966108be78 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -163,7 +163,6 @@ def _select(node: LanguageNode): return node def _expand(node: LanguageNode): - # s_time = time.time() if node.is_root(): infer_request = InferRequest([system_message, prompt_message]) else: @@ -173,6 +172,7 @@ def _expand(node: LanguageNode): } infer_request = InferRequest([system_message, prompt_message, history_message]) + # e_time = time.time() # 为了并行进行 Expand 操作,这里暂时不需要考虑顺序,因为 Prompt 是一样的 n = _args.num_return_sequences - len(node.children) with ThreadPoolExecutor(max_workers=n) as executor: @@ -188,6 +188,7 @@ def _expand(node: LanguageNode): responses += future.result() except Exception as e: print(f"任务 {task_id} 执行请求时发生错误: {e}") + # logger.info(f"expand.expand time: {time.time() - e_time}") # 为了并行获取 Outcome Reward,这里获得的 OR 是顺序返回的,所以可以直接对应 orm_infer_requests = [] @@ -211,9 +212,11 @@ def _expand(node: LanguageNode): if not node.is_root(): node.parent.active_children.remove(node) + # e_time = time.time() orm_score, _orm_mask = get_reward( self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(orm_infer_requests), threshold=0.0) + # logger.info(f"expand.orm time: {time.time() - e_time}") for child, score in zip(node.children, orm_score): if child.terminated: child.init_and_update_value(score) @@ -221,8 +224,8 @@ def _expand(node: LanguageNode): terminate_correct.append(child.answer) else: terminate_incorrect.append(child.answer) - # logger.info(f"expand time: {time.time() - s_time}") - # s_time = time.time() + + # e_time = time.time() if self.prm_model: prm_infer_requests = [] for child in node.children: @@ -235,7 +238,7 @@ def _expand(node: LanguageNode): threshold=0.0) for child, score in zip(node.children, prm_score): child.process_reward = score - # logger.info(f"prm time: {time.time() - s_time}") + # logger.info(f"expand.prm time: {time.time() - e_time}") def _rollout(node: LanguageNode): rollout_iter_index = 0 @@ -250,25 +253,46 @@ def _rollout(node: LanguageNode): } active_rollout_nodes = list(rollout_nodes.keys()) while len(active_rollout_nodes) > 0 and rollout_iter_index < _args.max_rollout_iterations: + # r_time = time.time() infer_requests = [InferRequest([system_message, prompt_message, rollout_nodes[index]['history_messages']]) for index in active_rollout_nodes] - responses = self.infer_engines[0].infer(infer_requests, - self.rollout_request_config, - **self.infer_kwargs) + # logger.info(f"rollout.prepare time: {time.time() - r_time}") + # r_time = time.time() + n = len(infer_requests) + with ThreadPoolExecutor(max_workers=n) as executor: + futures = {executor.submit(perform_infer, + self.infer_engines[i], + infer_requests[i], + self.rollout_request_config, + **self.infer_kwargs): i for i in range(n)} + responses = [] + for future in as_completed(futures): + task_id = futures[future] + try: + responses += future.result() + except Exception as e: + print(f"任务 {task_id} 执行请求时发生错误: {e}") + # logger.info(f"rollout.infer time: {time.time() - r_time}") + + # r_time = time.time() orm_infer_requests = [] end_paths = [] for index, response in zip(active_rollout_nodes, responses): self.update_usage_info(response) - output = response.choices[0].message.content.rstrip(SEP_TOKEN + '\n').split(SEP_TOKEN)[0] + SEP_TOKEN + '\n' + output = response.choices[0].message.content.rstrip(SEP_TOKEN + '\n').split(SEP_TOKEN)[ + 0] + SEP_TOKEN + '\n' rollout_nodes[index]['history_messages']["content"] += output end_paths.append(rollout_nodes[index]['history_messages']["content"]) orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']])) + # logger.info(f"rollout.orm_prepare time: {time.time() - r_time}") + # r_time = time.time() orm_score, _orm_mask = get_reward( self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0) + # logger.info(f"rollout.get_orm tiem: {time.time() - r_time}") terminated_state = check_terminate(end_paths) for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state): if terminated: diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index 5d20c9185a..a990126e54 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -175,6 +175,9 @@ def __init__(self): super().__init__() from transformers.utils import strtobool self.use_opencompass = strtobool(os.environ.get('USE_OPENCOMPASS_EVALUATOR')) + if self.use_opencompass: + from opencompass.datasets.math import MATHEvaluator + self.evaluator = MATHEvaluator() @staticmethod def extract_boxed_result(text): @@ -228,9 +231,7 @@ def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], prediction = MathORM.extract_boxed_result(prediction) ground_truth = MathORM.extract_boxed_result(ground_truth) if self.use_opencompass: - from opencompass.datasets.math import MATHEvaluator - evaluator = MATHEvaluator() - rewards.append(evaluator.is_equiv(prediction, ground_truth)) + rewards.append(self.evaluator.is_equiv(prediction, ground_truth)) else: rewards.append(MathORM.compare_consecutive(prediction, ground_truth)) From b2f5c601dd1d74aca2d58e2187ee8d41902493ef Mon Sep 17 00:00:00 2001 From: LiuXL Date: Tue, 21 Jan 2025 17:46:46 +0800 Subject: [PATCH 13/52] add next-prompt --- swift/experimental/sampling/mcts.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 966108be78..eee32db5c1 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -19,18 +19,26 @@ logger = get_logger(log_filename) -SYS_PROMPT = """You are a super intelligent AI, you can solve any math problem step by step. - +SYS_PROMPT = """You are a super intelligent AI, you can solve any math problem step by step. +Each step should end with 'ки'. Now answer the question: """ -SEP_TOKEN = "\n\n" +NXT_PROMPT = """Continue. +""" + +SEP_TOKEN = "ки" system_message = { "role": "system", "content": SYS_PROMPT, } +next_message = { + "role": "user", + "content": NXT_PROMPT, +} + def check_terminate(answers: Union[str, List[str]]) -> List[bool]: if isinstance(answers, str): answers = [answers] @@ -170,7 +178,7 @@ def _expand(node: LanguageNode): "role": "assistant", "content": node.answer, } - infer_request = InferRequest([system_message, prompt_message, history_message]) + infer_request = InferRequest([system_message, prompt_message, history_message, next_message]) # e_time = time.time() # 为了并行进行 Expand 操作,这里暂时不需要考虑顺序,因为 Prompt 是一样的 @@ -256,7 +264,8 @@ def _rollout(node: LanguageNode): # r_time = time.time() infer_requests = [InferRequest([system_message, prompt_message, - rollout_nodes[index]['history_messages']]) + rollout_nodes[index]['history_messages'], + next_message]) for index in active_rollout_nodes] # logger.info(f"rollout.prepare time: {time.time() - r_time}") # r_time = time.time() From 6e76f6d767f3ecc41bc812bf8f52fccffd323ba9 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Wed, 22 Jan 2025 10:13:43 +0800 Subject: [PATCH 14/52] change args --- swift/experimental/sampling/mcts.py | 37 +++++++++----------- swift/experimental/sampling/sampling_args.py | 10 ++++++ 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index eee32db5c1..b8b4bc5664 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -15,25 +15,11 @@ from typing import Union, List -log_filename = f"./output/sampler/mcts/mcts_{time.strftime('%Y%m%d_%H%M%S')}.log" -logger = get_logger(log_filename) - - -SYS_PROMPT = """You are a super intelligent AI, you can solve any math problem step by step. -Each step should end with 'ки'. -Now answer the question: -""" +logger = get_logger() NXT_PROMPT = """Continue. """ -SEP_TOKEN = "ки" - -system_message = { - "role": "system", - "content": SYS_PROMPT, -} - next_message = { "role": "user", "content": NXT_PROMPT, @@ -51,11 +37,18 @@ class LanguageNode: def __init__(self, step: str = None, + sep_token: str = None, parent: "LanguageNode" = None,): self.parent = parent + + if sep_token: + self.sep_token = sep_token + else: + self.sep_token = parent.sep_token + if parent: self.path = parent.path[:] + [step] - self.answer = parent.answer + step + SEP_TOKEN + self.answer = parent.answer + step + self.sep_token self.depth = parent.depth + 1 else: self.path = [] @@ -133,7 +126,7 @@ def _prepare_template(self) -> None: def _prepare_request_configs(self): _args = self.args request_config = _args.get_request_config() - request_config.stop = [SEP_TOKEN] + request_config.stop = _args.stop_words request_config.seed = _args.seed self.expand_request_configs = [] for i in range(_args.num_return_sequences): @@ -204,7 +197,7 @@ def _expand(node: LanguageNode): all_child_terminated = True for response in responses: self.update_usage_info(response) - output = response.choices[0].message.content.rstrip(SEP_TOKEN + '\n').split(SEP_TOKEN)[0] + output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0] if output in unique_output: continue unique_output.add(output) @@ -290,8 +283,8 @@ def _rollout(node: LanguageNode): end_paths = [] for index, response in zip(active_rollout_nodes, responses): self.update_usage_info(response) - output = response.choices[0].message.content.rstrip(SEP_TOKEN + '\n').split(SEP_TOKEN)[ - 0] + SEP_TOKEN + '\n' + output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[ + 0] + sep_token + '\n' rollout_nodes[index]['history_messages']["content"] += output end_paths.append(rollout_nodes[index]['history_messages']["content"]) orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']])) @@ -340,7 +333,9 @@ def _collect(curr_node: LanguageNode): return results _args = self.args - _root = LanguageNode() + system_message = _args.system_message + sep_token = _args.stop_words[0] + _root = LanguageNode(sep_token=sep_token) prompt_message = { "role": "user", "content": query, diff --git a/swift/experimental/sampling/sampling_args.py b/swift/experimental/sampling/sampling_args.py index 55e8e25ad7..654a61d489 100644 --- a/swift/experimental/sampling/sampling_args.py +++ b/swift/experimental/sampling/sampling_args.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os import dataclasses from dataclasses import dataclass from datetime import datetime @@ -70,4 +71,13 @@ def __post_init__(self): self.engine_kwargs = json.loads(self.engine_kwargs) else: self.engine_kwargs = {} + + if os.path.isfile(self.system): + with open(self.system, 'r') as f: + self.system = f.read() + self.system_message = { + "role": "system", + "content": self.system, + } + super().__post_init__() From 4331c1754d679b091078b359ded7fc827ba081c2 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Wed, 22 Jan 2025 11:35:58 +0800 Subject: [PATCH 15/52] prefer add ground_truth --- swift/experimental/sampling/mcts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index b8b4bc5664..e5ceb1ca17 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -324,6 +324,7 @@ def _collect(curr_node: LanguageNode): if curr_node.children[-1].outcome_reward - curr_node.children[0].outcome_reward > 0.6: results.append(json.dumps({ "query": query, + "ground_truth": ground_truth, "path": curr_node.path, "good": curr_node.children[-1].path[-1], "good_score": curr_node.children[-1].outcome_reward, From 15aac9213ecf8d1336f6e056e15cb6aad7470404 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Wed, 22 Jan 2025 14:06:30 +0800 Subject: [PATCH 16/52] collect_filter_threshold --- swift/experimental/sampling/mcts.py | 3 ++- swift/experimental/sampling/sampling_args.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index e5ceb1ca17..2a4056890a 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -321,7 +321,7 @@ def _collect(curr_node: LanguageNode): for child in curr_node.children: results += _collect(child) curr_node.children = sorted(curr_node.children) - if curr_node.children[-1].outcome_reward - curr_node.children[0].outcome_reward > 0.6: + if curr_node.children[-1].outcome_reward - curr_node.children[0].outcome_reward > collect_filter_threshold: results.append(json.dumps({ "query": query, "ground_truth": ground_truth, @@ -336,6 +336,7 @@ def _collect(curr_node: LanguageNode): _args = self.args system_message = _args.system_message sep_token = _args.stop_words[0] + collect_filter_threshold = _args.collect_filter_threshold _root = LanguageNode(sep_token=sep_token) prompt_message = { "role": "user", diff --git a/swift/experimental/sampling/sampling_args.py b/swift/experimental/sampling/sampling_args.py index 654a61d489..8f58c4c417 100644 --- a/swift/experimental/sampling/sampling_args.py +++ b/swift/experimental/sampling/sampling_args.py @@ -48,6 +48,7 @@ class SamplingArguments(BaseArguments): max_iterations: int = 100 process_reward_rate: float = 0.0 exploration_rate: float = 0.5 + collect_filter_threshold: float = 0.5 def _init_model_info(self): if self.sampler_engine != 'client': From 44546a4d249920714b94433f5fe74d147ae50bc7 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Wed, 22 Jan 2025 17:15:10 +0800 Subject: [PATCH 17/52] base_url and api_key in args --- swift/experimental/sampling/mcts.py | 5 ++--- swift/experimental/sampling/sampling_args.py | 4 ++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 2a4056890a..1a2a2a26d2 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -93,10 +93,9 @@ def _prepare_model_tokenizer(self): args = self.args self.infer_kwargs = {} if args.sampler_engine == 'client': - import os from swift.llm import InferClient - api_key = os.getenv('DASHSCOPE_API_KEY') - base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1' + api_key = args.api_key + base_url = args.base_url self.infer_engines = [InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences)] self.infer_kwargs['model'] = args.model else: diff --git a/swift/experimental/sampling/sampling_args.py b/swift/experimental/sampling/sampling_args.py index 8f58c4c417..fdbbb7e4ef 100644 --- a/swift/experimental/sampling/sampling_args.py +++ b/swift/experimental/sampling/sampling_args.py @@ -49,6 +49,8 @@ class SamplingArguments(BaseArguments): process_reward_rate: float = 0.0 exploration_rate: float = 0.5 collect_filter_threshold: float = 0.5 + api_key: str = "EMPTY" + base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1' def _init_model_info(self): if self.sampler_engine != 'client': @@ -80,5 +82,7 @@ def __post_init__(self): "role": "system", "content": self.system, } + if self.sampler_type == 'mcts' and self.sampler_engine != 'client': + raise ValueError(f'`mcts` sampler only supports `client` engine yet, but now is: {self.sampler_engine}') super().__post_init__() From f7a71f60157565d36b99f618bc7c72e5252dcfb2 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Wed, 22 Jan 2025 17:28:17 +0800 Subject: [PATCH 18/52] client prm --- swift/plugin/prm.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/swift/plugin/prm.py b/swift/plugin/prm.py index c16ebb280a..630550d082 100644 --- a/swift/plugin/prm.py +++ b/swift/plugin/prm.py @@ -104,16 +104,20 @@ def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], ] -class QwenPlusPRM(PRM): - def __init__(self): +class ClientPRM(PRM): + def __init__(self, api_key = None, base_url = None, model = None): super().__init__() - import os from swift.llm import InferClient - api_key = os.getenv('DASHSCOPE_API_KEY') - base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1' + import os + if api_key is None: + api_key = os.getenv('DASHSCOPE_API_KEY') + if base_url is None: + base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1' + if model is None: + model = 'qwen-plus' self.infer_engine = InferClient(base_url=base_url, api_key=api_key) self.infer_kwargs = { - 'model': 'qwen-plus', + 'model': model, } def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], @@ -166,5 +170,5 @@ def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], prms = { 'qwen_max': QwenMaxPRM, - 'qwen_plus': QwenPlusPRM, + 'client': ClientPRM, } From 78fde7f407fbcb2be16451c6c6c8f1e226a7f999 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Wed, 22 Jan 2025 19:56:48 +0800 Subject: [PATCH 19/52] update generated results --- swift/experimental/sampling/mcts.py | 41 ++++++++++++++++++----------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 1a2a2a26d2..6f974bc773 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -321,15 +321,13 @@ def _collect(curr_node: LanguageNode): results += _collect(child) curr_node.children = sorted(curr_node.children) if curr_node.children[-1].outcome_reward - curr_node.children[0].outcome_reward > collect_filter_threshold: - results.append(json.dumps({ - "query": query, - "ground_truth": ground_truth, + results.append({ "path": curr_node.path, "good": curr_node.children[-1].path[-1], "good_score": curr_node.children[-1].outcome_reward, "bad": curr_node.children[0].path[-1], "bad_score": curr_node.children[0].outcome_reward, - }, ensure_ascii=False) + '\n') + }) return results _args = self.args @@ -366,24 +364,36 @@ def _collect(curr_node: LanguageNode): if len(correct_answers) + len(incorrect_answers) >= _args.num_return_sequences: if 4 * len(incorrect_answers) < len(correct_answers): logger.info("too easy" + "!" * 20) - logger.info(f"correct_answers: {correct_answers}") - logger.info(f"incorrect_answers: {incorrect_answers}") + #logger.info(f"correct_answers: {correct_answers}") + #logger.info(f"incorrect_answers: {incorrect_answers}") too_easy = True elif 4 * len(correct_answers) < len(incorrect_answers): logger.info("too hard" + "!" * 20) - logger.info(f"correct_answers: {correct_answers}") - logger.info(f"incorrect_answers: {incorrect_answers}") + #logger.info(f"correct_answers: {correct_answers}") + #logger.info(f"incorrect_answers: {incorrect_answers}") too_hard = True iter_count += 1 if iter_count == _args.max_iterations: logger.info("too hard" + "!" * 20) - logger.info(f"correct_answers: {correct_answers}") - logger.info(f"incorrect_answers: {incorrect_answers}") too_hard = True - if not too_easy and not too_hard: - prefer_pair = _collect(_root) - logger.info(f"prefer_pair: {prefer_pair}") - return prefer_pair + #logger.info(f"correct_answers: {correct_answers}") + #logger.info(f"incorrect_answers: {incorrect_answers}") + prefer_pair = _collect(_root) + #logger.info(f"prefer_pair: {prefer_pair}") + + result = { + "query": query, + "ground_truth": ground_truth, + "prefer_pair": prefer_pair, + "correct_answers": correct_answers, + "incorrect_answers": incorrect_answers, + "terminate_correct": terminate_correct, + "terminate_incorrect": terminate_incorrect, + } + results = json.dumps(result, ensure_ascii=False) + logger.info(results) + + return results def do_sample(self, data): if not isinstance(data, list): @@ -395,8 +405,7 @@ def do_sample(self, data): messages = item['messages'][0] query = messages[0]['content'] ground_truth = messages[1]['content'] - prefer_pairs = self.search_single(query, ground_truth) - generated += prefer_pairs + generated.append(self.search_single(query, ground_truth)) except Exception as e: logger.error(f"Error: {e}") return generated \ No newline at end of file From d5808c856ec37cd2f1e9f67c4f8fc4c76b5eb604 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Wed, 22 Jan 2025 20:20:32 +0800 Subject: [PATCH 20/52] check terminated in orm --- swift/experimental/sampling/mcts.py | 13 ++----------- swift/plugin/orm.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 6f974bc773..f1189acf2c 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -12,8 +12,6 @@ from .utils import get_reward, perform_infer from .sampling_args import SamplingArguments -from typing import Union, List - logger = get_logger() @@ -25,13 +23,6 @@ "content": NXT_PROMPT, } -def check_terminate(answers: Union[str, List[str]]) -> List[bool]: - if isinstance(answers, str): - answers = [answers] - results = [] - for answer in answers: - results.append("\\boxed" in answer) - return results class LanguageNode: @@ -202,7 +193,7 @@ def _expand(node: LanguageNode): unique_output.add(output) orm_infer_requests.append(InferRequest([{"role": "assistant", "content": output}])) child = LanguageNode(step=output, parent=node) - if check_terminate(child.answer)[0]: + if self.orm_model.check_terminate(child.answer)[0]: child.terminated = True else: all_child_terminated = False @@ -294,7 +285,7 @@ def _rollout(node: LanguageNode): self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0) # logger.info(f"rollout.get_orm tiem: {time.time() - r_time}") - terminated_state = check_terminate(end_paths) + terminated_state = self.orm_model.check_terminate(end_paths) for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state): if terminated: node.active_children[index].outcome_reward = score diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index a990126e54..ea0e378c24 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -1,6 +1,6 @@ import os import re -from typing import List +from typing import List, Union import json import torch @@ -179,6 +179,15 @@ def __init__(self): from opencompass.datasets.math import MATHEvaluator self.evaluator = MATHEvaluator() + @staticmethod + def check_terminate(answers: Union[str, List[str]]) -> List[bool]: + if isinstance(answers, str): + answers = [answers] + results = [] + for answer in answers: + results.append("\\boxed" in answer) + return results + @staticmethod def extract_boxed_result(text): pattern = r'\\boxed{([^}]*)}' From 1f74f6cdfd64d44084c7fe7f2f0081a9ed1a2156 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Thu, 23 Jan 2025 14:45:43 +0800 Subject: [PATCH 21/52] rollout args change --- swift/experimental/sampling/mcts.py | 4 ++-- swift/experimental/sampling/sampling_args.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index f1189acf2c..6e22ce80f1 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -243,7 +243,7 @@ def _rollout(node: LanguageNode): }, } active_rollout_nodes = list(rollout_nodes.keys()) - while len(active_rollout_nodes) > 0 and rollout_iter_index < _args.max_rollout_iterations: + while len(active_rollout_nodes) > 0 and rollout_iter_index < _args.rollout_depth: # r_time = time.time() infer_requests = [InferRequest([system_message, prompt_message, @@ -345,7 +345,7 @@ def _collect(curr_node: LanguageNode): s_time = time.time() _expand(curr_node) logger.info("expand" + "=" * 10 + f"time: {time.time() - s_time}") - if curr_node.depth > 3: + if curr_node.depth > _args.rollout_start_depth: s_time = time.time() _rollout(curr_node) logger.info("rollout" + "=" * 10 + f"time: {time.time() - s_time}") diff --git a/swift/experimental/sampling/sampling_args.py b/swift/experimental/sampling/sampling_args.py index fdbbb7e4ef..f0c6da8113 100644 --- a/swift/experimental/sampling/sampling_args.py +++ b/swift/experimental/sampling/sampling_args.py @@ -44,7 +44,8 @@ class SamplingArguments(BaseArguments): cache_files: List[str] = dataclasses.field(default_factory=list) # MCTS - max_rollout_iterations: int = 5 + rollout_depth: int = 5 + rollout_start_depth: int = 3 max_iterations: int = 100 process_reward_rate: float = 0.0 exploration_rate: float = 0.5 From aacf1b60c7de1f50655a8d7bf4ecca5163357bb5 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Thu, 23 Jan 2025 16:01:52 +0800 Subject: [PATCH 22/52] stop_words \n --- swift/experimental/sampling/mcts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 6e22ce80f1..622d3dff9f 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -323,7 +323,7 @@ def _collect(curr_node: LanguageNode): _args = self.args system_message = _args.system_message - sep_token = _args.stop_words[0] + sep_token = _args.stop_words[0] + '\n' collect_filter_threshold = _args.collect_filter_threshold _root = LanguageNode(sep_token=sep_token) prompt_message = { From 48caf737793333791bc92d997e82f8a3fa7c8980 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 24 Jan 2025 11:33:11 +0800 Subject: [PATCH 23/52] sys_prompt from file --- ...21\275\344\273\244\350\241\214\345\217\202\346\225\260.md" | 2 +- docs/source_en/Instruction/Command-line-parameters.md | 2 +- swift/llm/argument/base_args/template_args.py | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index ebb96038f8..00dd6dafe6 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -44,7 +44,7 @@ ### 模板参数 - 🔥template: 对话模板类型,默认使用model对应的template类型。`swift pt`会将对话模版转为生成模板使用 -- 🔥system: 自定义system字段,默认为None,使用template的默认system +- 🔥system: 自定义system字段,可以是一个txt文件地址,默认为None,使用template的默认system - 🔥max_length: 单样本的tokens最大长度。默认为None,设置为模型支持的tokens最大长度(max_model_len) - truncation_strategy: 如果超长如何处理,支持`delete`, `left`和`right`,代表删除、左侧裁剪和右侧裁剪,默认为'delete' - 🔥max_pixels: 多模态模型图片前处理的最大像素数(H\*W),默认不缩放。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index ba59c970b0..e9f28d5bc0 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -44,7 +44,7 @@ The introduction to command line parameters will cover base arguments, atomic ar ### Template Arguments - 🔥template: Type of dialogue template, which defaults to the template type corresponding to the model. `swift pt` will convert the dialogue template into a generation template for use. -- 🔥system: Custom system field, default is None, uses the default system of the template. +- 🔥system: Custom system field, could be a txt file path, default is None, uses the default system of the template. - 🔥max_length: The maximum length of tokens for a single sample. Defaults to None, set to the maximum length of tokens supported by the model (max_model_len). - truncation_strategy: How to handle overly long tokens, supports `delete`, `left`, `right`, representing deletion, left trimming, and right trimming, default is 'delete'. - 🔥max_pixels: Maximum pixel count for pre-processing images in multimodal models (H*W), default is no scaling. diff --git a/swift/llm/argument/base_args/template_args.py b/swift/llm/argument/base_args/template_args.py index 98c4d80a72..2dd5ff47cf 100644 --- a/swift/llm/argument/base_args/template_args.py +++ b/swift/llm/argument/base_args/template_args.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os from dataclasses import dataclass, field from typing import Literal, Optional @@ -46,6 +47,9 @@ class TemplateArguments: def __post_init__(self): if self.template is None and hasattr(self, 'model_meta'): self.template = self.model_meta.template + if self.system.endswith('.txt') and os.path.isfile(self.system): + with open(self.system, 'r') as f: + self.system = f.read() def get_template_kwargs(self): truncation_strategy = self.truncation_strategy From 6327e6e9c2fdd8990259866d4865383a9513e034 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 24 Jan 2025 14:29:31 +0800 Subject: [PATCH 24/52] fix --- swift/llm/argument/base_args/template_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/argument/base_args/template_args.py b/swift/llm/argument/base_args/template_args.py index 2dd5ff47cf..23a6c79278 100644 --- a/swift/llm/argument/base_args/template_args.py +++ b/swift/llm/argument/base_args/template_args.py @@ -47,7 +47,7 @@ class TemplateArguments: def __post_init__(self): if self.template is None and hasattr(self, 'model_meta'): self.template = self.model_meta.template - if self.system.endswith('.txt') and os.path.isfile(self.system): + if self.system is not None and self.system.endswith('.txt') and os.path.isfile(self.system): with open(self.system, 'r') as f: self.system = f.read() From 21c576da002e4b5a5fb554de61da8d472e980b1c Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 24 Jan 2025 16:16:01 +0800 Subject: [PATCH 25/52] terminate state back propagate --- swift/experimental/sampling/mcts.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/swift/experimental/sampling/mcts.py b/swift/experimental/sampling/mcts.py index 622d3dff9f..a92d0c2464 100644 --- a/swift/experimental/sampling/mcts.py +++ b/swift/experimental/sampling/mcts.py @@ -184,7 +184,6 @@ def _expand(node: LanguageNode): # 为了并行获取 Outcome Reward,这里获得的 OR 是顺序返回的,所以可以直接对应 orm_infer_requests = [] unique_output = set() # 用于去重 - all_child_terminated = True for response in responses: self.update_usage_info(response) output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0] @@ -198,10 +197,6 @@ def _expand(node: LanguageNode): else: all_child_terminated = False node.add_child(child) - if all_child_terminated: - node.terminated = True - if not node.is_root(): - node.parent.active_children.remove(node) # e_time = time.time() orm_score, _orm_mask = get_reward( @@ -302,6 +297,8 @@ def _back_propagate(curr_node: LanguageNode): best_child_value = max([child.outcome_reward for child in curr_node.children]) curr_node.init_and_update_value(best_child_value) curr_node.visit() + if len(curr_node.active_children) == 0 and not curr_node.is_root(): + curr_node.parent.active_children.remove(curr_node) curr_node = curr_node.parent def _collect(curr_node: LanguageNode): From 7ec1d1aa3674ce7791c966c233519411ef26f3eb Mon Sep 17 00:00:00 2001 From: Leoyzen Date: Thu, 23 Jan 2025 15:55:21 +0800 Subject: [PATCH 26/52] add "enable_prefix_caching" args for vllm engine. (#2939) --- ...44\350\241\214\345\217\202\346\225\260.md" | 1 + .../Instruction/Command-line-parameters.md | 1 + swift/llm/argument/infer_args.py | 3 + swift/llm/infer/infer_engine/vllm_engine.py | 83 ++++++++++--------- 4 files changed, 51 insertions(+), 37 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 0c893a530b..4ba73a7ba1 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -277,6 +277,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数. - enforce_eager: vllm使用pytorch eager模式还是建立cuda graph. 默认为`False`. 设置为True可以节约显存, 但会影响效率. - 🔥limit_mm_per_prompt: 控制vllm使用多图, 默认为`None`. 例如传入`--limit_mm_per_prompt '{"image": 10, "video": 5}'` - vllm_max_lora_rank: 默认为`16`. vllm对于lora支持的参数 +- enable_prefix_caching: 是否开启 vllm 的 Prefix Caching 能力. 默认为`False`. 设置为 True 可以节约重复请求前缀(例如 System Prompt,长文档或多轮对话)处理时间。 ### 合并参数 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index c58781e7ca..a80d0851e5 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -282,6 +282,7 @@ Parameter meanings can be found in the [vllm documentation](https://docs.vllm.ai - enforce_eager: Whether vllm uses pytorch eager mode or establishes a cuda graph. Default is `False`. Setting to True can save memory but may affect efficiency. - 🔥limit_mm_per_prompt: Controls vllm using multiple images, default is `None`. For example, use `--limit_mm_per_prompt '{"image": 10, "video": 5}'`. - vllm_max_lora_rank: Default value is `16`. Parameters supported by vllm for LoRA. +- enable_prefix_caching: Whether enable `Automatic Prefix Caching` feature for vllm. Default is `False`. Setting to True can save processing time for repeatable request prefix(such as system prompt, long docs, or multi-turn dialog, etc). ### Merge Arguments diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index 39cce3206c..d2172ddc5f 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -61,6 +61,7 @@ class VllmArguments: enforce_eager (bool): Flag to enforce eager execution. Default is False. limit_mm_per_prompt (Optional[str]): Limit multimedia per prompt. Default is None. vllm_max_lora_rank (int): Maximum LoRA rank. Default is 16. + enable_prefix_caching (bool): Flag to enable automatic prefix caching. Default is False. """ # vllm gpu_memory_utilization: float = 0.9 @@ -72,6 +73,7 @@ class VllmArguments: enforce_eager: bool = False limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 10, "video": 5}' vllm_max_lora_rank: int = 16 + enable_prefix_caching: bool = False def __post_init__(self): self.limit_mm_per_prompt = ModelArguments.parse_to_dict(self.limit_mm_per_prompt) @@ -92,6 +94,7 @@ def get_vllm_engine_kwargs(self): 'max_lora_rank': self.vllm_max_lora_rank, 'enable_lora': len(adapters) > 0, 'max_loras': max(len(adapters), 1), + 'enable_prefix_caching': self.enable_prefix_caching, } diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 03ebded147..4b02c60c4c 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -34,28 +34,30 @@ class VllmEngine(InferEngine): def __init__( - self, - model_id_or_path: str, - torch_dtype: Optional[torch.dtype] = None, - *, - model_type: Optional[str] = None, - use_hf: Optional[bool] = None, - hub_token: Optional[str] = None, - revision: Optional[str] = None, - # engine_kwargs - gpu_memory_utilization: float = 0.9, - tensor_parallel_size: int = 1, - pipeline_parallel_size: int = 1, - max_model_len: Optional[int] = None, - max_num_seqs: int = 256, - disable_custom_all_reduce: bool = False, - enforce_eager: bool = False, - limit_mm_per_prompt: Optional[Dict[str, Any]] = None, - # lora - enable_lora: bool = False, - max_loras: int = 1, - max_lora_rank: int = 16, - engine_kwargs: Optional[Dict[str, Any]] = None) -> None: + self, + model_id_or_path: str, + torch_dtype: Optional[torch.dtype] = None, + *, + model_type: Optional[str] = None, + use_hf: Optional[bool] = None, + hub_token: Optional[str] = None, + revision: Optional[str] = None, + # engine_kwargs + gpu_memory_utilization: float = 0.9, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: Optional[int] = None, + max_num_seqs: int = 256, + disable_custom_all_reduce: bool = False, + enforce_eager: bool = False, + limit_mm_per_prompt: Optional[Dict[str, Any]] = None, + # lora + enable_lora: bool = False, + max_loras: int = 1, + max_lora_rank: int = 16, + enable_prefix_caching: bool = False, + engine_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: self.processor = get_model_tokenizer( model_id_or_path, torch_dtype, @@ -79,7 +81,9 @@ def __init__( enable_lora=enable_lora, max_loras=max_loras, max_lora_rank=max_lora_rank, - engine_kwargs=engine_kwargs) + enable_prefix_caching=enable_prefix_caching, + engine_kwargs=engine_kwargs, + ) self._prepare_engine() self._load_generation_config() @@ -91,19 +95,22 @@ def _prepare_engine(self) -> None: engine = AsyncLLMEngine.from_engine_args(self.engine_args) self.engine = engine - def _prepare_engine_kwargs(self, - gpu_memory_utilization: float = 0.9, - tensor_parallel_size: int = 1, - pipeline_parallel_size: int = 1, - max_model_len: Optional[int] = None, - max_num_seqs: int = 256, - disable_custom_all_reduce: bool = False, - enforce_eager: bool = False, - limit_mm_per_prompt: Optional[Dict[str, Any]] = None, - enable_lora: bool = False, - max_loras: int = 1, - max_lora_rank: int = 16, - engine_kwargs: Optional[Dict[str, Any]] = None) -> None: + def _prepare_engine_kwargs( + self, + gpu_memory_utilization: float = 0.9, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: Optional[int] = None, + max_num_seqs: int = 256, + disable_custom_all_reduce: bool = False, + enforce_eager: bool = False, + limit_mm_per_prompt: Optional[Dict[str, Any]] = None, + enable_lora: bool = False, + max_loras: int = 1, + max_lora_rank: int = 16, + enable_prefix_caching: bool = False, + engine_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: if engine_kwargs is None: engine_kwargs = {} disable_log_stats = engine_kwargs.pop('disable_log_stats', True) @@ -136,7 +143,9 @@ def _prepare_engine_kwargs(self, disable_custom_all_reduce=disable_custom_all_reduce, enforce_eager=enforce_eager, trust_remote_code=True, - **engine_kwargs) + enable_prefix_caching=enable_prefix_caching, + **engine_kwargs, + ) self.engine_args = engine_args self.enable_lora = enable_lora if max_model_len is not None: From c2cebd0a4a69f0887903e720c9b29cbcdc78be1a Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 23 Jan 2025 17:06:28 +0800 Subject: [PATCH 27/52] Fix vllm docs link & fix web-ui (#2970) --- ...75\344\273\244\350\241\214\345\217\202\346\225\260.md" | 4 ++-- docs/source_en/Instruction/Command-line-parameters.md | 4 ++-- swift/ui/llm_train/lora.py | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 4ba73a7ba1..ffc4f2afac 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -266,7 +266,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数. - vision_batch_size: 默认值`1` ### vLLM参数 -参数含义可以查看[vllm文档](https://docs.vllm.ai/en/latest/models/engine_args.html) +参数含义可以查看[vllm文档](https://docs.vllm.ai/en/latest/serving/engine_args.html) - 🔥gpu_memory_utilization: 默认值`0.9` - 🔥tensor_parallel_size: 默认为`1` @@ -277,7 +277,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数. - enforce_eager: vllm使用pytorch eager模式还是建立cuda graph. 默认为`False`. 设置为True可以节约显存, 但会影响效率. - 🔥limit_mm_per_prompt: 控制vllm使用多图, 默认为`None`. 例如传入`--limit_mm_per_prompt '{"image": 10, "video": 5}'` - vllm_max_lora_rank: 默认为`16`. vllm对于lora支持的参数 -- enable_prefix_caching: 是否开启 vllm 的 Prefix Caching 能力. 默认为`False`. 设置为 True 可以节约重复请求前缀(例如 System Prompt,长文档或多轮对话)处理时间。 +- enable_prefix_caching: 开启vllm的自动前缀缓存,节约重复查询前缀的处理时间。默认为`False` ### 合并参数 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index a80d0851e5..fa5339175c 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -271,7 +271,7 @@ Parameter meanings can be found in the [lmdeploy documentation](https://lmdeploy ### vLLM Arguments -Parameter meanings can be found in the [vllm documentation](https://docs.vllm.ai/en/latest/models/engine_args.html). +Parameter meanings can be found in the [vllm documentation](https://docs.vllm.ai/en/latest/serving/engine_args.html). - 🔥gpu_memory_utilization: Default value is `0.9`. - 🔥tensor_parallel_size: Default is `1`. @@ -282,7 +282,7 @@ Parameter meanings can be found in the [vllm documentation](https://docs.vllm.ai - enforce_eager: Whether vllm uses pytorch eager mode or establishes a cuda graph. Default is `False`. Setting to True can save memory but may affect efficiency. - 🔥limit_mm_per_prompt: Controls vllm using multiple images, default is `None`. For example, use `--limit_mm_per_prompt '{"image": 10, "video": 5}'`. - vllm_max_lora_rank: Default value is `16`. Parameters supported by vllm for LoRA. -- enable_prefix_caching: Whether enable `Automatic Prefix Caching` feature for vllm. Default is `False`. Setting to True can save processing time for repeatable request prefix(such as system prompt, long docs, or multi-turn dialog, etc). +- enable_prefix_caching: Enable the automatic prefix caching of vllm to save processing time for querying repeated prefixes. The default is `False`. ### Merge Arguments diff --git a/swift/ui/llm_train/lora.py b/swift/ui/llm_train/lora.py index eeba2c340b..e5fbe5d10f 100644 --- a/swift/ui/llm_train/lora.py +++ b/swift/ui/llm_train/lora.py @@ -23,8 +23,8 @@ class LoRA(BaseUI): 'en': 'LoRA target modules' }, 'info': { - 'zh': '设置LoRA目标模块,如训练所有Linear请改为ALL', - 'en': 'Set the LoRA target modules, fill in ALL if train all Linears' + 'zh': '设置LoRA目标模块,如训练所有Linear请改为`all-linear`', + 'en': 'Set the LoRA target modules, fill in `all-linear` if train all Linears' } }, 'lora_rank': { @@ -91,8 +91,8 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): with gr.Blocks(): with gr.Row(): gr.Textbox(elem_id='target_modules', lines=1, scale=5, value='all-linear', is_list=True) - gr.Slider(elem_id='lora_rank', value=32, minimum=1, maximum=512, step=8, scale=2) - gr.Slider(elem_id='lora_alpha', value=8, minimum=1, maximum=512, step=8, scale=2) + gr.Slider(elem_id='lora_rank', value=8, minimum=1, maximum=512, step=8, scale=2) + gr.Slider(elem_id='lora_alpha', value=32, minimum=1, maximum=512, step=8, scale=2) gr.Textbox(elem_id='lora_dropout', scale=2) with gr.Row(): gr.Dropdown(elem_id='lora_dtype', scale=2, value=None) From d363d84d8797942116d0ed679ed4b3ed63f62fe7 Mon Sep 17 00:00:00 2001 From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Thu, 23 Jan 2025 17:24:42 +0800 Subject: [PATCH 28/52] Fix sample (#2971) --- scripts/rft/rft.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/scripts/rft/rft.py b/scripts/rft/rft.py index c931e2ec45..4162d501b4 100644 --- a/scripts/rft/rft.py +++ b/scripts/rft/rft.py @@ -28,10 +28,11 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int): f'--sampler_engine lmdeploy ' f'--max_new_tokens 768 ' f'--override_exist_file true ' - f'--num_sampling_per_gpu_batch_size 2 ' + f'--num_sampling_per_gpu_batch_size 1 ' f'--num_return_sequences 64 ' f'--cache_files sample_output/iter_{iter}_proc_{device}_cache.jsonl ' f'--output_file iter_{iter}_proc_{device}_cache.jsonl ' + f'--top_p 1.0 ' f'--temperature 1.0 ') print(f'Sampling caches of iter {iter}, part {device}.', flush=True) env = os.environ.copy() @@ -60,10 +61,11 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int): f'--load_args false ' f'--sampler_engine no ' f'--orm_model math ' - f'--prm_model AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft ' + f'--prm_model Qwen/Qwen2.5-Math-PRM-7B ' + f'--prm_threshold 0.7 ' f'--max_new_tokens 768 ' f'--override_exist_file true ' - f'--num_sampling_per_gpu_batch_size 2 ' + f'--num_sampling_per_gpu_batch_size 1 ' f'--num_return_sequences 64 ' f'--output_file iter_{iter}_proc_{device}_sampling.jsonl ' f'--cache_files sample_output/iter_{iter}_proc_{device}_cache.jsonl ') From edbf4b6760435074cec8a8af90ea437b472614f0 Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 23 Jan 2025 17:54:22 +0800 Subject: [PATCH 29/52] support merge-lora & quant (#2973) --- ...75\344\273\244\350\241\214\345\217\202\346\225\260.md" | 2 ++ docs/source_en/Instruction/Command-line-parameters.md | 2 ++ swift/llm/argument/export_args.py | 8 +++++--- swift/llm/export/export.py | 8 ++++++-- swift/llm/export/merge_lora.py | 7 ++++--- 5 files changed, 19 insertions(+), 8 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index ffc4f2afac..d62a8f1e78 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -414,6 +414,8 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数) - quant_batch_size: 量化batch_size,默认为1 - group_size: 量化group大小,默认为128 +- exist_ok: 如果存在,不抛出异常。默认为False + - 🔥push_to_hub: 是否推送hub,默认为False - hub_model_id: 推送的model_id,默认为None - hub_private_repo: 是否是private repo,默认为False diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index fa5339175c..8f33bedb62 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -417,6 +417,8 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum - quant_batch_size: Quantization batch size, default is 1. - group_size: Group size for quantization, default is 128. +- exist_ok: If it exists, no exception is raised. Defaults to False. + - 🔥push_to_hub: Whether to push to the hub, default is False. - hub_model_id: Model ID for pushing, default is None. - hub_private_repo: Whether it is a private repo, default is False. diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index c7275c8cd4..aaf85ca8b0 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -52,6 +52,7 @@ class ExportArguments(MergeArguments, BaseArguments): commit_message: str = 'update files' # compat to_peft_format: bool = False + exist_ok: bool = False def _init_output_dir(self): if self.output_dir is None: @@ -59,12 +60,12 @@ def _init_output_dir(self): ckpt_dir, ckpt_name = os.path.split(ckpt_dir) if self.to_peft_format: suffix = 'peft' - elif self.merge_lora: - suffix = 'merged' elif self.quant_method: suffix = f'{self.quant_method}-int{self.quant_bits}' elif self.to_ollama: suffix = 'ollama' + elif self.merge_lora: + suffix = 'merged' else: return @@ -72,7 +73,8 @@ def _init_output_dir(self): logger.info(f'Setting args.output_dir: {self.output_dir}') self.output_dir = to_abspath(self.output_dir) - assert not os.path.exists(self.output_dir), f'args.output_dir: {self.output_dir} already exists.' + if not self.exist_ok and os.path.exists(self.output_dir): + raise FileExistsError(f'args.output_dir: {self.output_dir} already exists.') def __post_init__(self): if self.quant_batch_size == -1: diff --git a/swift/llm/export/export.py b/swift/llm/export/export.py index 1adb901dee..c14c88a1d7 100644 --- a/swift/llm/export/export.py +++ b/swift/llm/export/export.py @@ -19,9 +19,13 @@ def run(self): args = self.args if args.to_peft_format: args.adapters[0] = swift_to_peft_format(args.adapters[0], args.output_dir) - elif args.merge_lora: + if args.merge_lora: + output_dir = args.output_dir + if args.to_peft_format or args.quant_method or args.to_ollama or args.push_to_hub: + args.output_dir = None merge_lora(args) - elif args.quant_method: + args.output_dir = output_dir # recover + if args.quant_method: quantize_model(args) elif args.to_ollama: export_to_ollama(args) diff --git a/swift/llm/export/merge_lora.py b/swift/llm/export/merge_lora.py index ff265886af..f767739c7f 100644 --- a/swift/llm/export/merge_lora.py +++ b/swift/llm/export/merge_lora.py @@ -11,9 +11,6 @@ def merge_lora(args: ExportArguments, device_map=None, replace_if_exists=False) -> None: if replace_if_exists: logger.info(f'replace_if_exists: {replace_if_exists}') - assert args.quant_method is None, (f'args.quant_method: {args.quant_method}, ' - 'quantized model and does not support merge-lora.') - output_dir = getattr(args, 'output_dir', None) or f'{args.adapters[0]}-merged' if os.path.exists(output_dir) and not replace_if_exists: logger.info(f'The weight directory for the merged LoRA already exists in {output_dir}, ' @@ -24,6 +21,9 @@ def merge_lora(args: ExportArguments, device_map=None, replace_if_exists=False) args.device_map = device_map or args.device_map logger.info(f'merge_device_map: {device_map}') model, template = prepare_model_template(args) + quant_method = model.model_info.quant_method + assert quant_method is None, (f'quant_method: {quant_method}, ' + 'quantized model and does not support merge-lora.') logger.info('Merge LoRA...') Swift.merge_and_unload(model) model = model.model @@ -41,4 +41,5 @@ def merge_lora(args: ExportArguments, device_map=None, replace_if_exists=False) args.device_map = origin_device_map args.model = output_dir + args.model_dir = output_dir args.adapters = [] From 11131823fe7b8241de6e117891140885a5b3d67d Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 23 Jan 2025 18:39:42 +0800 Subject: [PATCH 30/52] support create_checkpoint_symlink (#2975) --- ...275\344\273\244\350\241\214\345\217\202\346\225\260.md" | 1 + docs/source_en/Instruction/Command-line-parameters.md | 1 + swift/llm/argument/train_args.py | 1 + swift/llm/dataset/utils.py | 3 +++ swift/llm/export/merge_lora.py | 3 +-- swift/llm/train/sft.py | 7 +++++++ 6 files changed, 14 insertions(+), 2 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index d62a8f1e78..ebb96038f8 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -295,6 +295,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数. - add_version: 在output_dir上额外增加目录`'<版本号>-<时间戳>'`防止权重覆盖,默认为True - resume_only_model: 如果resume_from_checkpoint,仅resume模型权重,默认为False - check_model: 检查本地模型文件有损坏或修改并给出提示,默认为True。如果是断网环境,请设置为False +- create_checkpoint_symlink: 额外创建checkpoint软链接。best_model和last_model分别为f'{output_dir}/best'和f'{output_dir}/last' - loss_type: loss类型,默认使用模型自带损失函数 - packing: 是否使用packing,默认为False diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 8f33bedb62..ba59c970b0 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -299,6 +299,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine - add_version: Add directory to output_dir with `'-'` to prevent weight overwrite, default is True. - resume_only_model: If resume_from_checkpoint, only resume model weights, default is False. - check_model: Check local model files for corruption or modification and give a prompt, default is True. If in an offline environment, please set to False. +- create_checkpoint_symlink: Create additional checkpoint symlinks. best_model and last_model are f'{output_dir}/best' and f'{output_dir}/last', respectively. - loss_type: Type of loss, default uses the model's built-in loss function. - packing: Whether to use packing, default is False. - 🔥lazy_tokenize: Whether to use lazy_tokenize, default is False during LLM training, default is True during MLLM training. diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 79ac554cb9..98fdcdad3b 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -109,6 +109,7 @@ class TrainArguments(TorchAccArguments, TunerArguments, Seq2SeqTrainingOverrideA add_version: bool = True resume_only_model: bool = False check_model: bool = True + create_checkpoint_symlink: bool = False # dataset packing: bool = False diff --git a/swift/llm/dataset/utils.py b/swift/llm/dataset/utils.py index 1b286d1b3c..f625a740f8 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/llm/dataset/utils.py @@ -188,6 +188,9 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: 'and another piece of data will be randomly selected.') self._traceback_counter += 1 + raise ValueError('Failed to retrieve the dataset. You can avoid this issue by increasing `max_length` or ' + 'modifying the `truncation_strategy`.') + def __len__(self) -> int: return len(self.dataset) diff --git a/swift/llm/export/merge_lora.py b/swift/llm/export/merge_lora.py index f767739c7f..80ae5f5657 100644 --- a/swift/llm/export/merge_lora.py +++ b/swift/llm/export/merge_lora.py @@ -14,8 +14,7 @@ def merge_lora(args: ExportArguments, device_map=None, replace_if_exists=False) output_dir = getattr(args, 'output_dir', None) or f'{args.adapters[0]}-merged' if os.path.exists(output_dir) and not replace_if_exists: logger.info(f'The weight directory for the merged LoRA already exists in {output_dir}, ' - 'skipping the saving process. ' - 'you can pass `replace_if_exists=True` to overwrite it.') + 'skipping the saving process.') else: origin_device_map = args.device_map args.device_map = device_map or args.device_map diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 92ebfdaf94..be8d678a4e 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -158,6 +158,13 @@ def _save_trainer_state(self, trainer): training_args = trainer.args state = trainer.state + if self.args.create_checkpoint_symlink: + last_checkpoint = os.path.join(self.args.output_dir, 'last') + best_checkpoint = os.path.join(self.args.output_dir, 'best') + os.symlink(state.last_model_checkpoint, last_checkpoint) + os.symlink(state.best_model_checkpoint, best_checkpoint) + state.last_model_checkpoint = last_checkpoint + state.best_model_checkpoint = best_checkpoint logger.info(f'last_model_checkpoint: {state.last_model_checkpoint}') logger.info(f'best_model_checkpoint: {state.best_model_checkpoint}') From 26429781dce030b2366458b58ed0d681e15524ac Mon Sep 17 00:00:00 2001 From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Thu, 23 Jan 2025 23:35:10 +0800 Subject: [PATCH 31/52] Sampling and RFT (#2977) --- README.md | 10 ++ README_CN.md | 10 ++ ...72\345\214\226\345\276\256\350\260\203.md" | 103 ++++++++++++++++++ .../Instruction/\351\207\207\346\240\267.md" | 13 ++- docs/source/index.rst | 2 + .../Instruction/Reinforced_Fine_tuning.md | 103 ++++++++++++++++++ docs/source_en/Instruction/Sample.md | 13 ++- docs/source_en/index.rst | 2 + {scripts => examples/train}/rft/math.json | 0 {scripts => examples/train}/rft/rft.py | 77 +++++++------ swift/cli/sample.py | 2 +- swift/experimental/__init__.py | 0 swift/experimental/sampling/__init__.py | 0 .../argument}/sampling_args.py | 0 swift/llm/sampling/__init__.py | 1 + swift/{experimental => llm}/sampling/base.py | 2 +- swift/{experimental => llm}/sampling/mcts.py | 0 .../sampling/sampling.py | 4 +- swift/{experimental => llm}/sampling/utils.py | 0 .../sampling/vanilla_sampler.py | 2 +- 20 files changed, 304 insertions(+), 40 deletions(-) create mode 100644 "docs/source/Instruction/\345\274\272\345\214\226\345\276\256\350\260\203.md" create mode 100644 docs/source_en/Instruction/Reinforced_Fine_tuning.md rename {scripts => examples/train}/rft/math.json (100%) rename {scripts => examples/train}/rft/rft.py (70%) delete mode 100644 swift/experimental/__init__.py delete mode 100644 swift/experimental/sampling/__init__.py rename swift/{experimental/sampling => llm/argument}/sampling_args.py (100%) create mode 100644 swift/llm/sampling/__init__.py rename swift/{experimental => llm}/sampling/base.py (96%) rename swift/{experimental => llm}/sampling/mcts.py (100%) rename swift/{experimental => llm}/sampling/sampling.py (95%) rename swift/{experimental => llm}/sampling/utils.py (100%) rename swift/{experimental => llm}/sampling/vanilla_sampler.py (99%) diff --git a/README.md b/README.md index a6dfa4766b..fa28aa1fca 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,7 @@ You can contact us and communicate with us by adding our group: ## 🎉 News +- 🎁 2024.01.23: SWIFT support the `sample` command, this is a very important feature for complex CoT and RFT. Meanwhile, we support an [Reinforced Fine-tuning script](docs/source_en/Instruction/Reinforced_Fine_tuning.md). - 🎁 2024.12.04: **SWIFT3.0** major version update. Please check the [Release Notes and Changes](https://swift.readthedocs.io/en/latest/Instruction/ReleaseNote3.0.html). - 🎉 2024.08.12: The SWIFT paper has been published on arXiv, and you can read it [here](https://arxiv.org/abs/2408.05517). - 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models. @@ -295,6 +296,15 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \ --infer_backend vllm ``` +### Sampling +```shell +CUDA_VISIBLE_DEVICES=0 swift sample \ + --model LLM-Research/Meta-Llama-3.1-8B-Instruct \ + --sampler_engine pt \ + --num_return_sequences 5 \ + --dataset AI-ModelScope/alpaca-gpt4-data-zh#5 +``` + ### Evaluation ```shell CUDA_VISIBLE_DEVICES=0 swift eval \ diff --git a/README_CN.md b/README_CN.md index 52bb3640ef..cb23522da2 100644 --- a/README_CN.md +++ b/README_CN.md @@ -74,6 +74,7 @@ - **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。 ## 🎉 新闻 +- 🎁 2024.01.23: SWIFT支持了`sample`命令, 这是一个对CoT和RFT非常重要的命令. 同时, 我们支持了一个[强化微调脚本](docs/source/Instruction/强化微调.md)。 - 🎁 2024.12.04: **SWIFT3.0**大版本更新. 请查看[发布说明和更改](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html)。 - 🎉 2024.08.12: SWIFT论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。 - 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。 @@ -288,6 +289,15 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \ --infer_backend vllm ``` +### 采样 +```shell +CUDA_VISIBLE_DEVICES=0 swift sample \ + --model LLM-Research/Meta-Llama-3.1-8B-Instruct \ + --sampler_engine pt \ + --num_return_sequences 5 \ + --dataset AI-ModelScope/alpaca-gpt4-data-zh#5 +``` + ### 评测 ```shell CUDA_VISIBLE_DEVICES=0 swift eval \ diff --git "a/docs/source/Instruction/\345\274\272\345\214\226\345\276\256\350\260\203.md" "b/docs/source/Instruction/\345\274\272\345\214\226\345\276\256\350\260\203.md" new file mode 100644 index 0000000000..002770cea1 --- /dev/null +++ "b/docs/source/Instruction/\345\274\272\345\214\226\345\276\256\350\260\203.md" @@ -0,0 +1,103 @@ +# 强化微调 + +强化微调是目前模型训练非常重要的功能之一,它本身的实现是多种多样的,SWIFT目前已经支持了强化微调所需要的原子能力,如采样、强化学习和微调。目前我们提供了拒绝采样微调的一个具体示例,可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。 + +## 强化微调的概念 + +强化微调是从2022年开始(甚至更早)就被提出的概念。其方式一般有下列流程: + +1. 使用某个模型生成数据,或进行原始数据扩充 +2. 使用数据训练目标模型 +3. 如果有必要,重复上述过程 + +步骤1: + +- 如果生成数据的模型是更大的模型,如GPT、Qwen-Max、DeepSeek-V3/R1等,则该强化微调可以理解为蒸馏 +- 如果生成数据的模型是本模型,则可以理解为自我提升(self-improvement)微调 +- 如果采样过程是采样一个batch,然后通过KL散度和reward进行拟合训练并不断循环,则可以理解为PPO、GRPO等on-policy算法 +- 采样数据的算法包含蒙特卡洛采样、do_sample采样、group beam search、dvts等 +- 采样过程可以引入ORM(结果判断),PRM(过程打分),多样性过滤,语种过滤等 + +步骤2: + +- 如果使用SFT,则称为拒绝采样微调 +- 如果是强化学习,则称为强化学习微调 + +步骤3: + +- 如果使用更大的模型蒸馏,例如更大模型的蒙特卡洛采样蒸馏,一般不会有循环 +- 如果使用本模型进行采样,或者PPO等算法,则会有循环 + +泛泛来说,常见强化微调的方式有下面几种: + +1. 蒸馏:使用蒙特卡洛、do_sample等方式从超大模型中采样大量优质数据,训练小模型 +2. 自我提升:从本模型中采样部分优质数据,筛选后训练本模型,循环执行 +3. on-policy RL:使用PPO、GRPO等方式循环训练 + +采样过程一般很漫长,比训练过程漫长的多。如果使用GPT等模型蒸馏数据,则需要购买token。因此,强化微调的时间成本和花费成本比较高,所以一般作为微调的补充机制出现,当然也有特例,例如最近的DeepSeek-R1。 + +DeepSeek-R1使用了GRPO算法从零使base模型涌现CoT能力,该方法需要大规模集群支持,且模型需要足够大才能发生能力涌现,在本文中不详细讨论。如果需要了解该过程,请查看[论文解析](https://zhuanlan.zhihu.com/p/19714987272)。 + +有关强化微调的一些论文: + +- 拒绝采样微调:https://arxiv.org/pdf/2308.01825 +- ReST:https://arxiv.org/pdf/2308.08998 +- B-STAR:https://arxiv.org/pdf/2412.17256 +- DeepSeekMath:https://arxiv.org/pdf/2402.03300 +- Qwen-math-PRM:https://arxiv.org/pdf/2501.07301 +- DeepSeek-R1:https://github.com/deepseek-ai/DeepSeek-R1/tree/main + +## 什么时候使用强化微调 + +在LLaMA3之后,我们发现一个非常明显但却是不常被提及的特点:使用某个含有CoT的train数据集训练Instruct模型,再通过对应的test集进行评测,会发现test集评测效果变差。例如,使用gsm8k训练集训练llama3.1-8b-instruct,对生成的ckpt使用test集进行评测,会发现掉点。 + +这个特性主要来源于模型的知识遗忘问题。在模型厂商的微调中,会加入非常多的CoT数据集,模型在解决数学任务的时候,用到的能力很有可能不是来自于math数据集,而是来自arc数据集,这个推论有[一些工作可以证明](https://zhuanlan.zhihu.com/p/19269451950)。在继续训练通用任务后,知识遗忘破坏了模型原有能力,导致了掉点。 + +然而,优先使用微调方式训练模型总是正确的。微调可以使模型快速适应数据集的分布,并且微调的成本很低。当有如下条件之一时使用强化微调: + +1. 已经微调过模型,能力不满足需求 +2. 需要更强的CoT能力 +3. 对基模型训练通用能力,而原始数据集已经导致模型效果无法提升 +4. 对应query的输出结果可以相对准确地评估好坏,例如结果清晰(数学,代码),过程清晰(翻译,风格)等 + +强化微调非常依赖于reward评估是否准确。如果评估结果不准确,可能导致模型训练原地震荡,甚至越训越差。 + +## SWIFT的实现 + +SWIFT支持sample命令,该命令就是用于模型采样。目前支持的采样方式有: + +- do_sample:sample方式对模型进行采样,该方式支持对开源模型进行采样,后续会支持模型蒸馏 + - sample方式后续会支持URL采样,用于大模型蒸馏 + +- mcts:蒙特卡洛采样,该方式在PR中,后续会支持 +- dvts:调研中 + +目前我们给出了一个较为通用的[RFT脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。该脚本适用于自我提升方式的训练,且支持动态调整采样温度值、PRM阈值等超参数,并且训练方式灵活可变(微调、DPO等;或者每次迭代重新训练原模型或继续训练上个迭代的模型,甚至加载上个迭代的所有训练状态等)。开发者可以在该脚本中增加其他数据过滤(生成的数据集中,id相同的行来自同一个query),例如多样性判断、语种判断等。 + +## 实验结果 + +我们对该RFT脚本针对数学领域使用competition_math数据集进行了训练和评测,结果如下: + +| 模型 | MATH指标 | 训练方式 | 迭代次数 | 训练后MATH指标 | +| ------------------------ | -------- | -------- | -------- | --------------------- | +| LLaMA3.1_8b | 12.0 | SFT | 3 | 25.2(LLaMA3.1_8b_sft) | +| LLaMA3.1_8b_sft | 25.2 | RFT | 2 | 32.4 | +| LLaMA3.1_8b_instruct | 52.2 | SFT | 2 | 39.0 | +| LLaMA3.1_8b_instruct | 52.2 | RFT | 3 | 58 | +| Qwen2.5_math_7b_instruct | 79.6 | RFT | 2 | 83.2 | + +可以看到,使用competition_math直接SFT后,instruct模型的掉点十分严重。而RFT后模型能力有提升,即使对Qwen2.5_math_7b_instruct这个SOTA的math模型也同样有一定提升空间。 + +特别地,针对Qwen2.5_math_7b_instruct我们测试了gsm8k的指标: + +| 模型 | gsm8k指标 | RFT后gsm8k指标 | +| ------------------------ | --------- | -------------- | +| Qwen2.5_math_7b_instruct | 92.8 | 91.6 | + +可以看到,RFT训练后gsm8k指标变化不大,并没有出现前述的掉点现象。 + +## 未来计划 + +1. 更多的采样方式,如MCTS +2. 超大模型蒸馏训练 +3. 以PPO为主的on-policy训练 diff --git "a/docs/source/Instruction/\351\207\207\346\240\267.md" "b/docs/source/Instruction/\351\207\207\346\240\267.md" index 0b5badc2f9..445dbe2000 100644 --- "a/docs/source/Instruction/\351\207\207\346\240\267.md" +++ "b/docs/source/Instruction/\351\207\207\346\240\267.md" @@ -50,7 +50,7 @@ class CustomPRM: pass @torch.inference_mode() - def infer(self, infer_requests: List[InferRequest], **kwargs) -> List[ChatCompletionResponse]: + def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], **kwargs) -> List[ChatCompletionResponse]: ... @@ -59,8 +59,17 @@ prms = {'custom': CustomPRM} 之后在命令行中使用`--prm_model custom`即可。 +## 显存控制 + +如果被采样模型和PRM共同加载进显存,则可能出现OOM的问题。因此采样可以分为两段进行: + +- 第一段指定`--model`和``--sampler_engine`,同时不指定`--orm_model`和`--prm_model`,仅进行采样,并存储为文件 +- 第二段指定`--sampler_engine no`,指定`--orm_model`和`--prm_model`,并同时指定`--cache_files`,仅进行RM数据过滤,不重新采样 + +通过两段方式可以每次仅加载一个模型,防止OOM。 + ## 实际例子 -请参考[强化微调脚本](https://github.com/modelscope/ms-swift/tree/main/scripts/rft.py)。该脚本给出了使用采样进行强化微调的实际例子。 +请参考[强化微调脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。该脚本给出了使用采样进行强化微调的实际例子。 > 注意:该脚本的实际效果和模型、数据、RM的质量强相关,因此仅作为样例出现,用户请自行修改该脚本并训练自己的RM和generator模型。 diff --git a/docs/source/index.rst b/docs/source/index.rst index 3c929f8d41..19a2af310a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,8 +21,10 @@ Swift DOCUMENTATION Instruction/预训练及微调.md Instruction/人类对齐.md Instruction/推理和部署.md + Instruction/采样.md Instruction/评测.md Instruction/导出.md + Instruction/强化微调.md Instruction/支持的模型和数据集.md Instruction/使用tuners.md Instruction/智能体的支持.md diff --git a/docs/source_en/Instruction/Reinforced_Fine_tuning.md b/docs/source_en/Instruction/Reinforced_Fine_tuning.md new file mode 100644 index 0000000000..77ccf3d3d7 --- /dev/null +++ b/docs/source_en/Instruction/Reinforced_Fine_tuning.md @@ -0,0 +1,103 @@ +# Reinforced Fine-Tuning + +Reinforced fine-tuning is one of the most important functionalities in current model training, with various implementations. SWIFT has already supported the atomic capabilities required for reinforced fine-tuning, such as sampling, reinforcement learning, and fine-tuning. Currently, we provide a specific example of rejection sampling fine-tuning, which can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py). + +## Concept of Reinforced Fine-Tuning + +The concept of reinforced fine-tuning has been proposed since 2022 (or even earlier). Its general workflow typically includes the following steps: + +1. Generate data using a specific model or augment the original dataset. +2. Train the target model using the generated data. +3. Repeat the above process if necessary. + +**Step 1:** + +- If the data-generating model is a larger model, such as GPT, Qwen-Max, DeepSeek-V3/R1, etc., this process can be understood as distillation. +- If the data-generating model is the same model being trained, this can be considered self-improvement fine-tuning. +- If the sampling process involves sampling a batch, fitting the data with KL divergence and rewards, and iterating continuously, it can be classified as on-policy algorithms like PPO or GRPO. +- Sampling algorithms include Monte Carlo sampling, do_sample, group beam search, DVTS, etc. +- The sampling process can incorporate ORM (Outcome Reward Model), PRM (Process Reward Model), diversity filtering, language filtering, etc. + +**Step 2:** + +- If SFT (Supervised Fine-Tuning) is used, it is referred to as rejection sampling fine-tuning. +- If reinforcement learning is used, it is called reinforcement learning fine-tuning. + +**Step 3:** + +- If distillation is performed using a larger model (e.g., Monte Carlo sampling distillation with a larger model), the process usually does not involve iterations. +- If the same model is used for sampling or algorithms like PPO are applied, iterations are typically included. + +In general, the common approaches to reinforced fine-tuning include: + +1. **Distillation**: Sampling high-quality data in bulk from a larger model using methods like Monte Carlo or do_sample, and training a smaller model on this data. +2. **Self-improvement**: Sampling a portion of high-quality data from the same model, filtering it, and training the model iteratively. +3. **On-policy RL**: Using methods like PPO or GRPO for iterative training. + +The sampling process is usually much more time-consuming than the training process. If data is distilled using GPT or other large models, token costs must be considered. Thus, reinforced fine-tuning is generally a supplementary mechanism for fine-tuning, except for special cases like DeepSeek-R1. + +DeepSeek-R1 uses the GRPO algorithm to enable the emergence of CoT (Chain-of-Thought) capabilities from scratch in a base model. This method requires large-scale cluster support and sufficiently large models for capability emergence. This is not discussed in detail here, but more information can be found in the [paper analysis](https://zhuanlan.zhihu.com/p/19714987272). + +Some related papers on reinforced fine-tuning: + +- Rejection Sampling Fine-Tuning: https://arxiv.org/pdf/2308.01825 +- ReST: https://arxiv.org/pdf/2308.08998 +- B-STAR: https://arxiv.org/pdf/2412.17256 +- DeepSeekMath: https://arxiv.org/pdf/2402.03300 +- Qwen-Math-PRM: https://arxiv.org/pdf/2501.07301 +- DeepSeek-R1: https://github.com/deepseek-ai/DeepSeek-R1/tree/main + +## When to Use Reinforced Fine-Tuning + +Since LLaMA3, we have observed a very noticeable yet rarely mentioned phenomenon: when training an Instruct model using a CoT-enabled training dataset and evaluating it on the corresponding test set, the test set performance tends to degrade. For example, training `llama3.1-8b-instruct` on the GSM8K training set and evaluating the generated checkpoint on the test set reveals performance degradation. + +This phenomenon mainly arises from the issue of knowledge forgetting disaster in models. During fine-tuning by model manufacturers, a significant amount of CoT data is often included. When solving mathematical tasks, the model's capability often originates not from the math dataset itself but potentially from datasets like ARC. This inference is supported by [some works](https://zhuanlan.zhihu.com/p/19269451950). Continued training on general tasks disrupts the model's existing capabilities, leading to performance degradation. + +However, it is always correct to prioritize fine-tuning. Fine-tuning allows the model to quickly adapt to the dataset distribution at a low cost. Reinforced fine-tuning should be used under the following conditions: + +1. The model has already been fine-tuned but does not meet the requirements. +2. Stronger CoT capabilities are needed. +3. Base model training for general capabilities is necessary, and the original dataset no longer improves performance. +4. The output results for corresponding queries can be relatively accurately evaluated, such as tasks with clear results (math, code) or clear processes (translation, style fitting). + +Reinforced fine-tuning heavily depends on the accuracy of reward evaluations. If the evaluations are inaccurate, the training may oscillate without progress or even degrade the model performance. + +## SWIFT Implementation + +SWIFT supports the `sample` command, which is used for model sampling. Currently supported sampling methods include: + +- **do_sample**: A sampling method for open-source models; future updates will include support for model distillation. + - URL sampling will also be supported in the future for large-model distillation. + +- **mcts**: Monte Carlo sampling, currently under review, with future support planned. +- **dvts**: Currently under investigation. + +We have provided a general [RFT script](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py). This script supports self-improvement training and allows dynamic adjustments of sampling temperature, PRM thresholds, and other hyperparameters. The training method is flexible (e.g., fine-tuning, DPO) and supports iterative retraining of the original model or continued training from the previous iteration, even loading all training states from the previous iteration. Developers can incorporate additional data filtering (e.g., ensuring rows with the same ID come from the same query), including diversity checks, language filtering, etc. + +## Experimental Results + +We used the RFT script to train and evaluate the `competition_math` dataset in the math domain. The results are as follows: + +| Model | MATH Score | Training Method | Iterations | Post-Training MATH Score | +|----------------------------|------------|-----------------|------------|---------------------------| +| LLaMA3.1_8b | 12.0 | SFT | 3 | 25.2 (LLaMA3.1_8b_sft) | +| LLaMA3.1_8b_sft | 25.2 | RFT | 2 | 32.4 | +| LLaMA3.1_8b_instruct | 52.2 | SFT | 2 | 39.0 | +| LLaMA3.1_8b_instruct | 52.2 | RFT | 3 | 58 | +| Qwen2.5_math_7b_instruct | 79.6 | RFT | 2 | 83.2 | + +As shown, applying SFT to the `competition_math` dataset resulted in significant performance degradation for the instruct model. However, RFT improved the model's capabilities, even for the state-of-the-art `Qwen2.5_math_7b_instruct` math model. + +Specifically, we tested the GSM8K metric for `Qwen2.5_math_7b_instruct`: + +| Model | GSM8K Score | Post-RFT GSM8K Score | +|----------------------------|-------------|-----------------------| +| Qwen2.5_math_7b_instruct | 92.8 | 91.6 | + +As shown, RFT training did not significantly change the GSM8K score, avoiding the previously mentioned performance degradation phenomenon. + +## Future Roadmap + +1. More sampling methods,MCTS for example +2. Distill from super huge model +3. On policy RFT like PPO diff --git a/docs/source_en/Instruction/Sample.md b/docs/source_en/Instruction/Sample.md index e3b69ae562..6a4fd87a88 100644 --- a/docs/source_en/Instruction/Sample.md +++ b/docs/source_en/Instruction/Sample.md @@ -53,7 +53,7 @@ class CustomPRM: pass @torch.inference_mode() - def infer(self, infer_requests: List[InferRequest], **kwargs) -> List[ChatCompletionResponse]: + def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], **kwargs) -> List[ChatCompletionResponse]: ... @@ -62,8 +62,17 @@ prms = {'custom': CustomPRM} Afterward, use `--prm_model custom` in the command line. +## Memory Control + +If the sampled model and PRM are loaded into memory simultaneously, it may lead to an OOM (Out of Memory) issue. To address this, sampling can be divided into two stages: + +- **Stage 1**: Specify `--model` and `--sampler_engine` without specifying `--orm_model` and `--prm_model`. Perform sampling only and save the results to a file. +- **Stage 2**: Specify `--sampler_engine no`, along with `--orm_model` and `--prm_model`, and also specify `--cache_files`. Perform only RM data filtering without re-sampling. + +By dividing the process into two stages, only one model is loaded at a time, avoiding OOM issues. + ## Practical Example -Please refer to the [Reinforcement Fine-Tuning Script](https://github.com/modelscope/ms-swift/tree/main/scripts/rft.py). This script provides a practical example of using sampling for reinforcement fine-tuning. +Please refer to the [Reinforcement Fine-Tuning Script](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py). This script provides a practical example of using sampling for reinforcement fine-tuning. > **Note:** The actual effectiveness of this script is strongly related to the quality of the model, data, and RM. Therefore, it is presented only as an example. Users should modify this script and train their own RM and generator models accordingly. diff --git a/docs/source_en/index.rst b/docs/source_en/index.rst index 3dde079d6c..22ea9b6ce7 100644 --- a/docs/source_en/index.rst +++ b/docs/source_en/index.rst @@ -21,8 +21,10 @@ Swift DOCUMENTATION Instruction/Pre-training-and-Fine-tuning.md Instruction/RLHF.md Instruction/Inference-and-deployment.md + Instruction/Sample.md Instruction/Evaluation.md Instruction/Export.md + Instruction/Reinforced_fine_tuning.md Instruction/Supported-models-and-datasets.md Instruction/Use-tuners.md Instruction/Agent-support.md diff --git a/scripts/rft/math.json b/examples/train/rft/math.json similarity index 100% rename from scripts/rft/math.json rename to examples/train/rft/math.json diff --git a/scripts/rft/rft.py b/examples/train/rft/rft.py similarity index 70% rename from scripts/rft/rft.py rename to examples/train/rft/rft.py index 4162d501b4..7b1ee9e429 100644 --- a/scripts/rft/rft.py +++ b/examples/train/rft/rft.py @@ -6,6 +6,9 @@ import torch.cuda +# NOTE: this script supports at most 8 GPUS in a node, if using multi node, please use custom logic. + +# Paste conda env # conda_prefix = 'source /root/miniconda3/etc/profile.d/conda.sh && conda activate py311 && ' conda_prefix = '' @@ -14,7 +17,8 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int): device_count = torch.cuda.device_count() handlers = [] datasets = [] - # Sampling cache + # Sampling cache, to avoid lmdeploy & PRM run at the same time + # Why lmdeploy not vllm? we found that the responses generated by lmdeploy are more similar than ones of vllm. for device in range(device_count): sample_cmd = (f'{conda_prefix} CUDA_VISIBLE_DEVICES={device} swift sample ' f'--model {model} --model_type {model_type} ' @@ -49,26 +53,31 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int): assert os.path.exists(os.path.join('sample_output', f'iter_{iter}_proc_{proc}_cache.jsonl')) handlers = [] + # Sample again, this time to filter with ORM & PRM + # Provide your PRM model or PRM name(add PRM in plugin/prm.py first) + # You can define your custom PRM logic in the plugin + # (like, split your steps, use the worst score/last score/avg score) for device in range(device_count): - sample_cmd = (f'{conda_prefix} CUDA_VISIBLE_DEVICES={device} swift sample ' - f'--model {model} --model_type {model_type} ' - f'--dataset {" ".join(dataset)} ' - f'--data_range {device} {device_count} ' - f'--max_length 2048 ' - f'--system "You are a math model, you should **think step by step** carefully, ' - f'and always consider the basic math principles to avoid making calculating mistakes.' - f'Give the final answer wrapped with \\boxed{{}}" ' - f'--load_args false ' - f'--sampler_engine no ' - f'--orm_model math ' - f'--prm_model Qwen/Qwen2.5-Math-PRM-7B ' - f'--prm_threshold 0.7 ' - f'--max_new_tokens 768 ' - f'--override_exist_file true ' - f'--num_sampling_per_gpu_batch_size 1 ' - f'--num_return_sequences 64 ' - f'--output_file iter_{iter}_proc_{device}_sampling.jsonl ' - f'--cache_files sample_output/iter_{iter}_proc_{device}_cache.jsonl ') + sample_cmd = ( + f'{conda_prefix} CUDA_VISIBLE_DEVICES={device} swift sample ' + f'--model {model} --model_type {model_type} ' # change to --resume_from_checkpoint to use the lastest optimzer state # noqa + f'--dataset {" ".join(dataset)} ' + f'--data_range {device} {device_count} ' + f'--max_length 2048 ' + f'--system "You are a math model, you should **think step by step** carefully, ' + f'and always consider the basic math principles to avoid making calculating mistakes.' + f'Give the final answer wrapped with \\boxed{{}}" ' + f'--load_args false ' + f'--sampler_engine no ' + f'--orm_model math ' # math defines in plugin/orm.py + f'--prm_model Qwen/Qwen2.5-Math-PRM-7B ' + f'--prm_threshold {min(0.7 + 0.1*iter, 0.9)} ' + f'--max_new_tokens 768 ' + f'--override_exist_file true ' # no not override the existing sample files + f'--num_sampling_per_gpu_batch_size 1 ' + f'--num_return_sequences 64 ' + f'--output_file iter_{iter}_proc_{device}_sampling.jsonl ' + f'--cache_files sample_output/iter_{iter}_proc_{device}_cache.jsonl ') print(f'Sampling iter {iter}, part {device}.', flush=True) env = os.environ.copy() env['CUDA_VISIBLE_DEVICES'] = str(device) @@ -95,7 +104,7 @@ def do_train(model: str, model_type: str, datasets: List[str], iter, cmd='sft'): ds_config = '--deepspeed zero3 ' extra_args = '' if cmd == 'rlhf': - extra_args = '--rlhf_type dpo --beta 2.0 ' + extra_args = '--rlhf_type dpo --beta 0.3 ' # use another reinforce learning method supported by swift ga = 128 // torch.cuda.device_count() // 2 train_cmd = (f'{conda_prefix} {gpu_prefix} swift {cmd} ' f'--model {model} --model_type {model_type} ' @@ -133,14 +142,16 @@ def do_train(model: str, model_type: str, datasets: List[str], iter, cmd='sft'): def do_eval(model, model_type: str, iter): - eval_cmd = (f'{conda_prefix} swift eval ' - '--eval_dataset math ' - '--infer_backend lmdeploy --eval_limit 500 ' - f'--model {model} --model_type {model_type} ' - '--model_type llama3_1 --system "You are a math model, you should **think step by step** carefully, ' - 'and always consider the basic math principles to avoid making calculating mistakes. ' - 'Give the final answer wrapped with \\boxed{}"') + eval_cmd = ( + f'{conda_prefix} swift eval ' + '--eval_dataset math ' # eval another dataset + '--infer_backend lmdeploy --eval_limit 500 ' + f'--model {model} --model_type {model_type} ' + '--system "You are a math model, you should **think step by step** carefully, ' + 'and always consider the basic math principles to avoid making calculating mistakes. ' + 'Give the final answer wrapped with \\boxed{}"') print('Evaluating.', flush=True) + # Replace the original dataset to the math.json, this is for test, comment this if not need replace_math_dataset() if iter is None: @@ -173,25 +184,29 @@ def replace_math_dataset(): if os.path.exists(os.path.join(user_dir, '.cache', 'opencompass', 'data', 'math', 'math.json')): os.remove(os.path.join(user_dir, '.cache', 'opencompass', 'data', 'math', 'math.json')) shutil.copy( - os.path.join('scripts', 'rft', 'math.json'), + os.path.join('examples', 'train', 'rft', 'math.json'), os.path.join(user_dir, '.cache', 'opencompass', 'data', 'math', 'math.json')) def main(): os.makedirs('logs', exist_ok=True) max_acc = 0. - first_model = 'LLM-Research/Meta-Llama-3.1-8B-Instruct' - model_type = 'llama3_1' + first_model = 'Qwen/Qwen2.5-Math-7B-Instruct' + model_type = 'qwen2_5_math' if False: + # eval the original model do_eval(first_model, None) model = first_model for i in range(5): ts = time.time() datasets = do_sample(model, model_type, ['tastelikefeet/competition_math'], i) + # add custom data filter here, for example: length or diversity control print(f'do sample cost: {(time.time()-ts) / 60:.1f} minutes.', flush=True) ts = time.time() + # if want to train the original dataset with datasets, add the original dataset here + # if want to train the original model everytime, change to first_model ckpt = do_train(model, model_type, datasets, i) print(f'do train cost: {(time.time() - ts) / 60:.1f} minutes.', flush=True) ts = time.time() diff --git a/swift/cli/sample.py b/swift/cli/sample.py index d763e608cf..b593842163 100644 --- a/swift/cli/sample.py +++ b/swift/cli/sample.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from swift.experimental.sampling.sampling import sampling_main +from swift.llm.sampling import sampling_main if __name__ == '__main__': sampling_main() diff --git a/swift/experimental/__init__.py b/swift/experimental/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/swift/experimental/sampling/__init__.py b/swift/experimental/sampling/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/swift/experimental/sampling/sampling_args.py b/swift/llm/argument/sampling_args.py similarity index 100% rename from swift/experimental/sampling/sampling_args.py rename to swift/llm/argument/sampling_args.py diff --git a/swift/llm/sampling/__init__.py b/swift/llm/sampling/__init__.py new file mode 100644 index 0000000000..58a0ca554d --- /dev/null +++ b/swift/llm/sampling/__init__.py @@ -0,0 +1 @@ +from .sampling import sampling_main diff --git a/swift/experimental/sampling/base.py b/swift/llm/sampling/base.py similarity index 96% rename from swift/experimental/sampling/base.py rename to swift/llm/sampling/base.py index bf0ec81be2..9be223fe40 100644 --- a/swift/experimental/sampling/base.py +++ b/swift/llm/sampling/base.py @@ -1,7 +1,7 @@ +from swift.llm.argument.sampling_args import SamplingArguments from swift.plugin.orm import orms from swift.plugin.prm import prms from swift.utils import get_logger -from .sampling_args import SamplingArguments logger = get_logger() diff --git a/swift/experimental/sampling/mcts.py b/swift/llm/sampling/mcts.py similarity index 100% rename from swift/experimental/sampling/mcts.py rename to swift/llm/sampling/mcts.py diff --git a/swift/experimental/sampling/sampling.py b/swift/llm/sampling/sampling.py similarity index 95% rename from swift/experimental/sampling/sampling.py rename to swift/llm/sampling/sampling.py index 22629cdc54..27ede10c4f 100644 --- a/swift/experimental/sampling/sampling.py +++ b/swift/llm/sampling/sampling.py @@ -5,8 +5,8 @@ from typing import List, Union from swift.llm import SwiftPipeline, load_dataset +from swift.llm.argument.sampling_args import SamplingArguments from swift.utils import get_logger -from .sampling_args import SamplingArguments logger = get_logger() @@ -26,7 +26,7 @@ def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> No self.cur_piece, self.total_piece = self.args.data_range if self.args.sampler_type == 'sample': - from swift.experimental.sampling.vanilla_sampler import VanillaSampler + from swift.llm.sampling.vanilla_sampler import VanillaSampler self.sampler = VanillaSampler(self.args) elif self.args.sampler_type == 'mcts': from swift.experimental.sampling.mcts import MctsSampler diff --git a/swift/experimental/sampling/utils.py b/swift/llm/sampling/utils.py similarity index 100% rename from swift/experimental/sampling/utils.py rename to swift/llm/sampling/utils.py diff --git a/swift/experimental/sampling/vanilla_sampler.py b/swift/llm/sampling/vanilla_sampler.py similarity index 99% rename from swift/experimental/sampling/vanilla_sampler.py rename to swift/llm/sampling/vanilla_sampler.py index 31fc8fc370..b37a641f6b 100644 --- a/swift/experimental/sampling/vanilla_sampler.py +++ b/swift/llm/sampling/vanilla_sampler.py @@ -5,8 +5,8 @@ import json import numpy as np -from swift.experimental.sampling.base import Sampler from swift.llm import RequestConfig +from swift.llm.sampling.base import Sampler from swift.llm.template.template_inputs import InferRequest from swift.utils import get_logger from .utils import get_messages_md5, get_reward From 4eb87236aabd5b52fc30517307f99f5a587882b7 Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 23 Jan 2025 23:38:17 +0800 Subject: [PATCH 32/52] support auto dataset mapping (#2976) --- swift/llm/dataset/loader.py | 55 +++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/swift/llm/dataset/loader.py b/swift/llm/dataset/loader.py index 55d7eb34ba..44fdc5a773 100644 --- a/swift/llm/dataset/loader.py +++ b/swift/llm/dataset/loader.py @@ -1,5 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import platform +import re import shutil from contextlib import nullcontext from dataclasses import dataclass, field @@ -92,17 +94,10 @@ def get_dataset_meta(self, use_hf: Optional[bool] = None): dataset_type = self.dataset_type if dataset_type == 'path': dataset_meta = dataset_meta_mapping.get((dataset_type, self.dataset.lower())) - if dataset_meta is None: - dataset_meta = DatasetMeta(dataset_path=self.dataset) else: dataset_type = {True: 'hf', False: 'ms'}[use_hf] dataset_meta = dataset_meta_mapping.get((dataset_type, self.dataset.lower())) - if dataset_meta is None: - if use_hf: - dataset_meta = DatasetMeta(hf_dataset_id=self.dataset) - else: - dataset_meta = DatasetMeta(ms_dataset_id=self.dataset) - return dataset_meta + return dataset_meta or self._get_matched_dataset_meta(dataset_meta_mapping) or DatasetMeta() @staticmethod def _get_dataset_meta_mapping() -> Dict[Tuple[str, str], DatasetMeta]: @@ -119,6 +114,28 @@ def _get_dataset_meta_mapping() -> Dict[Tuple[str, str], DatasetMeta]: _dataset_meta_mapping[('hf', dataset_meta.hf_dataset_id.lower())] = dataset_meta return _dataset_meta_mapping + @staticmethod + def get_dataset_name(dataset_id: str) -> str: + # compat hf hub + dataset_id = dataset_id.rstrip('/') + match_ = re.search('/datasets--.+?--(.+?)/snapshots/', dataset_id) + if match_ is not None: + return match_.group(1) + + dataset_name = dataset_id.rsplit('/', 1)[-1] + if platform.system().lower() == 'windows': + dataset_name = dataset_name.rsplit('\\', 1)[-1] + return dataset_name + + def _get_matched_dataset_meta(self, dataset_meta_mapping): + suffix_dataset_meta_mapping = {} + for dataset_name, dataset_meta in dataset_meta_mapping.items(): + dataset_name = self.get_dataset_name(dataset_name[1]).lower() + suffix_dataset_meta_mapping[dataset_name] = dataset_meta + dataset_name = self.get_dataset_name(self.dataset).lower() + dataset_meta = suffix_dataset_meta_mapping.get(dataset_name) + return dataset_meta + class DatasetLoader: @@ -161,13 +178,12 @@ def _concat_datasets(datasets: List[HfDataset], streaming: bool) -> Optional[HfD return concatenate_datasets(datasets) @staticmethod - def _load_dataset_path(dataset_meta: DatasetMeta, + def _load_dataset_path(dataset_path: str, + dataset_meta: DatasetMeta, *, num_proc: int = 1, strict: bool = False, streaming: bool = False) -> HfDataset: - dataset_path = dataset_meta.dataset_path - ext = os.path.splitext(dataset_path)[1].lstrip('.') file_type = {'jsonl': 'json', 'txt': 'text'}.get(ext) or ext kwargs = {'split': 'train', 'streaming': streaming, 'num_proc': num_proc} @@ -197,7 +213,7 @@ def _load_repo_dataset( retry = 1 load_context = nullcontext use_hf = True - dataset_str = f'Use local folder, dataset_id: {dataset_id}' + dataset_str = f'Use local folder, dataset_dir: {dataset_id}' # The dataset downloaded from modelscope will have an additional dataset_infos.json file. dataset_infos_path = os.path.join(dataset_id, 'dataset_infos.json') if os.path.isfile(dataset_infos_path): @@ -338,9 +354,9 @@ def load( strict: bool = False, download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists', ) -> HfDataset: - - if dataset_meta.dataset_path: + if dataset_syntax.dataset_type == 'path': dataset = DatasetLoader._load_dataset_path( + dataset_syntax.dataset, dataset_meta=dataset_meta, num_proc=num_proc, strict=strict, @@ -348,18 +364,11 @@ def load( ) else: subsets: List[SubsetDataset] = DatasetLoader._select_subsets(dataset_syntax.subsets, dataset_meta) - if use_hf: - dataset_id = dataset_meta.hf_dataset_id - revision = dataset_meta.hf_revision - else: - dataset_id = dataset_meta.ms_dataset_id - revision = dataset_meta.ms_revision - assert dataset_id is not None, (f'dataset: {dataset_syntax.dataset}, use_hf: {use_hf}, ' - f'dataset_id: {dataset_id}.') + revision = dataset_meta.hf_revision if use_hf else dataset_meta.ms_revision datasets = [] for subset in subsets: dataset = DatasetLoader._load_repo_dataset( - dataset_id, + dataset_syntax.dataset, subset, use_hf=use_hf, hub_token=hub_token, From d8a7ed8b3c5b5071e7281b67073c831b10ed7559 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 24 Jan 2025 11:33:11 +0800 Subject: [PATCH 33/52] sys_prompt from file --- ...21\275\344\273\244\350\241\214\345\217\202\346\225\260.md" | 2 +- docs/source_en/Instruction/Command-line-parameters.md | 2 +- swift/llm/argument/base_args/template_args.py | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index ebb96038f8..00dd6dafe6 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -44,7 +44,7 @@ ### 模板参数 - 🔥template: 对话模板类型,默认使用model对应的template类型。`swift pt`会将对话模版转为生成模板使用 -- 🔥system: 自定义system字段,默认为None,使用template的默认system +- 🔥system: 自定义system字段,可以是一个txt文件地址,默认为None,使用template的默认system - 🔥max_length: 单样本的tokens最大长度。默认为None,设置为模型支持的tokens最大长度(max_model_len) - truncation_strategy: 如果超长如何处理,支持`delete`, `left`和`right`,代表删除、左侧裁剪和右侧裁剪,默认为'delete' - 🔥max_pixels: 多模态模型图片前处理的最大像素数(H\*W),默认不缩放。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index ba59c970b0..e9f28d5bc0 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -44,7 +44,7 @@ The introduction to command line parameters will cover base arguments, atomic ar ### Template Arguments - 🔥template: Type of dialogue template, which defaults to the template type corresponding to the model. `swift pt` will convert the dialogue template into a generation template for use. -- 🔥system: Custom system field, default is None, uses the default system of the template. +- 🔥system: Custom system field, could be a txt file path, default is None, uses the default system of the template. - 🔥max_length: The maximum length of tokens for a single sample. Defaults to None, set to the maximum length of tokens supported by the model (max_model_len). - truncation_strategy: How to handle overly long tokens, supports `delete`, `left`, `right`, representing deletion, left trimming, and right trimming, default is 'delete'. - 🔥max_pixels: Maximum pixel count for pre-processing images in multimodal models (H*W), default is no scaling. diff --git a/swift/llm/argument/base_args/template_args.py b/swift/llm/argument/base_args/template_args.py index 98c4d80a72..2dd5ff47cf 100644 --- a/swift/llm/argument/base_args/template_args.py +++ b/swift/llm/argument/base_args/template_args.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os from dataclasses import dataclass, field from typing import Literal, Optional @@ -46,6 +47,9 @@ class TemplateArguments: def __post_init__(self): if self.template is None and hasattr(self, 'model_meta'): self.template = self.model_meta.template + if self.system.endswith('.txt') and os.path.isfile(self.system): + with open(self.system, 'r') as f: + self.system = f.read() def get_template_kwargs(self): truncation_strategy = self.truncation_strategy From 54288d18d250c1c3e6f64c99df0403c1701ce215 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 24 Jan 2025 14:29:31 +0800 Subject: [PATCH 34/52] fix --- swift/llm/argument/base_args/template_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/argument/base_args/template_args.py b/swift/llm/argument/base_args/template_args.py index 2dd5ff47cf..23a6c79278 100644 --- a/swift/llm/argument/base_args/template_args.py +++ b/swift/llm/argument/base_args/template_args.py @@ -47,7 +47,7 @@ class TemplateArguments: def __post_init__(self): if self.template is None and hasattr(self, 'model_meta'): self.template = self.model_meta.template - if self.system.endswith('.txt') and os.path.isfile(self.system): + if self.system is not None and self.system.endswith('.txt') and os.path.isfile(self.system): with open(self.system, 'r') as f: self.system = f.read() From e02ae6b8995c25212fcdb0f4e4ca76ce233e4ae2 Mon Sep 17 00:00:00 2001 From: Jintao Date: Fri, 24 Jan 2025 15:54:06 +0800 Subject: [PATCH 35/52] support qwen2_5 long (#2982) --- ...14\346\225\260\346\215\256\351\233\206.md" | 3 + .../Supported-models-and-datasets.md | 3 + swift/llm/infer/protocol.py | 2 +- swift/llm/model/model/qwen.py | 7 ++- tests/test_align/test_template/test_llm.py | 58 +++++++++++++++++-- 5 files changed, 67 insertions(+), 6 deletions(-) diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index 31ad0b049e..7058d12930 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -100,6 +100,8 @@ |[Qwen/Qwen2-Math-7B](https://modelscope.cn/models/Qwen/Qwen2-Math-7B)|qwen2|qwen|transformers>=4.37|math|[Qwen/Qwen2-Math-7B](https://huggingface.co/Qwen/Qwen2-Math-7B)| |[Qwen/Qwen2-Math-72B](https://modelscope.cn/models/Qwen/Qwen2-Math-72B)|qwen2|qwen|transformers>=4.37|math|[Qwen/Qwen2-Math-72B](https://huggingface.co/Qwen/Qwen2-Math-72B)| |[PowerInfer/SmallThinker-3B-Preview](https://modelscope.cn/models/PowerInfer/SmallThinker-3B-Preview)|qwen2|qwen|transformers>=4.37|-|[PowerInfer/SmallThinker-3B-Preview](https://huggingface.co/PowerInfer/SmallThinker-3B-Preview)| +|[Qwen/Qwen2.5-7B-Instruct-1M](https://modelscope.cn/models/Qwen/Qwen2.5-7B-Instruct-1M)|qwen2|qwen|transformers>=4.37|-|[Qwen/Qwen2.5-7B-Instruct-1M](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-1M)| +|[Qwen/Qwen2.5-14B-Instruct-1M](https://modelscope.cn/models/Qwen/Qwen2.5-14B-Instruct-1M)|qwen2|qwen|transformers>=4.37|-|[Qwen/Qwen2.5-14B-Instruct-1M](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct-1M)| |[Qwen/Qwen2.5-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct)|qwen2_5|qwen2_5|transformers>=4.37|-|[Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct)| |[Qwen/Qwen2.5-1.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct)|qwen2_5|qwen2_5|transformers>=4.37|-|[Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct)| |[Qwen/Qwen2.5-3B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-3B-Instruct)|qwen2_5|qwen2_5|transformers>=4.37|-|[Qwen/Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct)| @@ -480,6 +482,7 @@ |[Shanghai_AI_Laboratory/internlm2-20b-reward](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-20b-reward)|internlm2_reward|internlm2_reward|transformers>=4.38|-|[internlm/internlm2-20b-reward](https://huggingface.co/internlm/internlm2-20b-reward)| |[Qwen/Qwen2-Math-RM-72B](https://modelscope.cn/models/Qwen/Qwen2-Math-RM-72B)|qwen2_reward|qwen|transformers>=4.37|-|[Qwen/Qwen2-Math-RM-72B](https://huggingface.co/Qwen/Qwen2-Math-RM-72B)| |[Qwen/Qwen2.5-Math-PRM-7B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-PRM-7B)|qwen2_5_prm|qwen2_5_math_prm|transformers>=4.37|-|[Qwen/Qwen2.5-Math-PRM-7B](https://huggingface.co/Qwen/Qwen2.5-Math-PRM-7B)| +|[Qwen/Qwen2.5-Math-7B-PRM800K](https://modelscope.cn/models/Qwen/Qwen2.5-Math-7B-PRM800K)|qwen2_5_prm|qwen2_5_math_prm|transformers>=4.37|-|[Qwen/Qwen2.5-Math-7B-PRM800K](https://huggingface.co/Qwen/Qwen2.5-Math-7B-PRM800K)| |[Qwen/Qwen2.5-Math-PRM-72B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-PRM-72B)|qwen2_5_prm|qwen2_5_math_prm|transformers>=4.37|-|[Qwen/Qwen2.5-Math-PRM-72B](https://huggingface.co/Qwen/Qwen2.5-Math-PRM-72B)| |[Qwen/Qwen2.5-Math-RM-72B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-RM-72B)|qwen2_5_math_reward|qwen2_5_math|transformers>=4.37|-|[Qwen/Qwen2.5-Math-RM-72B](https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B)| |[AI-ModelScope/Skywork-Reward-Llama-3.1-8B](https://modelscope.cn/models/AI-ModelScope/Skywork-Reward-Llama-3.1-8B)|llama3_2_reward|llama3_2|transformers>=4.43|-|[Skywork/Skywork-Reward-Llama-3.1-8B](https://huggingface.co/Skywork/Skywork-Reward-Llama-3.1-8B)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 8b6f91759f..781e31f7d9 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -100,6 +100,8 @@ The table below introduces the models integrated with ms-swift: |[Qwen/Qwen2-Math-7B](https://modelscope.cn/models/Qwen/Qwen2-Math-7B)|qwen2|qwen|transformers>=4.37|math|[Qwen/Qwen2-Math-7B](https://huggingface.co/Qwen/Qwen2-Math-7B)| |[Qwen/Qwen2-Math-72B](https://modelscope.cn/models/Qwen/Qwen2-Math-72B)|qwen2|qwen|transformers>=4.37|math|[Qwen/Qwen2-Math-72B](https://huggingface.co/Qwen/Qwen2-Math-72B)| |[PowerInfer/SmallThinker-3B-Preview](https://modelscope.cn/models/PowerInfer/SmallThinker-3B-Preview)|qwen2|qwen|transformers>=4.37|-|[PowerInfer/SmallThinker-3B-Preview](https://huggingface.co/PowerInfer/SmallThinker-3B-Preview)| +|[Qwen/Qwen2.5-7B-Instruct-1M](https://modelscope.cn/models/Qwen/Qwen2.5-7B-Instruct-1M)|qwen2|qwen|transformers>=4.37|-|[Qwen/Qwen2.5-7B-Instruct-1M](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-1M)| +|[Qwen/Qwen2.5-14B-Instruct-1M](https://modelscope.cn/models/Qwen/Qwen2.5-14B-Instruct-1M)|qwen2|qwen|transformers>=4.37|-|[Qwen/Qwen2.5-14B-Instruct-1M](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct-1M)| |[Qwen/Qwen2.5-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct)|qwen2_5|qwen2_5|transformers>=4.37|-|[Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct)| |[Qwen/Qwen2.5-1.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct)|qwen2_5|qwen2_5|transformers>=4.37|-|[Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct)| |[Qwen/Qwen2.5-3B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-3B-Instruct)|qwen2_5|qwen2_5|transformers>=4.37|-|[Qwen/Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct)| @@ -480,6 +482,7 @@ The table below introduces the models integrated with ms-swift: |[Shanghai_AI_Laboratory/internlm2-20b-reward](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-20b-reward)|internlm2_reward|internlm2_reward|transformers>=4.38|-|[internlm/internlm2-20b-reward](https://huggingface.co/internlm/internlm2-20b-reward)| |[Qwen/Qwen2-Math-RM-72B](https://modelscope.cn/models/Qwen/Qwen2-Math-RM-72B)|qwen2_reward|qwen|transformers>=4.37|-|[Qwen/Qwen2-Math-RM-72B](https://huggingface.co/Qwen/Qwen2-Math-RM-72B)| |[Qwen/Qwen2.5-Math-PRM-7B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-PRM-7B)|qwen2_5_prm|qwen2_5_math_prm|transformers>=4.37|-|[Qwen/Qwen2.5-Math-PRM-7B](https://huggingface.co/Qwen/Qwen2.5-Math-PRM-7B)| +|[Qwen/Qwen2.5-Math-7B-PRM800K](https://modelscope.cn/models/Qwen/Qwen2.5-Math-7B-PRM800K)|qwen2_5_prm|qwen2_5_math_prm|transformers>=4.37|-|[Qwen/Qwen2.5-Math-7B-PRM800K](https://huggingface.co/Qwen/Qwen2.5-Math-7B-PRM800K)| |[Qwen/Qwen2.5-Math-PRM-72B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-PRM-72B)|qwen2_5_prm|qwen2_5_math_prm|transformers>=4.37|-|[Qwen/Qwen2.5-Math-PRM-72B](https://huggingface.co/Qwen/Qwen2.5-Math-PRM-72B)| |[Qwen/Qwen2.5-Math-RM-72B](https://modelscope.cn/models/Qwen/Qwen2.5-Math-RM-72B)|qwen2_5_math_reward|qwen2_5_math|transformers>=4.37|-|[Qwen/Qwen2.5-Math-RM-72B](https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B)| |[AI-ModelScope/Skywork-Reward-Llama-3.1-8B](https://modelscope.cn/models/AI-ModelScope/Skywork-Reward-Llama-3.1-8B)|llama3_2_reward|llama3_2|transformers>=4.43|-|[Skywork/Skywork-Reward-Llama-3.1-8B](https://huggingface.co/Skywork/Skywork-Reward-Llama-3.1-8B)| diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index 73f219b720..688f567840 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -49,7 +49,7 @@ class RequestConfig: top_p: Optional[float] = None repetition_penalty: Optional[float] = None num_beams: int = 1 - stop: List[str] = field(default_factory=list) + stop: Optional[List[str]] = field(default_factory=list) seed: Optional[int] = None stream: bool = False diff --git a/swift/llm/model/model/qwen.py b/swift/llm/model/model/qwen.py index 2110a7a85c..a6f28ee8e4 100644 --- a/swift/llm/model/model/qwen.py +++ b/swift/llm/model/model/qwen.py @@ -336,7 +336,11 @@ def _get_cast_dtype(self) -> torch.dtype: Model('Qwen/Qwen2-Math-72B', 'Qwen/Qwen2-Math-72B'), ], tags=['math']), - ModelGroup([Model('PowerInfer/SmallThinker-3B-Preview', 'PowerInfer/SmallThinker-3B-Preview')]) + ModelGroup([Model('PowerInfer/SmallThinker-3B-Preview', 'PowerInfer/SmallThinker-3B-Preview')]), + ModelGroup([ + Model('Qwen/Qwen2.5-7B-Instruct-1M', 'Qwen/Qwen2.5-7B-Instruct-1M'), + Model('Qwen/Qwen2.5-14B-Instruct-1M', 'Qwen/Qwen2.5-14B-Instruct-1M'), + ]) ], TemplateType.qwen, get_model_tokenizer_with_flash_attn, @@ -693,6 +697,7 @@ def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx [ ModelGroup([ Model('Qwen/Qwen2.5-Math-PRM-7B', 'Qwen/Qwen2.5-Math-PRM-7B'), + Model('Qwen/Qwen2.5-Math-7B-PRM800K', 'Qwen/Qwen2.5-Math-7B-PRM800K'), Model('Qwen/Qwen2.5-Math-PRM-72B', 'Qwen/Qwen2.5-Math-PRM-72B'), ]), ], diff --git a/tests/test_align/test_template/test_llm.py b/tests/test_align/test_template/test_llm.py index 109092e57d..bd1a095ff3 100644 --- a/tests/test_align/test_template/test_llm.py +++ b/tests/test_align/test_template/test_llm.py @@ -1,5 +1,6 @@ import os +import json import torch os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' @@ -29,10 +30,11 @@ def _infer_model(pt_engine, system=None, messages=None): def test_qwen2_5(): - pt_engine = PtEngine('Qwen/Qwen2.5-3B') - _infer_model(pt_engine) + pt_engine = PtEngine('Qwen/Qwen2.5-7B-Instruct-1M') + response = _infer_model(pt_engine) pt_engine.default_template.template_backend = 'jinja' - _infer_model(pt_engine) + response2 = _infer_model(pt_engine) + assert response == response2 def test_phi4(): @@ -265,6 +267,53 @@ def test_deepseek_r1_distill(): assert res == res2, f'res: {res}, res2: {res2}' +def test_qwen2_5_prm(): + pt_engine = PtEngine('Qwen/Qwen2.5-Math-7B-PRM800K') + data = { + 'system': + 'Please reason step by step, and put your final answer within \\boxed{}.', + 'query': ('Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. ' + "On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. " + 'On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and ' + "put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, " + 'they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more ' + 'pink plastic flamingos were out than white plastic flamingos?'), + 'response': + [('To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, ' + 'we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink ' + 'plastic flamingos.'), + ('On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) ' + 'flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint ' + "these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink " + 'flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos ' + 'and 6 white flamingos.'), + ("On Sunday, the neighbors add another 18 pink plastic flamingos to Sue's front yard. By the end of Sunday " + 'morning, Sue has (18 + 18 = 36) pink flamingos and still 6 white flamingos.'), + ('To find the difference, subtract the number of white flamingos from the number of pink ' + 'flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out ' + 'than white plastic flamingos. The answer is (\\boxed{30}).')] + } + + messages = [ + { + 'role': 'system', + 'content': data['system'] + }, + { + 'role': 'user', + 'content': data['query'] + }, + { + 'role': 'assistant', + 'content': ''.join(data['response']) + '' + }, + ] + res = _infer_model(pt_engine, messages=messages) + pt_engine.default_template.template_backend = 'jinja' + res2 = _infer_model(pt_engine, messages=messages) + assert res == res2 == json.dumps([0.9921875, 0.2490234375, 0.70703125, 0.9375]), f'res: {res}, res2: {res2}' + + if __name__ == '__main__': from swift.llm import PtEngine, RequestConfig, get_template, get_model_tokenizer from swift.utils import get_logger, seed_everything @@ -292,4 +341,5 @@ def test_deepseek_r1_distill(): # test_skywork_reward() # test_phi4() # test_internlm3() - test_deepseek_r1_distill() + # test_deepseek_r1_distill() + test_qwen2_5_prm() From 94135ed4816f66cba911ee043f7a0a58d5505e62 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 24 Jan 2025 17:28:14 +0800 Subject: [PATCH 36/52] merge main into sampler --- swift/llm/argument/sampling_args.py | 3 - swift/llm/sampling/mcts.py | 151 +++++++++++++++------------- swift/llm/sampling/sampling.py | 5 +- swift/plugin/prm.py | 2 +- 4 files changed, 84 insertions(+), 77 deletions(-) diff --git a/swift/llm/argument/sampling_args.py b/swift/llm/argument/sampling_args.py index f0c6da8113..e75f948962 100644 --- a/swift/llm/argument/sampling_args.py +++ b/swift/llm/argument/sampling_args.py @@ -76,9 +76,6 @@ def __post_init__(self): else: self.engine_kwargs = {} - if os.path.isfile(self.system): - with open(self.system, 'r') as f: - self.system = f.read() self.system_message = { "role": "system", "content": self.system, diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index a92d0c2464..76ea888047 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -7,10 +7,10 @@ from swift.llm import InferRequest from swift.llm.infer.protocol import UsageInfo from swift.utils import get_logger +from swift.llm.argument.sampling_args import SamplingArguments from .base import Sampler from .utils import get_reward, perform_infer -from .sampling_args import SamplingArguments logger = get_logger() @@ -136,36 +136,36 @@ def update_usage_info(self, response): setattr(self.usage_info, key, update_value) def search_single(self, query, ground_truth): - def _UCT(node: LanguageNode): + def _uct(uct_curr_node: LanguageNode): alpha = _args.process_reward_rate - value = alpha * node.process_reward + (1 - alpha) * node.outcome_reward - if node.is_root(): + value = alpha * uct_curr_node.process_reward + (1 - alpha) * uct_curr_node.outcome_reward + if uct_curr_node.is_root(): return value exploitation_score = value exploration_score = (_args.exploration_rate - * np.sqrt(np.log(node.parent.visit_count + 1) / (node.visit_count + 1))) + * np.sqrt(np.log(uct_curr_node.parent.visit_count + 1) / (uct_curr_node.visit_count + 1))) return exploration_score + exploitation_score - def _select(node: LanguageNode): - while not node.is_leaf(): - node = max(node.active_children, key=lambda x: _UCT(x)) - return node + def _select(select_curr_node: LanguageNode): + while not select_curr_node.is_leaf(): + select_curr_node = max(select_curr_node.active_children, key=lambda x: _uct(x)) + return select_curr_node - def _expand(node: LanguageNode): - if node.is_root(): + def _expand(expand_curr_node: LanguageNode): + if expand_curr_node.is_root(): infer_request = InferRequest([system_message, prompt_message]) else: history_message = { "role": "assistant", - "content": node.answer, + "content": expand_curr_node.answer, } infer_request = InferRequest([system_message, prompt_message, history_message, next_message]) # e_time = time.time() # 为了并行进行 Expand 操作,这里暂时不需要考虑顺序,因为 Prompt 是一样的 - n = _args.num_return_sequences - len(node.children) + n = _args.num_return_sequences - len(expand_curr_node.children) with ThreadPoolExecutor(max_workers=n) as executor: futures = {executor.submit(perform_infer, self.infer_engines[i], @@ -191,30 +191,27 @@ def _expand(node: LanguageNode): continue unique_output.add(output) orm_infer_requests.append(InferRequest([{"role": "assistant", "content": output}])) - child = LanguageNode(step=output, parent=node) + child = LanguageNode(step=output, parent=expand_curr_node) if self.orm_model.check_terminate(child.answer)[0]: child.terminated = True else: all_child_terminated = False - node.add_child(child) + expand_curr_node.add_child(child) # e_time = time.time() orm_score, _orm_mask = get_reward( self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(orm_infer_requests), threshold=0.0) # logger.info(f"expand.orm time: {time.time() - e_time}") - for child, score in zip(node.children, orm_score): + for child, score in zip(expand_curr_node.children, orm_score): if child.terminated: child.init_and_update_value(score) - if child.outcome_reward == 1: - terminate_correct.append(child.answer) - else: - terminate_incorrect.append(child.answer) + terminated_nodes.append(child) # e_time = time.time() if self.prm_model: prm_infer_requests = [] - for child in node.children: + for child in expand_curr_node.children: prm_message = {"role": "assistant", "content": child.answer} prm_infer_requests.append(InferRequest([prompt_message, prm_message])) prm_score, _prm_mask = get_reward( @@ -222,19 +219,19 @@ def _expand(node: LanguageNode): prm_infer_requests, ground_truths=[ground_truth] * len(prm_infer_requests), threshold=0.0) - for child, score in zip(node.children, prm_score): + for child, score in zip(expand_curr_node.children, prm_score): child.process_reward = score # logger.info(f"expand.prm time: {time.time() - e_time}") - def _rollout(node: LanguageNode): + def _rollout(rollout_curr_node: LanguageNode): rollout_iter_index = 0 rollout_nodes = {} - for i in range(len(node.active_children)): + for i in range(len(rollout_curr_node.active_children)): rollout_nodes[i] = { - "node": node.active_children[i], + "node": rollout_curr_node.active_children[i], "history_messages": { "role": "assistant", - "content": node.active_children[i].answer, + "content": rollout_curr_node.active_children[i].answer, }, } active_rollout_nodes = list(rollout_nodes.keys()) @@ -283,40 +280,57 @@ def _rollout(node: LanguageNode): terminated_state = self.orm_model.check_terminate(end_paths) for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state): if terminated: - node.active_children[index].outcome_reward = score + rollout_curr_node.active_children[index].outcome_reward = score if score == 1: - correct_answers.append(rollout_nodes[index]['history_messages']["content"]) + rollout_correct_answers.append(rollout_nodes[index]['history_messages']["content"]) else: - incorrect_answers.append(rollout_nodes[index]['history_messages']["content"]) + rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']["content"]) rollout_nodes.pop(index) active_rollout_nodes = list(rollout_nodes.keys()) rollout_iter_index += 1 - def _back_propagate(curr_node: LanguageNode): - while curr_node: - best_child_value = max([child.outcome_reward for child in curr_node.children]) - curr_node.init_and_update_value(best_child_value) - curr_node.visit() - if len(curr_node.active_children) == 0 and not curr_node.is_root(): - curr_node.parent.active_children.remove(curr_node) - curr_node = curr_node.parent - - def _collect(curr_node: LanguageNode): - if curr_node.is_leaf(): - return [] - results = [] - for child in curr_node.children: - results += _collect(child) - curr_node.children = sorted(curr_node.children) - if curr_node.children[-1].outcome_reward - curr_node.children[0].outcome_reward > collect_filter_threshold: - results.append({ - "path": curr_node.path, - "good": curr_node.children[-1].path[-1], - "good_score": curr_node.children[-1].outcome_reward, - "bad": curr_node.children[0].path[-1], - "bad_score": curr_node.children[0].outcome_reward, - }) - return results + def _back_propagate(back_curr_node: LanguageNode): + while back_curr_node: + best_child_value = max([child.outcome_reward for child in back_curr_node.children]) + back_curr_node.init_and_update_value(best_child_value) + back_curr_node.visit() + if len(back_curr_node.active_children) == 0 and not back_curr_node.is_root(): + back_curr_node.parent.active_children.remove(back_curr_node) + back_curr_node = back_curr_node.parent + + def _collect(collect_curr_node: LanguageNode, _outcome_rewards: list[float], _process_rewards: list[float]): + _prefer_pairs, _correct_answers, _incorrect_answers = [], [], [] + _outcome_rewards = _outcome_rewards[:] + [collect_curr_node.outcome_reward] + _process_rewards = _process_rewards[:] + [collect_curr_node.process_reward] + if not collect_curr_node.is_leaf(): + for child in collect_curr_node.children: + p, c, i = _collect(child, _outcome_rewards, _process_rewards) + _prefer_pairs += p + _correct_answers += c + _incorrect_answers += i + collect_curr_node.children = sorted(collect_curr_node.children) + if collect_curr_node.children[-1].outcome_reward - collect_curr_node.children[0].outcome_reward > collect_filter_threshold: + prefer_pair = { + "path": collect_curr_node.answer, + "good": collect_curr_node.children[-1].path[-1], + "good_score": collect_curr_node.children[-1].outcome_reward, + "bad": collect_curr_node.children[0].path[-1], + "bad_score": collect_curr_node.children[0].outcome_reward, + } + prefer_pairs.append(prefer_pair) + if collect_curr_node.terminated: + _answer = { + "answer": collect_curr_node.answer, + "mean_outcome_reward": np.mean(_outcome_rewards), + "min_outcome_reward": np.min(_outcome_rewards), + "mean_process_reward": np.mean(_process_rewards), + "min_process_reward": np.min(_process_rewards), + } + if collect_curr_node.outcome_reward == 1: + _correct_answers.append(_answer) + else: + _incorrect_answers.append(_answer) + return _prefer_pairs, _correct_answers, _incorrect_answers _args = self.args system_message = _args.system_message @@ -328,12 +342,11 @@ def _collect(curr_node: LanguageNode): "content": query, } - correct_answers, incorrect_answers, prefer_pair = [], [], [] - terminate_correct, terminate_incorrect = [], [] + rollout_correct_answers, rollout_incorrect_answers, prefer_pairs, terminated_nodes = [], [], [], [] too_easy, too_hard = False, False iter_count = 0 while (not too_easy and not too_hard - and len(terminate_incorrect) + len(terminate_correct) < _args.num_return_sequences + and len(terminated_nodes) < _args.num_return_sequences and iter_count < _args.max_iterations): logger.info(f"iter_count: {iter_count}" + "." * 10) s_time = time.time() @@ -349,13 +362,13 @@ def _collect(curr_node: LanguageNode): s_time = time.time() _back_propagate(curr_node) logger.info("back propagate" + "=" * 10 + f"time: {time.time() - s_time}") - if len(correct_answers) + len(incorrect_answers) >= _args.num_return_sequences: - if 4 * len(incorrect_answers) < len(correct_answers): + if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= _args.num_return_sequences: + if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers): logger.info("too easy" + "!" * 20) - #logger.info(f"correct_answers: {correct_answers}") - #logger.info(f"incorrect_answers: {incorrect_answers}") + #logger.info(f"rollout_correct_answers: {rollout_correct_answers}") + #logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}") too_easy = True - elif 4 * len(correct_answers) < len(incorrect_answers): + elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers): logger.info("too hard" + "!" * 20) #logger.info(f"correct_answers: {correct_answers}") #logger.info(f"incorrect_answers: {incorrect_answers}") @@ -364,19 +377,19 @@ def _collect(curr_node: LanguageNode): if iter_count == _args.max_iterations: logger.info("too hard" + "!" * 20) too_hard = True - #logger.info(f"correct_answers: {correct_answers}") - #logger.info(f"incorrect_answers: {incorrect_answers}") - prefer_pair = _collect(_root) - #logger.info(f"prefer_pair: {prefer_pair}") + #logger.info(f"rollout_correct_answers: {rollout_correct_answers}") + #logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}") + prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], []) + #logger.info(f"prefer_pairs: {prefer_pairs}") result = { "query": query, "ground_truth": ground_truth, - "prefer_pair": prefer_pair, + "prefer_pairs": prefer_pairs, + "rollout_correct_answers": rollout_correct_answers, + "rollout_incorrect_answers": rollout_incorrect_answers, "correct_answers": correct_answers, "incorrect_answers": incorrect_answers, - "terminate_correct": terminate_correct, - "terminate_incorrect": terminate_incorrect, } results = json.dumps(result, ensure_ascii=False) logger.info(results) diff --git a/swift/llm/sampling/sampling.py b/swift/llm/sampling/sampling.py index 27ede10c4f..2ec3410f05 100644 --- a/swift/llm/sampling/sampling.py +++ b/swift/llm/sampling/sampling.py @@ -29,7 +29,7 @@ def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> No from swift.llm.sampling.vanilla_sampler import VanillaSampler self.sampler = VanillaSampler(self.args) elif self.args.sampler_type == 'mcts': - from swift.experimental.sampling.mcts import MctsSampler + from swift.llm.sampling.mcts import MctsSampler self.sampler = MctsSampler(self.args) def _get_dataset(self): @@ -70,6 +70,3 @@ def run(self): def sampling_main(args: Union[List[str], SamplingArguments, None] = None): return SwiftSampling(args).main() - -if __name__ == "__main__": - sampling_main() \ No newline at end of file diff --git a/swift/plugin/prm.py b/swift/plugin/prm.py index 630550d082..91fc872b5e 100644 --- a/swift/plugin/prm.py +++ b/swift/plugin/prm.py @@ -80,7 +80,7 @@ def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], }, ] completion = client.chat.completions.create( - model='qwen-plus', + model='qwen-max', messages=messages, ) From b61b7f70c77dd01577d29403bd4f2d3cc1813f39 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sun, 26 Jan 2025 10:32:47 +0800 Subject: [PATCH 37/52] fix node.correct & fix prefer_pairs --- swift/llm/sampling/mcts.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index 76ea888047..3fde0cae23 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -52,6 +52,7 @@ def __init__(self, self.process_reward = 0.0 self.outcome_reward = 0.0 self.terminated = False + self.correct = False def is_leaf(self): return len(self.children) == 0 @@ -206,6 +207,7 @@ def _expand(expand_curr_node: LanguageNode): for child, score in zip(expand_curr_node.children, orm_score): if child.terminated: child.init_and_update_value(score) + child.correct = score > 0.9 terminated_nodes.append(child) # e_time = time.time() @@ -280,8 +282,8 @@ def _rollout(rollout_curr_node: LanguageNode): terminated_state = self.orm_model.check_terminate(end_paths) for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state): if terminated: - rollout_curr_node.active_children[index].outcome_reward = score - if score == 1: + rollout_curr_node.active_children[index].init_and_update_value(score) + if score > 0.9: rollout_correct_answers.append(rollout_nodes[index]['history_messages']["content"]) else: rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']["content"]) @@ -317,7 +319,7 @@ def _collect(collect_curr_node: LanguageNode, _outcome_rewards: list[float], _pr "bad": collect_curr_node.children[0].path[-1], "bad_score": collect_curr_node.children[0].outcome_reward, } - prefer_pairs.append(prefer_pair) + _prefer_pairs.append(prefer_pair) if collect_curr_node.terminated: _answer = { "answer": collect_curr_node.answer, @@ -326,7 +328,7 @@ def _collect(collect_curr_node: LanguageNode, _outcome_rewards: list[float], _pr "mean_process_reward": np.mean(_process_rewards), "min_process_reward": np.min(_process_rewards), } - if collect_curr_node.outcome_reward == 1: + if collect_curr_node.correct: _correct_answers.append(_answer) else: _incorrect_answers.append(_answer) @@ -362,7 +364,7 @@ def _collect(collect_curr_node: LanguageNode, _outcome_rewards: list[float], _pr s_time = time.time() _back_propagate(curr_node) logger.info("back propagate" + "=" * 10 + f"time: {time.time() - s_time}") - if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= _args.num_return_sequences: + if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences: if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers): logger.info("too easy" + "!" * 20) #logger.info(f"rollout_correct_answers: {rollout_correct_answers}") From 322fe9c8a24bf103422ffeb689d7579bd29d2ede Mon Sep 17 00:00:00 2001 From: LiuXL Date: Mon, 27 Jan 2025 12:06:19 +0800 Subject: [PATCH 38/52] fix file save & no system_prompt --- swift/llm/argument/sampling_args.py | 12 ++++++++---- swift/llm/sampling/mcts.py | 11 +++++------ 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/swift/llm/argument/sampling_args.py b/swift/llm/argument/sampling_args.py index e75f948962..0747f8b452 100644 --- a/swift/llm/argument/sampling_args.py +++ b/swift/llm/argument/sampling_args.py @@ -76,11 +76,15 @@ def __post_init__(self): else: self.engine_kwargs = {} - self.system_message = { - "role": "system", - "content": self.system, - } if self.sampler_type == 'mcts' and self.sampler_engine != 'client': raise ValueError(f'`mcts` sampler only supports `client` engine yet, but now is: {self.sampler_engine}') super().__post_init__() + + if self.system is not None: + self.system_message = [{ + "role": "system", + "content": self.system, + }] + else: + self.system_message = [] diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index 3fde0cae23..fce4216dac 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -156,13 +156,13 @@ def _select(select_curr_node: LanguageNode): def _expand(expand_curr_node: LanguageNode): if expand_curr_node.is_root(): - infer_request = InferRequest([system_message, prompt_message]) + infer_request = InferRequest(system_message + [prompt_message]) else: history_message = { "role": "assistant", "content": expand_curr_node.answer, } - infer_request = InferRequest([system_message, prompt_message, history_message, next_message]) + infer_request = InferRequest(system_message + [prompt_message, history_message, next_message]) # e_time = time.time() # 为了并行进行 Expand 操作,这里暂时不需要考虑顺序,因为 Prompt 是一样的 @@ -239,8 +239,7 @@ def _rollout(rollout_curr_node: LanguageNode): active_rollout_nodes = list(rollout_nodes.keys()) while len(active_rollout_nodes) > 0 and rollout_iter_index < _args.rollout_depth: # r_time = time.time() - infer_requests = [InferRequest([system_message, - prompt_message, + infer_requests = [InferRequest(system_message + [prompt_message, rollout_nodes[index]['history_messages'], next_message]) for index in active_rollout_nodes] @@ -335,7 +334,7 @@ def _collect(collect_curr_node: LanguageNode, _outcome_rewards: list[float], _pr return _prefer_pairs, _correct_answers, _incorrect_answers _args = self.args - system_message = _args.system_message + system_message = [] + _args.system_message sep_token = _args.stop_words[0] + '\n' collect_filter_threshold = _args.collect_filter_threshold _root = LanguageNode(sep_token=sep_token) @@ -408,7 +407,7 @@ def do_sample(self, data): messages = item['messages'][0] query = messages[0]['content'] ground_truth = messages[1]['content'] - generated.append(self.search_single(query, ground_truth)) + generated.append(self.search_single(query, ground_truth) + '\n') except Exception as e: logger.error(f"Error: {e}") return generated \ No newline at end of file From e01ba81fdb6cd24a0a2faf5618052b34608cb61e Mon Sep 17 00:00:00 2001 From: LiuXL Date: Thu, 30 Jan 2025 19:42:49 +0800 Subject: [PATCH 39/52] perform_infer & collect tree --- swift/llm/sampling/mcts.py | 113 +++++++++--------------------------- swift/llm/sampling/utils.py | 73 +++++++++++++++++++++-- 2 files changed, 97 insertions(+), 89 deletions(-) diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index fce4216dac..6c99d945d7 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -71,6 +71,19 @@ def add_child(self, child: "LanguageNode"): if not child.terminated: self.active_children.append(child) + def collect(self): + result = { + "path": self.path, + "depth": self.depth, + "visit_count": self.visit_count, + "process_reward": self.process_reward, + "outcome_reward": self.outcome_reward, + "terminated": self.terminated, + "correct": self.correct, + "children": [child.collect() for child in self.children], + } + return result + def __lt__(self, other): return self.outcome_reward < other.outcome_reward @@ -120,16 +133,18 @@ def _prepare_request_configs(self): request_config.stop = _args.stop_words request_config.seed = _args.seed self.expand_request_configs = [] + self.rollout_request_configs = [] for i in range(_args.num_return_sequences): expand_request_config = deepcopy(request_config) expand_request_config.n = 1 expand_request_config.num_beams = expand_request_config.n expand_request_config.seed += i self.expand_request_configs.append(expand_request_config) - self.rollout_request_config = deepcopy(request_config) - self.rollout_request_config.max_tokens = 500 - self.rollout_request_config.temperature = 0.0 - self.rollout_request_config.n = 1 + rollout_request_config = deepcopy(request_config) + rollout_request_config.max_tokens = 500 + rollout_request_config.temperature = 0.0 + rollout_request_config.n = 1 + self.rollout_request_configs.append(rollout_request_config) def update_usage_info(self, response): for key, value in self.usage_info.__dict__.items(): @@ -155,31 +170,20 @@ def _select(select_curr_node: LanguageNode): return select_curr_node def _expand(expand_curr_node: LanguageNode): + n = _args.num_return_sequences - len(expand_curr_node.children) if expand_curr_node.is_root(): - infer_request = InferRequest(system_message + [prompt_message]) + infer_requests = [InferRequest(system_message + [prompt_message]) for _ in n] else: history_message = { "role": "assistant", "content": expand_curr_node.answer, } infer_request = InferRequest(system_message + [prompt_message, history_message, next_message]) + infer_requests = [infer_request for _ in range(n)] # e_time = time.time() # 为了并行进行 Expand 操作,这里暂时不需要考虑顺序,因为 Prompt 是一样的 - n = _args.num_return_sequences - len(expand_curr_node.children) - with ThreadPoolExecutor(max_workers=n) as executor: - futures = {executor.submit(perform_infer, - self.infer_engines[i], - infer_request, - self.expand_request_configs[i], - **self.infer_kwargs): i for i in range(n)} - responses = [] - for future in as_completed(futures): - task_id = futures[future] - try: - responses += future.result() - except Exception as e: - print(f"任务 {task_id} 执行请求时发生错误: {e}") + responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs, **self.infer_kwargs) # logger.info(f"expand.expand time: {time.time() - e_time}") # 为了并行获取 Outcome Reward,这里获得的 OR 是顺序返回的,所以可以直接对应 @@ -195,8 +199,6 @@ def _expand(expand_curr_node: LanguageNode): child = LanguageNode(step=output, parent=expand_curr_node) if self.orm_model.check_terminate(child.answer)[0]: child.terminated = True - else: - all_child_terminated = False expand_curr_node.add_child(child) # e_time = time.time() @@ -246,19 +248,7 @@ def _rollout(rollout_curr_node: LanguageNode): # logger.info(f"rollout.prepare time: {time.time() - r_time}") # r_time = time.time() n = len(infer_requests) - with ThreadPoolExecutor(max_workers=n) as executor: - futures = {executor.submit(perform_infer, - self.infer_engines[i], - infer_requests[i], - self.rollout_request_config, - **self.infer_kwargs): i for i in range(n)} - responses = [] - for future in as_completed(futures): - task_id = futures[future] - try: - responses += future.result() - except Exception as e: - print(f"任务 {task_id} 执行请求时发生错误: {e}") + responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs, **self.infer_kwargs) # logger.info(f"rollout.infer time: {time.time() - r_time}") # r_time = time.time() @@ -277,7 +267,7 @@ def _rollout(rollout_curr_node: LanguageNode): orm_score, _orm_mask = get_reward( self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0) - # logger.info(f"rollout.get_orm tiem: {time.time() - r_time}") + # logger.info(f"rollout.get_orm time: {time.time() - r_time}") terminated_state = self.orm_model.check_terminate(end_paths) for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state): if terminated: @@ -299,40 +289,6 @@ def _back_propagate(back_curr_node: LanguageNode): back_curr_node.parent.active_children.remove(back_curr_node) back_curr_node = back_curr_node.parent - def _collect(collect_curr_node: LanguageNode, _outcome_rewards: list[float], _process_rewards: list[float]): - _prefer_pairs, _correct_answers, _incorrect_answers = [], [], [] - _outcome_rewards = _outcome_rewards[:] + [collect_curr_node.outcome_reward] - _process_rewards = _process_rewards[:] + [collect_curr_node.process_reward] - if not collect_curr_node.is_leaf(): - for child in collect_curr_node.children: - p, c, i = _collect(child, _outcome_rewards, _process_rewards) - _prefer_pairs += p - _correct_answers += c - _incorrect_answers += i - collect_curr_node.children = sorted(collect_curr_node.children) - if collect_curr_node.children[-1].outcome_reward - collect_curr_node.children[0].outcome_reward > collect_filter_threshold: - prefer_pair = { - "path": collect_curr_node.answer, - "good": collect_curr_node.children[-1].path[-1], - "good_score": collect_curr_node.children[-1].outcome_reward, - "bad": collect_curr_node.children[0].path[-1], - "bad_score": collect_curr_node.children[0].outcome_reward, - } - _prefer_pairs.append(prefer_pair) - if collect_curr_node.terminated: - _answer = { - "answer": collect_curr_node.answer, - "mean_outcome_reward": np.mean(_outcome_rewards), - "min_outcome_reward": np.min(_outcome_rewards), - "mean_process_reward": np.mean(_process_rewards), - "min_process_reward": np.min(_process_rewards), - } - if collect_curr_node.correct: - _correct_answers.append(_answer) - else: - _incorrect_answers.append(_answer) - return _prefer_pairs, _correct_answers, _incorrect_answers - _args = self.args system_message = [] + _args.system_message sep_token = _args.stop_words[0] + '\n' @@ -380,22 +336,11 @@ def _collect(collect_curr_node: LanguageNode, _outcome_rewards: list[float], _pr too_hard = True #logger.info(f"rollout_correct_answers: {rollout_correct_answers}") #logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}") - prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], []) - #logger.info(f"prefer_pairs: {prefer_pairs}") - - result = { - "query": query, - "ground_truth": ground_truth, - "prefer_pairs": prefer_pairs, - "rollout_correct_answers": rollout_correct_answers, - "rollout_incorrect_answers": rollout_incorrect_answers, - "correct_answers": correct_answers, - "incorrect_answers": incorrect_answers, - } - results = json.dumps(result, ensure_ascii=False) - logger.info(results) - return results + result = _root.collect() + result_json = json.dumps(result, ensure_ascii=False) + logger.info(result_json) + return result_json def do_sample(self, data): if not isinstance(data, list): diff --git a/swift/llm/sampling/utils.py b/swift/llm/sampling/utils.py index 4923f6d30d..ce52278476 100644 --- a/swift/llm/sampling/utils.py +++ b/swift/llm/sampling/utils.py @@ -5,6 +5,7 @@ import json import numpy as np +from swift.llm.argument.sampling_args import SamplingArguments from swift.llm import InferRequest, Messages, RequestConfig @@ -68,9 +69,71 @@ def normalize(arr): return normalize(arr), _mask -def perform_infer(infer_engine, infer_request, request_config, **infer_kwargs): - return infer_engine.infer( - [infer_request], - request_config, +def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs): + if isinstance(infer_engines, list): + assert len(infer_engines) >= len(infer_requests) == len(request_configs) + from concurrent.futures import ThreadPoolExecutor, as_completed + n = len(infer_requests) + with ThreadPoolExecutor(max_workers=n) as executor: + futures = {executor.submit(perform_infer, + infer_engines[i], + infer_requests[i], + request_configs[i], + **infer_kwargs): i for i in range(n)} + responses = [] + for future in as_completed(futures): + task_id = futures[future] + try: + responses += future.result() + except Exception as e: + print(f"任务 {task_id} 执行请求时发生错误: {e}") + return responses + + return infer_engines.infer( + [infer_requests], + request_configs, **infer_kwargs, - ) \ No newline at end of file + ) + +def collect_from_mct(monte_carlo_tree, args: SamplingArguments): + if isinstance(monte_carlo_tree, str): + monte_carlo_tree = json.loads(monte_carlo_tree) + def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: list[float]): + _prefer_pairs, _correct_answers, _incorrect_answers = [], [], [] + _outcome_rewards = _outcome_rewards[:] + [collect_curr_node["outcome_reward"]] + _process_rewards = _process_rewards[:] + [collect_curr_node["process_reward"]] + if not len(collect_curr_node["children"]) > 0: + for child in collect_curr_node["children"]: + p, c, i = _collect(child, _outcome_rewards, _process_rewards) + _prefer_pairs += p + _correct_answers += c + _incorrect_answers += i + sorted_children = sorted(collect_curr_node["children"], key=lambda x: x["outcome_reward"]) + if sorted_children[-1].outcome_reward - sorted_children[0].outcome_reward > collect_filter_threshold: + # TODO: filter with visit count + prefer_pair = { + "path": collect_curr_node["answer"], + "good": sorted_children[-1]["path"][-1], + "good_score": sorted_children[-1]["outcome_reward"], + "bad": sorted_children[0]["path"][-1], + "bad_score": sorted_children[0]["outcome_reward"], + } + _prefer_pairs.append(prefer_pair) + if collect_curr_node["terminated"]: + _answer = { + "answer": collect_curr_node["answer"], + "mean_outcome_reward": np.mean(_outcome_rewards), + "min_outcome_reward": np.min(_outcome_rewards), + "mean_process_reward": np.mean(_process_rewards), + "min_process_reward": np.min(_process_rewards), + } + if collect_curr_node["correct"]: + _correct_answers.append(_answer) + else: + _incorrect_answers.append(_answer) + return _prefer_pairs, _correct_answers, _incorrect_answers + + collect_filter_threshold = args.collect_filter_threshold + _root = monte_carlo_tree + prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], []) + return prefer_pairs, correct_answers, incorrect_answers \ No newline at end of file From ce2116f564bb1511381b4d7589211739ee75ba72 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Thu, 30 Jan 2025 21:20:09 +0800 Subject: [PATCH 40/52] fix --- swift/llm/sampling/mcts.py | 8 ++++---- swift/llm/sampling/utils.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index 6c99d945d7..575ac0c9f8 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -78,8 +78,8 @@ def collect(self): "visit_count": self.visit_count, "process_reward": self.process_reward, "outcome_reward": self.outcome_reward, - "terminated": self.terminated, - "correct": self.correct, + "terminated": str(self.terminated), + "correct": str(self.correct), "children": [child.collect() for child in self.children], } return result @@ -101,7 +101,7 @@ def _prepare_model_tokenizer(self): from swift.llm import InferClient api_key = args.api_key base_url = args.base_url - self.infer_engines = [InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences)] + self.infer_engine = [InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences)] self.infer_kwargs['model'] = args.model else: _Engine = self.get_infer_engine() @@ -172,7 +172,7 @@ def _select(select_curr_node: LanguageNode): def _expand(expand_curr_node: LanguageNode): n = _args.num_return_sequences - len(expand_curr_node.children) if expand_curr_node.is_root(): - infer_requests = [InferRequest(system_message + [prompt_message]) for _ in n] + infer_requests = [InferRequest(system_message + [prompt_message]) for _ in range(n)] else: history_message = { "role": "assistant", diff --git a/swift/llm/sampling/utils.py b/swift/llm/sampling/utils.py index ce52278476..7e0f842116 100644 --- a/swift/llm/sampling/utils.py +++ b/swift/llm/sampling/utils.py @@ -96,6 +96,7 @@ def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs ) def collect_from_mct(monte_carlo_tree, args: SamplingArguments): + from transformers.utils import strtobool if isinstance(monte_carlo_tree, str): monte_carlo_tree = json.loads(monte_carlo_tree) def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: list[float]): @@ -119,7 +120,7 @@ def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: "bad_score": sorted_children[0]["outcome_reward"], } _prefer_pairs.append(prefer_pair) - if collect_curr_node["terminated"]: + if strtobool(collect_curr_node["terminated"]): _answer = { "answer": collect_curr_node["answer"], "mean_outcome_reward": np.mean(_outcome_rewards), @@ -127,7 +128,7 @@ def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: "mean_process_reward": np.mean(_process_rewards), "min_process_reward": np.min(_process_rewards), } - if collect_curr_node["correct"]: + if strtobool(collect_curr_node["correct"]): _correct_answers.append(_answer) else: _incorrect_answers.append(_answer) From 1e5a86cc6754093b097b102f2f99aecfb7d25943 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Thu, 30 Jan 2025 22:34:12 +0800 Subject: [PATCH 41/52] stop_reason & fix --- swift/llm/sampling/mcts.py | 46 ++++++++++++++++++++++--------------- swift/llm/sampling/utils.py | 2 +- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index 575ac0c9f8..82e9ef8a87 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -2,6 +2,7 @@ import numpy as np import json import time +import traceback from concurrent.futures import ThreadPoolExecutor, as_completed from swift.llm import InferRequest @@ -285,8 +286,10 @@ def _back_propagate(back_curr_node: LanguageNode): best_child_value = max([child.outcome_reward for child in back_curr_node.children]) back_curr_node.init_and_update_value(best_child_value) back_curr_node.visit() - if len(back_curr_node.active_children) == 0 and not back_curr_node.is_root(): - back_curr_node.parent.active_children.remove(back_curr_node) + if len(back_curr_node.active_children) == 0: + back_curr_node.terminated = True + if not back_curr_node.is_root(): + back_curr_node.parent.active_children.remove(back_curr_node) back_curr_node = back_curr_node.parent _args = self.args @@ -302,9 +305,8 @@ def _back_propagate(back_curr_node: LanguageNode): rollout_correct_answers, rollout_incorrect_answers, prefer_pairs, terminated_nodes = [], [], [], [] too_easy, too_hard = False, False iter_count = 0 - while (not too_easy and not too_hard - and len(terminated_nodes) < _args.num_return_sequences - and iter_count < _args.max_iterations): + stop_reason = None + while True: logger.info(f"iter_count: {iter_count}" + "." * 10) s_time = time.time() curr_node = _select(_root) @@ -316,24 +318,29 @@ def _back_propagate(back_curr_node: LanguageNode): s_time = time.time() _rollout(curr_node) logger.info("rollout" + "=" * 10 + f"time: {time.time() - s_time}") - s_time = time.time() - _back_propagate(curr_node) - logger.info("back propagate" + "=" * 10 + f"time: {time.time() - s_time}") + s_time = time.time() + _back_propagate(curr_node) + logger.info("back propagate" + "=" * 10 + f"time: {time.time() - s_time}") if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences: if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers): - logger.info("too easy" + "!" * 20) - #logger.info(f"rollout_correct_answers: {rollout_correct_answers}") - #logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}") - too_easy = True + stop_reason = "too easy" + break elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers): - logger.info("too hard" + "!" * 20) - #logger.info(f"correct_answers: {correct_answers}") - #logger.info(f"incorrect_answers: {incorrect_answers}") - too_hard = True + stop_reason = "too hard" + break + elif _root.terminated: + stop_reason = "root terminated" + break + elif len(terminated_nodes) >= _args.num_return_sequences: + stop_reason = "enough nodes" + break + elif iter_count >= _args.max_iterations: + stop_reason = "max_iterations" + break iter_count += 1 - if iter_count == _args.max_iterations: - logger.info("too hard" + "!" * 20) - too_hard = True + if stop_reason is None and iter_count == _args.max_iterations: + stop_reason = "too hard" + logger.info(f"stop_reason: {stop_reason}") #logger.info(f"rollout_correct_answers: {rollout_correct_answers}") #logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}") @@ -355,4 +362,5 @@ def do_sample(self, data): generated.append(self.search_single(query, ground_truth) + '\n') except Exception as e: logger.error(f"Error: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") return generated \ No newline at end of file diff --git a/swift/llm/sampling/utils.py b/swift/llm/sampling/utils.py index 7e0f842116..e445d08350 100644 --- a/swift/llm/sampling/utils.py +++ b/swift/llm/sampling/utils.py @@ -71,7 +71,7 @@ def normalize(arr): def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs): if isinstance(infer_engines, list): - assert len(infer_engines) >= len(infer_requests) == len(request_configs) + assert len(infer_engines) >= len(request_configs) >= len(infer_requests) from concurrent.futures import ThreadPoolExecutor, as_completed n = len(infer_requests) with ThreadPoolExecutor(max_workers=n) as executor: From 57b7990fb41111479981dce5bf26f7784b829e0f Mon Sep 17 00:00:00 2001 From: LiuXL Date: Thu, 30 Jan 2025 22:51:40 +0800 Subject: [PATCH 42/52] result add query --- swift/llm/sampling/mcts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index 82e9ef8a87..bf97a348cf 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -345,6 +345,8 @@ def _back_propagate(back_curr_node: LanguageNode): #logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}") result = _root.collect() + result['query'] = query + result['ground_truth'] = ground_truth result_json = json.dumps(result, ensure_ascii=False) logger.info(result_json) return result_json From a944ebaaeed870fd87a1cb85e70187c9b6e1a37a Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 31 Jan 2025 00:34:15 +0800 Subject: [PATCH 43/52] fix --- swift/llm/sampling/mcts.py | 44 ++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index bf97a348cf..ec2500eb93 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -184,7 +184,14 @@ def _expand(expand_curr_node: LanguageNode): # e_time = time.time() # 为了并行进行 Expand 操作,这里暂时不需要考虑顺序,因为 Prompt 是一样的 - responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs, **self.infer_kwargs) + expand_iter_index = 0 + while True: + responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs, **self.infer_kwargs) + if len(responses) > 0: + break + if expand_iter_index == 5: + raise ValueError("Expand should not return any response") + expand_iter_index += 1 # logger.info(f"expand.expand time: {time.time() - e_time}") # 为了并行获取 Outcome Reward,这里获得的 OR 是顺序返回的,所以可以直接对应 @@ -229,7 +236,7 @@ def _expand(expand_curr_node: LanguageNode): # logger.info(f"expand.prm time: {time.time() - e_time}") def _rollout(rollout_curr_node: LanguageNode): - rollout_iter_index = 0 + rollout_depth = 0 rollout_nodes = {} for i in range(len(rollout_curr_node.active_children)): rollout_nodes[i] = { @@ -240,7 +247,7 @@ def _rollout(rollout_curr_node: LanguageNode): }, } active_rollout_nodes = list(rollout_nodes.keys()) - while len(active_rollout_nodes) > 0 and rollout_iter_index < _args.rollout_depth: + while len(active_rollout_nodes) > 0 and rollout_depth < _args.rollout_depth: # r_time = time.time() infer_requests = [InferRequest(system_message + [prompt_message, rollout_nodes[index]['history_messages'], @@ -248,8 +255,14 @@ def _rollout(rollout_curr_node: LanguageNode): for index in active_rollout_nodes] # logger.info(f"rollout.prepare time: {time.time() - r_time}") # r_time = time.time() - n = len(infer_requests) - responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs, **self.infer_kwargs) + rollout_iter_index = 0 + while True: + responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs, **self.infer_kwargs) + if len(responses) > 0: + break + if rollout_iter_index == 5: + raise ValueError("Rollout should not return any response") + rollout_iter_index += 1 # logger.info(f"rollout.infer time: {time.time() - r_time}") # r_time = time.time() @@ -279,7 +292,7 @@ def _rollout(rollout_curr_node: LanguageNode): rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']["content"]) rollout_nodes.pop(index) active_rollout_nodes = list(rollout_nodes.keys()) - rollout_iter_index += 1 + rollout_depth += 1 def _back_propagate(back_curr_node: LanguageNode): while back_curr_node: @@ -328,25 +341,28 @@ def _back_propagate(back_curr_node: LanguageNode): elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers): stop_reason = "too hard" break - elif _root.terminated: + if _root.terminated: stop_reason = "root terminated" break - elif len(terminated_nodes) >= _args.num_return_sequences: + if len(terminated_nodes) >= _args.num_return_sequences: stop_reason = "enough nodes" break - elif iter_count >= _args.max_iterations: + if iter_count >= _args.max_iterations: stop_reason = "max_iterations" break iter_count += 1 - if stop_reason is None and iter_count == _args.max_iterations: - stop_reason = "too hard" logger.info(f"stop_reason: {stop_reason}") #logger.info(f"rollout_correct_answers: {rollout_correct_answers}") #logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}") - result = _root.collect() - result['query'] = query - result['ground_truth'] = ground_truth + monte_carlo_tree = _root.collect() + result = { + "query": query, + "ground_truth": ground_truth, + "rollout_correct_answers": rollout_correct_answers, + "rollout_incorrect_answers": rollout_incorrect_answers, + "monte_carlo_tree": monte_carlo_tree, + } result_json = json.dumps(result, ensure_ascii=False) logger.info(result_json) return result_json From cd5a83a7d734780fab6d966658a8de77dfe7c899 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 31 Jan 2025 11:18:13 +0800 Subject: [PATCH 44/52] fix collect_from_mct --- swift/llm/sampling/utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/swift/llm/sampling/utils.py b/swift/llm/sampling/utils.py index e445d08350..9f0e60c967 100644 --- a/swift/llm/sampling/utils.py +++ b/swift/llm/sampling/utils.py @@ -95,7 +95,7 @@ def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs **infer_kwargs, ) -def collect_from_mct(monte_carlo_tree, args: SamplingArguments): +def collect_from_mct(monte_carlo_tree, collect_filter_threshold): from transformers.utils import strtobool if isinstance(monte_carlo_tree, str): monte_carlo_tree = json.loads(monte_carlo_tree) @@ -103,17 +103,17 @@ def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: _prefer_pairs, _correct_answers, _incorrect_answers = [], [], [] _outcome_rewards = _outcome_rewards[:] + [collect_curr_node["outcome_reward"]] _process_rewards = _process_rewards[:] + [collect_curr_node["process_reward"]] - if not len(collect_curr_node["children"]) > 0: + if len(collect_curr_node["children"]) > 0: for child in collect_curr_node["children"]: p, c, i = _collect(child, _outcome_rewards, _process_rewards) _prefer_pairs += p _correct_answers += c _incorrect_answers += i sorted_children = sorted(collect_curr_node["children"], key=lambda x: x["outcome_reward"]) - if sorted_children[-1].outcome_reward - sorted_children[0].outcome_reward > collect_filter_threshold: + if sorted_children[-1]["outcome_reward"] - sorted_children[0]["outcome_reward"] > collect_filter_threshold: # TODO: filter with visit count prefer_pair = { - "path": collect_curr_node["answer"], + "path": "ки\n".join(collect_curr_node["path"]), "good": sorted_children[-1]["path"][-1], "good_score": sorted_children[-1]["outcome_reward"], "bad": sorted_children[0]["path"][-1], @@ -122,7 +122,7 @@ def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: _prefer_pairs.append(prefer_pair) if strtobool(collect_curr_node["terminated"]): _answer = { - "answer": collect_curr_node["answer"], + "answer": "ки\n".join(collect_curr_node["path"]), "mean_outcome_reward": np.mean(_outcome_rewards), "min_outcome_reward": np.min(_outcome_rewards), "mean_process_reward": np.mean(_process_rewards), @@ -134,7 +134,6 @@ def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: _incorrect_answers.append(_answer) return _prefer_pairs, _correct_answers, _incorrect_answers - collect_filter_threshold = args.collect_filter_threshold _root = monte_carlo_tree prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], []) return prefer_pairs, correct_answers, incorrect_answers \ No newline at end of file From fbbbdb217421c78c025a19cc86bb2ba3a9647e9f Mon Sep 17 00:00:00 2001 From: LiuXL Date: Fri, 31 Jan 2025 18:27:58 +0800 Subject: [PATCH 45/52] fix vllm_engine --- swift/llm/argument/sampling_args.py | 5 +---- swift/llm/sampling/utils.py | 21 +++++++++++++++++++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/swift/llm/argument/sampling_args.py b/swift/llm/argument/sampling_args.py index 0747f8b452..bce42b9056 100644 --- a/swift/llm/argument/sampling_args.py +++ b/swift/llm/argument/sampling_args.py @@ -55,7 +55,7 @@ class SamplingArguments(BaseArguments): def _init_model_info(self): if self.sampler_engine != 'client': - return super._init_model_info(self) + return super()._init_model_info() self.task_type = 'causal_lm' return @@ -76,9 +76,6 @@ def __post_init__(self): else: self.engine_kwargs = {} - if self.sampler_type == 'mcts' and self.sampler_engine != 'client': - raise ValueError(f'`mcts` sampler only supports `client` engine yet, but now is: {self.sampler_engine}') - super().__post_init__() if self.system is not None: diff --git a/swift/llm/sampling/utils.py b/swift/llm/sampling/utils.py index 9f0e60c967..449611f6ed 100644 --- a/swift/llm/sampling/utils.py +++ b/swift/llm/sampling/utils.py @@ -88,9 +88,26 @@ def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs except Exception as e: print(f"任务 {task_id} 执行请求时发生错误: {e}") return responses - + elif isinstance(infer_requests, list): + responses = [] + if isinstance(request_configs, list): + assert len(infer_requests) <= len(request_configs) + for i in range(len(infer_requests)): + responses += infer_engines.infer( + [infer_requests[i]], + request_configs[i], + **infer_kwargs, + ) + elif isinstance(request_configs, RequestConfig): + for infer_request in infer_requests: + responses += infer_engines.infer( + infer_request, + request_configs, + **infer_kwargs, + ) + return responses return infer_engines.infer( - [infer_requests], + infer_requests, request_configs, **infer_kwargs, ) From 9712cc68d7aecb59483784aace6a608eff826830 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 1 Feb 2025 01:27:03 +0800 Subject: [PATCH 46/52] fix perform_infer --- swift/llm/sampling/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swift/llm/sampling/utils.py b/swift/llm/sampling/utils.py index 449611f6ed..0f0517ce72 100644 --- a/swift/llm/sampling/utils.py +++ b/swift/llm/sampling/utils.py @@ -101,13 +101,13 @@ def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs elif isinstance(request_configs, RequestConfig): for infer_request in infer_requests: responses += infer_engines.infer( - infer_request, + [infer_request], request_configs, **infer_kwargs, ) return responses return infer_engines.infer( - infer_requests, + [infer_requests], request_configs, **infer_kwargs, ) From bb39b6c645a35c9c7ff3a8f90744bbd935cb4d6e Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 1 Feb 2025 01:30:00 +0800 Subject: [PATCH 47/52] fix perform_infer & pre-commit --- swift/llm/argument/sampling_args.py | 8 +- swift/llm/sampling/mcts.py | 158 +++++++++++++++------------- swift/llm/sampling/utils.py | 55 +++++----- swift/plugin/orm.py | 2 +- swift/plugin/prm.py | 3 +- 5 files changed, 121 insertions(+), 105 deletions(-) diff --git a/swift/llm/argument/sampling_args.py b/swift/llm/argument/sampling_args.py index bce42b9056..29977ba24e 100644 --- a/swift/llm/argument/sampling_args.py +++ b/swift/llm/argument/sampling_args.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os import dataclasses +import os from dataclasses import dataclass from datetime import datetime from typing import List, Literal, Optional @@ -50,7 +50,7 @@ class SamplingArguments(BaseArguments): process_reward_rate: float = 0.0 exploration_rate: float = 0.5 collect_filter_threshold: float = 0.5 - api_key: str = "EMPTY" + api_key: str = 'EMPTY' base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1' def _init_model_info(self): @@ -80,8 +80,8 @@ def __post_init__(self): if self.system is not None: self.system_message = [{ - "role": "system", - "content": self.system, + 'role': 'system', + 'content': self.system, }] else: self.system_message = [] diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index ec2500eb93..42844027bd 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -1,36 +1,37 @@ -from copy import deepcopy -import numpy as np -import json import time import traceback from concurrent.futures import ThreadPoolExecutor, as_completed +from copy import deepcopy + +import json +import numpy as np from swift.llm import InferRequest +from swift.llm.argument.sampling_args import SamplingArguments from swift.llm.infer.protocol import UsageInfo from swift.utils import get_logger -from swift.llm.argument.sampling_args import SamplingArguments - from .base import Sampler from .utils import get_reward, perform_infer - logger = get_logger() NXT_PROMPT = """Continue. """ next_message = { - "role": "user", - "content": NXT_PROMPT, + 'role': 'user', + 'content': NXT_PROMPT, } class LanguageNode: - def __init__(self, - step: str = None, - sep_token: str = None, - parent: "LanguageNode" = None,): + def __init__( + self, + step: str = None, + sep_token: str = None, + parent: 'LanguageNode' = None, + ): self.parent = parent if sep_token: @@ -44,7 +45,7 @@ def __init__(self, self.depth = parent.depth + 1 else: self.path = [] - self.answer = "" + self.answer = '' self.depth = 0 self.active_children = [] @@ -67,21 +68,21 @@ def visit(self): def init_and_update_value(self, value): self.outcome_reward = (self.outcome_reward * self.visit_count + value) / (self.visit_count + 1) - def add_child(self, child: "LanguageNode"): + def add_child(self, child: 'LanguageNode'): self.children.append(child) if not child.terminated: self.active_children.append(child) def collect(self): result = { - "path": self.path, - "depth": self.depth, - "visit_count": self.visit_count, - "process_reward": self.process_reward, - "outcome_reward": self.outcome_reward, - "terminated": str(self.terminated), - "correct": str(self.correct), - "children": [child.collect() for child in self.children], + 'path': self.path, + 'depth': self.depth, + 'visit_count': self.visit_count, + 'process_reward': self.process_reward, + 'outcome_reward': self.outcome_reward, + 'terminated': str(self.terminated), + 'correct': str(self.correct), + 'children': [child.collect() for child in self.children], } return result @@ -93,7 +94,7 @@ class MctsSampler(Sampler): def __init__(self, input_args: SamplingArguments): super().__init__(input_args) - self.usage_info = UsageInfo(0,0,0) + self.usage_info = UsageInfo(0, 0, 0) def _prepare_model_tokenizer(self): args = self.args @@ -102,7 +103,9 @@ def _prepare_model_tokenizer(self): from swift.llm import InferClient api_key = args.api_key base_url = args.base_url - self.infer_engine = [InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences)] + self.infer_engine = [ + InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences) + ] self.infer_kwargs['model'] = args.model else: _Engine = self.get_infer_engine() @@ -153,6 +156,7 @@ def update_usage_info(self, response): setattr(self.usage_info, key, update_value) def search_single(self, query, ground_truth): + def _uct(uct_curr_node: LanguageNode): alpha = _args.process_reward_rate value = alpha * uct_curr_node.process_reward + (1 - alpha) * uct_curr_node.outcome_reward @@ -160,8 +164,9 @@ def _uct(uct_curr_node: LanguageNode): return value exploitation_score = value - exploration_score = (_args.exploration_rate - * np.sqrt(np.log(uct_curr_node.parent.visit_count + 1) / (uct_curr_node.visit_count + 1))) + exploration_score = ( + _args.exploration_rate + * np.sqrt(np.log(uct_curr_node.parent.visit_count + 1) / (uct_curr_node.visit_count + 1))) return exploration_score + exploitation_score @@ -176,8 +181,8 @@ def _expand(expand_curr_node: LanguageNode): infer_requests = [InferRequest(system_message + [prompt_message]) for _ in range(n)] else: history_message = { - "role": "assistant", - "content": expand_curr_node.answer, + 'role': 'assistant', + 'content': expand_curr_node.answer, } infer_request = InferRequest(system_message + [prompt_message, history_message, next_message]) infer_requests = [infer_request for _ in range(n)] @@ -186,24 +191,25 @@ def _expand(expand_curr_node: LanguageNode): # 为了并行进行 Expand 操作,这里暂时不需要考虑顺序,因为 Prompt 是一样的 expand_iter_index = 0 while True: - responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs, **self.infer_kwargs) + responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs, + **self.infer_kwargs) if len(responses) > 0: break if expand_iter_index == 5: - raise ValueError("Expand should not return any response") + raise ValueError('Expand should not return any response') expand_iter_index += 1 # logger.info(f"expand.expand time: {time.time() - e_time}") # 为了并行获取 Outcome Reward,这里获得的 OR 是顺序返回的,所以可以直接对应 orm_infer_requests = [] - unique_output = set() # 用于去重 + unique_output = set() # 用于去重 for response in responses: self.update_usage_info(response) output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0] if output in unique_output: continue unique_output.add(output) - orm_infer_requests.append(InferRequest([{"role": "assistant", "content": output}])) + orm_infer_requests.append(InferRequest([{'role': 'assistant', 'content': output}])) child = LanguageNode(step=output, parent=expand_curr_node) if self.orm_model.check_terminate(child.answer)[0]: child.terminated = True @@ -211,7 +217,9 @@ def _expand(expand_curr_node: LanguageNode): # e_time = time.time() orm_score, _orm_mask = get_reward( - self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(orm_infer_requests), + self.orm_model, + orm_infer_requests, + ground_truths=[ground_truth] * len(orm_infer_requests), threshold=0.0) # logger.info(f"expand.orm time: {time.time() - e_time}") for child, score in zip(expand_curr_node.children, orm_score): @@ -224,7 +232,7 @@ def _expand(expand_curr_node: LanguageNode): if self.prm_model: prm_infer_requests = [] for child in expand_curr_node.children: - prm_message = {"role": "assistant", "content": child.answer} + prm_message = {'role': 'assistant', 'content': child.answer} prm_infer_requests.append(InferRequest([prompt_message, prm_message])) prm_score, _prm_mask = get_reward( self.prm_model, @@ -240,28 +248,30 @@ def _rollout(rollout_curr_node: LanguageNode): rollout_nodes = {} for i in range(len(rollout_curr_node.active_children)): rollout_nodes[i] = { - "node": rollout_curr_node.active_children[i], - "history_messages": { - "role": "assistant", - "content": rollout_curr_node.active_children[i].answer, + 'node': rollout_curr_node.active_children[i], + 'history_messages': { + 'role': 'assistant', + 'content': rollout_curr_node.active_children[i].answer, }, } active_rollout_nodes = list(rollout_nodes.keys()) while len(active_rollout_nodes) > 0 and rollout_depth < _args.rollout_depth: # r_time = time.time() - infer_requests = [InferRequest(system_message + [prompt_message, - rollout_nodes[index]['history_messages'], - next_message]) - for index in active_rollout_nodes] + infer_requests = [ + InferRequest(system_message + + [prompt_message, rollout_nodes[index]['history_messages'], next_message]) + for index in active_rollout_nodes + ] # logger.info(f"rollout.prepare time: {time.time() - r_time}") # r_time = time.time() rollout_iter_index = 0 while True: - responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs, **self.infer_kwargs) + responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs, + **self.infer_kwargs) if len(responses) > 0: break if rollout_iter_index == 5: - raise ValueError("Rollout should not return any response") + raise ValueError('Rollout should not return any response') rollout_iter_index += 1 # logger.info(f"rollout.infer time: {time.time() - r_time}") @@ -270,16 +280,18 @@ def _rollout(rollout_curr_node: LanguageNode): end_paths = [] for index, response in zip(active_rollout_nodes, responses): self.update_usage_info(response) - output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[ - 0] + sep_token + '\n' - rollout_nodes[index]['history_messages']["content"] += output - end_paths.append(rollout_nodes[index]['history_messages']["content"]) + output = response.choices[0].message.content.rstrip(sep_token + + '\n').split(sep_token)[0] + sep_token + '\n' + rollout_nodes[index]['history_messages']['content'] += output + end_paths.append(rollout_nodes[index]['history_messages']['content']) orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']])) # logger.info(f"rollout.orm_prepare time: {time.time() - r_time}") # r_time = time.time() orm_score, _orm_mask = get_reward( - self.orm_model, orm_infer_requests, ground_truths=[ground_truth] * len(infer_requests), + self.orm_model, + orm_infer_requests, + ground_truths=[ground_truth] * len(infer_requests), threshold=0.0) # logger.info(f"rollout.get_orm time: {time.time() - r_time}") terminated_state = self.orm_model.check_terminate(end_paths) @@ -287,9 +299,9 @@ def _rollout(rollout_curr_node: LanguageNode): if terminated: rollout_curr_node.active_children[index].init_and_update_value(score) if score > 0.9: - rollout_correct_answers.append(rollout_nodes[index]['history_messages']["content"]) + rollout_correct_answers.append(rollout_nodes[index]['history_messages']['content']) else: - rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']["content"]) + rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']['content']) rollout_nodes.pop(index) active_rollout_nodes = list(rollout_nodes.keys()) rollout_depth += 1 @@ -311,8 +323,8 @@ def _back_propagate(back_curr_node: LanguageNode): collect_filter_threshold = _args.collect_filter_threshold _root = LanguageNode(sep_token=sep_token) prompt_message = { - "role": "user", - "content": query, + 'role': 'user', + 'content': query, } rollout_correct_answers, rollout_incorrect_answers, prefer_pairs, terminated_nodes = [], [], [], [] @@ -320,48 +332,48 @@ def _back_propagate(back_curr_node: LanguageNode): iter_count = 0 stop_reason = None while True: - logger.info(f"iter_count: {iter_count}" + "." * 10) + logger.info(f'iter_count: {iter_count}' + '.' * 10) s_time = time.time() curr_node = _select(_root) - logger.info("select" + "=" * 10+ f"time: {time.time() - s_time}") + logger.info('select' + '=' * 10 + f'time: {time.time() - s_time}') s_time = time.time() _expand(curr_node) - logger.info("expand" + "=" * 10 + f"time: {time.time() - s_time}") + logger.info('expand' + '=' * 10 + f'time: {time.time() - s_time}') if curr_node.depth > _args.rollout_start_depth: s_time = time.time() _rollout(curr_node) - logger.info("rollout" + "=" * 10 + f"time: {time.time() - s_time}") + logger.info('rollout' + '=' * 10 + f'time: {time.time() - s_time}') s_time = time.time() _back_propagate(curr_node) - logger.info("back propagate" + "=" * 10 + f"time: {time.time() - s_time}") + logger.info('back propagate' + '=' * 10 + f'time: {time.time() - s_time}') if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences: if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers): - stop_reason = "too easy" + stop_reason = 'too easy' break elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers): - stop_reason = "too hard" + stop_reason = 'too hard' break if _root.terminated: - stop_reason = "root terminated" + stop_reason = 'root terminated' break if len(terminated_nodes) >= _args.num_return_sequences: - stop_reason = "enough nodes" + stop_reason = 'enough nodes' break if iter_count >= _args.max_iterations: - stop_reason = "max_iterations" + stop_reason = 'max_iterations' break iter_count += 1 - logger.info(f"stop_reason: {stop_reason}") + logger.info(f'stop_reason: {stop_reason}') #logger.info(f"rollout_correct_answers: {rollout_correct_answers}") #logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}") monte_carlo_tree = _root.collect() result = { - "query": query, - "ground_truth": ground_truth, - "rollout_correct_answers": rollout_correct_answers, - "rollout_incorrect_answers": rollout_incorrect_answers, - "monte_carlo_tree": monte_carlo_tree, + 'query': query, + 'ground_truth': ground_truth, + 'rollout_correct_answers': rollout_correct_answers, + 'rollout_incorrect_answers': rollout_incorrect_answers, + 'monte_carlo_tree': monte_carlo_tree, } result_json = json.dumps(result, ensure_ascii=False) logger.info(result_json) @@ -372,13 +384,13 @@ def do_sample(self, data): data = [data] generated = [] for item in data: - logger.info(f"time: {time.ctime(time.time())}") + logger.info(f'time: {time.ctime(time.time())}') try: messages = item['messages'][0] query = messages[0]['content'] ground_truth = messages[1]['content'] generated.append(self.search_single(query, ground_truth) + '\n') except Exception as e: - logger.error(f"Error: {e}") - logger.error(f"Traceback: {traceback.format_exc()}") - return generated \ No newline at end of file + logger.error(f'Error: {e}') + logger.error(f'Traceback: {traceback.format_exc()}') + return generated diff --git a/swift/llm/sampling/utils.py b/swift/llm/sampling/utils.py index 0f0517ce72..4ad9a07ef4 100644 --- a/swift/llm/sampling/utils.py +++ b/swift/llm/sampling/utils.py @@ -5,8 +5,8 @@ import json import numpy as np -from swift.llm.argument.sampling_args import SamplingArguments from swift.llm import InferRequest, Messages, RequestConfig +from swift.llm.argument.sampling_args import SamplingArguments def get_messages_md5(messages: Messages): @@ -69,24 +69,25 @@ def normalize(arr): return normalize(arr), _mask + def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs): if isinstance(infer_engines, list): assert len(infer_engines) >= len(request_configs) >= len(infer_requests) from concurrent.futures import ThreadPoolExecutor, as_completed n = len(infer_requests) with ThreadPoolExecutor(max_workers=n) as executor: - futures = {executor.submit(perform_infer, - infer_engines[i], - infer_requests[i], - request_configs[i], - **infer_kwargs): i for i in range(n)} + futures = { + executor.submit(perform_infer, infer_engines[i], infer_requests[i], request_configs[i], **infer_kwargs): + i + for i in range(n) + } responses = [] for future in as_completed(futures): task_id = futures[future] try: responses += future.result() except Exception as e: - print(f"任务 {task_id} 执行请求时发生错误: {e}") + print(f'任务 {task_id} 执行请求时发生错误: {e}') return responses elif isinstance(infer_requests, list): responses = [] @@ -112,40 +113,42 @@ def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs **infer_kwargs, ) + def collect_from_mct(monte_carlo_tree, collect_filter_threshold): from transformers.utils import strtobool if isinstance(monte_carlo_tree, str): monte_carlo_tree = json.loads(monte_carlo_tree) + def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: list[float]): _prefer_pairs, _correct_answers, _incorrect_answers = [], [], [] - _outcome_rewards = _outcome_rewards[:] + [collect_curr_node["outcome_reward"]] - _process_rewards = _process_rewards[:] + [collect_curr_node["process_reward"]] - if len(collect_curr_node["children"]) > 0: - for child in collect_curr_node["children"]: + _outcome_rewards = _outcome_rewards[:] + [collect_curr_node['outcome_reward']] + _process_rewards = _process_rewards[:] + [collect_curr_node['process_reward']] + if len(collect_curr_node['children']) > 0: + for child in collect_curr_node['children']: p, c, i = _collect(child, _outcome_rewards, _process_rewards) _prefer_pairs += p _correct_answers += c _incorrect_answers += i - sorted_children = sorted(collect_curr_node["children"], key=lambda x: x["outcome_reward"]) - if sorted_children[-1]["outcome_reward"] - sorted_children[0]["outcome_reward"] > collect_filter_threshold: + sorted_children = sorted(collect_curr_node['children'], key=lambda x: x['outcome_reward']) + if sorted_children[-1]['outcome_reward'] - sorted_children[0]['outcome_reward'] > collect_filter_threshold: # TODO: filter with visit count prefer_pair = { - "path": "ки\n".join(collect_curr_node["path"]), - "good": sorted_children[-1]["path"][-1], - "good_score": sorted_children[-1]["outcome_reward"], - "bad": sorted_children[0]["path"][-1], - "bad_score": sorted_children[0]["outcome_reward"], + 'path': 'ки\n'.join(collect_curr_node['path']), + 'good': sorted_children[-1]['path'][-1], + 'good_score': sorted_children[-1]['outcome_reward'], + 'bad': sorted_children[0]['path'][-1], + 'bad_score': sorted_children[0]['outcome_reward'], } _prefer_pairs.append(prefer_pair) - if strtobool(collect_curr_node["terminated"]): + if strtobool(collect_curr_node['terminated']): _answer = { - "answer": "ки\n".join(collect_curr_node["path"]), - "mean_outcome_reward": np.mean(_outcome_rewards), - "min_outcome_reward": np.min(_outcome_rewards), - "mean_process_reward": np.mean(_process_rewards), - "min_process_reward": np.min(_process_rewards), + 'answer': 'ки\n'.join(collect_curr_node['path']), + 'mean_outcome_reward': np.mean(_outcome_rewards), + 'min_outcome_reward': np.min(_outcome_rewards), + 'mean_process_reward': np.mean(_process_rewards), + 'min_process_reward': np.min(_process_rewards), } - if strtobool(collect_curr_node["correct"]): + if strtobool(collect_curr_node['correct']): _correct_answers.append(_answer) else: _incorrect_answers.append(_answer) @@ -153,4 +156,4 @@ def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: _root = monte_carlo_tree prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], []) - return prefer_pairs, correct_answers, incorrect_answers \ No newline at end of file + return prefer_pairs, correct_answers, incorrect_answers diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index ea0e378c24..fbd5b9f7fb 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -185,7 +185,7 @@ def check_terminate(answers: Union[str, List[str]]) -> List[bool]: answers = [answers] results = [] for answer in answers: - results.append("\\boxed" in answer) + results.append('\\boxed' in answer) return results @staticmethod diff --git a/swift/plugin/prm.py b/swift/plugin/prm.py index 91fc872b5e..8847f9af26 100644 --- a/swift/plugin/prm.py +++ b/swift/plugin/prm.py @@ -105,7 +105,8 @@ def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], class ClientPRM(PRM): - def __init__(self, api_key = None, base_url = None, model = None): + + def __init__(self, api_key=None, base_url=None, model=None): super().__init__() from swift.llm import InferClient import os From 626b23f71e5f08555df6b9a4824b675cbef72a46 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 1 Feb 2025 13:29:20 +0800 Subject: [PATCH 48/52] fix --- swift/llm/sampling/mcts.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index 42844027bd..a92b2dff7b 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -320,15 +320,13 @@ def _back_propagate(back_curr_node: LanguageNode): _args = self.args system_message = [] + _args.system_message sep_token = _args.stop_words[0] + '\n' - collect_filter_threshold = _args.collect_filter_threshold _root = LanguageNode(sep_token=sep_token) prompt_message = { 'role': 'user', 'content': query, } - rollout_correct_answers, rollout_incorrect_answers, prefer_pairs, terminated_nodes = [], [], [], [] - too_easy, too_hard = False, False + rollout_correct_answers, rollout_incorrect_answers, terminated_nodes = [], [], [] iter_count = 0 stop_reason = None while True: @@ -364,8 +362,8 @@ def _back_propagate(back_curr_node: LanguageNode): break iter_count += 1 logger.info(f'stop_reason: {stop_reason}') - #logger.info(f"rollout_correct_answers: {rollout_correct_answers}") - #logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}") + # logger.info(f"rollout_correct_answers: {rollout_correct_answers}") + # logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}") monte_carlo_tree = _root.collect() result = { From 55c061ef7c155eeba61d779da990b9c8189c2080 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 1 Feb 2025 16:41:19 +0800 Subject: [PATCH 49/52] examples --- examples/sampler/mcts.py | 115 +++++++++++++++++++++++++++++ examples/sampler/mcts.sh | 15 ++++ examples/sampler/system_prompt.txt | 7 ++ 3 files changed, 137 insertions(+) create mode 100644 examples/sampler/mcts.py create mode 100644 examples/sampler/mcts.sh create mode 100644 examples/sampler/system_prompt.txt diff --git a/examples/sampler/mcts.py b/examples/sampler/mcts.py new file mode 100644 index 0000000000..3eeeebf04a --- /dev/null +++ b/examples/sampler/mcts.py @@ -0,0 +1,115 @@ +import os +import json +from modelscope.msdatasets import MsDataset +import subprocess +import time +from typing import List + +from pyarrow.dataset import dataset + +conda_prefix = '' + + +def client_sample(model: str, orm: str, dataset_path: str, iter: int, device_count: int, output_dir: str): + handlers = [] + # Sampling cache + api_key = os.getenv('DASHSCOPE_API_KEY') + + for device in range(device_count): + + output_file = f"iter_{iter}_proc_{device}.jsonl" + cache_file = f"iter_{iter}_proc_{device}_cache.jsonl" + dataset = f"train_{device:02}.jsonl" + + # output_file_path = os.path.join(output_dir, output_file) + cache_file_path = os.path.join(output_dir, cache_file) + single_dataset_path = os.path.join(dataset_path, dataset) + + if not os.path.exists(cache_file_path): + open(cache_file_path, 'w').close() + sample_cmd = (f'USE_OPENCOMPASS_EVALUATOR=True ' + f'swift sample ' + f'--model {model} ' + f'--orm_model {orm} ' + f'--sampler_type mcts ' + f'--process_reward_rate 0 ' + f'--stop_words ки ' + f'--seed 42 ' + f'--api_key {api_key} ' + f'--dataset {single_dataset_path} ' + f'--max_length 2048 ' + f'--system ./scripts/sampler/system_prompt.txt ' + f'--load_args false ' + f'--sampler_engine client ' + f'--max_new_tokens 768 ' + f'--override_exist_file true ' + f'--num_sampling_per_gpu_batch_size 1 ' + f'--num_return_sequences 16 ' + f'--max_iterations 100 ' + f'--output_dir {output_dir} ' + f'--cache_files {cache_file} ' + f'--output_file {output_file} ' + f'--temperature 1.0 ') + print(f'Sampling caches of iter {iter}, part {device}.', flush=True) + env = os.environ.copy() + # env['CUDA_VISIBLE_DEVICES'] = str(device) + handler = subprocess.Popen( + f'{sample_cmd}' + f' > mcts_logs/sample_iter_{iter}_proc_{device}_cache.log 2>&1', + env=os.environ.copy(), + shell=True, + executable='/bin/bash') + handlers.append(handler) + + datasets = [] + for proc, handler in enumerate(handlers): + handler.wait() + assert os.path.exists(os.path.join(output_dir, f'iter_{iter}_proc_{proc}.jsonl')) + datasets.append(os.path.join('sample_output', f'iter_{iter}_proc_{proc}.jsonl')) + print(f'Sampling done, files:{datasets}', flush=True) + return datasets + + +def split_dataset(ds, split_size, out_path): + data_size = int(len(ds) / split_size) + 1 + + for i in range(split_size): + file_name = f'train_{i:02}.jsonl' + file_path = os.path.join(out_path, file_name) + print(file_path) + ds_split = ds[data_size * i: min(data_size * (i + 1), len(ds))] + print(f"split_size: {len(ds_split['problem'])}") + with open(file_path, 'w', encoding='utf-8') as file: + for problem, solution in zip(ds_split['problem'], ds_split['solution']): + message = { + 'messages': [{ + 'role': 'user', + 'content': problem, + }, { + 'role': 'assistant', + 'content': solution, + }, ] + } + file.write(json.dumps(message, ensure_ascii=False) + '\n') + + +def main(): + server_model = 'qwen-max' + orm = 'math' + device_count = 20 + output_dir = 'output/sampler/client_mcts/' + dataset_dir = 'datasets/competition_math/' + log_dir = 'mcts_logs/' + + os.makedirs(output_dir, exist_ok=True) + os.makedirs(dataset_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) + ds = MsDataset.load('tastelikefeet/competition_math', subset_name='default', split='train') + split_dataset(ds, device_count, dataset_dir) + + ts = time.time() + client_data = client_sample(server_model, orm, dataset_dir, 0, device_count, output_dir) + print(f'do sample cost: {(time.time() - ts) / 60:.1f} minutes.', flush=True) + + +if __name__ == '__main__': + main() diff --git a/examples/sampler/mcts.sh b/examples/sampler/mcts.sh new file mode 100644 index 0000000000..c9a0638ccc --- /dev/null +++ b/examples/sampler/mcts.sh @@ -0,0 +1,15 @@ +export CUDA_VISIBLE_DEVICES=2 +export USE_OPENCOMPASS_EVALUATOR=True + +swift sample \ + --model ./output/Qwen2.5-Math-7B-Instruct/v40-20250126-161112/checkpoint-20 \ + --orm_model math \ + --sampler_type mcts \ + --sampler_engine vllm \ + --output_dir ./output/sampler/vllm_mcts \ + --system ./examples/sampler/system_prompt.txt \ + --stop_words ки \ + --dataset ./datasets/competition_math/small_test.jsonl \ + --num_return_sequences 2 \ + --process_reward_rate 0 \ + --max_new_tokens 2048 \ No newline at end of file diff --git a/examples/sampler/system_prompt.txt b/examples/sampler/system_prompt.txt new file mode 100644 index 0000000000..5bdec31ef5 --- /dev/null +++ b/examples/sampler/system_prompt.txt @@ -0,0 +1,7 @@ +You are a math model, you should **think step by step** carefully. Each step should **end with \"ки\”**. Final answer should be in a ‘\boxed()’. + +## Example: +Step1: XXX. ки\n +Step2: XXX. ки\n +Step3: XXX. ки\n +Answer: \boxed(answer). ки\n \ No newline at end of file From e354c3213c5610e52cbeeae5840dea4ec827f513 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 1 Feb 2025 16:43:26 +0800 Subject: [PATCH 50/52] pre commit --- examples/sampler/mcts.py | 34 +++++++++++++++--------------- examples/sampler/mcts.sh | 2 +- examples/sampler/system_prompt.txt | 2 +- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/examples/sampler/mcts.py b/examples/sampler/mcts.py index 3eeeebf04a..69809fb993 100644 --- a/examples/sampler/mcts.py +++ b/examples/sampler/mcts.py @@ -1,11 +1,10 @@ import os -import json -from modelscope.msdatasets import MsDataset import subprocess import time from typing import List -from pyarrow.dataset import dataset +import json +from modelscope.msdatasets import MsDataset conda_prefix = '' @@ -17,9 +16,9 @@ def client_sample(model: str, orm: str, dataset_path: str, iter: int, device_cou for device in range(device_count): - output_file = f"iter_{iter}_proc_{device}.jsonl" - cache_file = f"iter_{iter}_proc_{device}_cache.jsonl" - dataset = f"train_{device:02}.jsonl" + output_file = f'iter_{iter}_proc_{device}.jsonl' + cache_file = f'iter_{iter}_proc_{device}_cache.jsonl' + dataset = f'train_{device:02}.jsonl' # output_file_path = os.path.join(output_dir, output_file) cache_file_path = os.path.join(output_dir, cache_file) @@ -51,7 +50,6 @@ def client_sample(model: str, orm: str, dataset_path: str, iter: int, device_cou f'--output_file {output_file} ' f'--temperature 1.0 ') print(f'Sampling caches of iter {iter}, part {device}.', flush=True) - env = os.environ.copy() # env['CUDA_VISIBLE_DEVICES'] = str(device) handler = subprocess.Popen( f'{sample_cmd}' + f' > mcts_logs/sample_iter_{iter}_proc_{device}_cache.log 2>&1', @@ -66,7 +64,6 @@ def client_sample(model: str, orm: str, dataset_path: str, iter: int, device_cou assert os.path.exists(os.path.join(output_dir, f'iter_{iter}_proc_{proc}.jsonl')) datasets.append(os.path.join('sample_output', f'iter_{iter}_proc_{proc}.jsonl')) print(f'Sampling done, files:{datasets}', flush=True) - return datasets def split_dataset(ds, split_size, out_path): @@ -76,18 +73,21 @@ def split_dataset(ds, split_size, out_path): file_name = f'train_{i:02}.jsonl' file_path = os.path.join(out_path, file_name) print(file_path) - ds_split = ds[data_size * i: min(data_size * (i + 1), len(ds))] + ds_split = ds[data_size * i:min(data_size * (i + 1), len(ds))] print(f"split_size: {len(ds_split['problem'])}") with open(file_path, 'w', encoding='utf-8') as file: for problem, solution in zip(ds_split['problem'], ds_split['solution']): message = { - 'messages': [{ - 'role': 'user', - 'content': problem, - }, { - 'role': 'assistant', - 'content': solution, - }, ] + 'messages': [ + { + 'role': 'user', + 'content': problem, + }, + { + 'role': 'assistant', + 'content': solution, + }, + ] } file.write(json.dumps(message, ensure_ascii=False) + '\n') @@ -107,7 +107,7 @@ def main(): split_dataset(ds, device_count, dataset_dir) ts = time.time() - client_data = client_sample(server_model, orm, dataset_dir, 0, device_count, output_dir) + client_sample(server_model, orm, dataset_dir, 0, device_count, output_dir) print(f'do sample cost: {(time.time() - ts) / 60:.1f} minutes.', flush=True) diff --git a/examples/sampler/mcts.sh b/examples/sampler/mcts.sh index c9a0638ccc..01bc031730 100644 --- a/examples/sampler/mcts.sh +++ b/examples/sampler/mcts.sh @@ -12,4 +12,4 @@ swift sample \ --dataset ./datasets/competition_math/small_test.jsonl \ --num_return_sequences 2 \ --process_reward_rate 0 \ - --max_new_tokens 2048 \ No newline at end of file + --max_new_tokens 2048 diff --git a/examples/sampler/system_prompt.txt b/examples/sampler/system_prompt.txt index 5bdec31ef5..a52891c4d5 100644 --- a/examples/sampler/system_prompt.txt +++ b/examples/sampler/system_prompt.txt @@ -4,4 +4,4 @@ You are a math model, you should **think step by step** carefully. Each step sho Step1: XXX. ки\n Step2: XXX. ки\n Step3: XXX. ки\n -Answer: \boxed(answer). ки\n \ No newline at end of file +Answer: \boxed(answer). ки\n From 958d3e376c1254c0477406ff1628b577a7381491 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 1 Feb 2025 19:09:14 +0800 Subject: [PATCH 51/52] less log & change example arg --- examples/sampler/mcts.py | 5 +++-- swift/llm/sampling/mcts.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/sampler/mcts.py b/examples/sampler/mcts.py index 69809fb993..0fc9ac0958 100644 --- a/examples/sampler/mcts.py +++ b/examples/sampler/mcts.py @@ -43,8 +43,9 @@ def client_sample(model: str, orm: str, dataset_path: str, iter: int, device_cou f'--max_new_tokens 768 ' f'--override_exist_file true ' f'--num_sampling_per_gpu_batch_size 1 ' - f'--num_return_sequences 16 ' - f'--max_iterations 100 ' + f'--num_return_sequences 8 ' + f'--exploration_rate 0.2 ' + f'--max_iterations 200 ' f'--output_dir {output_dir} ' f'--cache_files {cache_file} ' f'--output_file {output_file} ' diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index a92b2dff7b..b2614ab09d 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -333,17 +333,17 @@ def _back_propagate(back_curr_node: LanguageNode): logger.info(f'iter_count: {iter_count}' + '.' * 10) s_time = time.time() curr_node = _select(_root) - logger.info('select' + '=' * 10 + f'time: {time.time() - s_time}') + logger.debug('select' + '=' * 10 + f'time: {time.time() - s_time}') s_time = time.time() _expand(curr_node) - logger.info('expand' + '=' * 10 + f'time: {time.time() - s_time}') + logger.debug('expand' + '=' * 10 + f'time: {time.time() - s_time}') if curr_node.depth > _args.rollout_start_depth: s_time = time.time() _rollout(curr_node) - logger.info('rollout' + '=' * 10 + f'time: {time.time() - s_time}') + logger.debug('rollout' + '=' * 10 + f'time: {time.time() - s_time}') s_time = time.time() _back_propagate(curr_node) - logger.info('back propagate' + '=' * 10 + f'time: {time.time() - s_time}') + logger.debug('back propagate' + '=' * 10 + f'time: {time.time() - s_time}') if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences: if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers): stop_reason = 'too easy' From 54452395cbd009f165332864f687b904f5b8ea93 Mon Sep 17 00:00:00 2001 From: LiuXL Date: Sat, 8 Feb 2025 18:43:08 +0800 Subject: [PATCH 52/52] fix --- ...44\350\241\214\345\217\202\346\225\260.md" | 11 +++++++++ .../Instruction/Command-line-parameters.md | 9 +++++++ examples/sampler/mcts.sh | 24 +++++++++++++++++-- swift/llm/argument/sampling_args.py | 1 - swift/llm/sampling/mcts.py | 17 +++++++++---- swift/llm/sampling/utils.py | 5 +++- 6 files changed, 58 insertions(+), 9 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 9df1dabc48..bf5f4910d5 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -431,6 +431,8 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数) - prm_model: 过程奖励模型的类型,可以是模型id(以pt方式拉起),或者plugin中定义的prm key(自定义推理过程) - orm_model: 结果奖励模型的类型,通常是通配符或测试用例等,一般定义在plugin中 +- sampler_type:采样类型,目前支持 sample, mcts,未来会支持 dvts +- sampler_engine:支持`pt`, `lmdeploy`, `vllm`, `client`, `no`,默认为`pt`,采样模型的推理引擎 - sampler_type:采样类型,目前支持sample(do_sample方式),未来会支持mcts和dvts - sampler_engine:支持`pt`, `lmdeploy`, `vllm`, `no`,默认为`pt`,采样模型的推理引擎 - output_dir:输出目录,默认为`sample_output` @@ -448,6 +450,15 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数) - cache_files:为避免同时加载prm和generator造成显存OOM,可以分两步进行采样,第一步将prm和orm置为`None`,则所有结果都会输出到文件中,第二次运行采样将sampler_engine置为`no`并传入`--cache_files`为上次采样的输出文件,则会使用上次输出的结果进行prm和orm评估并输出最终结果。 - 注意:使用cache_files时,`--dataset`仍然需要传入,这是因为cache_files的id是由原始数据计算的md5,需要把两部分信息结合使用。 +#### MCTS +- rollout_depth:rollout 时的最大深度,默认为 `5` +- rollout_start_depth:开始 rollout 时的深度,低于此深度的节点只会进行 expand 操作,默认为 `3` +- max_iterations:mcts 的最大迭代次数,默认为 `100` +- process_reward_rate:select 中计算 value 时 process reward 占的比例,默认为 `0.0`,即不使用 PRM +- exploration_rate:UCT 算法中的探索参数,值越大越照顾探索次数较小的节点,默认为 `0.5` +- api_key:使用 client 作为推理引擎时需要,默认为 `EMPTY` +- base_url:使用 client 作为推理引擎时需要,默认为 'https://dashscope.aliyuncs.com/compatible-mode/v1' + ## 特定模型参数 特定模型参数可以通过`--model_kwargs`或者环境变量进行设置,例如: `--model_kwargs '{"fps_max_frames": 12}'`或者`FPS_MAX_FRAMES=12` diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 8507b9798a..685c5e3f27 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -458,6 +458,15 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum - cache_files: To avoid loading both `prm` and `generator` simultaneously and causing GPU memory OOM, sampling can be done in two steps. In the first step, set `prm` and `orm` to `None`, and all results will be output to a file. In the second run, set `sampler_engine` to `no` and pass `--cache_files` with the output file from the first sampling. This will use the results from the first run for `prm` and `orm` evaluation and output the final results. - Note: When using `cache_files`, the `--dataset` still needs to be provided because the ID for `cache_files` is calculated using the MD5 of the original data. Both pieces of information need to be used together. +#### MCTS +- rollout_depth: The maximum depth during rollouts, default is `5`. +- rollout_start_depth: The depth at which rollouts begin; nodes below this depth will only undergo expand operations, default is `3`. +- max_iterations: The maximum number of iterations for MCTS, default is `100`. +- process_reward_rate: The proportion of process reward used in calculating value during selection, default is `0.0`, meaning PRM is not used. +- exploration_rate: A parameter in the UCT algorithm that balances exploration; a higher value gives more weight to nodes with fewer explorations, default is `0.5`. +- api_key: Required when using the client as an inference engine, default is `EMPTY`. +- base_url: Required when using the client as an inference engine, default is 'https://dashscope.aliyuncs.com/compatible-mode/v1'. + ## Specific Model Arguments Specific model arguments can be set using `--model_kwargs` or environment variables, for example: `--model_kwargs '{"fps_max_frames": 12}'` or `FPS_MAX_FRAMES=12`. diff --git a/examples/sampler/mcts.sh b/examples/sampler/mcts.sh index 01bc031730..6b91ab10b9 100644 --- a/examples/sampler/mcts.sh +++ b/examples/sampler/mcts.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=2 +export CUDA_VISIBLE_DEVICES=0 export USE_OPENCOMPASS_EVALUATOR=True swift sample \ @@ -6,10 +6,30 @@ swift sample \ --orm_model math \ --sampler_type mcts \ --sampler_engine vllm \ - --output_dir ./output/sampler/vllm_mcts \ + --output_dir ./output/sampler/mcts \ --system ./examples/sampler/system_prompt.txt \ --stop_words ки \ --dataset ./datasets/competition_math/small_test.jsonl \ --num_return_sequences 2 \ --process_reward_rate 0 \ --max_new_tokens 2048 + +## Train +# nproc_per_node=8 +# NPROC_PER_NODE=$nproc_per_node \ +# swift sft \ +# --model Qwen/Qwen2.5-Math-7B-Instruct \ +# --train_type full \ +# --torch_dtype bfloat16 \ +# --dataset 'datasets/gen_V5.jsonl' \ +# --num_train_epochs 1 \ +# --per_device_train_batch_size 1 \ +# --learning_rate 1e-5 \ +# --gradient_accumulation_steps $(expr 128 / $nproc_per_node) \ +# --eval_steps 1000 \ +# --save_steps 10 \ +# --save_total_limit 100 \ +# --max_length 10000 \ +# --logging_steps 5 \ +# --gradient_checkpointing_kwargs '{"use_reentrant": false}' \ +# --deepspeed zero3 diff --git a/swift/llm/argument/sampling_args.py b/swift/llm/argument/sampling_args.py index 29977ba24e..e269f21ba9 100644 --- a/swift/llm/argument/sampling_args.py +++ b/swift/llm/argument/sampling_args.py @@ -49,7 +49,6 @@ class SamplingArguments(BaseArguments): max_iterations: int = 100 process_reward_rate: float = 0.0 exploration_rate: float = 0.5 - collect_filter_threshold: float = 0.5 api_key: str = 'EMPTY' base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1' diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index b2614ab09d..6dc94b2670 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -188,7 +188,8 @@ def _expand(expand_curr_node: LanguageNode): infer_requests = [infer_request for _ in range(n)] # e_time = time.time() - # 为了并行进行 Expand 操作,这里暂时不需要考虑顺序,因为 Prompt 是一样的 + # To perform the Expand operation in parallel, + # there's no need to consider the order for now, since the Prompt is the same. expand_iter_index = 0 while True: responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs, @@ -200,9 +201,10 @@ def _expand(expand_curr_node: LanguageNode): expand_iter_index += 1 # logger.info(f"expand.expand time: {time.time() - e_time}") - # 为了并行获取 Outcome Reward,这里获得的 OR 是顺序返回的,所以可以直接对应 + # To fetch Outcome Reward in parallel, + # the Outcome-Reward obtained is returned in order, so they can be directly matched accordingly. orm_infer_requests = [] - unique_output = set() # 用于去重 + unique_output = set() for response in responses: self.update_usage_info(response) output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0] @@ -308,8 +310,13 @@ def _rollout(rollout_curr_node: LanguageNode): def _back_propagate(back_curr_node: LanguageNode): while back_curr_node: - best_child_value = max([child.outcome_reward for child in back_curr_node.children]) - back_curr_node.init_and_update_value(best_child_value) + if back_curr_node == curr_node: + best_child_value = max([child.outcome_reward for child in back_curr_node.children]) + back_curr_node.init_and_update_value(best_child_value) + last_child_value = back_curr_node.outcome_reward + else: + back_curr_node.init_and_update_value(last_child_value) + last_child_value = back_curr_node.outcome_reward back_curr_node.visit() if len(back_curr_node.active_children) == 0: back_curr_node.terminated = True diff --git a/swift/llm/sampling/utils.py b/swift/llm/sampling/utils.py index ce10f38704..5b9b25ab1c 100644 --- a/swift/llm/sampling/utils.py +++ b/swift/llm/sampling/utils.py @@ -6,6 +6,9 @@ import numpy as np from swift.llm import InferRequest, Messages, RequestConfig +from swift.utils import get_logger + +logger = get_logger() def get_messages_md5(messages: Messages): @@ -86,7 +89,7 @@ def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs try: responses += future.result() except Exception as e: - print(f'任务 {task_id} 执行请求时发生错误: {e}') + logger.info(f'Perform infer task: {task_id} get an error: {e}') return responses elif isinstance(infer_requests, list): responses = []