Skip to content

Commit

Permalink
add test_slurm_fn_task
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier committed Feb 22, 2025
1 parent 285a9e1 commit 6b70901
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
Empty file.
57 changes: 57 additions & 0 deletions plugins/flytekit-slurm/tests/test_slurm_fn_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os.path
from unittest import mock
from flytekit.core import context_manager
import flytekit
from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task, ImageSpec
from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings, DefaultImages
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState
from flytekitplugins.slurm import SlurmFunction


def test_slurm_task():
script_file = """#!/bin/bash -i
echo Run function with sbatch...
# Run the user-defined task function
{task.fn}
"""

@task(
# container_image=image,
task_config=SlurmFunction(
ssh_config={
"host": "your-slurm-host",
"username": "ubuntu",
},
sbatch_conf={
"partition": "debug",
"job-name": "tiny-slurm",
"output": "/home/ubuntu/fn_task.log"
},
script=script_file
)
)
def plus_one(x: int) -> int:
return x + 1

assert plus_one.task_config is not None
assert plus_one.task_config.ssh_config == {"host": "your-slurm-host", "username": "ubuntu"}
assert plus_one.task_config.sbatch_conf == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/fn_task.log"}
assert plus_one.task_config.script == script_file

default_img = Image(name="default", fqn="test", tag="tag")
settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env={"FOO": "baz"},
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)

retrieved_settings = plus_one.get_custom(settings)
assert retrieved_settings["ssh_config"] == {"host": "your-slurm-host", "username": "ubuntu"}
assert retrieved_settings["sbatch_conf"] == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/fn_task.log"}
assert retrieved_settings["script"] == script_file


0 comments on commit 6b70901

Please sign in to comment.