Skip to content

Commit

Permalink
don't override timeout on with_overrides if not specified (#3097)
Browse files Browse the repository at this point in the history
* don't override timeout on with_overrides if not specified

Signed-off-by: Paul Dittamo <[email protected]>

* lint

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* set default to timedelta instead of 0

Signed-off-by: Paul Dittamo <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
  • Loading branch information
pvditt authored Feb 24, 2025
1 parent 93c87c3 commit 55c15ea
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 17 deletions.
24 changes: 14 additions & 10 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class Node(object):
ID, which from the registration step
"""

TIMEOUT_OVERRIDE_SENTINEL = object()

def __init__(
self,
id: str,
Expand Down Expand Up @@ -130,7 +132,7 @@ def metadata(self) -> _workflow_model.NodeMetadata:
def _override_node_metadata(
self,
name,
timeout: Optional[Union[int, datetime.timedelta]] = None,
timeout: Optional[Union[int, datetime.timedelta, object]] = TIMEOUT_OVERRIDE_SENTINEL,
retries: Optional[int] = None,
interruptible: typing.Optional[bool] = None,
cache: typing.Optional[bool] = None,
Expand All @@ -145,14 +147,16 @@ def _override_node_metadata(
else:
node_metadata = self._metadata

if timeout is None:
node_metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
node_metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
node_metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if timeout is not Node.TIMEOUT_OVERRIDE_SENTINEL:
if timeout is None:
node_metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
node_metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
node_metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")

if retries is not None:
assert_not_promise(retries, "retries")
node_metadata._retries = (
Expand Down Expand Up @@ -184,7 +188,7 @@ def with_overrides(
aliases: Optional[Dict[str, str]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
timeout: Optional[Union[int, datetime.timedelta]] = None,
timeout: Optional[Union[int, datetime.timedelta, object]] = TIMEOUT_OVERRIDE_SENTINEL,
retries: Optional[int] = None,
interruptible: Optional[bool] = None,
name: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def with_overrides(
aliases: Optional[Dict[str, str]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
timeout: Optional[Union[int, datetime.timedelta]] = None,
timeout: Optional[Union[int, datetime.timedelta, object]] = Node.TIMEOUT_OVERRIDE_SENTINEL,
retries: Optional[int] = None,
interruptible: Optional[bool] = None,
name: Optional[str] = None,
Expand Down
39 changes: 33 additions & 6 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,18 +303,42 @@ def my_wf(a: typing.List[str]) -> typing.List[str]:
]


preset_timeout = datetime.timedelta(seconds=100)


@pytest.mark.parametrize(
"timeout,expected",
[(None, datetime.timedelta()), (10, datetime.timedelta(seconds=10))],
"timeout,t1_expected_timeout_overridden, t1_expected_timeout_unset, t2_expected_timeout_overridden, "
"t2_expected_timeout_unset",
[
(None, datetime.timedelta(0), 0, datetime.timedelta(0), preset_timeout),
(10, datetime.timedelta(seconds=10), 0,
datetime.timedelta(seconds=10), preset_timeout)
],
)
def test_timeout_override(timeout, expected):
def test_timeout_override(
timeout,
t1_expected_timeout_overridden,
t1_expected_timeout_unset,
t2_expected_timeout_overridden,
t2_expected_timeout_unset,
):
@task
def t1(a: str) -> str:
return f"*~*~*~{a}*~*~*~"

@task(
timeout=preset_timeout
)
def t2(a: str) -> str:
return f"*~*~*~{a}*~*~*~"

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(timeout=timeout)
s = t1(a=a).with_overrides(timeout=timeout)
s1 = t1(a=s).with_overrides()
s2 = t2(a=s1).with_overrides(timeout=timeout)
s3 = t2(a=s2).with_overrides()
return s3

serialization_settings = flytekit.configuration.SerializationSettings(
project="test_proj",
Expand All @@ -324,8 +348,11 @@ def my_wf(a: str) -> str:
env={},
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].metadata.timeout == expected
assert len(wf_spec.template.nodes) == 4
assert wf_spec.template.nodes[0].metadata.timeout == t1_expected_timeout_overridden
assert wf_spec.template.nodes[1].metadata.timeout == t1_expected_timeout_unset
assert wf_spec.template.nodes[2].metadata.timeout == t2_expected_timeout_overridden
assert wf_spec.template.nodes[3].metadata.timeout == t2_expected_timeout_unset


def test_timeout_override_invalid_value():
Expand Down

0 comments on commit 55c15ea

Please sign in to comment.