Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_steps_values & steps #1

Open
wants to merge 2 commits into
base: ricliu-ray-maxtext
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
@@ -318,6 +318,9 @@ grain_worker_count: 1
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
log_period: 100 # Flushes Tensorboard

# Training steps per loop
steps_per_loop: 100

# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
# Learning rate schedule has either two or three parts:
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
44 changes: 30 additions & 14 deletions MaxText/ray_trainer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import logging
import ray

from absl import app
from ray_tpu import RayTpuManager
from ray.job_submission import JobSubmissionClient
from trainer import MaxTextTrainer

import logging
import os
import argparse
import pyconfig
from typing import Sequence, Optional
from absl import app

from typing import Sequence


#### Configurations
@@ -46,12 +42,32 @@ def get_job_submission_id() -> str:
return [job.submission_id for job in jobs if job.job_id == current_job_id][0]


def get_steps_values(args):
"""
Extracts the values of 'steps' and 'steps_per_loop' from args.

Args:
args: A list of key=value as strings.

Returns:
A tuple containing the values of 'steps' and 'steps_per_loop' as integers.
Returns (None, None) if not found.
"""
steps = None
steps_per_loop = None
for item in args:
if item.startswith('steps='):
steps = int(item.split('=')[1])
elif item.startswith('steps_per_loop='):
steps_per_loop = int(item.split('=')[1])
return steps, steps_per_loop


def main(argv: Sequence[str]):
ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers))
run_name = get_job_submission_id()
logging.info("Got args: %s", argv)
logging.info("This run name: %s", run_name)

tpu_resources = RayTpuManager.get_available_resources()
num_detected_tpu_types = len(tpu_resources.keys())
if num_detected_tpu_types == 0:
@@ -84,16 +100,16 @@ def main(argv: Sequence[str]):
raise e

logging.info("Initialization complete. Starting MaxText training...")
total_steps = 50 #int(args.total_steps)
steps_per_loop = 100 #int(args.steps_per_loop)
steps = 0
steps, steps_per_loop = get_steps_values(argv)
logging.info(f"KubeRay training running for total steps: {steps}, steps per loop: {steps_per_loop}")
steps_counter = 0

while steps < total_steps:
while steps_counter < steps:
logging.info("Training from step %d to %d.", steps, steps_per_loop)

try:
r = ray.get([actor.train.remote(num_steps=steps_per_loop) for actor in actors])
steps = r[0]
steps_counter = r[0]
except Exception as e:
logging.error("Caught error during training: %s", e)
logging.error("Shutting down...")