Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slurm agent fn task #3150

Merged
merged 85 commits into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from 68 commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
421d1b8
Add slurm plugin blank components
JiangJiaWei1103 Dec 14, 2024
1d1f806
feat: Add naive slurm agent create and get with rest api
JiangJiaWei1103 Dec 16, 2024
5d97126
Use asyncssh instead of REST API
JiangJiaWei1103 Dec 17, 2024
2e7f0f2
Test ssh communication and run sbatch
JiangJiaWei1103 Dec 18, 2024
9644b99
Add delete method and support slurm job state
JiangJiaWei1103 Dec 19, 2024
e41b181
feat: Submit and run SlurmTask on a remote Slurm cluster
JiangJiaWei1103 Dec 27, 2024
6db24dc
refactor: Remove redundant task_module transfer
JiangJiaWei1103 Dec 28, 2024
122c7f1
refactor: Remove redundant env var
JiangJiaWei1103 Dec 28, 2024
e9760a7
docs: Add env setup guide for local test
JiangJiaWei1103 Dec 30, 2024
e68fda9
docs: Add links and figures
JiangJiaWei1103 Dec 30, 2024
470637c
docs: Fix commit sha
JiangJiaWei1103 Dec 30, 2024
1579ab4
docs: Fix commit sha for demo guide
JiangJiaWei1103 Dec 30, 2024
0e538f0
docs: Fix links
JiangJiaWei1103 Dec 30, 2024
8229418
feat: Support SSH config in task config
JiangJiaWei1103 Dec 31, 2024
9e6d8a6
docs: Include ssh config in demo example
JiangJiaWei1103 Dec 31, 2024
cdb56c0
fix: Retain user-specified file format info
JiangJiaWei1103 Jan 1, 2025
36ab6e9
fix: Set sdt format based on user-specified file_format
JiangJiaWei1103 Jan 3, 2025
58751fb
Merge branch 'master' into fix-sd-empty-str-file-format
JiangJiaWei1103 Jan 3, 2025
4af1f00
Remove redundant modification
JiangJiaWei1103 Jan 4, 2025
46b7b62
test: Test file_format attribute alignment in dc.sd
JiangJiaWei1103 Jan 4, 2025
e07b09a
refactor: Reduce ssh_conf option to slurm_host only
JiangJiaWei1103 Jan 7, 2025
3a7eb6d
feat: Support Slurm agent with ShellTask
JiangJiaWei1103 Jan 7, 2025
a815fd9
feat: Simplify Slurm job submission logic
JiangJiaWei1103 Jan 9, 2025
a3ea014
Added script args to agent and task
pryce-turner Jan 10, 2025
a109bd8
Add asyncssh to dependencies
JiangJiaWei1103 Jan 11, 2025
e5da665
docs: Update setup and demo for a basic use case
JiangJiaWei1103 Jan 11, 2025
0a3d9f1
docs: Update basic arch figure path
JiangJiaWei1103 Jan 11, 2025
1b0f6df
docs: Fix typo and hyperlink
JiangJiaWei1103 Jan 11, 2025
26cc201
fix: A tmp workaround to test agent locally without container_image
JiangJiaWei1103 Jan 11, 2025
16d953e
feat: Support user-defined batch script content with SlurmShellTask
JiangJiaWei1103 Jan 14, 2025
c743917
feat: Fall back to PythonTask for naive use cases
JiangJiaWei1103 Jan 15, 2025
e365dee
refactor: Define Slurm as a base task config and extend for remote sc…
JiangJiaWei1103 Jan 15, 2025
c1064d4
feat: Support PythonFunctionTask and reorganize agent structure
JiangJiaWei1103 Jan 16, 2025
361fbd1
Use poetry virtual env to avoid contamination
JiangJiaWei1103 Jan 22, 2025
fc3e34e
docs: Complete local test env setup process
JiangJiaWei1103 Jan 23, 2025
5cd58ec
docs: Add use cases ranging from basic to advanced
JiangJiaWei1103 Jan 23, 2025
9985305
feat: Add a script option for the Slurm function task
JiangJiaWei1103 Feb 6, 2025
e00a5db
fix: Avoid attaching async resource to different event loops
JiangJiaWei1103 Feb 6, 2025
805548c
Merge branch 'master' into slurm-agent-dev
Future-Outlier Feb 6, 2025
34c7c51
use await self._connect(slurm_host) in slurm agent
Future-Outlier Feb 12, 2025
4e85b0a
change
Future-Outlier Feb 12, 2025
27cdc8d
print more info
Future-Outlier Feb 13, 2025
1f30e04
use logger
Future-Outlier Feb 13, 2025
079d101
print more infor
Future-Outlier Feb 13, 2025
1ce5cf8
print
Future-Outlier Feb 13, 2025
7b9b38f
Use sbatch for running Slurm function task
JiangJiaWei1103 Feb 13, 2025
b23174a
update
Future-Outlier Feb 14, 2025
3727f4d
push
Future-Outlier Feb 14, 2025
f291e0b
feat: Show stdout and stderr msg of the Slurm cluster
JiangJiaWei1103 Feb 14, 2025
a465607
feat: Show stdout and stderr msg of the Slurm cluster for SlurmFuncti…
JiangJiaWei1103 Feb 14, 2025
7b75a51
feat: Make an SSH connetion based on client config file or ssh_config
JiangJiaWei1103 Feb 16, 2025
d0967bd
Clarify SSH connection logic
JiangJiaWei1103 Feb 17, 2025
f80ccd4
feat: Interpolate the script with dynamic input values
JiangJiaWei1103 Feb 18, 2025
ac25446
feat: Interpolate the script with dynamic output values
JiangJiaWei1103 Feb 18, 2025
aa6e3ae
Merge branch 'master' into slurm-agent-dev
Future-Outlier Feb 19, 2025
5fcf6c5
add assertion
Future-Outlier Feb 19, 2025
64e9e06
update
Future-Outlier Feb 19, 2025
9d06ed1
update
Future-Outlier Feb 19, 2025
6cda6ae
Fix Script agent bug
Future-Outlier Feb 19, 2025
3b64ddd
agent service for shell task
Future-Outlier Feb 20, 2025
b107b07
Remove remote path to avoid race condition
Future-Outlier Feb 20, 2025
37e4ee1
Revert agent server change
Future-Outlier Feb 20, 2025
77f4d61
use key val to run ssh config
Future-Outlier Feb 20, 2025
68ab8a6
update
Future-Outlier Feb 20, 2025
a968c01
use _get_or_create_ssh_connection
Future-Outlier Feb 20, 2025
13cd31e
update
Future-Outlier Feb 20, 2025
52d73d2
use SlurmCluster and hash
Future-Outlier Feb 20, 2025
6b27668
updagte
Future-Outlier Feb 20, 2025
5aa4e87
update
Future-Outlier Feb 20, 2025
3cc29fe
update
Future-Outlier Feb 20, 2025
19de56a
refactor: Simplify validation process and clean up legacy code
JiangJiaWei1103 Feb 20, 2025
31fc564
Add Slurm agent function task
JiangJiaWei1103 Feb 21, 2025
80f3693
Revert ShellTask behavior
JiangJiaWei1103 Feb 21, 2025
96756df
Remove fix for SlurmShellTask
JiangJiaWei1103 Feb 21, 2025
074cf91
Remove blank line
JiangJiaWei1103 Feb 21, 2025
2d0e0d7
fix doc string and remove logs
Future-Outlier Feb 21, 2025
acfaa76
build plugins
Future-Outlier Feb 21, 2025
6227801
merge
JiangJiaWei1103 Feb 22, 2025
860b703
Merge branch 'master' into slurm-agent-fn-task
Future-Outlier Feb 22, 2025
45aae83
merge master
Future-Outlier Feb 22, 2025
285a9e1
fix-sage-maker-test
Future-Outlier Feb 22, 2025
6b70901
add test_slurm_fn_task
Future-Outlier Feb 22, 2025
cf8eafa
fix
Future-Outlier Feb 22, 2025
f0f7930
update flytebot
Future-Outlier Feb 22, 2025
d0d59d3
add know host = None
Future-Outlier Feb 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ class AsyncAgentExecutorMixin:

