Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eager cleanup #3148

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
START_NODE_ID = "start-node"
END_NODE_ID = "end-node"

DEFAULT_FAILURE_NODE_ID = "nfail"

# If set this environment variable overrides the default container image and the default base image in ImageSpec.
FLYTE_INTERNAL_IMAGE_ENV_VAR = "FLYTE_INTERNAL_IMAGE"

Expand Down
7 changes: 6 additions & 1 deletion flytekit/core/node_creation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Union

from flytekit.core.base_task import PythonTask
Expand All @@ -10,6 +11,7 @@
from flytekit.core.workflow import WorkflowBase
from flytekit.exceptions import user as _user_exceptions
from flytekit.loggers import logger
from flytekit.utils.asyn import run_sync

if TYPE_CHECKING:
from flytekit.remote.remote_callable import RemoteEntity
Expand Down Expand Up @@ -77,7 +79,10 @@ def create_node(
# When compiling, calling the entity will create a node.
ctx = FlyteContext.current_context()
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
outputs = entity(**kwargs)
if inspect.iscoroutinefunction(entity.__call__):
outputs = run_sync(entity, **kwargs)
else:
outputs = entity(**kwargs)
Comment on lines +82 to +85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding error handling for async execution

Consider handling potential exceptions from run_sync when executing coroutine functions. The current implementation may silently fail if the async execution encounters issues.

Code suggestion
Check the AI-generated fix before applying
Suggested change
if inspect.iscoroutinefunction(entity.__call__):
outputs = run_sync(entity, **kwargs)
else:
outputs = entity(**kwargs)
if inspect.iscoroutinefunction(entity.__call__):
try:
outputs = run_sync(entity, **kwargs)
except Exception as e:
raise RuntimeError(f"Async execution failed for {entity.name}: {str(e)}") from e
else:
outputs = entity(**kwargs)

Code Review Run #ce446d


Should Bito avoid suggestions like this for future reviews? (Manage Rules)

  • Yes, avoid them

# This is always the output of create_and_link_node which returns create_task_output, which can be
# VoidPromise, Promise, or our custom namedtuple of Promises.
node = ctx.compilation_state.nodes[-1]
Expand Down
121 changes: 120 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import inspect
import os
import signal
import time
import typing
from abc import ABC
from collections import OrderedDict
from contextlib import suppress
Expand All @@ -32,7 +34,7 @@
from flytekit.core.constants import EAGER_ROOT_ENV_NAME
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.docstring import Docstring
from flytekit.core.interface import transform_function_to_interface
from flytekit.core.interface import Interface, transform_function_to_interface
from flytekit.core.promise import (
Promise,
VoidPromise,
Expand All @@ -59,7 +61,11 @@
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import literals as _literal_models
from flytekit.models import task as task_models
from flytekit.models.admin import common as admin_common_models
from flytekit.models.admin import workflow as admin_workflow_models
from flytekit.models.filters import ValueIn
from flytekit.models.literals import LiteralMap
from flytekit.models.security import Secret
from flytekit.utils.asyn import loop_manager

T = TypeVar("T")
Expand Down Expand Up @@ -636,3 +642,116 @@ def run(self, remote: "FlyteRemote", ss: SerializationSettings, **kwargs): # ty

with FlyteContextManager.with_context(builder):
return loop_manager.run_sync(self.async_execute, self, **kwargs)

def get_as_workflow(self):
from flytekit.core.workflow import ImperativeWorkflow

cleanup = EagerFailureHandlerTask(name=f"{self.name}-cleanup", inputs=self.python_interface.inputs)
# todo: remove this before merging
# this is actually bad, but useful for developing
cleanup._container_image = self._container_image
Comment on lines +652 to +654
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

wb = ImperativeWorkflow(name=self.name)

input_kwargs = {}
for input_name, input_python_type in self.python_interface.inputs.items():
wb.add_workflow_input(input_name, input_python_type)
input_kwargs[input_name] = wb.inputs[input_name]

node = wb.add_entity(self, **input_kwargs)
for output_name, output_python_type in self.python_interface.outputs.items():
wb.add_workflow_output(output_name, node.outputs[output_name])

wb.add_on_failure_handler(cleanup)
return wb


class EagerFailureTaskResolver(TaskResolverMixin):
@property
def location(self) -> str:
return f"{EagerFailureTaskResolver.__module__}.eager_failure_task_resolver"

def name(self) -> str:
return "eager_failure_task_resolver"

def load_task(self, loader_args: List[str]) -> Task:
"""
Given the set of identifier keys, should return one Python Task or raise an error if not found
"""
return EagerFailureHandlerTask(name="no_input_default_cleanup_task", inputs={})

def loader_args(self, settings: SerializationSettings, t: Task) -> List[str]:
"""
Return a list of strings that can help identify the parameter Task
"""
return ["eager", "failure", "handler"]

def get_all_tasks(self) -> List[Task]:
"""
Future proof method. Just making it easy to access all tasks (Not required today as we auto register them)
"""
return []


eager_failure_task_resolver = EagerFailureTaskResolver()


class EagerFailureHandlerTask(PythonAutoContainerTask, metaclass=FlyteTrackedABC):
_TASK_TYPE = "eager_failure_handler_task"

def __init__(self, name: str, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, **kwargs):
""" """
inputs = inputs or {}
super().__init__(
task_type=self._TASK_TYPE,
name=name,
interface=Interface(inputs=inputs, outputs=None),
task_config=None,
task_resolver=eager_failure_task_resolver,
secret_requests=[Secret(group="", key="EAGER_API_KEY")], # todo: remove this before merging
**kwargs,
)

def dispatch_execute(self, ctx: FlyteContext, input_literal_map: LiteralMap) -> LiteralMap:
"""
This task should only be called during remote execution. Because when rehydrating this task at execution
time, we don't have access to the python interface of the corresponding eager task/workflow, we don't
have the Python types to convert the input literal map, but nor do we need them.
This task is responsible only for ensuring that all executions are terminated.
"""
# Recursive imports
from flytekit import current_context
from flytekit.configuration.plugin import get_plugin

most_recent = admin_common_models.Sort("created_at", admin_common_models.Sort.Direction.DESCENDING)
current_exec_id = current_context().execution_id
project = current_exec_id.project
domain = current_exec_id.domain
name = current_exec_id.name
logger.warning(f"Cleaning up potentially still running tasks for execution {name} in {project}/{domain}")
remote = get_plugin().get_remote(config=None, project=project, domain=domain)
key_filter = ValueIn("execution_tag.key", ["eager-exec"])
value_filter = ValueIn("execution_tag.value", [name])
phase_filter = ValueIn("phase", ["UNDEFINED", "QUEUED", "RUNNING"])
# This should be made more robust, currently lacking retries and exception handling
while True:
exec_models, _ = remote.client.list_executions_paginated(
project,
domain,
limit=100,
filters=[key_filter, value_filter, phase_filter],
sort_by=most_recent,
)
logger.warning(f"Found {len(exec_models)} executions this round for termination")
if not exec_models:
break
logger.warning(exec_models)
for exec_model in exec_models:
logger.warning(f"Terminating execution {exec_model.id}, phase {exec_model.closure.phase}")
remote.client.terminate_execution(exec_model.id, f"clean up by parent eager execution {name}")
time.sleep(0.5)

# Just echo back
return input_literal_map

def execute(self, **kwargs) -> Any:
raise AssertionError("this task shouldn't need to call execute")
34 changes: 34 additions & 0 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,40 @@ def add_launch_plan(self, launch_plan: _annotated_launch_plan.LaunchPlan, **kwar
def add_subwf(self, sub_wf: WorkflowBase, **kwargs) -> Node:
return self.add_entity(sub_wf, **kwargs)

def add_on_failure_handler(self, entity):
"""
This is a special function that mimics the add_entity function, but this is only used
to add the failure node. Failure nodes are special because we don't want
them to be part of the main workflow.
"""
from flytekit.core.node_creation import create_node

ctx = FlyteContext.current_context()
if ctx.compilation_state is not None:
raise RuntimeError("Can't already be compiling")
with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx:
if entity.python_interface and self.python_interface:
workflow_inputs = self.python_interface.inputs
failure_node_inputs = entity.python_interface.inputs

# Workflow inputs should be a subset of failure node inputs.
if (failure_node_inputs | workflow_inputs) != failure_node_inputs:
raise FlyteFailureNodeInputMismatchException(self.on_failure, self)
additional_keys = failure_node_inputs.keys() - workflow_inputs.keys()
# Raising an error if the additional inputs in the failure node are not optional.
for k in additional_keys:
if not is_optional_type(failure_node_inputs[k]):
raise FlyteFailureNodeInputMismatchException(self.on_failure, self)

n = create_node(entity=entity, **self._inputs)
# Maybe this can be cleaned up, but the create node function creates a node
# and add it to the compilation state. We need to pop it off because we don't
# want it in the actual workflow.
ctx.compilation_state.nodes.pop(-1)
self._failure_node = n
n._id = _common_constants.DEFAULT_FAILURE_NODE_ID
return n # type: ignore

def ready(self) -> bool:
"""
This function returns whether or not the workflow is in a ready state, which means
Expand Down
6 changes: 6 additions & 0 deletions flytekit/models/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def from_python_std(cls, string):
return Contains._parse_from_string(string)
elif string.startswith("value_in("):
return ValueIn._parse_from_string(string)
elif string.startswith("value_not_in("):
return ValueNotIn._parse_from_string(string)
else:
raise ValueError("'{}' could not be parsed into a filter.".format(string))

Expand Down Expand Up @@ -133,3 +135,7 @@ class Contains(SetFilter):

class ValueIn(SetFilter):
_comparator = "value_in"


class ValueNotIn(SetFilter):
_comparator = "value_not_in"
6 changes: 6 additions & 0 deletions flytekit/tools/serialize_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flytekit import LaunchPlan
from flytekit.core import context_manager as flyte_context
from flytekit.core.base_task import PythonTask
from flytekit.core.python_function_task import EagerAsyncPythonFunctionTask
from flytekit.core.workflow import WorkflowBase
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models import task as task_models
Expand Down Expand Up @@ -60,6 +61,11 @@ def get_registrable_entities(
lp = LaunchPlan.get_default_launch_plan(ctx, entity)
get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp, options)

if isinstance(entity, EagerAsyncPythonFunctionTask):
wf = entity.get_as_workflow()
lp = LaunchPlan.get_default_launch_plan(ctx, wf)
get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp, options)

new_api_model_values = list(new_api_serializable_entities.values())
entities_to_be_serialized = list(filter(_should_register_with_admin, new_api_model_values))

Expand Down
4 changes: 4 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def get_serializable_workflow(
if n.id == _common_constants.GLOBAL_INPUT_NODE_ID:
continue

# Ensure no node is named the failure node id
if n.id == _common_constants.DEFAULT_FAILURE_NODE_ID:
raise ValueError(f"Node {n.id} is reserved for the failure node")

# Recursively serialize the node
serialized_nodes.append(get_serializable(entity_mapping, settings, n, options))

Expand Down
33 changes: 33 additions & 0 deletions tests/flytekit/unit/core/test_eager_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from collections import OrderedDict

import flytekit.configuration
from flytekit.configuration import Image, ImageConfig
from flytekit.core.python_function_task import EagerFailureHandlerTask
from flytekit.tools.translator import get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


def test_failure():
t = EagerFailureHandlerTask(name="tester", inputs={"a": int})

spec = get_serializable(OrderedDict(), serialization_settings, t)
print(spec)

assert spec.template.container.args == ['pyflyte-execute', '--inputs', '{{.input}}', '--output-prefix', '{{.outputPrefix}}', '--raw-output-data-prefix', '{{.rawOutputDataPrefix}}', '--checkpoint-path', '{{.checkpointOutputPrefix}}', '--prev-checkpoint', '{{.prevCheckpointPrefix}}', '--resolver', 'flytekit.core.python_function_task.eager_failure_task_resolver', '--', 'eager', 'failure', 'handler']


def test_loading():
from flytekit.tools.module_loader import load_object_from_module

resolver = load_object_from_module("flytekit.core.python_function_task.eager_failure_task_resolver")
print(resolver)
t = resolver.load_task([])
assert isinstance(t, EagerFailureHandlerTask)
57 changes: 56 additions & 1 deletion tests/flytekit/unit/core/test_imperative.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
from collections import OrderedDict

import pytest

from dataclasses import dataclass, fields, field
import flytekit.configuration
from flytekit.configuration import Image, ImageConfig
from flytekit.core.base_task import kwtypes
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.task import reference_task, task
from flytekit.core.workflow import ImperativeWorkflow, get_promise, workflow
from flytekit.core.python_function_task import EagerFailureHandlerTask
from flytekit.exceptions.user import FlyteValidationException
from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task
from flytekit.models import literals as literal_models
from flytekit.tools.translator import get_serializable
from flytekit.types.file import FlyteFile
from flytekit.types.schema import FlyteSchema
from flytekit.models.admin.workflow import WorkflowSpec

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
Expand Down Expand Up @@ -137,6 +139,59 @@ def t1(a: typing.Dict[str, typing.List[int]]) -> typing.Dict[str, int]:
assert wb(in1=3, in2=4, in3=5) == {"a": 7, "b": 9}


def test_imperative_with_failure():

@dataclass
class DC:
string: typing.Optional[str] = None

@task
def t1(a: typing.Dict[str, typing.List[int]]) -> typing.Dict[str, int]:
return {k: sum(v) for k, v in a.items()}

@task
def t2():
print("side effect")

@task
def t3(dc: DC) -> DC:
if dc.string is None:
DC(string="default")
return DC(string=dc.string + " world") # type: ignore[operator]

wb = ImperativeWorkflow(name="my.workflow.a")

# mapped inputs
in1 = wb.add_workflow_input("in1", int)
wb.add_workflow_input("in2", int)
in3 = wb.add_workflow_input("in3", int)
node = wb.add_entity(t1, a={"a": [in1, wb.inputs["in2"]], "b": [wb.inputs["in2"], in3]})
wb.add_workflow_output("from_n0t1", node.outputs["o0"])

# pure side effect task
wb.add_entity(t2)

failure_task = EagerFailureHandlerTask(name="sample-failure-task", inputs=wb.python_interface.inputs)
wb.add_on_failure_handler(failure_task)

# Add a data
dc_input = wb.add_workflow_input("dc_in", DC)
node_dc = wb.add_entity(t3, dc=dc_input)
wb.add_workflow_output("updated_dc", node_dc.outputs["o0"])

r = wb(in1=3, in2=4, in3=5, dc_in=DC(string="hello"))
assert r.from_n0t1 == {"a": 7, "b": 9}
assert r.updated_dc.string == "hello world"

wf_spec: WorkflowSpec = get_serializable(OrderedDict(), serialization_settings, wb)
assert len(wf_spec.template.nodes) == 3
assert len(wf_spec.template.interface.inputs) == 4

node_names = [n.id for n in wf_spec.template.nodes]
assert wf_spec.template.failure_node is not None
assert wf_spec.template.failure_node.id == "nfail"


def test_imperative_with_list_io():
@task
def t1(a: int) -> typing.List[int]:
Expand Down
Loading
Loading