Skip to content

Commit

Permalink
Slurm agent fn task (#3150)
Browse files Browse the repository at this point in the history
* Add slurm plugin blank components

Signed-off-by: jiangjiawei1103 <[email protected]>

* feat: Add naive slurm agent create and get with rest api

Signed-off-by: jiangjiawei1103 <[email protected]>

* Use asyncssh instead of REST API

Signed-off-by: jiangjiawei1103 <[email protected]>

* Test ssh communication and run sbatch

Signed-off-by: JiaWei Jiang <[email protected]>

* Add delete method and support slurm job state

Signed-off-by: JiaWei Jiang <[email protected]>

* feat: Submit and run SlurmTask on a remote Slurm cluster

Successfully submit and run the user-defined task as a normal python
function on a remote Slurm cluster.

1. Inherit from PythonFunctionTask instead of PythonTask
2. Transfer the task module through sftp
3. Interact with amazon s3 bucket on both localhost and Slurm cluster

Signed-off-by: JiaWei Jiang <[email protected]>

* refactor: Remove redundant task_module transfer

Specifying `--raw-output-data-prefix` option handles task_module download.

Signed-off-by: JiaWei Jiang <[email protected]>

* refactor: Remove redundant env var

Signed-off-by: JiaWei Jiang <[email protected]>

* docs: Add env setup guide for local test

Signed-off-by: JiaWei Jiang <[email protected]>

* docs: Add links and figures

Signed-off-by: JiaWei Jiang <[email protected]>

* docs: Fix commit sha

Signed-off-by: JiaWei Jiang <[email protected]>

* docs: Fix commit sha for demo guide

Signed-off-by: JiaWei Jiang <[email protected]>

* docs: Fix links

Signed-off-by: JiaWei Jiang <[email protected]>

* feat: Support SSH config in task config

Add `ssh_conf` filed to let users specify connection secret

Note that reconnection is done in both `get` and `delete`. This is just
a temporary workaround.

Signed-off-by: JiaWei Jiang <[email protected]>

* docs: Include ssh config in demo example

Signed-off-by: JiaWei Jiang <[email protected]>

* fix: Retain user-specified file format info

Signed-off-by: JiaWei Jiang <[email protected]>

* fix: Set sdt format based on user-specified file_format

Signed-off-by: JiaWei Jiang <[email protected]>

* Remove redundant modification

Signed-off-by: JiaWei Jiang <[email protected]>

* test: Test file_format attribute alignment in dc.sd

Signed-off-by: JiaWei Jiang <[email protected]>

* refactor: Reduce ssh_conf option to slurm_host only

For data scientists and MLEs developing flyte wf with Slurm agent,
they don't actually need to know ssh connection details. We assume
they only need to specify which Slurm cluster to use by hostname.

Signed-off-by: JiaWei Jiang <[email protected]>

* feat: Support Slurm agent with ShellTask

1. Write user-defined batch script to a tmp file
2. Transfer the batch script through sftp
3. Construct sbatch command to run on Slurm cluster

Signed-off-by: JiaWei Jiang <[email protected]>

* feat: Simplify Slurm job submission logic

1. Remove SFTP for batch script transfer
    * Assume Slurm batch script is present on Slurm cluster
2. Support directly specifying a remote batch script path

Signed-off-by: JiaWei Jiang <[email protected]>

* Added script args to agent and task

Signed-off-by: pryce-turner <[email protected]>

* Add asyncssh to dependencies

Signed-off-by: JiaWei Jiang <[email protected]>

* docs: Update setup and demo for a basic use case

Signed-off-by: JiaWei Jiang <[email protected]>

* docs: Update basic arch figure path

Signed-off-by: JiaWei Jiang <[email protected]>

* docs: Fix typo and hyperlink

Signed-off-by: JiaWei Jiang <[email protected]>

* fix: A tmp workaround to test agent locally without container_image

Signed-off-by: JiaWei Jiang <[email protected]>

* feat: Support user-defined batch script content with SlurmShellTask

`SlurmTask` and `SlurmShellTask` now share the same agent.

Signed-off-by: JiaWei Jiang <[email protected]>

* feat: Fall back to PythonTask for naive use cases

1. Inherited from `PythonTask` for cases in which the batch script is
    already on the Slurm cluster
2. Use a dummy `Interface` as a tmp workaround

Signed-off-by: JiaWei Jiang <[email protected]>

* refactor: Define Slurm as a base task config and extend for remote script

Signed-off-by: JiaWei Jiang <[email protected]>

* feat: Support PythonFunctionTask and reorganize agent structure

1. Add back `PythonFunctionTask` to support running user-defined functions
    on Slurm
2. Categorize task types into `script/` and `function/`

Signed-off-by: JiaWei Jiang <[email protected]>

* Use poetry virtual env to avoid contamination

Signed-off-by: JiangJiaWei1103 <[email protected]>

* docs: Complete local test env setup process

Signed-off-by: JiangJiaWei1103 <[email protected]>

* docs: Add use cases ranging from basic to advanced

Signed-off-by: JiangJiaWei1103 <[email protected]>

* feat: Add a script option for the Slurm function task

Signed-off-by: JiangJiaWei1103 <[email protected]>

* fix: Avoid attaching async resource to different event loops

Signed-off-by: JiangJiaWei1103 <[email protected]>

* use await self._connect(slurm_host) in slurm agent

Signed-off-by: Future-Outlier <[email protected]>

* change

Signed-off-by: Future-Outlier <[email protected]>

* print more info

Signed-off-by: Future-Outlier <[email protected]>

* use logger

Signed-off-by: Future-Outlier <[email protected]>

* print more infor

Signed-off-by: Future-Outlier <[email protected]>

* print

Signed-off-by: Future-Outlier <[email protected]>

* Use sbatch for running Slurm function task

Signed-off-by: JiangJiaWei1103 <[email protected]>

* update

Signed-off-by: Future-Outlier <[email protected]>

* push

Signed-off-by: Future-Outlier <[email protected]>

* feat: Show stdout and stderr msg of the Slurm cluster

Signed-off-by: JiangJiaWei1103 <[email protected]>

* feat: Show stdout and stderr msg of the Slurm cluster for SlurmFunctionTask

Signed-off-by: JiangJiaWei1103 <[email protected]>

* feat: Make an SSH connetion based on client config file or ssh_config

1. Make SSH `host` and `username` required fields
2. Support SSH connection based on the default OpenSSH client config
    file `~/.ssh/config`
3. Support SSH connection via public key auth either by user-specified
    `client_keys` or the secret for key `FLYTE_SLURM_PRIVATE_KEY`

Signed-off-by: JiangJiaWei1103 <[email protected]>

* Clarify SSH connection logic

Signed-off-by: JiangJiaWei1103 <[email protected]>

* feat: Interpolate the script with dynamic input values

Signed-off-by: JiangJiaWei1103 <[email protected]>

* feat: Interpolate the script with dynamic output values

Support passing files across multiple `SlurmShellTask`

Signed-off-by: JiangJiaWei1103 <[email protected]>

* add assertion

Signed-off-by: Future-Outlier <[email protected]>

* update

Signed-off-by: Future-Outlier <[email protected]>

* update

Signed-off-by: Future-Outlier <[email protected]>

* Fix Script agent bug

Signed-off-by: Future-Outlier <[email protected]>

* agent service for shell task

Signed-off-by: Future-Outlier <[email protected]>

* Remove remote path to avoid race condition

Signed-off-by: Future-Outlier <[email protected]>

* Revert agent server change

Signed-off-by: Future-Outlier <[email protected]>

* use key val to run ssh config

Signed-off-by: Future-Outlier <[email protected]>

* update

Signed-off-by: Future-Outlier <[email protected]>

* use _get_or_create_ssh_connection

Signed-off-by: Future-Outlier <[email protected]>

* update

Signed-off-by: Future-Outlier <[email protected]>

* use SlurmCluster and hash

Signed-off-by: Future-Outlier <[email protected]>

* updagte

Signed-off-by: Future-Outlier <[email protected]>

* update

Signed-off-by: Future-Outlier <[email protected]>

* update

Signed-off-by: Future-Outlier <[email protected]>

* refactor: Simplify validation process and clean up legacy code

1. Ensure `"host"` must be provided in `__post_init__`
2. Explicitly set `known_hosts` to `None`
3. Make `username` optional
4. Remove legacy code snippets
5. Make docstring clear

Signed-off-by: JiangJiaWei1103 <[email protected]>

* Add Slurm agent function task

Signed-off-by: JiangJiaWei1103 <[email protected]>

* Revert ShellTask behavior

Signed-off-by: JiangJiaWei1103 <[email protected]>

* Remove fix for SlurmShellTask

Signed-off-by: JiangJiaWei1103 <[email protected]>

* Remove blank line

Signed-off-by: JiangJiaWei1103 <[email protected]>

* fix doc string and remove logs

Signed-off-by: Future-Outlier <[email protected]>

* build plugins

Signed-off-by: Future-Outlier <[email protected]>

* merge master

Signed-off-by: Future-Outlier <[email protected]>

* fix-sage-maker-test

Signed-off-by: Future-Outlier <[email protected]>

* add test_slurm_fn_task

Signed-off-by: Future-Outlier <[email protected]>

* fix

Signed-off-by: Future-Outlier <[email protected]>

* update flytebot

Signed-off-by: Future-Outlier <[email protected]>

* add know host = None

Signed-off-by: Future-Outlier <[email protected]>

---------

Signed-off-by: jiangjiawei1103 <[email protected]>
Signed-off-by: JiaWei Jiang <[email protected]>
Signed-off-by: pryce-turner <[email protected]>
Signed-off-by: JiangJiaWei1103 <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: pryce-turner <[email protected]>
Co-authored-by: Future-Outlier <[email protected]>
  • Loading branch information
3 people authored Feb 22, 2025
1 parent d7a3cef commit acc09a5
Show file tree
Hide file tree
Showing 11 changed files with 512 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ jobs:
- flytekit-papermill
- flytekit-polars
- flytekit-ray
- flytekit-slurm
- flytekit-snowflake
- flytekit-spark
- flytekit-sqlalchemy
Expand Down
4 changes: 2 additions & 2 deletions flytekit/extend/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def convert_to_flyte_phase(state: str) -> TaskExecution.Phase:
Convert the state from the agent to the phase in flyte.
"""
state = state.lower()
if state in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]:
if state in ["failed", "timeout", "timedout", "canceled", "cancelled", "skipped", "internal_error"]:
return TaskExecution.FAILED
elif state in ["done", "succeeded", "success"]:
elif state in ["done", "succeeded", "success", "completed"]:
return TaskExecution.SUCCEEDED
elif state in ["running", "terminating"]:
return TaskExecution.RUNNING
Expand Down
6 changes: 3 additions & 3 deletions plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flytekitplugins.awssagemaker_inference import triton_image_uri
from flytekitplugins.awssagemaker_inference.boto3_mixin import (
Boto3AgentMixin,
update_dict_fn,
format_dict,
)

from flytekit import FlyteContext, StructuredDataset
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_inputs():
},
)

result = update_dict_fn(
result = format_dict(
service="s3",
original_dict=original_dict,
update_dict={"inputs": literal_map_string_repr(inputs)},
Expand All @@ -75,7 +75,7 @@ def test_container():
original_dict = {"a": "{images.primary_container_image}"}
images = {"primary_container_image": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"}

result = update_dict_fn(
result = format_dict(
service="sagemaker", original_dict=original_dict, update_dict={"images": images}
)

Expand Down
5 changes: 5 additions & 0 deletions plugins/flytekit-slurm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Flytekit Slurm Plugin

The Slurm agent is designed to integrate Flyte workflows with Slurm-managed high-performance computing (HPC) clusters, enabling users to leverage Slurm's capability of compute resource allocation, scheduling, and monitoring.

This [guide](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md) provides a concise overview of the design philosophy behind the Slurm agent and explains how to set up a local environment for testing the agent.
2 changes: 2 additions & 0 deletions plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .function.agent import SlurmFunctionAgent
from .function.task import SlurmFunction, SlurmFunctionTask
190 changes: 190 additions & 0 deletions plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import tempfile
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

from asyncssh import SSHClientConnection

from flytekit import logger
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
from flytekit.extend.backend.utils import convert_to_flyte_phase
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

from ..ssh_utils import ssh_connect


@dataclass
class SlurmJobMetadata(ResourceMeta):
"""Slurm job metadata.
Args:
job_id: Slurm job id.
ssh_config: Options of SSH client connection. For available options, please refer to
the ssh_utils module.
Attributes:
job_id (str): Slurm job id.
ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH configuration options
for establishing client connections.
"""

job_id: str
ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]


@dataclass
class SlurmCluster:
host: str
username: Optional[str] = None

def __hash__(self):
return hash((self.host, self.username))


class SlurmFunctionAgent(AsyncAgentBase):
name = "Slurm Function Agent"

# SSH connection pool for multi-host environment
ssh_config_to_ssh_conn: Dict[SlurmCluster, SSHClientConnection] = {}

def __init__(self) -> None:
super(SlurmFunctionAgent, self).__init__(task_type_name="slurm_fn", metadata_type=SlurmJobMetadata)

async def create(
self,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
**kwargs,
) -> SlurmJobMetadata:
unique_script_name = f"/tmp/task_{uuid.uuid4().hex}.slurm"

# Retrieve task config
ssh_config = task_template.custom["ssh_config"]
sbatch_conf = task_template.custom["sbatch_conf"]
script = task_template.custom["script"]

# Construct command for Slurm cluster
cmd, script = _get_sbatch_cmd_and_script(
sbatch_conf=sbatch_conf,
entrypoint=" ".join(task_template.container.args),
script=script,
batch_script_path=unique_script_name,
)

# Run Slurm job
conn = await self._get_or_create_ssh_connection(ssh_config)
with tempfile.NamedTemporaryFile("w") as f:
f.write(script)
f.flush()
async with conn.start_sftp_client() as sftp:
await sftp.put(f.name, unique_script_name)
res = await conn.run(cmd, check=True)

# Retrieve Slurm job id
job_id = res.stdout.split()[-1]

return SlurmJobMetadata(job_id=job_id, ssh_config=ssh_config)

async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource:
ssh_config = resource_meta.ssh_config
conn = await self._get_or_create_ssh_connection(ssh_config)
job_res = await conn.run(f"scontrol show job {resource_meta.job_id}", check=True)

# Determine the current flyte phase from Slurm job state
job_state = "running"
msg = "No stdout available"
for o in job_res.stdout.split(" "):
if "JobState" in o:
job_state = o.split("=")[1].strip().lower()
elif "StdOut" in o:
stdout_path = o.split("=")[1].strip()
msg_res = await conn.run(f"cat {stdout_path}", check=True)
msg = msg_res.stdout

cur_phase = convert_to_flyte_phase(job_state)

return Resource(phase=cur_phase, message=msg)

async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None:
conn = await self._get_or_create_ssh_connection(resource_meta.ssh_config)
_ = await conn.run(f"scancel {resource_meta.job_id}", check=True)

async def _get_or_create_ssh_connection(
self, ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]
) -> SSHClientConnection:
"""Get an existing SSH connection or create a new one if needed.
Args:
ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH configuration dictionary.
Returns:
SSHClientConnection: An active SSH connection, either pre-existing or newly established.
"""
host = ssh_config.get("host")
username = ssh_config.get("username")

ssh_cluster_config = SlurmCluster(host=host, username=username)
if self.ssh_config_to_ssh_conn.get(ssh_cluster_config) is None:
logger.info("ssh connection key not found, creating new connection")
conn = await ssh_connect(ssh_config=ssh_config)
self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn
else:
conn = self.ssh_config_to_ssh_conn[ssh_cluster_config]
try:
await conn.run("echo [TEST] SSH connection", check=True)
logger.info("re-using new connection")
except Exception as e:
logger.info(f"Re-establishing SSH connection due to error: {e}")
conn = await ssh_connect(ssh_config=ssh_config)
self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn

return conn


def _get_sbatch_cmd_and_script(
sbatch_conf: Dict[str, str],
entrypoint: str,
script: Optional[str] = None,
batch_script_path: str = "/tmp/task.slurm",
) -> Tuple[str, str]:
"""Construct the Slurm sbatch command and the batch script content.
Flyte entrypoint, pyflyte-execute, is run within a bash shell process.
Args:
sbatch_conf (Dict[str, str]): Options of sbatch command.
entrypoint (str): Flyte entrypoint.
script (Optional[str], optional): User-defined script where "{task.fn}" serves as a placeholder for the
task function execution. Users should insert "{task.fn}" at the desired
execution point within the script. If the script is not provided, the task
function will be executed directly. Defaults to None.
batch_script_path (str, optional): Absolute path of the batch script on Slurm cluster.
Defaults to "/tmp/task.slurm".
Returns:
Tuple[str, str]: A tuple containing:
- cmd: Slurm sbatch command
- script: The batch script content
"""
# Setup sbatch options
cmd = ["sbatch"]
for opt, val in sbatch_conf.items():
cmd.extend([f"--{opt}", str(val)])

# Assign the batch script to run
cmd.append(batch_script_path)

if script is None:
script = f"""#!/bin/bash -i
{entrypoint}
"""
else:
script = script.replace("{task.fn}", entrypoint)

cmd = " ".join(cmd)

return cmd, script


AgentRegistry.register(SlurmFunctionAgent())
83 changes: 83 additions & 0 deletions plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Slurm task.
"""

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from flytekit import FlyteContextManager, PythonFunctionTask
from flytekit.configuration import SerializationSettings
from flytekit.extend import TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec


@dataclass
class SlurmFunction(object):
"""Configure Slurm settings. Note that we focus on sbatch command now.
Args:
ssh_config: Options of SSH client connection. For available options, please refer to
<newly-added-ssh-utils-file>
sbatch_conf: Options of sbatch command. If not provided, defaults to an empty dict.
script: User-defined script where "{task.fn}" serves as a placeholder for the
task function execution. Users should insert "{task.fn}" at the desired
execution point within the script. If the script is not provided, the task
function will be executed directly.
Attributes:
ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH client configuration options.
sbatch_conf (Optional[Dict[str, str]]): Slurm sbatch command options.
script (Optional[str]): Custom script template for task execution.
"""

ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]
sbatch_conf: Optional[Dict[str, str]] = None
script: Optional[str] = None