def execute(self: PythonTask, **kwargs) -> LiteralMap:
ctx = FlyteContext.current_context()
ss = ctx.serialization_settings or SerializationSettings(ImageConfig())
ss = ctx.serialization_settings or SerializationSettings(ImageConfig.auto_default_image())
output_prefix = ctx.file_access.get_random_remote_directory()
self.resource_meta = None

Expand Down
4 changes: 2 additions & 2 deletions flytekit/extend/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def convert_to_flyte_phase(state: str) -> TaskExecution.Phase:
Convert the state from the agent to the phase in flyte.
"""
state = state.lower()
if state in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]:
if state in ["failed", "timeout", "timedout", "canceled", "cancelled", "skipped", "internal_error"]:
return TaskExecution.FAILED
elif state in ["done", "succeeded", "success"]:
elif state in ["done", "succeeded", "success", "completed"]:
return TaskExecution.SUCCEEDED
elif state in ["running", "terminating"]:
return TaskExecution.RUNNING
Expand Down
2 changes: 1 addition & 1 deletion flytekit/extras/tasks/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def execute(self, **kwargs) -> typing.Any:
return None

def post_execute(self, user_params: ExecutionParameters, rval: typing.Any) -> typing.Any:
return self._config_task_instance.post_execute(user_params, rval)
return self._config_task_instance.pre_execute(user_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect execution lifecycle method call

The post_execute method appears to be incorrectly calling pre_execute instead of post_execute on the config task instance. This could lead to incorrect execution flow since pre and post execute serve different purposes.

Code suggestion
Check the AI-generated fix before applying
Suggested change
return self._config_task_instance.pre_execute(user_params)
return self._config_task_instance.post_execute(user_params, rval)

Code Review Run #0c867b


Should Bito avoid suggestions like this for future reviews? (Manage Rules)

  • Yes, avoid them



class RawShellTask(ShellTask):
Expand Down
5 changes: 5 additions & 0 deletions plugins/flytekit-slurm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Flytekit Slurm Plugin

The Slurm agent is designed to integrate Flyte workflows with Slurm-managed high-performance computing (HPC) clusters, enabling users to leverage Slurm's capability of compute resource allocation, scheduling, and monitoring.

This [guide](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md) provides a concise overview of the design philosophy behind the Slurm agent and explains how to set up a local environment for testing the agent.
2 changes: 2 additions & 0 deletions plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .function.agent import SlurmFunctionAgent
from .function.task import SlurmFunction, SlurmFunctionTask
191 changes: 191 additions & 0 deletions plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import tempfile
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

from asyncssh import SSHClientConnection

from flytekit import logger
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
from flytekit.extend.backend.utils import convert_to_flyte_phase
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

from ..ssh_utils import ssh_connect


@dataclass
class SlurmJobMetadata(ResourceMeta):
"""Slurm job metadata.

