Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ terminate/delete RayJob on deployment timeout #53

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions dagster_ray/kuberay/client/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, TypeVar

if TYPE_CHECKING:
from kubernetes import client
Expand All @@ -18,7 +18,10 @@ def load_kubeconfig(context: Optional[str] = None, config_file: Optional[str] =
pass


class BaseKubeRayClient:
T_Status = TypeVar("T_Status")


class BaseKubeRayClient(Generic[T_Status]):
def __init__(
self,
group: str,
Expand All @@ -37,7 +40,7 @@ def __init__(
self._api = client.CustomObjectsApi(api_client=api_client)
self._core_v1_api = client.CoreV1Api(api_client=api_client)

def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_interval: int = 5, timeout: int = 60):
def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_interval: int = 5, timeout: int = 600):
from kubernetes.client import ApiException

start_time = time.time()
Expand All @@ -63,7 +66,7 @@ def wait_for_service_endpoints(self, service_name: str, namespace: str, poll_int

time.sleep(poll_interval)

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> Dict[str, Any]:
def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> T_Status:
from kubernetes.client import ApiException

while timeout > 0:
Expand Down
8 changes: 1 addition & 7 deletions dagster_ray/kuberay/client/raycluster/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class RayClusterStatus(TypedDict):
state: NotRequired[str]


class RayClusterClient(BaseKubeRayClient):
class RayClusterClient(BaseKubeRayClient[RayClusterStatus]):
def __init__(
self,
config_file: Optional[str] = None,
Expand All @@ -91,12 +91,6 @@ def __init__(
self.config_file = config_file
self.context = context

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayClusterStatus: # type: ignore
return cast(
RayClusterStatus,
super().get_status(name=name, namespace=namespace, timeout=timeout, poll_interval=poll_interval),
)

def wait_until_ready(
self,
name: str,
Expand Down
30 changes: 17 additions & 13 deletions dagster_ray/kuberay/client/rayjob/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import time
from typing import TYPE_CHECKING, Iterator, Literal, Optional, TypedDict, cast
from typing import TYPE_CHECKING, Iterator, Literal, Optional, TypedDict

from typing_extensions import NotRequired

Expand Down Expand Up @@ -31,7 +31,7 @@ class RayJobStatus(TypedDict):
message: NotRequired[str]


class RayJobClient(BaseKubeRayClient):
class RayJobClient(BaseKubeRayClient[RayJobStatus]):
def __init__(
self,
config_file: Optional[str] = None,
Expand All @@ -46,12 +46,6 @@ def __init__(

super().__init__(group=GROUP, version=VERSION, kind=KIND, plural=PLURAL, api_client=api_client)

def get_status(self, name: str, namespace: str, timeout: int = 60, poll_interval: int = 5) -> RayJobStatus: # type: ignore
return cast(
RayJobStatus,
super().get_status(name=name, namespace=namespace, timeout=timeout, poll_interval=poll_interval),
)

def get_ray_cluster_name(self, name: str, namespace: str) -> str:
return self.get_status(name, namespace)["rayClusterName"]

Expand All @@ -66,8 +60,10 @@ def wait_until_running(
self,
name: str,
namespace: str,
timeout: int = 300,
timeout: int = 600,
poll_interval: int = 5,
terminate_on_timeout: bool = True,
port_forward: bool = False,
) -> bool:
start_time = time.time()

Expand All @@ -80,9 +76,17 @@ def wait_until_running(
raise RuntimeError(f"RayJob {namespace}/{name} deployment failed. Status:\n{status}")

if time.time() - start_time > timeout:
raise TimeoutError(
f"Timed out waiting for RayJob {namespace}/{name} deployment to become available. Status:\n{status}"
)
if terminate_on_timeout:
logger.warning(f"Terminating RayJob {namespace}/{name} because of timeout {timeout}s")
try:
self.terminate(name, namespace, port_forward=port_forward)
except Exception as e:
logger.warning(
f"Failed to gracefully terminate RayJob {namespace}/{name}: {e}, will delete it instead."
)
self.delete(name, namespace)

raise TimeoutError(f"Timed out waiting for RayJob {namespace}/{name} to start. Status:\n{status}")

time.sleep(poll_interval)

Expand All @@ -103,7 +107,7 @@ def _wait_for_job_submission(
self,
name: str,
namespace: str,
timeout: int = 300,
timeout: int = 600,
poll_interval: int = 10,
):
start_time = time.time()
Expand Down
5 changes: 4 additions & 1 deletion dagster_ray/kuberay/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class PipesKubeRayJobClient(PipesClient, TreatAsResourceParam):
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the glue job run. Defaults to :py:class:`PipesRayJobMessageReader`.
client (Optional[boto3.client]): The Kubernetes API client.
forward_termination (bool): Whether to cancel the `RayJob` job run when the Dagster process receives a termination signal.
forward_termination (bool): Whether to terminate the Ray job when the Dagster process receives a termination signal,
or if the startup timeout is reached. Defaults to ``True``.
timeout (int): Timeout for various internal interactions with the Kubernetes RayJob.
poll_interval (int): Interval at which to poll the Kubernetes for status updates.
port_forward (bool): Whether to use Kubernetes port-forwarding to connect to the KubeRay cluster.
Expand Down Expand Up @@ -169,6 +170,8 @@ def _start(self, context: OpExecutionContext, ray_job: Dict[str, Any]) -> Dict[s
namespace=namespace,
timeout=self.timeout,
poll_interval=self.poll_interval,
terminate_on_timeout=self.forward_termination,
port_forward=self.port_forward,
)

return self.client.get(
Expand Down
Loading