diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 9df1dabc48..bf5f4910d5 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -431,6 +431,8 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数) - prm_model: 过程奖励模型的类型,可以是模型id(以pt方式拉起),或者plugin中定义的prm key(自定义推理过程) - orm_model: 结果奖励模型的类型,通常是通配符或测试用例等,一般定义在plugin中 +- sampler_type:采样类型,目前支持 sample, mcts,未来会支持 dvts +- sampler_engine:支持`pt`, `lmdeploy`, `vllm`, `client`, `no`,默认为`pt`,采样模型的推理引擎 - sampler_type:采样类型,目前支持sample(do_sample方式),未来会支持mcts和dvts - sampler_engine:支持`pt`, `lmdeploy`, `vllm`, `no`,默认为`pt`,采样模型的推理引擎 - output_dir:输出目录,默认为`sample_output` @@ -448,6 +450,15 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数) - cache_files:为避免同时加载prm和generator造成显存OOM,可以分两步进行采样,第一步将prm和orm置为`None`,则所有结果都会输出到文件中,第二次运行采样将sampler_engine置为`no`并传入`--cache_files`为上次采样的输出文件,则会使用上次输出的结果进行prm和orm评估并输出最终结果。 - 注意:使用cache_files时,`--dataset`仍然需要传入,这是因为cache_files的id是由原始数据计算的md5,需要把两部分信息结合使用。 +#### MCTS +- rollout_depth:rollout 时的最大深度,默认为 `5` +- rollout_start_depth:开始 rollout 时的深度,低于此深度的节点只会进行 expand 操作,默认为 `3` +- max_iterations:mcts 的最大迭代次数,默认为 `100` +- process_reward_rate:select 中计算 value 时 process reward 占的比例,默认为 `0.0`,即不使用 PRM +- exploration_rate:UCT 算法中的探索参数,值越大越照顾探索次数较小的节点,默认为 `0.5` +- api_key:使用 client 作为推理引擎时需要,默认为 `EMPTY` +- base_url:使用 client 作为推理引擎时需要,默认为 'https://dashscope.aliyuncs.com/compatible-mode/v1' + ## 特定模型参数 特定模型参数可以通过`--model_kwargs`或者环境变量进行设置,例如: `--model_kwargs '{"fps_max_frames": 12}'`或者`FPS_MAX_FRAMES=12` diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 8507b9798a..685c5e3f27 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -458,6 +458,15 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum - cache_files: To avoid loading both `prm` and `generator` simultaneously and causing GPU memory OOM, sampling can be done in two steps. In the first step, set `prm` and `orm` to `None`, and all results will be output to a file. In the second run, set `sampler_engine` to `no` and pass `--cache_files` with the output file from the first sampling. This will use the results from the first run for `prm` and `orm` evaluation and output the final results. - Note: When using `cache_files`, the `--dataset` still needs to be provided because the ID for `cache_files` is calculated using the MD5 of the original data. Both pieces of information need to be used together. +#### MCTS +- rollout_depth: The maximum depth during rollouts, default is `5`. +- rollout_start_depth: The depth at which rollouts begin; nodes below this depth will only undergo expand operations, default is `3`. +- max_iterations: The maximum number of iterations for MCTS, default is `100`. +- process_reward_rate: The proportion of process reward used in calculating value during selection, default is `0.0`, meaning PRM is not used. +- exploration_rate: A parameter in the UCT algorithm that balances exploration; a higher value gives more weight to nodes with fewer explorations, default is `0.5`. +- api_key: Required when using the client as an inference engine, default is `EMPTY`. +- base_url: Required when using the client as an inference engine, default is 'https://dashscope.aliyuncs.com/compatible-mode/v1'. + ## Specific Model Arguments Specific model arguments can be set using `--model_kwargs` or environment variables, for example: `--model_kwargs '{"fps_max_frames": 12}'` or `FPS_MAX_FRAMES=12`. diff --git a/examples/sampler/mcts.py b/examples/sampler/mcts.py new file mode 100644 index 0000000000..0fc9ac0958 --- /dev/null +++ b/examples/sampler/mcts.py @@ -0,0 +1,116 @@ +import os +import subprocess +import time +from typing import List + +import json +from modelscope.msdatasets import MsDataset + +conda_prefix = '' + + +def client_sample(model: str, orm: str, dataset_path: str, iter: int, device_count: int, output_dir: str): + handlers = [] + # Sampling cache + api_key = os.getenv('DASHSCOPE_API_KEY') + + for device in range(device_count): + + output_file = f'iter_{iter}_proc_{device}.jsonl' + cache_file = f'iter_{iter}_proc_{device}_cache.jsonl' + dataset = f'train_{device:02}.jsonl' + + # output_file_path = os.path.join(output_dir, output_file) + cache_file_path = os.path.join(output_dir, cache_file) + single_dataset_path = os.path.join(dataset_path, dataset) + + if not os.path.exists(cache_file_path): + open(cache_file_path, 'w').close() + sample_cmd = (f'USE_OPENCOMPASS_EVALUATOR=True ' + f'swift sample ' + f'--model {model} ' + f'--orm_model {orm} ' + f'--sampler_type mcts ' + f'--process_reward_rate 0 ' + f'--stop_words ки ' + f'--seed 42 ' + f'--api_key {api_key} ' + f'--dataset {single_dataset_path} ' + f'--max_length 2048 ' + f'--system ./scripts/sampler/system_prompt.txt ' + f'--load_args false ' + f'--sampler_engine client ' + f'--max_new_tokens 768 ' + f'--override_exist_file true ' + f'--num_sampling_per_gpu_batch_size 1 ' + f'--num_return_sequences 8 ' + f'--exploration_rate 0.2 ' + f'--max_iterations 200 ' + f'--output_dir {output_dir} ' + f'--cache_files {cache_file} ' + f'--output_file {output_file} ' + f'--temperature 1.0 ') + print(f'Sampling caches of iter {iter}, part {device}.', flush=True) + # env['CUDA_VISIBLE_DEVICES'] = str(device) + handler = subprocess.Popen( + f'{sample_cmd}' + f' > mcts_logs/sample_iter_{iter}_proc_{device}_cache.log 2>&1', + env=os.environ.copy(), + shell=True, + executable='/bin/bash') + handlers.append(handler) + + datasets = [] + for proc, handler in enumerate(handlers): + handler.wait() + assert os.path.exists(os.path.join(output_dir, f'iter_{iter}_proc_{proc}.jsonl')) + datasets.append(os.path.join('sample_output', f'iter_{iter}_proc_{proc}.jsonl')) + print(f'Sampling done, files:{datasets}', flush=True) + + +def split_dataset(ds, split_size, out_path): + data_size = int(len(ds) / split_size) + 1 + + for i in range(split_size): + file_name = f'train_{i:02}.jsonl' + file_path = os.path.join(out_path, file_name) + print(file_path) + ds_split = ds[data_size * i:min(data_size * (i + 1), len(ds))] + print(f"split_size: {len(ds_split['problem'])}") + with open(file_path, 'w', encoding='utf-8') as file: + for problem, solution in zip(ds_split['problem'], ds_split['solution']): + message = { + 'messages': [ + { + 'role': 'user', + 'content': problem, + }, + { + 'role': 'assistant', + 'content': solution, + }, + ] + } + file.write(json.dumps(message, ensure_ascii=False) + '\n') + + +def main(): + server_model = 'qwen-max' + orm = 'math' + device_count = 20 + output_dir = 'output/sampler/client_mcts/' + dataset_dir = 'datasets/competition_math/' + log_dir = 'mcts_logs/' + + os.makedirs(output_dir, exist_ok=True) + os.makedirs(dataset_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) + ds = MsDataset.load('tastelikefeet/competition_math', subset_name='default', split='train') + split_dataset(ds, device_count, dataset_dir) + + ts = time.time() + client_sample(server_model, orm, dataset_dir, 0, device_count, output_dir) + print(f'do sample cost: {(time.time() - ts) / 60:.1f} minutes.', flush=True) + + +if __name__ == '__main__': + main() diff --git a/examples/sampler/mcts.sh b/examples/sampler/mcts.sh new file mode 100644 index 0000000000..6b91ab10b9 --- /dev/null +++ b/examples/sampler/mcts.sh @@ -0,0 +1,35 @@ +export CUDA_VISIBLE_DEVICES=0 +export USE_OPENCOMPASS_EVALUATOR=True + +swift sample \ + --model ./output/Qwen2.5-Math-7B-Instruct/v40-20250126-161112/checkpoint-20 \ + --orm_model math \ + --sampler_type mcts \ + --sampler_engine vllm \ + --output_dir ./output/sampler/mcts \ + --system ./examples/sampler/system_prompt.txt \ + --stop_words ки \ + --dataset ./datasets/competition_math/small_test.jsonl \ + --num_return_sequences 2 \ + --process_reward_rate 0 \ + --max_new_tokens 2048 + +## Train +# nproc_per_node=8 +# NPROC_PER_NODE=$nproc_per_node \ +# swift sft \ +# --model Qwen/Qwen2.5-Math-7B-Instruct \ +# --train_type full \ +# --torch_dtype bfloat16 \ +# --dataset 'datasets/gen_V5.jsonl' \ +# --num_train_epochs 1 \ +# --per_device_train_batch_size 1 \ +# --learning_rate 1e-5 \ +# --gradient_accumulation_steps $(expr 128 / $nproc_per_node) \ +# --eval_steps 1000 \ +# --save_steps 10 \ +# --save_total_limit 100 \ +# --max_length 10000 \ +# --logging_steps 5 \ +# --gradient_checkpointing_kwargs '{"use_reentrant": false}' \ +# --deepspeed zero3 diff --git a/examples/sampler/system_prompt.txt b/examples/sampler/system_prompt.txt new file mode 100644 index 0000000000..a52891c4d5 --- /dev/null +++ b/examples/sampler/system_prompt.txt @@ -0,0 +1,7 @@ +You are a math model, you should **think step by step** carefully. Each step should **end with \"ки\”**. Final answer should be in a ‘\boxed()’. + +## Example: +Step1: XXX. ки\n +Step2: XXX. ки\n +Step3: XXX. ки\n +Answer: \boxed(answer). ки\n diff --git a/swift/llm/argument/sampling_args.py b/swift/llm/argument/sampling_args.py index 839f1ac99e..e269f21ba9 100644 --- a/swift/llm/argument/sampling_args.py +++ b/swift/llm/argument/sampling_args.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import dataclasses +import os from dataclasses import dataclass from datetime import datetime from typing import List, Literal, Optional @@ -20,8 +21,8 @@ class SamplingArguments(BaseArguments): # sampler settings # sample/mcts/dvts/xxx - sampler_type: str = 'sample' - sampler_engine: Literal['pt', 'lmdeploy', 'vllm', 'no'] = 'pt' + sampler_type: Literal['sample', 'mcts'] = 'sample' + sampler_engine: Literal['pt', 'lmdeploy', 'vllm', 'no', 'client'] = 'pt' output_dir: str = 'sample_output' output_file: Optional[str] = None override_exist_file: bool = False @@ -42,6 +43,21 @@ class SamplingArguments(BaseArguments): # Vanilla cache_files: List[str] = dataclasses.field(default_factory=list) + # MCTS + rollout_depth: int = 5 + rollout_start_depth: int = 3 + max_iterations: int = 100 + process_reward_rate: float = 0.0 + exploration_rate: float = 0.5 + api_key: str = 'EMPTY' + base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1' + + def _init_model_info(self): + if self.sampler_engine != 'client': + return super()._init_model_info() + self.task_type = 'causal_lm' + return + def __post_init__(self): if self.output_file is None: now = datetime.now() @@ -58,4 +74,13 @@ def __post_init__(self): self.engine_kwargs = json.loads(self.engine_kwargs) else: self.engine_kwargs = {} + super().__post_init__() + + if self.system is not None: + self.system_message = [{ + 'role': 'system', + 'content': self.system, + }] + else: + self.system_message = [] diff --git a/swift/llm/sampling/mcts.py b/swift/llm/sampling/mcts.py index 464090415c..6dc94b2670 100644 --- a/swift/llm/sampling/mcts.py +++ b/swift/llm/sampling/mcts.py @@ -1 +1,401 @@ -# TODO +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed +from copy import deepcopy + +import json +import numpy as np + +from swift.llm import InferRequest +from swift.llm.argument.sampling_args import SamplingArguments +from swift.llm.infer.protocol import UsageInfo +from swift.utils import get_logger +from .base import Sampler +from .utils import get_reward, perform_infer + +logger = get_logger() + +NXT_PROMPT = """Continue. +""" + +next_message = { + 'role': 'user', + 'content': NXT_PROMPT, +} + + +class LanguageNode: + + def __init__( + self, + step: str = None, + sep_token: str = None, + parent: 'LanguageNode' = None, + ): + self.parent = parent + + if sep_token: + self.sep_token = sep_token + else: + self.sep_token = parent.sep_token + + if parent: + self.path = parent.path[:] + [step] + self.answer = parent.answer + step + self.sep_token + self.depth = parent.depth + 1 + else: + self.path = [] + self.answer = '' + self.depth = 0 + + self.active_children = [] + self.children = [] + self.visit_count = 0 + self.process_reward = 0.0 + self.outcome_reward = 0.0 + self.terminated = False + self.correct = False + + def is_leaf(self): + return len(self.children) == 0 + + def is_root(self): + return self.parent is None + + def visit(self): + self.visit_count += 1 + + def init_and_update_value(self, value): + self.outcome_reward = (self.outcome_reward * self.visit_count + value) / (self.visit_count + 1) + + def add_child(self, child: 'LanguageNode'): + self.children.append(child) + if not child.terminated: + self.active_children.append(child) + + def collect(self): + result = { + 'path': self.path, + 'depth': self.depth, + 'visit_count': self.visit_count, + 'process_reward': self.process_reward, + 'outcome_reward': self.outcome_reward, + 'terminated': str(self.terminated), + 'correct': str(self.correct), + 'children': [child.collect() for child in self.children], + } + return result + + def __lt__(self, other): + return self.outcome_reward < other.outcome_reward + + +class MctsSampler(Sampler): + + def __init__(self, input_args: SamplingArguments): + super().__init__(input_args) + self.usage_info = UsageInfo(0, 0, 0) + + def _prepare_model_tokenizer(self): + args = self.args + self.infer_kwargs = {} + if args.sampler_engine == 'client': + from swift.llm import InferClient + api_key = args.api_key + base_url = args.base_url + self.infer_engine = [ + InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences) + ] + self.infer_kwargs['model'] = args.model + else: + _Engine = self.get_infer_engine() + self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs) + + def get_infer_engine(self): + if self.args.sampler_engine == 'pt': + from swift.llm import PtEngine + _Engine = PtEngine + elif self.args.sampler_engine == 'vllm': + from swift.llm import VllmEngine + _Engine = VllmEngine + elif self.args.sampler_engine == 'lmdeploy': + from swift.llm import LmdeployEngine + _Engine = LmdeployEngine + elif self.args.sampler_engine == 'no': + _Engine = None + else: + raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}') + return _Engine + + def _prepare_template(self) -> None: + # Hack from super() + self._prepare_request_configs() + + def _prepare_request_configs(self): + _args = self.args + request_config = _args.get_request_config() + request_config.stop = _args.stop_words + request_config.seed = _args.seed + self.expand_request_configs = [] + self.rollout_request_configs = [] + for i in range(_args.num_return_sequences): + expand_request_config = deepcopy(request_config) + expand_request_config.n = 1 + expand_request_config.num_beams = expand_request_config.n + expand_request_config.seed += i + self.expand_request_configs.append(expand_request_config) + rollout_request_config = deepcopy(request_config) + rollout_request_config.max_tokens = 500 + rollout_request_config.temperature = 0.0 + rollout_request_config.n = 1 + self.rollout_request_configs.append(rollout_request_config) + + def update_usage_info(self, response): + for key, value in self.usage_info.__dict__.items(): + update_value = getattr(response.usage, key, None) + value + setattr(self.usage_info, key, update_value) + + def search_single(self, query, ground_truth): + + def _uct(uct_curr_node: LanguageNode): + alpha = _args.process_reward_rate + value = alpha * uct_curr_node.process_reward + (1 - alpha) * uct_curr_node.outcome_reward + if uct_curr_node.is_root(): + return value + + exploitation_score = value + exploration_score = ( + _args.exploration_rate + * np.sqrt(np.log(uct_curr_node.parent.visit_count + 1) / (uct_curr_node.visit_count + 1))) + + return exploration_score + exploitation_score + + def _select(select_curr_node: LanguageNode): + while not select_curr_node.is_leaf(): + select_curr_node = max(select_curr_node.active_children, key=lambda x: _uct(x)) + return select_curr_node + + def _expand(expand_curr_node: LanguageNode): + n = _args.num_return_sequences - len(expand_curr_node.children) + if expand_curr_node.is_root(): + infer_requests = [InferRequest(system_message + [prompt_message]) for _ in range(n)] + else: + history_message = { + 'role': 'assistant', + 'content': expand_curr_node.answer, + } + infer_request = InferRequest(system_message + [prompt_message, history_message, next_message]) + infer_requests = [infer_request for _ in range(n)] + + # e_time = time.time() + # To perform the Expand operation in parallel, + # there's no need to consider the order for now, since the Prompt is the same. + expand_iter_index = 0 + while True: + responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs, + **self.infer_kwargs) + if len(responses) > 0: + break + if expand_iter_index == 5: + raise ValueError('Expand should not return any response') + expand_iter_index += 1 + # logger.info(f"expand.expand time: {time.time() - e_time}") + + # To fetch Outcome Reward in parallel, + # the Outcome-Reward obtained is returned in order, so they can be directly matched accordingly. + orm_infer_requests = [] + unique_output = set() + for response in responses: + self.update_usage_info(response) + output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0] + if output in unique_output: + continue + unique_output.add(output) + orm_infer_requests.append(InferRequest([{'role': 'assistant', 'content': output}])) + child = LanguageNode(step=output, parent=expand_curr_node) + if self.orm_model.check_terminate(child.answer)[0]: + child.terminated = True + expand_curr_node.add_child(child) + + # e_time = time.time() + orm_score, _orm_mask = get_reward( + self.orm_model, + orm_infer_requests, + ground_truths=[ground_truth] * len(orm_infer_requests), + threshold=0.0) + # logger.info(f"expand.orm time: {time.time() - e_time}") + for child, score in zip(expand_curr_node.children, orm_score): + if child.terminated: + child.init_and_update_value(score) + child.correct = score > 0.9 + terminated_nodes.append(child) + + # e_time = time.time() + if self.prm_model: + prm_infer_requests = [] + for child in expand_curr_node.children: + prm_message = {'role': 'assistant', 'content': child.answer} + prm_infer_requests.append(InferRequest([prompt_message, prm_message])) + prm_score, _prm_mask = get_reward( + self.prm_model, + prm_infer_requests, + ground_truths=[ground_truth] * len(prm_infer_requests), + threshold=0.0) + for child, score in zip(expand_curr_node.children, prm_score): + child.process_reward = score + # logger.info(f"expand.prm time: {time.time() - e_time}") + + def _rollout(rollout_curr_node: LanguageNode): + rollout_depth = 0 + rollout_nodes = {} + for i in range(len(rollout_curr_node.active_children)): + rollout_nodes[i] = { + 'node': rollout_curr_node.active_children[i], + 'history_messages': { + 'role': 'assistant', + 'content': rollout_curr_node.active_children[i].answer, + }, + } + active_rollout_nodes = list(rollout_nodes.keys()) + while len(active_rollout_nodes) > 0 and rollout_depth < _args.rollout_depth: + # r_time = time.time() + infer_requests = [ + InferRequest(system_message + + [prompt_message, rollout_nodes[index]['history_messages'], next_message]) + for index in active_rollout_nodes + ] + # logger.info(f"rollout.prepare time: {time.time() - r_time}") + # r_time = time.time() + rollout_iter_index = 0 + while True: + responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs, + **self.infer_kwargs) + if len(responses) > 0: + break + if rollout_iter_index == 5: + raise ValueError('Rollout should not return any response') + rollout_iter_index += 1 + # logger.info(f"rollout.infer time: {time.time() - r_time}") + + # r_time = time.time() + orm_infer_requests = [] + end_paths = [] + for index, response in zip(active_rollout_nodes, responses): + self.update_usage_info(response) + output = response.choices[0].message.content.rstrip(sep_token + + '\n').split(sep_token)[0] + sep_token + '\n' + rollout_nodes[index]['history_messages']['content'] += output + end_paths.append(rollout_nodes[index]['history_messages']['content']) + orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']])) + # logger.info(f"rollout.orm_prepare time: {time.time() - r_time}") + + # r_time = time.time() + orm_score, _orm_mask = get_reward( + self.orm_model, + orm_infer_requests, + ground_truths=[ground_truth] * len(infer_requests), + threshold=0.0) + # logger.info(f"rollout.get_orm time: {time.time() - r_time}") + terminated_state = self.orm_model.check_terminate(end_paths) + for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state): + if terminated: + rollout_curr_node.active_children[index].init_and_update_value(score) + if score > 0.9: + rollout_correct_answers.append(rollout_nodes[index]['history_messages']['content']) + else: + rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']['content']) + rollout_nodes.pop(index) + active_rollout_nodes = list(rollout_nodes.keys()) + rollout_depth += 1 + + def _back_propagate(back_curr_node: LanguageNode): + while back_curr_node: + if back_curr_node == curr_node: + best_child_value = max([child.outcome_reward for child in back_curr_node.children]) + back_curr_node.init_and_update_value(best_child_value) + last_child_value = back_curr_node.outcome_reward + else: + back_curr_node.init_and_update_value(last_child_value) + last_child_value = back_curr_node.outcome_reward + back_curr_node.visit() + if len(back_curr_node.active_children) == 0: + back_curr_node.terminated = True + if not back_curr_node.is_root(): + back_curr_node.parent.active_children.remove(back_curr_node) + back_curr_node = back_curr_node.parent + + _args = self.args + system_message = [] + _args.system_message + sep_token = _args.stop_words[0] + '\n' + _root = LanguageNode(sep_token=sep_token) + prompt_message = { + 'role': 'user', + 'content': query, + } + + rollout_correct_answers, rollout_incorrect_answers, terminated_nodes = [], [], [] + iter_count = 0 + stop_reason = None + while True: + logger.info(f'iter_count: {iter_count}' + '.' * 10) + s_time = time.time() + curr_node = _select(_root) + logger.debug('select' + '=' * 10 + f'time: {time.time() - s_time}') + s_time = time.time() + _expand(curr_node) + logger.debug('expand' + '=' * 10 + f'time: {time.time() - s_time}') + if curr_node.depth > _args.rollout_start_depth: + s_time = time.time() + _rollout(curr_node) + logger.debug('rollout' + '=' * 10 + f'time: {time.time() - s_time}') + s_time = time.time() + _back_propagate(curr_node) + logger.debug('back propagate' + '=' * 10 + f'time: {time.time() - s_time}') + if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences: + if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers): + stop_reason = 'too easy' + break + elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers): + stop_reason = 'too hard' + break + if _root.terminated: + stop_reason = 'root terminated' + break + if len(terminated_nodes) >= _args.num_return_sequences: + stop_reason = 'enough nodes' + break + if iter_count >= _args.max_iterations: + stop_reason = 'max_iterations' + break + iter_count += 1 + logger.info(f'stop_reason: {stop_reason}') + # logger.info(f"rollout_correct_answers: {rollout_correct_answers}") + # logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}") + + monte_carlo_tree = _root.collect() + result = { + 'query': query, + 'ground_truth': ground_truth, + 'rollout_correct_answers': rollout_correct_answers, + 'rollout_incorrect_answers': rollout_incorrect_answers, + 'monte_carlo_tree': monte_carlo_tree, + } + result_json = json.dumps(result, ensure_ascii=False) + logger.info(result_json) + return result_json + + def do_sample(self, data): + if not isinstance(data, list): + data = [data] + generated = [] + for item in data: + logger.info(f'time: {time.ctime(time.time())}') + try: + messages = item['messages'][0] + query = messages[0]['content'] + ground_truth = messages[1]['content'] + generated.append(self.search_single(query, ground_truth) + '\n') + except Exception as e: + logger.error(f'Error: {e}') + logger.error(f'Traceback: {traceback.format_exc()}') + return generated diff --git a/swift/llm/sampling/sampling.py b/swift/llm/sampling/sampling.py index 0ae33baddd..2ec3410f05 100644 --- a/swift/llm/sampling/sampling.py +++ b/swift/llm/sampling/sampling.py @@ -28,6 +28,9 @@ def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> No if self.args.sampler_type == 'sample': from swift.llm.sampling.vanilla_sampler import VanillaSampler self.sampler = VanillaSampler(self.args) + elif self.args.sampler_type == 'mcts': + from swift.llm.sampling.mcts import MctsSampler + self.sampler = MctsSampler(self.args) def _get_dataset(self): args = self.args diff --git a/swift/llm/sampling/utils.py b/swift/llm/sampling/utils.py index c056d9a287..5b9b25ab1c 100644 --- a/swift/llm/sampling/utils.py +++ b/swift/llm/sampling/utils.py @@ -6,6 +6,9 @@ import numpy as np from swift.llm import InferRequest, Messages, RequestConfig +from swift.utils import get_logger + +logger = get_logger() def get_messages_md5(messages: Messages): @@ -67,3 +70,92 @@ def normalize(arr): return normalized return normalize(arr), _mask + + +def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs): + if isinstance(infer_engines, list): + assert len(infer_engines) >= len(request_configs) >= len(infer_requests) + from concurrent.futures import ThreadPoolExecutor, as_completed + n = len(infer_requests) + with ThreadPoolExecutor(max_workers=n) as executor: + futures = { + executor.submit(perform_infer, infer_engines[i], infer_requests[i], request_configs[i], **infer_kwargs): + i + for i in range(n) + } + responses = [] + for future in as_completed(futures): + task_id = futures[future] + try: + responses += future.result() + except Exception as e: + logger.info(f'Perform infer task: {task_id} get an error: {e}') + return responses + elif isinstance(infer_requests, list): + responses = [] + if isinstance(request_configs, list): + assert len(infer_requests) <= len(request_configs) + for i in range(len(infer_requests)): + responses += infer_engines.infer( + [infer_requests[i]], + request_configs[i], + **infer_kwargs, + ) + elif isinstance(request_configs, RequestConfig): + for infer_request in infer_requests: + responses += infer_engines.infer( + [infer_request], + request_configs, + **infer_kwargs, + ) + return responses + return infer_engines.infer( + [infer_requests], + request_configs, + **infer_kwargs, + ) + + +def collect_from_mct(monte_carlo_tree, collect_filter_threshold): + from transformers.utils import strtobool + if isinstance(monte_carlo_tree, str): + monte_carlo_tree = json.loads(monte_carlo_tree) + + def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: list[float]): + _prefer_pairs, _correct_answers, _incorrect_answers = [], [], [] + _outcome_rewards = _outcome_rewards[:] + [collect_curr_node['outcome_reward']] + _process_rewards = _process_rewards[:] + [collect_curr_node['process_reward']] + if len(collect_curr_node['children']) > 0: + for child in collect_curr_node['children']: + p, c, i = _collect(child, _outcome_rewards, _process_rewards) + _prefer_pairs += p + _correct_answers += c + _incorrect_answers += i + sorted_children = sorted(collect_curr_node['children'], key=lambda x: x['outcome_reward']) + if sorted_children[-1]['outcome_reward'] - sorted_children[0]['outcome_reward'] > collect_filter_threshold: + # TODO: filter with visit count + prefer_pair = { + 'path': 'ки\n'.join(collect_curr_node['path']), + 'good': sorted_children[-1]['path'][-1], + 'good_score': sorted_children[-1]['outcome_reward'], + 'bad': sorted_children[0]['path'][-1], + 'bad_score': sorted_children[0]['outcome_reward'], + } + _prefer_pairs.append(prefer_pair) + if strtobool(collect_curr_node['terminated']): + _answer = { + 'answer': 'ки\n'.join(collect_curr_node['path']), + 'mean_outcome_reward': np.mean(_outcome_rewards), + 'min_outcome_reward': np.min(_outcome_rewards), + 'mean_process_reward': np.mean(_process_rewards), + 'min_process_reward': np.min(_process_rewards), + } + if strtobool(collect_curr_node['correct']): + _correct_answers.append(_answer) + else: + _incorrect_answers.append(_answer) + return _prefer_pairs, _correct_answers, _incorrect_answers + + _root = monte_carlo_tree + prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], []) + return prefer_pairs, correct_answers, incorrect_answers diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index 5d20c9185a..fbd5b9f7fb 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -1,6 +1,6 @@ import os import re -from typing import List +from typing import List, Union import json import torch @@ -175,6 +175,18 @@ def __init__(self): super().__init__() from transformers.utils import strtobool self.use_opencompass = strtobool(os.environ.get('USE_OPENCOMPASS_EVALUATOR')) + if self.use_opencompass: + from opencompass.datasets.math import MATHEvaluator + self.evaluator = MATHEvaluator() + + @staticmethod + def check_terminate(answers: Union[str, List[str]]) -> List[bool]: + if isinstance(answers, str): + answers = [answers] + results = [] + for answer in answers: + results.append('\\boxed' in answer) + return results @staticmethod def extract_boxed_result(text): @@ -228,9 +240,7 @@ def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], prediction = MathORM.extract_boxed_result(prediction) ground_truth = MathORM.extract_boxed_result(ground_truth) if self.use_opencompass: - from opencompass.datasets.math import MATHEvaluator - evaluator = MATHEvaluator() - rewards.append(evaluator.is_equiv(prediction, ground_truth)) + rewards.append(self.evaluator.is_equiv(prediction, ground_truth)) else: rewards.append(MathORM.compare_consecutive(prediction, ground_truth)) diff --git a/swift/plugin/prm.py b/swift/plugin/prm.py index a2fd255388..8847f9af26 100644 --- a/swift/plugin/prm.py +++ b/swift/plugin/prm.py @@ -5,7 +5,7 @@ import torch from swift.llm import InferRequest -from swift.llm.infer.protocol import ChatCompletionResponse +from swift.llm.infer.protocol import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage class PRM: @@ -52,13 +52,15 @@ class QwenMaxPRM(PRM): def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], **kwargs) -> List[ChatCompletionResponse]: rewards = [] - for request, ground_truth in zip(infer_requests, ground_truths): - from openai import OpenAI - client = OpenAI( - api_key=os.getenv('DASHSCOPE_API_KEY'), - base_url='https://dashscope.aliyuncs.com/compatible-mode/v1', - ) + from openai import OpenAI + + client = OpenAI( + api_key=os.getenv('DASHSCOPE_API_KEY'), + base_url='https://dashscope.aliyuncs.com/compatible-mode/v1', + ) + + for request, ground_truth in zip(infer_requests, ground_truths): previous = request.messages[:-1] if previous[0]['role'] == 'system': previous = previous[1:] @@ -78,7 +80,7 @@ def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], }, ] completion = client.chat.completions.create( - model='qwen-plus', + model='qwen-max', messages=messages, ) @@ -91,7 +93,83 @@ def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], except Exception: rewards.append(None) - return rewards + return [ + ChatCompletionResponse( + choices=[ + ChatCompletionResponseChoice( + message=ChatMessage(content=1.0 if r else 0.0, role='assistant'), index=0, finish_reason='') + ], + model=None, + usage=None) for r in rewards + ] + + +class ClientPRM(PRM): + + def __init__(self, api_key=None, base_url=None, model=None): + super().__init__() + from swift.llm import InferClient + import os + if api_key is None: + api_key = os.getenv('DASHSCOPE_API_KEY') + if base_url is None: + base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1' + if model is None: + model = 'qwen-plus' + self.infer_engine = InferClient(base_url=base_url, api_key=api_key) + self.infer_kwargs = { + 'model': model, + } + + def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], + **kwargs) -> List[ChatCompletionResponse]: + prm_infer_requests = [] + for request, ground_truth in zip(infer_requests, ground_truths): + previous = request.messages[:-1] + if previous[0]['role'] == 'system': + previous = previous[1:] + + assert request.messages[-1]['role'] == 'assistant' + query = QUERY.replace('#query#', json.dumps(previous)) + query = query.replace('#ground_truth#', ground_truth) + query = query.replace('#response#', request.messages[-1]['content']) + messages = [ + { + 'role': 'system', + 'content': SYSTEM + }, + { + 'role': 'user', + 'content': query + }, + ] + + prm_infer_requests.append(InferRequest(messages=messages)) + responses = self.infer_engine.infer(prm_infer_requests, **self.infer_kwargs) + rewards = [] + for response in responses: + content = response.choices[0].message.content + if 'Reward:' not in content: + rewards.append(None) + try: + reward = float(content.split('Reward:')[1].strip().replace('*', '')) + rewards.append(reward) + except Exception: + rewards.append(None) -prms = {'qwen_max': QwenMaxPRM} + return [ + ChatCompletionResponse( + choices=[ + ChatCompletionResponseChoice( + message=ChatMessage(content=1.0 if r else 0.0, role='assistant'), index=0, finish_reason='') + ], + model=None, + usage=None) for r in rewards + ] + + +prms = { + 'qwen_max': QwenMaxPRM, + 'client': ClientPRM, +}