diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 067ec5b779..fe549957e9 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -483,63 +483,10 @@ def fetch_workflow( wf_templates.extend([swf.template for swf in compiled_wf.sub_workflows]) node_launch_plans = {} - - def find_launch_plan( - lp_ref: id_models, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec] - ) -> None: - if lp_ref not in node_launch_plans: - admin_launch_plan = self.client.get_launch_plan(lp_ref) - node_launch_plans[lp_ref] = admin_launch_plan.spec - for wf_template in wf_templates: for node in FlyteWorkflow.get_non_system_nodes(wf_template.nodes): - if node.workflow_node is not None and node.workflow_node.launchplan_ref is not None: - lp_ref = node.workflow_node.launchplan_ref - find_launch_plan(lp_ref, node_launch_plans) - - # Inspect array nodes for launch plans - if ( - node.array_node is not None - and node.array_node.node.workflow_node is not None - and node.array_node.node.workflow_node.launchplan_ref is not None - ): - lp_ref = node.array_node.node.workflow_node.launchplan_ref - find_launch_plan(lp_ref, node_launch_plans) - - # Inspect conditional branch nodes for launch plans - def get_launch_plan_from_branch( - branch_node: BranchNode, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec] - ) -> None: - def get_launch_plan_from_then_node( - child_then_node: Node, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec] - ) -> None: - # then_node could have nested branch_node or be a normal then_node - if child_then_node.branch_node: - get_launch_plan_from_branch(child_then_node.branch_node, node_launch_plans) - elif child_then_node.workflow_node and child_then_node.workflow_node.launchplan_ref: - lp_ref = child_then_node.workflow_node.launchplan_ref - find_launch_plan(lp_ref, node_launch_plans) - - if branch_node and branch_node.if_else: - branch = branch_node.if_else - if branch.case and branch.case.then_node: - child_then_node = branch.case.then_node - get_launch_plan_from_then_node(child_then_node, node_launch_plans) - if branch.other: - for o in branch.other: - if o.then_node: - child_then_node = o.then_node - get_launch_plan_from_then_node(child_then_node, node_launch_plans) - if branch.else_node: - # else_node could have nested conditional branch_node - if branch.else_node.branch_node: - get_launch_plan_from_branch(branch.else_node.branch_node, node_launch_plans) - elif branch.else_node.workflow_node and branch.else_node.workflow_node.launchplan_ref: - lp_ref = branch.else_node.workflow_node.launchplan_ref - find_launch_plan(lp_ref, node_launch_plans) - - if node.branch_node: - get_launch_plan_from_branch(node.branch_node, node_launch_plans) + self.find_launch_plan_for_node(node, node_launch_plans) + flyte_workflow = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) flyte_workflow.template._id = workflow_id return flyte_workflow @@ -556,6 +503,65 @@ def _upgrade_launchplan(self, lp: launch_plan_models.LaunchPlan) -> FlyteLaunchP flyte_lp._flyte_workflow = workflow return flyte_lp + def find_launch_plan( + self, lp_ref: id_models, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec] + ) -> None: + if lp_ref not in node_launch_plans: + admin_launch_plan = self.client.get_launch_plan(lp_ref) + node_launch_plans[lp_ref] = admin_launch_plan.spec + + def find_launch_plan_for_node( + self, node: Node, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec] + ): + # Case 1: workflow node + if node.workflow_node is not None and node.workflow_node.launchplan_ref is not None: + lp_ref = node.workflow_node.launchplan_ref + self.find_launch_plan(lp_ref, node_launch_plans) + + # Case 2: array node + if ( + node.array_node is not None + and node.array_node.node.workflow_node is not None + and node.array_node.node.workflow_node.launchplan_ref is not None + ): + lp_ref = node.array_node.node.workflow_node.launchplan_ref + self.find_launch_plan(lp_ref, node_launch_plans) + + # Case 3: branch node + def get_launch_plan_from_branch( + branch_node: BranchNode, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec] + ) -> None: + def get_launch_plan_from_then_node( + child_then_node: Node, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec] + ) -> None: + # then_node could have nested branch_node or be a normal then_node + if child_then_node.branch_node: + get_launch_plan_from_branch(child_then_node.branch_node, node_launch_plans) + elif child_then_node.workflow_node and child_then_node.workflow_node.launchplan_ref: + lp_ref = child_then_node.workflow_node.launchplan_ref + self.find_launch_plan(lp_ref, node_launch_plans) + + if branch_node and branch_node.if_else: + branch = branch_node.if_else + if branch.case and branch.case.then_node: + child_then_node = branch.case.then_node + get_launch_plan_from_then_node(child_then_node, node_launch_plans) + if branch.other: + for o in branch.other: + if o.then_node: + child_then_node = o.then_node + get_launch_plan_from_then_node(child_then_node, node_launch_plans) + if branch.else_node: + # else_node could have nested conditional branch_node + if branch.else_node.branch_node: + get_launch_plan_from_branch(branch.else_node.branch_node, node_launch_plans) + elif branch.else_node.workflow_node and branch.else_node.workflow_node.launchplan_ref: + lp_ref = branch.else_node.workflow_node.launchplan_ref + self.find_launch_plan(lp_ref, node_launch_plans) + + if node.branch_node: + get_launch_plan_from_branch(node.branch_node, node_launch_plans) + def fetch_active_launchplan( self, project: str = None, domain: str = None, name: str = None ) -> typing.Optional[FlyteLaunchPlan]: @@ -2575,17 +2581,9 @@ def sync_node_execution( if node_execution_get_data_response.dynamic_workflow is not None: compiled_wf = node_execution_get_data_response.dynamic_workflow.compiled_workflow node_launch_plans = {} - # TODO: Inspect branch nodes for launch plans for template in [compiled_wf.primary.template] + [swf.template for swf in compiled_wf.sub_workflows]: for node in FlyteWorkflow.get_non_system_nodes(template.nodes): - if ( - node.workflow_node is not None - and node.workflow_node.launchplan_ref is not None - and node.workflow_node.launchplan_ref not in node_launch_plans - ): - node_launch_plans[node.workflow_node.launchplan_ref] = self.client.get_launch_plan( - node.workflow_node.launchplan_ref - ).spec + self.find_launch_plan_for_node(node, node_launch_plans) dynamic_flyte_wf = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) execution._underlying_node_executions = [