Args:
job_id: Slurm job id.
ssh_config: Options of SSH client connection. For available options, please refer to
<newly-added-ssh-utils-file>
"""

job_id: str
ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]


@dataclass
class SlurmCluster:
host: str
username: Optional[str] = None

def __hash__(self):
return hash((self.host, self.username))


class SlurmFunctionAgent(AsyncAgentBase):
name = "Slurm Function Agent"

# SSH connection pool for multi-host environment
ssh_config_to_ssh_conn: Dict[SlurmCluster, SSHClientConnection] = {}

def __init__(self) -> None:
super(SlurmFunctionAgent, self).__init__(task_type_name="slurm_fn", metadata_type=SlurmJobMetadata)

async def create(
self,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
**kwargs,
) -> SlurmJobMetadata:
unique_script_name = f"/tmp/task_{uuid.uuid4().hex}.slurm"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use secure temporary file creation

Using hardcoded temporary file paths can be insecure. Consider using tempfile.mkstemp() for secure temporary file creation

Code suggestion
Check the AI-generated fix before applying
Suggested change
unique_script_name = f"/tmp/task_{uuid.uuid4().hex}.slurm"
_, unique_script_name = tempfile.mkstemp(suffix='.slurm', prefix='task_')

Code Review Run #2afb6d


Should Bito avoid suggestions like this for future reviews? (Manage Rules)

  • Yes, avoid them


# Retrieve task config
ssh_config = task_template.custom["ssh_config"]
sbatch_conf = task_template.custom["sbatch_conf"]
script = task_template.custom["script"]

# Construct command for Slurm cluster
cmd, script = _get_sbatch_cmd_and_script(
sbatch_conf=sbatch_conf,
entrypoint=" ".join(task_template.container.args),
script=script,
batch_script_path=unique_script_name,
)

logger.info("@@@ task_template.container.args:")
logger.info(task_template.container.args)
logger.info("@@@ Slurm Command: ")
logger.info(cmd)
logger.info("@@@ Batch script: ")
logger.info(script)

# Run Slurm job
conn = await self._get_or_create_ssh_connection(ssh_config)
with tempfile.NamedTemporaryFile("w") as f:
f.write(script)
f.flush()
async with conn.start_sftp_client() as sftp:
await sftp.put(f.name, unique_script_name)
res = await conn.run(cmd, check=True)

# Retrieve Slurm job id
job_id = res.stdout.split()[-1]
logger.info("@@@ create slurm job id: " + job_id)

return SlurmJobMetadata(job_id=job_id, ssh_config=ssh_config)

async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource:
ssh_config = resource_meta.ssh_config
conn = await self._get_or_create_ssh_connection(ssh_config)
job_res = await conn.run(f"scontrol show job {resource_meta.job_id}", check=True)

# Determine the current flyte phase from Slurm job state
job_state = "running"
for o in job_res.stdout.split(" "):
if "JobState" in o:
job_state = o.split("=")[1].strip().lower()
elif "StdOut" in o:
stdout_path = o.split("=")[1].strip()
msg_res = await conn.run(f"cat {stdout_path}", check=True)
msg = msg_res.stdout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding error handling for stdout

Consider adding error handling for the case when stdout_path is not found in the job output. Currently, the code assumes msg will always be initialized which could lead to potential UnboundLocalError if StdOut is not found in the job output.

Code suggestion
Check the AI-generated fix before applying
Suggested change
job_state = "running"
for o in job_res.stdout.split(" "):
if "JobState" in o:
job_state = o.split("=")[1].strip().lower()
elif "StdOut" in o:
stdout_path = o.split("=")[1].strip()
msg_res = await conn.run(f"cat {stdout_path}", check=True)
msg = msg_res.stdout
job_state = "running"
msg = "No stdout available"
for o in job_res.stdout.split(" "):
if "JobState" in o:
job_state = o.split("=")[1].strip().lower()
elif "StdOut" in o:
stdout_path = o.split("=")[1].strip()
msg_res = await conn.run(f"cat {stdout_path}", check=True)
msg = msg_res.stdout

Code Review Run #2afb6d


Should Bito avoid suggestions like this for future reviews? (Manage Rules)

  • Yes, avoid them


logger.info("@@@ GET PHASE: ")
logger.info(str(job_state))
cur_phase = convert_to_flyte_phase(job_state)

return Resource(phase=cur_phase, message=msg)

async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None:
conn = await self._get_or_create_ssh_connection(resource_meta.ssh_config)
_ = await conn.run(f"scancel {resource_meta.job_id}", check=True)

async def _get_or_create_ssh_connection(
self, ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]
) -> SSHClientConnection:
"""Get an existing SSH connection or create a new one if needed.