def __post_init__(self):
assert self.ssh_config["host"] is not None, "'host' must be specified in ssh_config."
if self.sbatch_conf is None:
self.sbatch_conf = {}


class SlurmFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SlurmFunction]):
"""
Actual Plugin that transforms the local python code for execution within a slurm context...
"""

_TASK_TYPE = "slurm_fn"

def __init__(
self,
task_config: SlurmFunction,
task_function: Callable,
container_image: Optional[Union[str, ImageSpec]] = None,
**kwargs,
):
super(SlurmFunctionTask, self).__init__(
task_config=task_config,
task_type=self._TASK_TYPE,
task_function=task_function,
container_image=container_image,
**kwargs,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return {
"ssh_config": self.task_config.ssh_config,
"sbatch_conf": self.task_config.sbatch_conf,
"script": self.task_config.script,
}

def execute(self, **kwargs) -> Any:
ctx = FlyteContextManager.current_context()
if ctx.execution_state and ctx.execution_state.is_local_execution():
# Mimic the propeller's behavior in local agent test
return AsyncAgentExecutorMixin.execute(self, **kwargs)
else:
# Execute the task with a direct python function call
return PythonFunctionTask.execute(self, **kwargs)


TaskPlugins.register_pythontask_plugin(SlurmFunction, SlurmFunctionTask)
Loading

0 comments on commit acc09a5

Please sign in to comment.