diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 34e3a2c4a..e31d2dda6 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -318,6 +318,9 @@ grain_worker_count: 1 steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps log_period: 100 # Flushes Tensorboard +# Training steps per loop +steps_per_loop: 100 + # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 # Learning rate schedule has either two or three parts: # 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] diff --git a/MaxText/ray_trainer.py b/MaxText/ray_trainer.py index 55e60e659..a0f83ae8f 100644 --- a/MaxText/ray_trainer.py +++ b/MaxText/ray_trainer.py @@ -1,15 +1,11 @@ +import logging import ray + +from absl import app from ray_tpu import RayTpuManager from ray.job_submission import JobSubmissionClient from trainer import MaxTextTrainer - -import logging -import os -import argparse -import pyconfig -from typing import Sequence, Optional -from absl import app - +from typing import Sequence #### Configurations @@ -46,12 +42,32 @@ def get_job_submission_id() -> str: return [job.submission_id for job in jobs if job.job_id == current_job_id][0] +def get_steps_values(args): + """ + Extracts the values of 'steps' and 'steps_per_loop' from args. + + Args: + args: A list of key=value as strings. + + Returns: + A tuple containing the values of 'steps' and 'steps_per_loop' as integers. + Returns (None, None) if not found. + """ + steps = None + steps_per_loop = None + for item in args: + if item.startswith('steps='): + steps = int(item.split('=')[1]) + elif item.startswith('steps_per_loop='): + steps_per_loop = int(item.split('=')[1]) + return steps, steps_per_loop + + def main(argv: Sequence[str]): ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers)) run_name = get_job_submission_id() logging.info("Got args: %s", argv) logging.info("This run name: %s", run_name) - tpu_resources = RayTpuManager.get_available_resources() num_detected_tpu_types = len(tpu_resources.keys()) if num_detected_tpu_types == 0: @@ -84,16 +100,16 @@ def main(argv: Sequence[str]): raise e logging.info("Initialization complete. Starting MaxText training...") - total_steps = 50 #int(args.total_steps) - steps_per_loop = 100 #int(args.steps_per_loop) - steps = 0 + steps, steps_per_loop = get_steps_values(argv) + logging.info(f"KubeRay training running for total steps: {steps}, steps per loop: {steps_per_loop}") + steps_counter = 0 - while steps < total_steps: + while steps_counter < steps: logging.info("Training from step %d to %d.", steps, steps_per_loop) try: r = ray.get([actor.train.remote(num_steps=steps_per_loop) for actor in actors]) - steps = r[0] + steps_counter = r[0] except Exception as e: logging.error("Caught error during training: %s", e) logging.error("Shutting down...")