Skip to content

Commit

Permalink
Merge pull request #1 from Jiayi-Pan/countdown-wip
Browse files Browse the repository at this point in the history
Countdown wip
  • Loading branch information
Jiayi-Pan authored Jan 23, 2025
2 parents e189fa2 + 32b9cd7 commit e721899
Show file tree
Hide file tree
Showing 7 changed files with 348 additions and 11 deletions.
108 changes: 101 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,31 @@ pip3 install flash-attn --no-build-isolation
pip install wandb IPython matplotlib
```

## The Task

## Generate Data
```
conda activate zero
python verl/examples/data_preprocess/arth.py
python examples/data_preprocess/countdown.py
```

## Run Training
```
conda activate zero
```
**Single GPU dry run**
Works for model <= 1.5B

For Qwen2.5-0.5B base, we know it fails to learn reasoning.

```
export CUDA_VISIBLE_DEVICES=6
export CUDA_VISIBLE_DEVICES=7
export N_GPUS=1
export BASE_MODEL=Qwen/Qwen2.5-0.5B
export DATA_DIR=$HOME/data/arithmetic-3_digit
export BASE_MODEL=Qwen/Qwen2.5-1.5B
export DATA_DIR=$HOME/data/countdown
export WANDB_API_KEY=0929e692448f1bc929d71d7e3ece80073c3041e6
export EXPERIMENT_NAME=countdown-qwen2.5-1.5b
PYTHONUNBUFFERE=1 python3 -m verl.trainer.main_ppo \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
Expand All @@ -57,9 +67,93 @@ PYTHONUNBUFFERE=1 python3 -m verl.trainer.main_ppo \
trainer.default_hdfs_dir=null \
trainer.n_gpus_per_node=$N_GPUS \
trainer.nnodes=1 \
trainer.save_freq=10 \
trainer.save_freq=30 \
trainer.test_freq=10 \
trainer.project_name=TinyZero \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.total_epochs=15 2>&1 | tee verl_demo.log
```

**3B model dry run**
In this case, the base model is able to develop sophisticated reasoning skills.
```
export CUDA_VISIBLE_DEVICES=4,5
export N_GPUS=2
export BASE_MODEL=Qwen/Qwen2.5-3B
export DATA_DIR=$HOME/data/countdown
export ROLLOUT_TP_SIZE=2
export WANDB_API_KEY=0929e692448f1bc929d71d7e3ece80073c3041e6
export EXPERIMENT_NAME=countdown-qwen2.5-3b
python3 -m verl.trainer.main_ppo \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
data.train_batch_size=256 \
data.val_batch_size=1312 \
data.max_prompt_length=256 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=$BASE_MODEL \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size=8 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=4 \
critic.optim.lr=1e-5 \
critic.model.path=$BASE_MODEL \
critic.ppo_micro_batch_size=8 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=['wandb'] \
+trainer.val_before_train=False \
trainer.default_hdfs_dir=null \
trainer.n_gpus_per_node=$N_GPUS \
trainer.nnodes=1 \
trainer.save_freq=30 \
trainer.test_freq=10 \
trainer.project_name=TinyZero \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.total_epochs=15 2>&1 | tee verl_demo.log
```

**OpenLlama 7B model dry run**
In this case, the base model is able to develop sophisticated reasoning skills.
```
export CUDA_VISIBLE_DEVICES=4,5,6,7
export N_GPUS=4
export EXPERIMENT_NAME=countdown-open_llama_7b
export BASE_MODEL=openlm-research/open_llama_7b_v2
export DATA_DIR=$HOME/data/countdown
export ROLLOUT_TP_SIZE=4
export WANDB_API_KEY=0929e692448f1bc929d71d7e3ece80073c3041e6
python3 -m verl.trainer.main_ppo \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
data.train_batch_size=256 \
data.val_batch_size=1312 \
data.max_prompt_length=256 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=$BASE_MODEL \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size=8 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=4 \
critic.optim.lr=1e-5 \
critic.model.path=$BASE_MODEL \
critic.ppo_micro_batch_size=8 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=['wandb'] \
+trainer.val_before_train=False \
trainer.default_hdfs_dir=null \
trainer.n_gpus_per_node=$N_GPUS \
trainer.nnodes=1 \
trainer.save_freq=30 \
trainer.test_freq=10 \
trainer.project_name=zero \
trainer.experiment_name=multi-hard \
trainer.project_name=TinyZero \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.total_epochs=15 2>&1 | tee verl_demo.log
```
125 changes: 125 additions & 0 deletions examples/data_preprocess/countdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
Preprocess dataset for countdown task - given a target number and N numbers, generate equations to reach target
"""

import re
import os
from datasets import Dataset, load_dataset
from random import randint, seed, choice
from typing import List, Tuple
from tqdm import tqdm
from verl.utils.hdfs_io import copy, makedirs
import argparse


def gen_dataset(
num_samples: int,
num_operands: int = 6,
max_target: int = 1000,
min_number: int = 1,
max_number: int = 100,
operations: List[str] = ['+', '-', '*', '/'],
seed_value: int = 42,
) -> List[Tuple]:
"""Generate dataset for countdown task.
Args:
num_samples: Number of samples to generate
num_operands: Number of numbers provided in each sample
max_target: Maximum value for target number
min_number: Minimum value for provided numbers
max_number: Maximum value for provided numbers
operations: List of allowed operations
seed_value: Random seed for reproducibility
Returns:
List of tuples containing (target, numbers, solution)
"""
seed(seed_value)
samples = []

for _ in tqdm(range(num_samples)):
# Generate random target
target = randint(1, max_target)

# Generate random numbers
numbers = [randint(min_number, max_number) for _ in range(num_operands)]


samples.append((target, numbers))

return samples

def make_prefix(dp):
target = dp['target']
numbers = dp['nums']

prefix = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
User: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.
Assistant: Let me solve this step by step.
<think>"""
return prefix


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='~/data/countdown')
parser.add_argument('--hdfs_dir', default=None)
parser.add_argument('--num_samples', type=int, default=100000)
parser.add_argument('--num_operands', type=int, default=6)
parser.add_argument('--max_target', type=int, default=1000)
parser.add_argument('--min_number', type=int, default=1)
parser.add_argument('--max_number', type=int, default=100)
parser.add_argument('--train_size', type=int, default=327680)
parser.add_argument('--test_size', type=int, default=4096)

args = parser.parse_args()

data_source = 'countdown'
TRAIN_SIZE = args.train_size
TEST_SIZE = args.test_size

raw_dataset = load_dataset('Jiayi-Pan/Countdown-Tasks-3to4', split='train')

assert len(raw_dataset) > TRAIN_SIZE + TEST_SIZE
train_dataset = raw_dataset.select(range(TRAIN_SIZE))
test_dataset = raw_dataset.select(range(TRAIN_SIZE, TRAIN_SIZE + TEST_SIZE))

def make_map_fn(split):
def process_fn(example, idx):
question = make_prefix(example)
solution = {
"target": example['target'],
"numbers": example['nums']
}
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": question,
}],
"ability": "math",
"reward_model": {
"style": "rule",
"ground_truth": solution
},
"extra_info": {
'split': split,
'index': idx,
}
}
return data
return process_fn

train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)

local_dir = args.local_dir
hdfs_dir = args.hdfs_dir

train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))

if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
4 changes: 3 additions & 1 deletion verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from verl import DataProto
import torch
from verl.utils.reward_score import gsm8k, math, multiply
from verl.utils.reward_score import gsm8k, math, multiply, countdown
from verl.trainer.ppo.ray_trainer import RayPPOTrainer


Expand All @@ -28,6 +28,8 @@ def _select_rm_score_fn(data_source):
return math.compute_score
elif "multiply" in data_source or "arithmetic" in data_source:
return multiply.compute_score
elif "countdown" in data_source:
return countdown.compute_score
else:
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def fit(self):

for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
print(f'epoch {epoch}, step {self.global_steps}')
metrics = {}
timing_raw = {}

Expand Down
8 changes: 5 additions & 3 deletions verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ def _read_files_and_tokenize(self):
# filter out too long prompts
tokenizer = self.tokenizer
prompt_key = self.prompt_key
self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
axis=1)]

# nvm if prompt is too long
# self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
# tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
# axis=1)]

print(f'filter dataset len: {len(self.dataframe)}')

Expand Down
Loading

0 comments on commit e721899

Please sign in to comment.