From cdc81bb002c9ba82f68b9350333ef5f5b099251c Mon Sep 17 00:00:00 2001 From: Tony DiGangi <38962493+tdigangi@users.noreply.github.com> Date: Mon, 25 Nov 2024 16:08:44 -0500 Subject: [PATCH 1/2] Adding steps_per_loop to base.yml, adding func to attain steps & steps per loop from argv --- MaxText/configs/base.yml | 3 +++ MaxText/ray_trainer.py | 46 ++++++++++++++++++++++++++++------------ 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 34e3a2c4a..e31d2dda6 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -318,6 +318,9 @@ grain_worker_count: 1 steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps log_period: 100 # Flushes Tensorboard +# Training steps per loop +steps_per_loop: 100 + # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 # Learning rate schedule has either two or three parts: # 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] diff --git a/MaxText/ray_trainer.py b/MaxText/ray_trainer.py index 55e60e659..1433aff0a 100644 --- a/MaxText/ray_trainer.py +++ b/MaxText/ray_trainer.py @@ -1,15 +1,11 @@ +import logging import ray + +from absl import app from ray_tpu import RayTpuManager from ray.job_submission import JobSubmissionClient from trainer import MaxTextTrainer - -import logging -import os -import argparse -import pyconfig -from typing import Sequence, Optional -from absl import app - +from typing import Sequence #### Configurations @@ -46,12 +42,32 @@ def get_job_submission_id() -> str: return [job.submission_id for job in jobs if job.job_id == current_job_id][0] +def get_steps_values(args): + """ + Extracts the values of 'steps' and 'steps_per_loop' from args. + + Args: + args: A list of key=value as strings. + + Returns: + A tuple containing the values of 'steps' and 'steps_per_loop' as integers. + Returns (None, None) if not found. + """ + steps = None + steps_per_loop = None + for item in args: + if item.startswith('steps='): + steps = int(item.split('=')[1]) + elif item.startswith('steps_per_loop='): + steps_per_loop = int(item.split('=')[1]) + return steps, steps_per_loop + + def main(argv: Sequence[str]): ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers)) run_name = get_job_submission_id() logging.info("Got args: %s", argv) logging.info("This run name: %s", run_name) - tpu_resources = RayTpuManager.get_available_resources() num_detected_tpu_types = len(tpu_resources.keys()) if num_detected_tpu_types == 0: @@ -84,16 +100,18 @@ def main(argv: Sequence[str]): raise e logging.info("Initialization complete. Starting MaxText training...") - total_steps = 50 #int(args.total_steps) - steps_per_loop = 100 #int(args.steps_per_loop) - steps = 0 + steps, steps_per_loop = get_steps_values(argv) + logging.info(f"KubeRay training running for total steps: {steps}, steps per loop: {steps_per_loop}") + # steps = argv.steps #50 #int(args.steps) + # steps_per_loop = argv.steps_per_loop #100 #int(args.steps_per_loop) + steps_counter = 0 - while steps < total_steps: + while steps_counter < steps: logging.info("Training from step %d to %d.", steps, steps_per_loop) try: r = ray.get([actor.train.remote(num_steps=steps_per_loop) for actor in actors]) - steps = r[0] + steps_counter = r[0] except Exception as e: logging.error("Caught error during training: %s", e) logging.error("Shutting down...") From 15351d3d98da4769fbef8a6c8c81e7bc752aed8d Mon Sep 17 00:00:00 2001 From: Tony DiGangi <38962493+tdigangi@users.noreply.github.com> Date: Mon, 25 Nov 2024 16:12:03 -0500 Subject: [PATCH 2/2] Cleaning up commented lines --- MaxText/ray_trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/MaxText/ray_trainer.py b/MaxText/ray_trainer.py index 1433aff0a..a0f83ae8f 100644 --- a/MaxText/ray_trainer.py +++ b/MaxText/ray_trainer.py @@ -102,8 +102,6 @@ def main(argv: Sequence[str]): logging.info("Initialization complete. Starting MaxText training...") steps, steps_per_loop = get_steps_values(argv) logging.info(f"KubeRay training running for total steps: {steps}, steps per loop: {steps_per_loop}") - # steps = argv.steps #50 #int(args.steps) - # steps_per_loop = argv.steps_per_loop #100 #int(args.steps_per_loop) steps_counter = 0 while steps_counter < steps: