Skip to content

Commit

Permalink
✅ add test for PipesRayJobClient with a KubeRay cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Oct 21, 2024
1 parent 3b9d9fb commit 8f1de50
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 55 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ from dagster_ray.kuberay import PipesKubeRayJobClient
@asset
def my_asset(
context: AssetExecutionContext, pipes_rayjob_client: PipesKubeRayJobClient
context: AssetExecutionContext, pipes_kube_rayjob_client: PipesKubeRayJobClient
):
pipes_rayjob_client.run(
pipes_kube_rayjob_client.run(
context=context,
ray_job={
# RayJob manifest goes here
Expand All @@ -404,7 +404,7 @@ def my_asset(


definitions = Definitions(
resources={"pipes_rayjob_client": PipesKubeRayJobClient()}, assets=[my_asset]
resources={"pipes_kube_rayjob_client": PipesKubeRayJobClient()}, assets=[my_asset]
)
```

Expand Down Expand Up @@ -440,7 +440,7 @@ When running locally, the `port_forward` option has to be set to `True` in the `
```python
from dagster_ray.kuberay.configs import in_k8s

pipes_rayjob_client = PipesKubeRayJobClient(..., port_forward=not in_k8s)
pipes_kube_rayjob_client = PipesKubeRayJobClient(..., port_forward=not in_k8s)
```

## Resources
Expand Down
15 changes: 8 additions & 7 deletions dagster_ray/kuberay/client/raycluster/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from kubernetes.client import ApiClient
from ray.job_submission import JobSubmissionClient


Expand Down Expand Up @@ -77,13 +78,13 @@ class RayClusterStatus(TypedDict):


class RayClusterClient(BaseKubeRayClient):
def __init__(self, config_file: Optional[str] = None, context: Optional[str] = None) -> None:
super().__init__(
group=GROUP,
version=VERSION,
kind=KIND,
plural=PLURAL,
)
def __init__(
self,
config_file: Optional[str] = None,
context: Optional[str] = None,
api_client: Optional["ApiClient"] = None,
) -> None:
super().__init__(group=GROUP, version=VERSION, kind=KIND, plural=PLURAL, api_client=api_client)

# these are only used because of kubectl port-forward CLI command
# TODO: remove kubectl usage and remove these attributes
Expand Down
19 changes: 11 additions & 8 deletions dagster_ray/kuberay/client/rayjob/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import logging
import time
from typing import Iterator, Literal, Optional, TypedDict, cast
from typing import TYPE_CHECKING, Iterator, Literal, Optional, TypedDict, cast

from typing_extensions import NotRequired

from dagster_ray.kuberay.client.base import BaseKubeRayClient, load_kubeconfig
from dagster_ray.kuberay.client.raycluster import RayClusterClient, RayClusterStatus

if TYPE_CHECKING:
from kubernetes.client import ApiClient

GROUP = "ray.io"
VERSION = "v1"
PLURAL = "rayjobs"
Expand All @@ -29,19 +32,19 @@ class RayJobStatus(TypedDict):


class RayJobClient(BaseKubeRayClient):
def __init__(self, config_file: Optional[str] = None, context: Optional[str] = None) -> None:
def __init__(
self,
config_file: Optional[str] = None,
context: Optional[str] = None,
api_client: Optional["ApiClient"] = None,
) -> None:
# this call must happen BEFORE creating K8s apis
load_kubeconfig(config_file=config_file, context=context)

self.config_file = config_file
self.context = context

super().__init__(
group=GROUP,
version=VERSION,
kind=KIND,
plural=PLURAL,
)
super().__init__(group=GROUP, version=VERSION, kind=KIND, plural=PLURAL, api_client=api_client)

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayJobStatus: # type: ignore
return cast(
Expand Down
3 changes: 0 additions & 3 deletions dagster_ray/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ class PipesRayJobClient(PipesClient, TreatAsResourceParam):
forward_termination (bool): Whether to cancel the `RayJob` job run when the Dagster process receives a termination signal.
timeout (int): Timeout for various internal interactions with the Kubernetes RayJob.
poll_interval (int): Interval at which to poll the Kubernetes for status updates.
port_forward (bool): Whether to use Kubernetes port-forwarding to connect to the KubeRay cluster.
Is useful when running in a local environment.
"""

Expand All @@ -185,7 +184,6 @@ def __init__(
forward_termination: bool = True,
timeout: int = 600,
poll_interval: int = 5,
port_forward: bool = False,
):
self.client = client
self._context_injector = context_injector or PipesEnvContextInjector()
Expand All @@ -194,7 +192,6 @@ def __init__(
self.forward_termination = check.bool_param(forward_termination, "forward_termination")
self.timeout = check.int_param(timeout, "timeout")
self.poll_interval = check.int_param(poll_interval, "poll_interval")
self.port_forward = check.bool_param(port_forward, "port_forward")

self._job_submission_client: Optional["JobSubmissionClient"] = None

Expand Down
81 changes: 71 additions & 10 deletions tests/kuberay/conftest.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
import os
import socket
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Generator
from typing import Any, Dict, Generator, Iterator, List, Tuple

import pytest
import pytest_cases
from kubernetes import config # noqa: TID253
from pytest_kubernetes.options import ClusterOptions
from pytest_kubernetes.providers import AClusterManager, select_provider_manager

from dagster_ray.kuberay.client import RayClusterClient
from dagster_ray.kuberay.configs import DEFAULT_HEAD_GROUP_SPEC, DEFAULT_WORKER_GROUP_SPECS
from tests import ROOT_DIR


def get_random_free_port():
sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]
from tests.kuberay.utils import NAMESPACE, get_random_free_port


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -72,8 +69,6 @@ def dagster_ray_image():

KUBERAY_VERSIONS = os.getenv("PYTEST_KUBERAY_VERSIONS", "1.2.2").split(",")

NAMESPACE = "ray"


@pytest_cases.fixture(scope="session")
@pytest.mark.parametrize("kuberay_version", KUBERAY_VERSIONS)
Expand Down Expand Up @@ -123,3 +118,69 @@ def k8s_with_kuberay(

yield k8s
k8s.delete()


@pytest.fixture(scope="session")
def head_group_spec(dagster_ray_image: str) -> Dict[str, Any]:
head_group_spec = DEFAULT_HEAD_GROUP_SPEC.copy()
head_group_spec["serviceType"] = "LoadBalancer"
head_group_spec["template"]["spec"]["containers"][0]["image"] = dagster_ray_image
head_group_spec["template"]["spec"]["containers"][0]["imagePullPolicy"] = "IfNotPresent"
return head_group_spec


@pytest.fixture(scope="session")
def worker_group_specs(dagster_ray_image: str) -> List[Dict[str, Any]]:
worker_group_specs = DEFAULT_WORKER_GROUP_SPECS.copy()
worker_group_specs[0]["template"]["spec"]["containers"][0]["image"] = dagster_ray_image
worker_group_specs[0]["template"]["spec"]["containers"][0]["imagePullPolicy"] = "IfNotPresent"
return worker_group_specs


PERSISTENT_RAY_CLUSTER_NAME = "persistent-ray-cluster"


@pytest.fixture(scope="session")
def k8s_with_raycluster(
k8s_with_kuberay: AClusterManager,
head_group_spec: Dict[str, Any],
worker_group_specs: List[Dict[str, Any]],
) -> Iterator[Tuple[dict[str, int], AClusterManager]]:
# create a RayCluster
config.load_kube_config(str(k8s_with_kuberay.kubeconfig))

client = RayClusterClient(
config_file=str(k8s_with_kuberay.kubeconfig),
)

client.create(
body={
"metadata": {"name": PERSISTENT_RAY_CLUSTER_NAME},
"spec": {
"headGroupSpec": head_group_spec,
"workerGroupSpecs": worker_group_specs,
},
},
namespace=NAMESPACE,
)

redis_port = get_random_free_port()
dashboard_port = get_random_free_port()

with k8s_with_kuberay.port_forwarding(
target=f"svc/{PERSISTENT_RAY_CLUSTER_NAME}-head-svc",
source_port=redis_port,
target_port=10001,
namespace=NAMESPACE,
), k8s_with_kuberay.port_forwarding(
target=f"svc/{PERSISTENT_RAY_CLUSTER_NAME}-head-svc",
source_port=dashboard_port,
target_port=8265,
namespace=NAMESPACE,
):
yield {"redis": redis_port, "dashboard": dashboard_port}, k8s_with_kuberay

client.delete(
name=PERSISTENT_RAY_CLUSTER_NAME,
namespace=NAMESPACE,
)
67 changes: 62 additions & 5 deletions tests/kuberay/test_pipes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from typing import Tuple

import pytest
import ray # noqa: TID253
Expand All @@ -9,9 +10,12 @@
)
from dagster._core.instance_for_test import instance_for_test
from pytest_kubernetes.providers import AClusterManager
from ray.job_submission import JobSubmissionClient # noqa: TID253

from dagster_ray import PipesRayJobClient
from dagster_ray.kuberay.client import RayJobClient
from dagster_ray.kuberay.pipes import PipesKubeRayJobClient
from tests.test_pipes import LOCAL_SCRIPT_PATH

RAY_JOB = {
"apiVersion": "ray.io/v1",
Expand Down Expand Up @@ -60,7 +64,7 @@


@pytest.fixture(scope="session")
def pipes_rayjob_client(k8s_with_kuberay: AClusterManager):
def pipes_kube_rayjob_client(k8s_with_kuberay: AClusterManager):
return PipesKubeRayJobClient(
client=RayJobClient(
config_file=str(k8s_with_kuberay.kubeconfig),
Expand All @@ -69,10 +73,10 @@ def pipes_rayjob_client(k8s_with_kuberay: AClusterManager):
)


def test_rayjob_pipes(pipes_rayjob_client: PipesKubeRayJobClient, dagster_ray_image: str, capsys):
def test_rayjob_pipes(pipes_kube_rayjob_client: PipesKubeRayJobClient, dagster_ray_image: str, capsys):
@asset
def my_asset(context: AssetExecutionContext, pipes_rayjob_client: PipesKubeRayJobClient):
result = pipes_rayjob_client.run(
def my_asset(context: AssetExecutionContext, pipes_kube_rayjob_client: PipesKubeRayJobClient):
result = pipes_kube_rayjob_client.run(
context=context,
ray_job=RAY_JOB,
extras={"foo": "bar"},
Expand All @@ -83,7 +87,7 @@ def my_asset(context: AssetExecutionContext, pipes_rayjob_client: PipesKubeRayJo
with instance_for_test() as instance:
result = materialize(
[my_asset],
resources={"pipes_rayjob_client": pipes_rayjob_client},
resources={"pipes_kube_rayjob_client": pipes_kube_rayjob_client},
instance=instance,
tags={"dagster/image": dagster_ray_image},
)
Expand Down Expand Up @@ -111,3 +115,56 @@ def my_asset(context: AssetExecutionContext, pipes_rayjob_client: PipesKubeRayJo
assert "Hello from stdout!" in captured.out
assert "Hello from stderr!" in captured.out
assert "Hello from Ray Pipes!" in captured.err


@pytest.fixture(scope="session")
def pipes_ray_job_client(k8s_with_raycluster: Tuple[dict[str, int], AClusterManager]):
ports, k8s = k8s_with_raycluster
return PipesRayJobClient(
client=JobSubmissionClient(
address=f"https://localhost:{ports['dashboard']}",
)
)


def test_ray_job_pipes(pipes_ray_job_client: PipesRayJobClient, capsys):
@asset
def my_asset(context: AssetExecutionContext, pipes_ray_job_client: PipesRayJobClient):
result = pipes_ray_job_client.run(
context=context,
submit_job_params={"entrypoint": f"{sys.executable} {LOCAL_SCRIPT_PATH}"},
extras={"foo": "bar"},
).get_materialize_result()

return result

with instance_for_test() as instance:
result = materialize(
[my_asset],
resources={"pipes_ray_job_client": pipes_ray_job_client},
instance=instance,
)

captured = capsys.readouterr()

print(captured.out)
print(captured.err, file=sys.stderr)

mat_evts = result.get_asset_materialization_events()

mat = instance.get_latest_materialization_event(my_asset.key)
instance.get_event_records(event_records_filter=EventRecordsFilter(event_type=DagsterEventType.LOGS_CAPTURED))

assert len(mat_evts) == 1

assert result.success
assert mat
assert mat and mat.asset_materialization
assert mat.asset_materialization.metadata["some_metric"].value == 0
assert mat.asset_materialization.tags
assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]

assert "Hello from stdout!" in captured.out
assert "Hello from stderr!" in captured.out
assert "Hello from Ray Pipes!" in captured.err
18 changes: 0 additions & 18 deletions tests/kuberay/test_raycluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,10 @@
from dagster_ray import RayResource
from dagster_ray.kuberay import KubeRayCluster, RayClusterClientResource, RayClusterConfig, cleanup_kuberay_clusters
from dagster_ray.kuberay.client import RayClusterClient
from dagster_ray.kuberay.configs import DEFAULT_HEAD_GROUP_SPEC, DEFAULT_WORKER_GROUP_SPECS
from dagster_ray.kuberay.ops import CleanupKuberayClustersConfig
from tests.kuberay.conftest import NAMESPACE, get_random_free_port


@pytest.fixture(scope="session")
def head_group_spec(dagster_ray_image: str) -> Dict[str, Any]:
head_group_spec = DEFAULT_HEAD_GROUP_SPEC.copy()
head_group_spec["serviceType"] = "LoadBalancer"
head_group_spec["template"]["spec"]["containers"][0]["image"] = dagster_ray_image
head_group_spec["template"]["spec"]["containers"][0]["imagePullPolicy"] = "IfNotPresent"
return head_group_spec


@pytest.fixture(scope="session")
def worker_group_specs(dagster_ray_image: str) -> List[Dict[str, Any]]:
worker_group_specs = DEFAULT_WORKER_GROUP_SPECS.copy()
worker_group_specs[0]["template"]["spec"]["containers"][0]["image"] = dagster_ray_image
worker_group_specs[0]["template"]["spec"]["containers"][0]["imagePullPolicy"] = "IfNotPresent"
return worker_group_specs


@pytest.fixture(scope="session")
def ray_cluster_resource(
k8s_with_kuberay: AClusterManager,
Expand Down
10 changes: 10 additions & 0 deletions tests/kuberay/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import socket


def get_random_free_port():
sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]


NAMESPACE = "ray"

0 comments on commit 8f1de50

Please sign in to comment.