Skip to content

Commit

Permalink
✅ add test for PipesRayJobClient with a KubeRay cluster (#46)
Browse files Browse the repository at this point in the history
* ✅ add test for PipesRayJobClient with a KubeRay cluster

* matrix for ray and dagster versions
  • Loading branch information
danielgafni authored Oct 22, 2024
1 parent 3b9d9fb commit b2c185d
Show file tree
Hide file tree
Showing 14 changed files with 283 additions and 119 deletions.
43 changes: 27 additions & 16 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ env:

jobs:
test:
name: Test Python ${{ matrix.py }} - KubeRay ${{ matrix.kuberay }}
name: Test Python ${{ matrix.py }} - Ray ${{ matrix.ray }} - Dagster ${{ matrix.dagster }} - KubeRay ${{ matrix.kuberay }}
runs-on: ${{ matrix.os }}-latest
strategy:
fail-fast: false
Expand All @@ -31,20 +31,16 @@ jobs:
- "3.11"
- "3.10"
- "3.9"
ray:
- "2.37.0"
- "2.24.0"
dagster:
- "1.8.12"
kuberay:
- "1.1.0"
- "1.2.2"
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: 0.4.18
enable-cache: true
- name: Set up Python ${{ matrix.py }}
run: uv python install ${{ matrix.py }}
- name: Install dependencies
run: uv sync --all-extras --dev
- uses: azure/[email protected]
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
Expand All @@ -53,14 +49,27 @@ jobs:
with:
start: false
driver: docker
#- uses: mxschmitt/action-tmate@v3
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: 0.4.25
enable-cache: true
- name: Set up Python ${{ matrix.py }}
run: uv python install ${{ matrix.py }} && uv python pin ${{ matrix.py }} && uv venv --python ${{ matrix.py }}
- name: Override ray==${{ matrix.ray }} dagster==${{ matrix.dagster }}
id: override
run: uv add --no-sync "ray[all]==${{ matrix.ray }}" "dagster==${{ matrix.dagster }}" || echo SKIP=1 >> $GITHUB_OUTPUT
- name: Install dependencies
run: uv sync --all-extras --dev
if: ${{ steps.override.outputs.SKIP != '1' }}
- name: Run tests
env:
PYTEST_KUBERAY_VERSIONS: "${{ matrix.kuberay }}"
run: uv run pytest -v .
if: ${{ steps.override.outputs.SKIP != '1' }}

lint:
name: lint ${{ matrix.py }} - ${{ matrix.os }}
name: Lint ${{ matrix.py }}
runs-on: ${{ matrix.os }}-latest
strategy:
fail-fast: false
Expand All @@ -76,10 +85,10 @@ jobs:
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: 0.4.18
version: 0.4.25
enable-cache: true
- name: Set up Python ${{ matrix.py }}
run: uv python install ${{ matrix.py }}
run: uv python install ${{ matrix.py }} && uv python pin ${{ matrix.py }}
- name: Install dependencies
run: uv sync --all-extras --dev
- name: Run pre-commit hooks
Expand Down Expand Up @@ -108,10 +117,12 @@ jobs:
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: 0.4.18
version: 0.4.25
enable-cache: true
- name: Set up Python
run: uv python install 3.11.9
run: uv python install $PYTHON_VERSION && uv python pin $PYTHON_VERSION
env:
PYTHON_VERSION: 3.11.9
- name: Generate Version
run: export VERSION=$(uv run dunamai from any --style pep440) && echo "Version is $VERSION" && echo "VERSION=$VERSION" >> $GITHUB_ENV
- name: Replace version in code
Expand Down
18 changes: 14 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ COPY --from=bitnami/kubectl:1.30.3 /opt/bitnami/kubectl/bin/kubectl /usr/local/b

# install uv (https://github.com/astral-sh/uv)
# docs for using uv with Docker: https://docs.astral.sh/uv/guides/integration/docker/
COPY --from=ghcr.io/astral-sh/uv:0.4.18 /uv /bin/uv
COPY --from=ghcr.io/astral-sh/uv:0.4.25 /uv /bin/uv

ENV UV_PROJECT_ENVIRONMENT=/usr/local/
ENV DAGSTER_HOME=/opt/dagster/dagster_home
Expand All @@ -27,12 +27,12 @@ WORKDIR /src
COPY pyproject.toml uv.lock ./

RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --frozen --all-extras --no-dev --no-install-project
uv sync --frozen --all-extras --no-dev --no-install-project --inexact

FROM base-prod AS base-dev

# Node.js is needed for pyright in CI
ARG NODE_VERSION=20.7.0
ARG NODE_VERSION=23.0.0
ARG NODE_PACKAGE=node-v$NODE_VERSION-linux-x64
ARG NODE_HOME=/opt/$NODE_PACKAGE
ENV NODE_PATH $NODE_HOME/lib/node_modules
Expand All @@ -41,6 +41,16 @@ RUN --mount=type=cache,target=/cache/downloads \
curl https://nodejs.org/dist/v$NODE_VERSION/$NODE_PACKAGE.tar.gz -o /cache/downloads/$NODE_PACKAGE.tar.gz \
&& tar -xzC /opt/ -f /cache/downloads/$NODE_PACKAGE.tar.gz


RUN mkdir dagster_ray && touch dagster_ray/__init__.py && touch README.md
COPY dagster_ray/_version.py dagster_ray/_version.py

# Install specific Dagster and Ray versions (for integration tests)
ARG RAY_VERSION=2.35.0
ARG DAGSTER_VERSION=1.8.12
RUN --mount=type=cache,target=/root/.cache/uv \
uv add --no-sync "ray[all]==$RAY_VERSION" "dagster==$DAGSTER_VERSION"

RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --frozen --all-extras --no-install-project

Expand All @@ -51,4 +61,4 @@ FROM base-${BUILD_DEPENDENCIES} AS final
COPY . .

# finally install all our code
RUN uv sync --frozen --all-extras
RUN uv sync --frozen --all-extras --inexact
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ from dagster_ray.kuberay import PipesKubeRayJobClient
@asset
def my_asset(
context: AssetExecutionContext, pipes_rayjob_client: PipesKubeRayJobClient
context: AssetExecutionContext, pipes_kube_rayjob_client: PipesKubeRayJobClient
):
pipes_rayjob_client.run(
pipes_kube_rayjob_client.run(
context=context,
ray_job={
# RayJob manifest goes here
Expand All @@ -404,7 +404,7 @@ def my_asset(


definitions = Definitions(
resources={"pipes_rayjob_client": PipesKubeRayJobClient()}, assets=[my_asset]
resources={"pipes_kube_rayjob_client": PipesKubeRayJobClient()}, assets=[my_asset]
)
```

Expand Down Expand Up @@ -440,7 +440,7 @@ When running locally, the `port_forward` option has to be set to `True` in the `
```python
from dagster_ray.kuberay.configs import in_k8s

pipes_rayjob_client = PipesKubeRayJobClient(..., port_forward=not in_k8s)
pipes_kube_rayjob_client = PipesKubeRayJobClient(..., port_forward=not in_k8s)
```

## Resources
Expand Down
17 changes: 13 additions & 4 deletions dagster_ray/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, cast

import dagster
from dagster import (
_check as check,
)
Expand All @@ -19,7 +20,9 @@
StepHandler,
StepHandlerContext,
)
from dagster._core.remote_representation.origin import RemoteJobOrigin
from dagster._utils.merger import merge_dicts
from packaging.version import Version
from pydantic import Field

from dagster_ray.config import RayExecutionConfig, RayJobSubmissionClientConfig
Expand Down Expand Up @@ -149,10 +152,16 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[Dags
"dagster/op": step_key,
"dagster/run-id": step_handler_context.execute_step_args.run_id,
}
if run.external_job_origin:
labels["dagster/code-location"] = (
run.external_job_origin.repository_origin.code_location_origin.location_name
)

if Version(dagster.__version__) >= Version("1.8.12"):
remote_job_origin = run.remote_job_origin # type: ignore
else:
remote_job_origin = run.external_job_origin # type: ignore

remote_job_origin = cast(Optional[RemoteJobOrigin], remote_job_origin)

if remote_job_origin:
labels["dagster/code-location"] = remote_job_origin.repository_origin.code_location_origin.location_name

user_provided_config = RayExecutionConfig.from_tags({**step_handler_context.step_tags[step_key]})

Expand Down
15 changes: 8 additions & 7 deletions dagster_ray/kuberay/client/raycluster/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from kubernetes.client import ApiClient
from ray.job_submission import JobSubmissionClient


Expand Down Expand Up @@ -77,13 +78,13 @@ class RayClusterStatus(TypedDict):


class RayClusterClient(BaseKubeRayClient):
def __init__(self, config_file: Optional[str] = None, context: Optional[str] = None) -> None:
super().__init__(
group=GROUP,
version=VERSION,
kind=KIND,
plural=PLURAL,
)
def __init__(
self,
config_file: Optional[str] = None,
context: Optional[str] = None,
api_client: Optional["ApiClient"] = None,
) -> None:
super().__init__(group=GROUP, version=VERSION, kind=KIND, plural=PLURAL, api_client=api_client)

# these are only used because of kubectl port-forward CLI command
# TODO: remove kubectl usage and remove these attributes
Expand Down
19 changes: 11 additions & 8 deletions dagster_ray/kuberay/client/rayjob/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import logging
import time
from typing import Iterator, Literal, Optional, TypedDict, cast
from typing import TYPE_CHECKING, Iterator, Literal, Optional, TypedDict, cast

from typing_extensions import NotRequired

from dagster_ray.kuberay.client.base import BaseKubeRayClient, load_kubeconfig
from dagster_ray.kuberay.client.raycluster import RayClusterClient, RayClusterStatus

if TYPE_CHECKING:
from kubernetes.client import ApiClient

GROUP = "ray.io"
VERSION = "v1"
PLURAL = "rayjobs"
Expand All @@ -29,19 +32,19 @@ class RayJobStatus(TypedDict):


class RayJobClient(BaseKubeRayClient):
def __init__(self, config_file: Optional[str] = None, context: Optional[str] = None) -> None:
def __init__(
self,
config_file: Optional[str] = None,
context: Optional[str] = None,
api_client: Optional["ApiClient"] = None,
) -> None:
# this call must happen BEFORE creating K8s apis
load_kubeconfig(config_file=config_file, context=context)

self.config_file = config_file
self.context = context

super().__init__(
group=GROUP,
version=VERSION,
kind=KIND,
plural=PLURAL,
)
super().__init__(group=GROUP, version=VERSION, kind=KIND, plural=PLURAL, api_client=api_client)

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayJobStatus: # type: ignore
return cast(
Expand Down
10 changes: 4 additions & 6 deletions dagster_ray/kuberay/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
"containers": [
{
"volumeMounts": [
# {"mountPath": "/tmp/ray", "name": "log-volume"},
{"mountPath": "/tmp/ray", "name": "ray-logs"},
],
"name": "head",
"imagePullPolicy": "Always",
},
],
"volumes": [
{"name": "log-volume", "emptyDir": {}},
{"name": "ray-logs", "emptyDir": {}},
],
"affinity": {},
"tolerations": [],
Expand All @@ -58,15 +58,13 @@
"imagePullSecrets": [],
"containers": [
{
"volumeMounts": [
# {"mountPath": "/tmp/ray", "name": "log-volume"}
],
"volumeMounts": [{"mountPath": "/tmp/ray", "name": "ray-logs"}],
"name": "worker",
"imagePullPolicy": "Always",
}
],
"volumes": [
{"name": "log-volume", "emptyDir": {}},
{"name": "ray-logs", "emptyDir": {}},
],
"affinity": {},
"tolerations": [],
Expand Down
18 changes: 9 additions & 9 deletions dagster_ray/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ class PipesRayJobClient(PipesClient, TreatAsResourceParam):
forward_termination (bool): Whether to cancel the `RayJob` job run when the Dagster process receives a termination signal.
timeout (int): Timeout for various internal interactions with the Kubernetes RayJob.
poll_interval (int): Interval at which to poll the Kubernetes for status updates.
port_forward (bool): Whether to use Kubernetes port-forwarding to connect to the KubeRay cluster.
Is useful when running in a local environment.
"""

Expand All @@ -185,7 +184,6 @@ def __init__(
forward_termination: bool = True,
timeout: int = 600,
poll_interval: int = 5,
port_forward: bool = False,
):
self.client = client
self._context_injector = context_injector or PipesEnvContextInjector()
Expand All @@ -194,7 +192,6 @@ def __init__(
self.forward_termination = check.bool_param(forward_termination, "forward_termination")
self.timeout = check.int_param(timeout, "timeout")
self.poll_interval = check.int_param(poll_interval, "poll_interval")
self.port_forward = check.bool_param(port_forward, "port_forward")

self._job_submission_client: Optional["JobSubmissionClient"] = None

Expand All @@ -213,7 +210,6 @@ def run( # type: ignore
ray_job (Dict[str, Any]): RayJob specification. `API reference <https://ray-project.github.io/kuberay/reference/api/#rayjob>`_.
extras (Optional[Dict[str, Any]]): Additional information to pass to the Pipes session.
"""
from ray.job_submission import JobStatus

with open_pipes_session(
context=context,
Expand All @@ -227,11 +223,7 @@ def run( # type: ignore

try:
self._read_messages(context, job_id)
status = self._wait_for_completion(context, job_id)

if status in {JobStatus.FAILED, JobStatus.STOPPED}:
raise RuntimeError(f"RayJob {job_id} failed with status {status}")

self._wait_for_completion(context, job_id)
return PipesClientCompletedInvocation(session)

except DagsterExecutionInterruptedError:
Expand Down Expand Up @@ -284,12 +276,20 @@ def _read_messages(self, context: OpExecutionContext, job_id: str) -> None:
)

def _wait_for_completion(self, context: OpExecutionContext, job_id: str) -> "JobStatus":
from ray.job_submission import JobStatus

context.log.info(f"[pipes] Waiting for RayJob {job_id} to complete...")

while True:
status = self.client.get_job_status(job_id)

if status.is_terminal():
if status in {JobStatus.FAILED, JobStatus.STOPPED}:
job_details = self.client.get_job_info(job_id)
raise RuntimeError(
f"[pipes] RayJob {job_id} failed with status {status}. Message:\n{job_details.message}"
)

return status

time.sleep(self.poll_interval)
Expand Down
Loading

0 comments on commit b2c185d

Please sign in to comment.