Skip to content

Commit

Permalink
fix sync for map over lp in dynamic (#3155)
Browse files Browse the repository at this point in the history
Signed-off-by: Troy Chiu <[email protected]>
  • Loading branch information
troychiu authored Feb 25, 2025
1 parent f03cec8 commit 87fb3c6
Showing 1 changed file with 62 additions and 64 deletions.
126 changes: 62 additions & 64 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,63 +483,10 @@ def fetch_workflow(
wf_templates.extend([swf.template for swf in compiled_wf.sub_workflows])

node_launch_plans = {}

def find_launch_plan(
lp_ref: id_models, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec]
) -> None:
if lp_ref not in node_launch_plans:
admin_launch_plan = self.client.get_launch_plan(lp_ref)
node_launch_plans[lp_ref] = admin_launch_plan.spec

for wf_template in wf_templates:
for node in FlyteWorkflow.get_non_system_nodes(wf_template.nodes):
if node.workflow_node is not None and node.workflow_node.launchplan_ref is not None:
lp_ref = node.workflow_node.launchplan_ref
find_launch_plan(lp_ref, node_launch_plans)

# Inspect array nodes for launch plans
if (
node.array_node is not None
and node.array_node.node.workflow_node is not None
and node.array_node.node.workflow_node.launchplan_ref is not None
):
lp_ref = node.array_node.node.workflow_node.launchplan_ref
find_launch_plan(lp_ref, node_launch_plans)

# Inspect conditional branch nodes for launch plans
def get_launch_plan_from_branch(
branch_node: BranchNode, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec]
) -> None:
def get_launch_plan_from_then_node(
child_then_node: Node, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec]
) -> None:
# then_node could have nested branch_node or be a normal then_node
if child_then_node.branch_node:
get_launch_plan_from_branch(child_then_node.branch_node, node_launch_plans)
elif child_then_node.workflow_node and child_then_node.workflow_node.launchplan_ref:
lp_ref = child_then_node.workflow_node.launchplan_ref
find_launch_plan(lp_ref, node_launch_plans)

if branch_node and branch_node.if_else:
branch = branch_node.if_else
if branch.case and branch.case.then_node:
child_then_node = branch.case.then_node
get_launch_plan_from_then_node(child_then_node, node_launch_plans)
if branch.other:
for o in branch.other:
if o.then_node:
child_then_node = o.then_node
get_launch_plan_from_then_node(child_then_node, node_launch_plans)
if branch.else_node:
# else_node could have nested conditional branch_node
if branch.else_node.branch_node:
get_launch_plan_from_branch(branch.else_node.branch_node, node_launch_plans)
elif branch.else_node.workflow_node and branch.else_node.workflow_node.launchplan_ref:
lp_ref = branch.else_node.workflow_node.launchplan_ref
find_launch_plan(lp_ref, node_launch_plans)

if node.branch_node:
get_launch_plan_from_branch(node.branch_node, node_launch_plans)
self.find_launch_plan_for_node(node, node_launch_plans)

flyte_workflow = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans)
flyte_workflow.template._id = workflow_id
return flyte_workflow
Expand All @@ -556,6 +503,65 @@ def _upgrade_launchplan(self, lp: launch_plan_models.LaunchPlan) -> FlyteLaunchP
flyte_lp._flyte_workflow = workflow
return flyte_lp

def find_launch_plan(
self, lp_ref: id_models, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec]
) -> None:
if lp_ref not in node_launch_plans:
admin_launch_plan = self.client.get_launch_plan(lp_ref)
node_launch_plans[lp_ref] = admin_launch_plan.spec

def find_launch_plan_for_node(
self, node: Node, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec]
):
# Case 1: workflow node
if node.workflow_node is not None and node.workflow_node.launchplan_ref is not None:
lp_ref = node.workflow_node.launchplan_ref
self.find_launch_plan(lp_ref, node_launch_plans)

# Case 2: array node
if (
node.array_node is not None
and node.array_node.node.workflow_node is not None
and node.array_node.node.workflow_node.launchplan_ref is not None
):
lp_ref = node.array_node.node.workflow_node.launchplan_ref
self.find_launch_plan(lp_ref, node_launch_plans)

# Case 3: branch node
def get_launch_plan_from_branch(
branch_node: BranchNode, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec]
) -> None:
def get_launch_plan_from_then_node(
child_then_node: Node, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec]
) -> None:
# then_node could have nested branch_node or be a normal then_node
if child_then_node.branch_node:
get_launch_plan_from_branch(child_then_node.branch_node, node_launch_plans)
elif child_then_node.workflow_node and child_then_node.workflow_node.launchplan_ref:
lp_ref = child_then_node.workflow_node.launchplan_ref
self.find_launch_plan(lp_ref, node_launch_plans)

if branch_node and branch_node.if_else:
branch = branch_node.if_else
if branch.case and branch.case.then_node:
child_then_node = branch.case.then_node
get_launch_plan_from_then_node(child_then_node, node_launch_plans)
if branch.other:
for o in branch.other:
if o.then_node:
child_then_node = o.then_node
get_launch_plan_from_then_node(child_then_node, node_launch_plans)
if branch.else_node:
# else_node could have nested conditional branch_node
if branch.else_node.branch_node:
get_launch_plan_from_branch(branch.else_node.branch_node, node_launch_plans)
elif branch.else_node.workflow_node and branch.else_node.workflow_node.launchplan_ref:
lp_ref = branch.else_node.workflow_node.launchplan_ref
self.find_launch_plan(lp_ref, node_launch_plans)

if node.branch_node:
get_launch_plan_from_branch(node.branch_node, node_launch_plans)

def fetch_active_launchplan(
self, project: str = None, domain: str = None, name: str = None
) -> typing.Optional[FlyteLaunchPlan]:
Expand Down Expand Up @@ -2575,17 +2581,9 @@ def sync_node_execution(
if node_execution_get_data_response.dynamic_workflow is not None:
compiled_wf = node_execution_get_data_response.dynamic_workflow.compiled_workflow
node_launch_plans = {}
# TODO: Inspect branch nodes for launch plans
for template in [compiled_wf.primary.template] + [swf.template for swf in compiled_wf.sub_workflows]:
for node in FlyteWorkflow.get_non_system_nodes(template.nodes):
if (
node.workflow_node is not None
and node.workflow_node.launchplan_ref is not None
and node.workflow_node.launchplan_ref not in node_launch_plans
):
node_launch_plans[node.workflow_node.launchplan_ref] = self.client.get_launch_plan(
node.workflow_node.launchplan_ref
).spec
self.find_launch_plan_for_node(node, node_launch_plans)

dynamic_flyte_wf = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans)
execution._underlying_node_executions = [
Expand Down

0 comments on commit 87fb3c6

Please sign in to comment.