Args:
ssh_config: SSH configuration dictionary.

Returns:
An active SSH connection, either pre-existing or newly established.
"""
host = ssh_config.get("host")
username = ssh_config.get("username")

ssh_cluster_config = SlurmCluster(host=host, username=username)
if self.ssh_config_to_ssh_conn.get(ssh_cluster_config) is None:
logger.info("ssh connection key not found, creating new connection")
conn = await ssh_connect(ssh_config=ssh_config)
self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn
else:
conn = self.ssh_config_to_ssh_conn[ssh_cluster_config]
try:
await conn.run("echo [TEST] SSH connection", check=True)
logger.info("re-using new connection")
except Exception as e:
logger.info(f"Re-establishing SSH connection due to error: {e}")
conn = await ssh_connect(ssh_config=ssh_config)
self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn

return conn


def _get_sbatch_cmd_and_script(
sbatch_conf: Dict[str, str],
entrypoint: str,
script: Optional[str] = None,
batch_script_path: str = "/tmp/task.slurm",
) -> str:
"""Construct the Slurm sbatch command and the batch script content.

Flyte entrypoint, pyflyte-execute, is run within a bash shell process.

Args:
sbatch_conf: Options of sbatch command.
entrypoint: Flyte entrypoint.
script: User-defined script where "{task.fn}" serves as a placeholder for the
task function execution. Users should insert "{task.fn}" at the desired
execution point within the script. If the script is not provided, the task
function will be executed directly.
batch_script_path: Absolute path of the batch script on Slurm cluster.

