Skip to content

Commit

Permalink
fix a bug where sync_execution is not calling itself recursively (#3132)
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Zhao <[email protected]>
Co-authored-by: Mark Zhao <[email protected]>
  • Loading branch information
bz38 and Mark Zhao authored Feb 19, 2025
1 parent e35bdd0 commit 58ea1a5
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2405,7 +2405,7 @@ def sync(
:param execution:
:param entity_definition:
:param sync_nodes: By default sync will fetch data on all underlying node executions (recursively,
so subworkflows will also get picked up). Set this to False in order to prevent that (which
so subworkflows and launch plans will also get picked up). Set this to False in order to prevent that (which
will make this call faster).
:return: Returns the same execution object, but with additional information pulled in.
"""
Expand Down Expand Up @@ -2542,7 +2542,7 @@ def sync_node_execution(
launched_exec = self.fetch_execution(
project=launched_exec_id.project, domain=launched_exec_id.domain, name=launched_exec_id.name
)
self.sync_execution(launched_exec)
self.sync_execution(launched_exec, sync_nodes=True)
if launched_exec.is_done:
# The synced underlying execution should've had these populated.
execution._inputs = launched_exec.inputs
Expand Down
49 changes: 49 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,55 @@ def test_monitor_workflow_execution(register):
assert execution.outputs["o0"] == "hello world"


def test_sync_execution_sync_nodes_get_all_executions(register):
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
flyte_launch_plan = remote.fetch_launch_plan(name="basic.deep_child_workflow.parent_wf", version=VERSION)
execution = remote.execute(
flyte_launch_plan,
inputs={"a": 3},
)

poll_interval = datetime.timedelta(seconds=1)
time_to_give_up = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=600)

execution = remote.sync_execution(execution, sync_nodes=True)
while datetime.datetime.now(datetime.timezone.utc) < time_to_give_up:
if execution.is_done:
break

with pytest.raises(
FlyteAssertion, match="Please wait until the execution has completed before requesting the outputs.",
):
execution.outputs

time.sleep(poll_interval.total_seconds())
execution = remote.sync_execution(execution, sync_nodes=True)

if execution.node_executions:
assert execution.node_executions["start-node"].closure.phase == 3 # SUCCEEDED

for key in execution.node_executions:
assert execution.node_executions[key].closure.phase == 3

# check node execution getting correct number of nested workflows and executions
assert len(execution.node_executions) == 5
execution_n0 = execution.node_executions["n0"]
execution_n1 = execution.node_executions["n1"]
assert len(execution_n1.workflow_executions[0].node_executions) == 4
execution_n1_n0 = execution_n1.workflow_executions[0].node_executions["n0"]
assert len(execution_n1_n0.workflow_executions[0].node_executions) == 3
execution_n1_n0_n0 = execution_n1_n0.workflow_executions[0].node_executions["n0"]

# check inputs and outputs each node execution
assert execution_n0.inputs == {"a": 3}
assert execution_n0.outputs["o0"] == 6
assert execution_n1.inputs == {"a": 6}
assert execution_n1_n0.inputs == {"a": 6}
assert execution_n1_n0_n0.inputs == {"a": 6}
assert execution_n1_n0_n0.outputs["o0"] == 12



def test_fetch_execute_launch_plan_with_subworkflows(register):
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from flytekit import LaunchPlan, task, workflow
from flytekit.models.common import Labels


@task
def double(a: int) -> int:
return a * 2


@task
def add(a: int, b: int) -> int:
return a + b


@workflow
def my_deep_childwf(a: int = 42) -> int:
b = double(a=a)
return b

deep_child_lp = LaunchPlan.get_or_create(my_deep_childwf, name="my_fixed_deep_child_lp", labels=Labels({"l1": "v1"}))


@workflow
def my_childwf(a: int = 42) -> int:
b = deep_child_lp(a=a)
c = double(a=b)
return c


shallow_child_lp = LaunchPlan.get_or_create(my_childwf, name="my_shallow_fixed_child_lp", labels=Labels({"l1": "v1"}))


@workflow
def parent_wf(a: int) -> int:
x = double(a=a)
y = shallow_child_lp(a=x)
z = add(a=x, b=y)
return z


if __name__ == "__main__":
print(f"Running parent_wf(a=3) {parent_wf(a=3)}")

0 comments on commit 58ea1a5

Please sign in to comment.