-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slurm agent fn task #3150
Slurm agent fn task #3150
Changes from 67 commits
421d1b8
1d1f806
5d97126
2e7f0f2
9644b99
e41b181
6db24dc
122c7f1
e9760a7
e68fda9
470637c
1579ab4
0e538f0
8229418
9e6d8a6
cdb56c0
36ab6e9
58751fb
4af1f00
46b7b62
e07b09a
3a7eb6d
a815fd9
a3ea014
a109bd8
e5da665
0a3d9f1
1b0f6df
26cc201
16d953e
c743917
e365dee
c1064d4
361fbd1
fc3e34e
5cd58ec
9985305
e00a5db
805548c
34c7c51
4e85b0a
27cdc8d
1f30e04
079d101
1ce5cf8
7b9b38f
b23174a
3727f4d
f291e0b
a465607
7b75a51
d0967bd
f80ccd4
ac25446
aa6e3ae
5fcf6c5
64e9e06
9d06ed1
6cda6ae
3b64ddd
b107b07
37e4ee1
77f4d61
68ab8a6
a968c01
13cd31e
52d73d2
6b27668
5aa4e87
3cc29fe
19de56a
31fc564
80f3693
96756df
074cf91
2d0e0d7
acfaa76
6227801
860b703
45aae83
285a9e1
6b70901
cf8eafa
f0f7930
d0d59d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Flytekit Slurm Plugin | ||
|
||
The Slurm agent is designed to integrate Flyte workflows with Slurm-managed high-performance computing (HPC) clusters, enabling users to leverage Slurm's capability of compute resource allocation, scheduling, and monitoring. | ||
|
||
This [guide](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md) provides a concise overview of the design philosophy behind the Slurm agent and explains how to set up a local environment for testing the agent. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .function.agent import SlurmFunctionAgent | ||
from .function.task import SlurmFunction, SlurmFunctionTask |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,191 @@ | ||||||||||||||||||||||||||||||||||||
import tempfile | ||||||||||||||||||||||||||||||||||||
import uuid | ||||||||||||||||||||||||||||||||||||
from dataclasses import dataclass | ||||||||||||||||||||||||||||||||||||
from typing import Dict, List, Optional, Tuple, Union | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
from asyncssh import SSHClientConnection | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
from flytekit import logger | ||||||||||||||||||||||||||||||||||||
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta | ||||||||||||||||||||||||||||||||||||
from flytekit.extend.backend.utils import convert_to_flyte_phase | ||||||||||||||||||||||||||||||||||||
from flytekit.models.literals import LiteralMap | ||||||||||||||||||||||||||||||||||||
from flytekit.models.task import TaskTemplate | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
from ..ssh_utils import ssh_connect | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@dataclass | ||||||||||||||||||||||||||||||||||||
class SlurmJobMetadata(ResourceMeta): | ||||||||||||||||||||||||||||||||||||
"""Slurm job metadata. | ||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||
job_id: Slurm job id. | ||||||||||||||||||||||||||||||||||||
ssh_config: Options of SSH client connection. For available options, please refer to | ||||||||||||||||||||||||||||||||||||
<newly-added-ssh-utils-file> | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
job_id: str | ||||||||||||||||||||||||||||||||||||
ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]] | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@dataclass | ||||||||||||||||||||||||||||||||||||
class SlurmCluster: | ||||||||||||||||||||||||||||||||||||
host: str | ||||||||||||||||||||||||||||||||||||
username: Optional[str] = None | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def __hash__(self): | ||||||||||||||||||||||||||||||||||||
return hash((self.host, self.username)) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
class SlurmFunctionAgent(AsyncAgentBase): | ||||||||||||||||||||||||||||||||||||
name = "Slurm Function Agent" | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# SSH connection pool for multi-host environment | ||||||||||||||||||||||||||||||||||||
ssh_config_to_ssh_conn: Dict[SlurmCluster, SSHClientConnection] = {} | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def __init__(self) -> None: | ||||||||||||||||||||||||||||||||||||
super(SlurmFunctionAgent, self).__init__(task_type_name="slurm_fn", metadata_type=SlurmJobMetadata) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
async def create( | ||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||
task_template: TaskTemplate, | ||||||||||||||||||||||||||||||||||||
inputs: Optional[LiteralMap] = None, | ||||||||||||||||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||||||||||||||||
) -> SlurmJobMetadata: | ||||||||||||||||||||||||||||||||||||
unique_script_name = f"/tmp/task_{uuid.uuid4().hex}.slurm" | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# Retrieve task config | ||||||||||||||||||||||||||||||||||||
ssh_config = task_template.custom["ssh_config"] | ||||||||||||||||||||||||||||||||||||
sbatch_conf = task_template.custom["sbatch_conf"] | ||||||||||||||||||||||||||||||||||||
script = task_template.custom["script"] | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# Construct command for Slurm cluster | ||||||||||||||||||||||||||||||||||||
cmd, script = _get_sbatch_cmd_and_script( | ||||||||||||||||||||||||||||||||||||
sbatch_conf=sbatch_conf, | ||||||||||||||||||||||||||||||||||||
entrypoint=" ".join(task_template.container.args), | ||||||||||||||||||||||||||||||||||||
script=script, | ||||||||||||||||||||||||||||||||||||
batch_script_path=unique_script_name, | ||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
logger.info("@@@ task_template.container.args:") | ||||||||||||||||||||||||||||||||||||
logger.info(task_template.container.args) | ||||||||||||||||||||||||||||||||||||
logger.info("@@@ Slurm Command: ") | ||||||||||||||||||||||||||||||||||||
logger.info(cmd) | ||||||||||||||||||||||||||||||||||||
logger.info("@@@ Batch script: ") | ||||||||||||||||||||||||||||||||||||
logger.info(script) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# Run Slurm job | ||||||||||||||||||||||||||||||||||||
conn = await self._get_or_create_ssh_connection(ssh_config) | ||||||||||||||||||||||||||||||||||||
with tempfile.NamedTemporaryFile("w") as f: | ||||||||||||||||||||||||||||||||||||
f.write(script) | ||||||||||||||||||||||||||||||||||||
f.flush() | ||||||||||||||||||||||||||||||||||||
async with conn.start_sftp_client() as sftp: | ||||||||||||||||||||||||||||||||||||
await sftp.put(f.name, unique_script_name) | ||||||||||||||||||||||||||||||||||||
res = await conn.run(cmd, check=True) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# Retrieve Slurm job id | ||||||||||||||||||||||||||||||||||||
job_id = res.stdout.split()[-1] | ||||||||||||||||||||||||||||||||||||
logger.info("@@@ create slurm job id: " + job_id) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return SlurmJobMetadata(job_id=job_id, ssh_config=ssh_config) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource: | ||||||||||||||||||||||||||||||||||||
ssh_config = resource_meta.ssh_config | ||||||||||||||||||||||||||||||||||||
conn = await self._get_or_create_ssh_connection(ssh_config) | ||||||||||||||||||||||||||||||||||||
job_res = await conn.run(f"scontrol show job {resource_meta.job_id}", check=True) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# Determine the current flyte phase from Slurm job state | ||||||||||||||||||||||||||||||||||||
job_state = "running" | ||||||||||||||||||||||||||||||||||||
for o in job_res.stdout.split(" "): | ||||||||||||||||||||||||||||||||||||
if "JobState" in o: | ||||||||||||||||||||||||||||||||||||
job_state = o.split("=")[1].strip().lower() | ||||||||||||||||||||||||||||||||||||
elif "StdOut" in o: | ||||||||||||||||||||||||||||||||||||
stdout_path = o.split("=")[1].strip() | ||||||||||||||||||||||||||||||||||||
msg_res = await conn.run(f"cat {stdout_path}", check=True) | ||||||||||||||||||||||||||||||||||||
msg = msg_res.stdout | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding error handling for stdout
Consider adding error handling for the case when Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #2afb6d Should Bito avoid suggestions like this for future reviews? (Manage Rules)
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
logger.info("@@@ GET PHASE: ") | ||||||||||||||||||||||||||||||||||||
logger.info(str(job_state)) | ||||||||||||||||||||||||||||||||||||
cur_phase = convert_to_flyte_phase(job_state) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return Resource(phase=cur_phase, message=msg) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None: | ||||||||||||||||||||||||||||||||||||
conn = await self._get_or_create_ssh_connection(resource_meta.ssh_config) | ||||||||||||||||||||||||||||||||||||
_ = await conn.run(f"scancel {resource_meta.job_id}", check=True) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
async def _get_or_create_ssh_connection( | ||||||||||||||||||||||||||||||||||||
self, ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]] | ||||||||||||||||||||||||||||||||||||
) -> SSHClientConnection: | ||||||||||||||||||||||||||||||||||||
"""Get an existing SSH connection or create a new one if needed. | ||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||
ssh_config: SSH configuration dictionary. | ||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||
An active SSH connection, either pre-existing or newly established. | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
host = ssh_config.get("host") | ||||||||||||||||||||||||||||||||||||
username = ssh_config.get("username") | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
ssh_cluster_config = SlurmCluster(host=host, username=username) | ||||||||||||||||||||||||||||||||||||
if self.ssh_config_to_ssh_conn.get(ssh_cluster_config) is None: | ||||||||||||||||||||||||||||||||||||
logger.info("ssh connection key not found, creating new connection") | ||||||||||||||||||||||||||||||||||||
conn = await ssh_connect(ssh_config=ssh_config) | ||||||||||||||||||||||||||||||||||||
self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
conn = self.ssh_config_to_ssh_conn[ssh_cluster_config] | ||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||
await conn.run("echo [TEST] SSH connection", check=True) | ||||||||||||||||||||||||||||||||||||
logger.info("re-using new connection") | ||||||||||||||||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||||||||||||||||
logger.info(f"Re-establishing SSH connection due to error: {e}") | ||||||||||||||||||||||||||||||||||||
conn = await ssh_connect(ssh_config=ssh_config) | ||||||||||||||||||||||||||||||||||||
self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return conn | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def _get_sbatch_cmd_and_script( | ||||||||||||||||||||||||||||||||||||
sbatch_conf: Dict[str, str], | ||||||||||||||||||||||||||||||||||||
entrypoint: str, | ||||||||||||||||||||||||||||||||||||
script: Optional[str] = None, | ||||||||||||||||||||||||||||||||||||
batch_script_path: str = "/tmp/task.slurm", | ||||||||||||||||||||||||||||||||||||
) -> str: | ||||||||||||||||||||||||||||||||||||
"""Construct the Slurm sbatch command and the batch script content. | ||||||||||||||||||||||||||||||||||||
Flyte entrypoint, pyflyte-execute, is run within a bash shell process. | ||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||
sbatch_conf: Options of sbatch command. | ||||||||||||||||||||||||||||||||||||
entrypoint: Flyte entrypoint. | ||||||||||||||||||||||||||||||||||||
script: User-defined script where "{task.fn}" serves as a placeholder for the | ||||||||||||||||||||||||||||||||||||
task function execution. Users should insert "{task.fn}" at the desired | ||||||||||||||||||||||||||||||||||||
execution point within the script. If the script is not provided, the task | ||||||||||||||||||||||||||||||||||||
function will be executed directly. | ||||||||||||||||||||||||||||||||||||
batch_script_path: Absolute path of the batch script on Slurm cluster. | ||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||
cmd: Slurm sbatch command. | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
# Setup sbatch options | ||||||||||||||||||||||||||||||||||||
cmd = ["sbatch"] | ||||||||||||||||||||||||||||||||||||
for opt, val in sbatch_conf.items(): | ||||||||||||||||||||||||||||||||||||
cmd.extend([f"--{opt}", str(val)]) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# Assign the batch script to run | ||||||||||||||||||||||||||||||||||||
cmd.append(batch_script_path) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if script is None: | ||||||||||||||||||||||||||||||||||||
script = f"""#!/bin/bash -i | ||||||||||||||||||||||||||||||||||||
{entrypoint} | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
Comment on lines
+179
to
+181
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider more portable shebang line
Consider using a more robust shebang line Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #2afb6d Should Bito avoid suggestions like this for future reviews? (Manage Rules)
|
||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
script = script.replace("{task.fn}", entrypoint) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
cmd = " ".join(cmd) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return cmd, script | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
AgentRegistry.register(SlurmFunctionAgent()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
""" | ||
Slurm task. | ||
""" | ||
|
||
from dataclasses import dataclass | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
||
from flytekit import FlyteContextManager, PythonFunctionTask | ||
from flytekit.configuration import SerializationSettings | ||
from flytekit.extend import TaskPlugins | ||
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin | ||
from flytekit.image_spec import ImageSpec | ||
|
||
|
||
@dataclass | ||
class SlurmFunction(object): | ||
"""Configure Slurm settings. Note that we focus on sbatch command now. | ||
|
||
Args: | ||
ssh_config: Options of SSH client connection. For available options, please refer to | ||
<newly-added-ssh-utils-file> | ||
sbatch_conf: Options of sbatch command. | ||
script: User-defined script where "{task.fn}" serves as a placeholder for the | ||
task function execution. Users should insert "{task.fn}" at the desired | ||
execution point within the script. If the script is not provided, the task | ||
function will be executed directly. | ||
""" | ||
|
||
ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]] | ||
sbatch_conf: Optional[Dict[str, str]] = None | ||
script: Optional[str] = None | ||
|
||
def __post_init__(self): | ||
assert self.ssh_config["host"] is not None, "'host' must be specified in ssh_config." | ||
if self.sbatch_conf is None: | ||
self.sbatch_conf = {} | ||
|
||
|
||
class SlurmFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SlurmFunction]): | ||
""" | ||
Actual Plugin that transforms the local python code for execution within a slurm context... | ||
""" | ||
|
||
_TASK_TYPE = "slurm_fn" | ||
|
||
def __init__( | ||
self, | ||
task_config: SlurmFunction, | ||
task_function: Callable, | ||
container_image: Optional[Union[str, ImageSpec]] = None, | ||
**kwargs, | ||
): | ||
super(SlurmFunctionTask, self).__init__( | ||
task_config=task_config, | ||
task_type=self._TASK_TYPE, | ||
task_function=task_function, | ||
container_image=container_image, | ||
**kwargs, | ||
) | ||
|
||
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: | ||
return { | ||
"ssh_config": self.task_config.ssh_config, | ||
"sbatch_conf": self.task_config.sbatch_conf, | ||
"script": self.task_config.script, | ||
} | ||
|
||
def execute(self, **kwargs) -> Any: | ||
ctx = FlyteContextManager.current_context() | ||
if ctx.execution_state and ctx.execution_state.is_local_execution(): | ||
# Mimic the propeller's behavior in local agent test | ||
return AsyncAgentExecutorMixin.execute(self, **kwargs) | ||
else: | ||
# Execute the task with a direct python function call | ||
return PythonFunctionTask.execute(self, **kwargs) | ||
|
||
|
||
TaskPlugins.register_pythontask_plugin(SlurmFunction, SlurmFunctionTask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using hardcoded temporary file paths can be insecure. Consider using
tempfile.mkstemp()
for secure temporary file creationCode suggestion
Code Review Run #2afb6d
Should Bito avoid suggestions like this for future reviews? (Manage Rules)