Returns:
cmd: Slurm sbatch command.
"""
# Setup sbatch options
cmd = ["sbatch"]
for opt, val in sbatch_conf.items():
cmd.extend([f"--{opt}", str(val)])

# Assign the batch script to run
cmd.append(batch_script_path)

if script is None:
script = f"""#!/bin/bash -i
{entrypoint}
"""
Comment on lines +179 to +181
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider more portable shebang line

Consider using a more robust shebang line #!/usr/bin/env bash instead of #!/bin/bash -i as it provides better portability across different systems where bash may not be in the standard location. Additionally, the -i flag for interactive mode may not be necessary for a batch script.

Code suggestion
Check the AI-generated fix before applying
Suggested change
script = f"""#!/bin/bash -i
{entrypoint}
"""
script = f"""#!/usr/bin/env bash
{entrypoint}
"""

Code Review Run #2afb6d


Should Bito avoid suggestions like this for future reviews? (Manage Rules)

  • Yes, avoid them

else:
script = script.replace("{task.fn}", entrypoint)

cmd = " ".join(cmd)

return cmd, script


AgentRegistry.register(SlurmFunctionAgent())
78 changes: 78 additions & 0 deletions plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Slurm task.
"""

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from flytekit import FlyteContextManager, PythonFunctionTask
from flytekit.configuration import SerializationSettings
from flytekit.extend import TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec


@dataclass
class SlurmFunction(object):
"""Configure Slurm settings. Note that we focus on sbatch command now.

Args:
ssh_config: Options of SSH client connection. For available options, please refer to
<newly-added-ssh-utils-file>
sbatch_conf: Options of sbatch command.
script: User-defined script where "{task.fn}" serves as a placeholder for the
task function execution. Users should insert "{task.fn}" at the desired
execution point within the script. If the script is not provided, the task
function will be executed directly.
"""

ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]
sbatch_conf: Optional[Dict[str, str]] = None
script: Optional[str] = None

def __post_init__(self):
assert self.ssh_config["host"] is not None, "'host' must be specified in ssh_config."
if self.sbatch_conf is None:
self.sbatch_conf = {}


class SlurmFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SlurmFunction]):
"""
Actual Plugin that transforms the local python code for execution within a slurm context...
"""

_TASK_TYPE = "slurm_fn"

def __init__(
self,
task_config: SlurmFunction,
task_function: Callable,
container_image: Optional[Union[str, ImageSpec]] = None,
**kwargs,
):
super(SlurmFunctionTask, self).__init__(
task_config=task_config,
task_type=self._TASK_TYPE,
task_function=task_function,
container_image=container_image,
**kwargs,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return {
"ssh_config": self.task_config.ssh_config,
"sbatch_conf": self.task_config.sbatch_conf,
"script": self.task_config.script,
}

def execute(self, **kwargs) -> Any:
ctx = FlyteContextManager.current_context()
if ctx.execution_state and ctx.execution_state.is_local_execution():
# Mimic the propeller's behavior in local agent test
return AsyncAgentExecutorMixin.execute(self, **kwargs)
else:
# Execute the task with a direct python function call
return PythonFunctionTask.execute(self, **kwargs)


TaskPlugins.register_pythontask_plugin(SlurmFunction, SlurmFunctionTask)
Loading
Loading