Skip to content

Commit

Permalink
Support FlyteRemote.execute interruptible flag override (#2885)
Browse files Browse the repository at this point in the history
Signed-off-by: redartera <[email protected]>
  • Loading branch information
redartera authored Feb 11, 2025
1 parent 7f58efe commit 74bde49
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 16 deletions.
10 changes: 10 additions & 0 deletions flytekit/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import flyteidl.admin.execution_pb2 as _execution_pb2
import flyteidl.admin.node_execution_pb2 as _node_execution_pb2
import flyteidl.admin.task_execution_pb2 as _task_execution_pb2
from google.protobuf.wrappers_pb2 import BoolValue

import flytekit
from flytekit.models import common as _common_models
Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
overwrite_cache: Optional[bool] = None,
interruptible: Optional[bool] = None,
envs: Optional[_common_models.Envs] = None,
tags: Optional[typing.List[str]] = None,
cluster_assignment: Optional[ClusterAssignment] = None,
Expand All @@ -198,6 +200,7 @@ def __init__(
parallelism/concurrency of MapTasks is independent from this.
:param security_context: Optional security context to use for this execution.
:param overwrite_cache: Optional flag to overwrite the cache for this execution.
:param interruptible: Optional flag to override the default interruptible flag of the executed entity.
:param envs: flytekit.models.common.Envs environment variables to set for this execution.
:param tags: Optional list of tags to apply to the execution.
:param execution_cluster_label: Optional execution cluster label to use for this execution.
Expand All @@ -213,6 +216,7 @@ def __init__(
self._max_parallelism = max_parallelism
self._security_context = security_context
self._overwrite_cache = overwrite_cache
self._interruptible = interruptible
self._envs = envs
self._tags = tags
self._cluster_assignment = cluster_assignment
Expand Down Expand Up @@ -287,6 +291,10 @@ def security_context(self) -> typing.Optional[security.SecurityContext]:
def overwrite_cache(self) -> Optional[bool]:
return self._overwrite_cache

@property
def interruptible(self) -> Optional[bool]:
return self._interruptible

@property
def envs(self) -> Optional[_common_models.Envs]:
return self._envs
Expand Down Expand Up @@ -321,6 +329,7 @@ def to_flyte_idl(self):
max_parallelism=self.max_parallelism,
security_context=self.security_context.to_flyte_idl() if self.security_context else None,
overwrite_cache=self.overwrite_cache,
interruptible=BoolValue(value=self.interruptible) if self.interruptible is not None else None,
envs=self.envs.to_flyte_idl() if self.envs else None,
tags=self.tags,
cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None,
Expand Down Expand Up @@ -351,6 +360,7 @@ def from_flyte_idl(cls, p):
if p.security_context
else None,
overwrite_cache=p.overwrite_cache,
interruptible=p.interruptible.value if p.HasField("interruptible") else None,
envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None,
tags=p.tags,
cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment)
Expand Down
66 changes: 50 additions & 16 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,7 @@ def _execute(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand All @@ -1475,6 +1476,7 @@ def _execute(
:param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten
for a single execution. If enabled, all calculations are performed even if cached results would
be available, overwriting the stored data once execution finishes successfully.
:param interruptible: Optional flag to override the default interruptible flag of the executed entity.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
Expand Down Expand Up @@ -1548,6 +1550,7 @@ def _execute(
0,
),
overwrite_cache=overwrite_cache,
interruptible=interruptible,
notifications=notifications,
disable_all=options.disable_notifications,
labels=options.labels,
Expand Down Expand Up @@ -1626,6 +1629,7 @@ def execute(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand Down Expand Up @@ -1666,6 +1670,7 @@ def execute(
:param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten
for a single execution. If enabled, all calculations are performed even if cached results would
be available, overwriting the stored data once execution finishes successfully.
:param interruptible: Optional flag to override the default interruptible flag of the executed entity.
:param envs: Environment variables to be set for the execution.
:param tags: Tags to be set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
Expand All @@ -1690,6 +1695,7 @@ def execute(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1707,6 +1713,7 @@ def execute(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1722,6 +1729,7 @@ def execute(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1737,6 +1745,7 @@ def execute(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1752,6 +1761,7 @@ def execute(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1770,6 +1780,7 @@ def execute(
image_config=image_config,
wait=wait,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1790,6 +1801,7 @@ def execute(
options=options,
wait=wait,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1808,6 +1820,7 @@ def execute(
options=options,
wait=wait,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1830,6 +1843,7 @@ def execute_remote_task_lp(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand All @@ -1850,6 +1864,7 @@ def execute_remote_task_lp(
options=options,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1868,6 +1883,7 @@ def execute_remote_wf(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand All @@ -1889,6 +1905,7 @@ def execute_remote_wf(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1907,6 +1924,7 @@ def execute_reference_task(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand Down Expand Up @@ -1938,6 +1956,7 @@ def execute_reference_task(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -1954,6 +1973,7 @@ def execute_reference_workflow(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand Down Expand Up @@ -1999,6 +2019,7 @@ def execute_reference_workflow(
options=options,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -2015,6 +2036,7 @@ def execute_reference_launch_plan(
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand Down Expand Up @@ -2046,6 +2068,7 @@ def execute_reference_launch_plan(
wait=wait,
type_hints=type_hints,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -2068,6 +2091,7 @@ def execute_local_task(
image_config: typing.Optional[ImageConfig] = None,
wait: bool = False,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
Expand All @@ -2087,6 +2111,7 @@ def execute_local_task(
:param image_config: If provided, will use this image config in the pod.
:param wait: If True, will wait for the execution to complete before returning.
:param overwrite_cache: If True, will overwrite the cache.
:param interruptible: Optional flag to override the default interruptible flag of the executed entity.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
Expand Down Expand Up @@ -2131,6 +2156,7 @@ def execute_local_task(
wait=wait,
type_hints=entity.python_interface.inputs,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
options=options,
envs=envs,
tags=tags,
Expand All @@ -2152,29 +2178,32 @@ def execute_local_workflow(
options: typing.Optional[Options] = None,
wait: bool = False,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute an @workflow decorated function.
:param entity:
:param inputs:
:param project:
:param domain:
:param name:
:param version:
:param execution_name:
:param image_config:
:param options:
:param wait:
:param overwrite_cache:
:param envs:
:param tags:
:param cluster_pool:
:param execution_cluster_label:
:return:
:param entity: The workflow to execute
:param inputs: Input dictionary
:param project: Project to execute in
:param domain: Domain to execute in
:param name: Optional name override for the workflow
:param version: Optional version for the workflow
:param execution_name: Optional name for the execution
:param image_config: Optional image config override
:param options: Optional Options object
:param wait: Whether to wait for execution completion
:param overwrite_cache: If True, will overwrite the cache
:param interruptible: Optional flag to override the default interruptible flag of the executed entity
:param envs: Environment variables to set for the execution
:param tags: Tags to set for the execution
:param cluster_pool: Specify cluster pool on which newly created execution should be placed
:param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed
:return: FlyteWorkflowExecution object
"""
if not image_config:
image_config = ImageConfig.auto_default_image()
Expand Down Expand Up @@ -2230,6 +2259,7 @@ def execute_local_workflow(
options=options,
type_hints=entity.python_interface.inputs,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand All @@ -2249,12 +2279,14 @@ def execute_local_launch_plan(
options: typing.Optional[Options] = None,
wait: bool = False,
overwrite_cache: typing.Optional[bool] = None,
interruptible: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
execution_cluster_label: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a locally defined `LaunchPlan`.
:param entity: The locally defined launch plan object
:param inputs: Inputs to be passed into the execution as a dict with Python native values.
Expand All @@ -2266,6 +2298,7 @@ def execute_local_launch_plan(
:param options: Options to be passed into the execution.
:param wait: If True, will wait for the execution to complete before returning.
:param overwrite_cache: If True, will overwrite the cache.
:param interruptible: Optional flag to override the default interruptible flag of the executed entity.
:param envs: Environment variables to be passed into the execution.
:param tags: Tags to be passed into the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
Expand Down Expand Up @@ -2297,6 +2330,7 @@ def execute_local_launch_plan(
wait=wait,
type_hints=entity.python_interface.inputs,
overwrite_cache=overwrite_cache,
interruptible=interruptible,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
Expand Down
19 changes: 19 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,25 @@ def test_execute_workflow_with_maptask(register):
)
assert execution.outputs["o0"] == [4, 5, 6]

def test_executes_nested_workflow_dictating_interruptible(register):
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
flyte_launch_plan = remote.fetch_launch_plan(name="basic.child_workflow.parent_wf", version=VERSION)
# The values we want to test for
interruptible_values = [True, False, None]
executions = []
for creation_interruptible in interruptible_values:
execution = remote.execute(flyte_launch_plan, inputs={"a": 10}, wait=False, interruptible=creation_interruptible)
executions.append(execution)
# Wait for all executions to complete
for execution, expected_interruptible in zip(executions, interruptible_values):
execution = remote.wait(execution, timeout=300)
# Check that the parent workflow is interruptible as expected
assert execution.spec.interruptible == expected_interruptible
# Check that the child workflow is interruptible as expected
subwf_execution_id = execution.node_executions["n1"].closure.workflow_node_metadata.execution_id.name
subwf_execution = remote.fetch_execution(project=PROJECT, domain=DOMAIN, name=subwf_execution_id)
assert subwf_execution.spec.interruptible == expected_interruptible


@pytest.mark.lftransfers
class TestLargeFileTransfers:
Expand Down
Loading

0 comments on commit 74bde49

Please sign in to comment.