diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 7e4ff02645..0019e4d79b 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -10,6 +10,7 @@ import flyteidl.admin.execution_pb2 as _execution_pb2 import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 +from google.protobuf.wrappers_pb2 import BoolValue import flytekit from flytekit.models import common as _common_models @@ -179,6 +180,7 @@ def __init__( max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, overwrite_cache: Optional[bool] = None, + interruptible: Optional[bool] = None, envs: Optional[_common_models.Envs] = None, tags: Optional[typing.List[str]] = None, cluster_assignment: Optional[ClusterAssignment] = None, @@ -198,6 +200,7 @@ def __init__( parallelism/concurrency of MapTasks is independent from this. :param security_context: Optional security context to use for this execution. :param overwrite_cache: Optional flag to overwrite the cache for this execution. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: flytekit.models.common.Envs environment variables to set for this execution. :param tags: Optional list of tags to apply to the execution. :param execution_cluster_label: Optional execution cluster label to use for this execution. @@ -213,6 +216,7 @@ def __init__( self._max_parallelism = max_parallelism self._security_context = security_context self._overwrite_cache = overwrite_cache + self._interruptible = interruptible self._envs = envs self._tags = tags self._cluster_assignment = cluster_assignment @@ -287,6 +291,10 @@ def security_context(self) -> typing.Optional[security.SecurityContext]: def overwrite_cache(self) -> Optional[bool]: return self._overwrite_cache + @property + def interruptible(self) -> Optional[bool]: + return self._interruptible + @property def envs(self) -> Optional[_common_models.Envs]: return self._envs @@ -321,6 +329,7 @@ def to_flyte_idl(self): max_parallelism=self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, + interruptible=BoolValue(value=self.interruptible) if self.interruptible is not None else None, envs=self.envs.to_flyte_idl() if self.envs else None, tags=self.tags, cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, @@ -351,6 +360,7 @@ def from_flyte_idl(cls, p): if p.security_context else None, overwrite_cache=p.overwrite_cache, + interruptible=p.interruptible.value if p.HasField("interruptible") else None, envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None, tags=p.tags, cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index c0a2a8508e..fa1eb196df 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -1457,6 +1457,7 @@ def _execute( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1475,6 +1476,7 @@ def _execute( :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten for a single execution. If enabled, all calculations are performed even if cached results would be available, overwriting the stored data once execution finishes successfully. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -1548,6 +1550,7 @@ def _execute( 0, ), overwrite_cache=overwrite_cache, + interruptible=interruptible, notifications=notifications, disable_all=options.disable_notifications, labels=options.labels, @@ -1626,6 +1629,7 @@ def execute( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1666,6 +1670,7 @@ def execute( :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten for a single execution. If enabled, all calculations are performed even if cached results would be available, overwriting the stored data once execution finishes successfully. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to be set for the execution. :param tags: Tags to be set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -1690,6 +1695,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1707,6 +1713,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1722,6 +1729,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1737,6 +1745,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1752,6 +1761,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1770,6 +1780,7 @@ def execute( image_config=image_config, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1790,6 +1801,7 @@ def execute( options=options, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1808,6 +1820,7 @@ def execute( options=options, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1830,6 +1843,7 @@ def execute_remote_task_lp( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1850,6 +1864,7 @@ def execute_remote_task_lp( options=options, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1868,6 +1883,7 @@ def execute_remote_wf( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1889,6 +1905,7 @@ def execute_remote_wf( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1907,6 +1924,7 @@ def execute_reference_task( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1938,6 +1956,7 @@ def execute_reference_task( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1954,6 +1973,7 @@ def execute_reference_workflow( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1999,6 +2019,7 @@ def execute_reference_workflow( options=options, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -2015,6 +2036,7 @@ def execute_reference_launch_plan( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -2046,6 +2068,7 @@ def execute_reference_launch_plan( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -2068,6 +2091,7 @@ def execute_local_task( image_config: typing.Optional[ImageConfig] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -2087,6 +2111,7 @@ def execute_local_task( :param image_config: If provided, will use this image config in the pod. :param wait: If True, will wait for the execution to complete before returning. :param overwrite_cache: If True, will overwrite the cache. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -2131,6 +2156,7 @@ def execute_local_task( wait=wait, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, options=options, envs=envs, tags=tags, @@ -2152,6 +2178,7 @@ def execute_local_workflow( options: typing.Optional[Options] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -2159,22 +2186,24 @@ def execute_local_workflow( ) -> FlyteWorkflowExecution: """ Execute an @workflow decorated function. - :param entity: - :param inputs: - :param project: - :param domain: - :param name: - :param version: - :param execution_name: - :param image_config: - :param options: - :param wait: - :param overwrite_cache: - :param envs: - :param tags: - :param cluster_pool: - :param execution_cluster_label: - :return: + + :param entity: The workflow to execute + :param inputs: Input dictionary + :param project: Project to execute in + :param domain: Domain to execute in + :param name: Optional name override for the workflow + :param version: Optional version for the workflow + :param execution_name: Optional name for the execution + :param image_config: Optional image config override + :param options: Optional Options object + :param wait: Whether to wait for execution completion + :param overwrite_cache: If True, will overwrite the cache + :param interruptible: Optional flag to override the default interruptible flag of the executed entity + :param envs: Environment variables to set for the execution + :param tags: Tags to set for the execution + :param cluster_pool: Specify cluster pool on which newly created execution should be placed + :param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed + :return: FlyteWorkflowExecution object """ if not image_config: image_config = ImageConfig.auto_default_image() @@ -2230,6 +2259,7 @@ def execute_local_workflow( options=options, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -2249,12 +2279,14 @@ def execute_local_launch_plan( options: typing.Optional[Options] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ + Execute a locally defined `LaunchPlan`. :param entity: The locally defined launch plan object :param inputs: Inputs to be passed into the execution as a dict with Python native values. @@ -2266,6 +2298,7 @@ def execute_local_launch_plan( :param options: Options to be passed into the execution. :param wait: If True, will wait for the execution to complete before returning. :param overwrite_cache: If True, will overwrite the cache. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to be passed into the execution. :param tags: Tags to be passed into the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -2297,6 +2330,7 @@ def execute_local_launch_plan( wait=wait, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 6e2898e433..8b733fcf03 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -614,6 +614,25 @@ def test_execute_workflow_with_maptask(register): ) assert execution.outputs["o0"] == [4, 5, 6] +def test_executes_nested_workflow_dictating_interruptible(register): + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + flyte_launch_plan = remote.fetch_launch_plan(name="basic.child_workflow.parent_wf", version=VERSION) + # The values we want to test for + interruptible_values = [True, False, None] + executions = [] + for creation_interruptible in interruptible_values: + execution = remote.execute(flyte_launch_plan, inputs={"a": 10}, wait=False, interruptible=creation_interruptible) + executions.append(execution) + # Wait for all executions to complete + for execution, expected_interruptible in zip(executions, interruptible_values): + execution = remote.wait(execution, timeout=300) + # Check that the parent workflow is interruptible as expected + assert execution.spec.interruptible == expected_interruptible + # Check that the child workflow is interruptible as expected + subwf_execution_id = execution.node_executions["n1"].closure.workflow_node_metadata.execution_id.name + subwf_execution = remote.fetch_execution(project=PROJECT, domain=DOMAIN, name=subwf_execution_id) + assert subwf_execution.spec.interruptible == expected_interruptible + @pytest.mark.lftransfers class TestLargeFileTransfers: diff --git a/tests/flytekit/unit/models/test_execution.py b/tests/flytekit/unit/models/test_execution.py index fec2b5cfbb..8e1dfa749a 100644 --- a/tests/flytekit/unit/models/test_execution.py +++ b/tests/flytekit/unit/models/test_execution.py @@ -166,6 +166,7 @@ def test_execution_spec(literal_value_pair): ), raw_output_data_config=_common_models.RawOutputDataConfig(output_location_prefix="raw_output"), max_parallelism=100, + interruptible=True ) assert obj.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN assert obj.launch_plan.domain == "domain" @@ -183,6 +184,7 @@ def test_execution_spec(literal_value_pair): ] assert obj.disable_all is None assert obj.max_parallelism == 100 + assert obj.interruptible == True assert obj.raw_output_data_config.output_location_prefix == "raw_output" obj2 = _execution.ExecutionSpec.from_flyte_idl(obj.to_flyte_idl()) @@ -203,6 +205,7 @@ def test_execution_spec(literal_value_pair): ] assert obj2.disable_all is None assert obj2.max_parallelism == 100 + assert obj2.interruptible == True assert obj2.raw_output_data_config.output_location_prefix == "raw_output" obj = _execution.ExecutionSpec( @@ -220,6 +223,7 @@ def test_execution_spec(literal_value_pair): assert obj.metadata.principal == "tester" assert obj.notifications is None assert obj.disable_all is True + assert obj.interruptible is None obj2 = _execution.ExecutionSpec.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 @@ -233,6 +237,7 @@ def test_execution_spec(literal_value_pair): assert obj2.metadata.principal == "tester" assert obj2.notifications is None assert obj2.disable_all is True + assert obj2.interruptible is None def test_workflow_execution_data_response():