Skip to content

Commit

Permalink
expose external resources
Browse files Browse the repository at this point in the history
Signed-off-by: Troy Chiu <[email protected]>
  • Loading branch information
troychiu committed Feb 24, 2025
1 parent 5f8f8ca commit af9829e
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 33 deletions.
3 changes: 1 addition & 2 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from flytekit.models import interface as _interface
from flytekit.models import types as type_models
from flytekit.models.core import condition as _condition
from flytekit.models.core import identifier
from flytekit.models.core import identifier as _identifier
from flytekit.models.literals import Binding as _Binding
from flytekit.models.literals import RetryStrategy as _RetryStrategy
Expand Down Expand Up @@ -745,7 +744,7 @@ def __init__(self, launchplan_ref=None, sub_workflow_ref=None):
self._sub_workflow_ref = sub_workflow_ref

@property
def launchplan_ref(self) -> identifier.Identifier:
def launchplan_ref(self):
"""
[Optional] A globally unique identifier for the launch plan. Should map to Admin.
Expand Down
2 changes: 1 addition & 1 deletion flytekit/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def to_flyte_idl(self):
@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.event.TaskExecutionMetadata proto:
:param flyteidl.event.event_pb2.TaskExecutionMetadata proto:
:rtype: TaskExecutionMetadata
"""
return cls(
Expand Down
2 changes: 1 addition & 1 deletion flytekit/remote/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def promote_from_model(
cls,
model: _workflow_model.ArrayNode,
flyte_node: FlyteNode,
) -> FlyteArrayNode:
):
return cls(
flyte_node=flyte_node,
parallelism=model._parallelism,
Expand Down
34 changes: 5 additions & 29 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,15 +497,6 @@ def find_launch_plan(
lp_ref = node.workflow_node.launchplan_ref
find_launch_plan(lp_ref, node_launch_plans)

# Inspect array nodes for launch plans
if (
node.array_node is not None
and node.array_node.node.workflow_node is not None
and node.array_node.node.workflow_node.launchplan_ref is not None
):
lp_ref = node.array_node.node.workflow_node.launchplan_ref
find_launch_plan(lp_ref, node_launch_plans)

# Inspect conditional branch nodes for launch plans
def get_launch_plan_from_branch(
branch_node: BranchNode, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec]
Expand Down Expand Up @@ -2628,28 +2619,14 @@ def sync_node_execution(
if execution._node.array_node.node.task_node is not None:
t = execution._node.flyte_entity.flyte_node.task_node.flyte_task
execution._task_executions = [
self.sync_task_execution(FlyteTaskExecution.promote_from_model(task_execution), t.interface)
self.sync_task_execution(FlyteTaskExecution.promote_from_model(task_execution), t)
for task_execution in iterate_task_executions(self.client, execution.id)
]
if t.interface:
execution._interface = t.interface
else:
logger.error(f"Fetched map task does not have an interface, skipping i/o {t}")
return execution
elif execution._node.array_node.node.workflow_node is not None:
launch_plan_id = execution._node.array_node.node.workflow_node.launchplan_ref
launch_plan = self.fetch_launch_plan(
launch_plan_id.project, launch_plan_id.domain, launch_plan_id.name, launch_plan_id.version
)
task_execution_interface = launch_plan.interface.transform_interface_to_list()
execution._task_executions = [
self.sync_task_execution(
FlyteTaskExecution.promote_from_model(task_execution), task_execution_interface
)
for task_execution in iterate_task_executions(self.client, execution.id)
]
execution._interface = task_execution_interface
return execution
else:
logger.error("Array node not over task, skipping i/o")
return execution
Expand All @@ -2663,7 +2640,7 @@ def sync_node_execution(
else:
execution._task_executions = [
self.sync_task_execution(
FlyteTaskExecution.promote_from_model(t), node_mapping[node_id].task_node.flyte_task.interface
FlyteTaskExecution.promote_from_model(t), node_mapping[node_id].task_node.flyte_task
)
for t in iterate_task_executions(self.client, execution.id)
]
Expand All @@ -2678,16 +2655,15 @@ def sync_node_execution(
return execution

def sync_task_execution(
self, execution: FlyteTaskExecution, entity_interface: typing.Optional[TypedInterface] = None
self, execution: FlyteTaskExecution, entity_definition: typing.Optional[FlyteTask] = None
) -> FlyteTaskExecution:
"""Sync a FlyteTaskExecution object with its corresponding remote state."""
execution._closure = self.client.get_task_execution(execution.id).closure
execution_data = self.client.get_task_execution_data(execution.id)
task_id = execution.id.task_id
if entity_interface is None:
if entity_definition is None:
entity_definition = self.fetch_task(task_id.project, task_id.domain, task_id.name, task_id.version)
entity_interface = entity_definition.interface
return self._assign_inputs_and_outputs(execution, execution_data, entity_interface)
return self._assign_inputs_and_outputs(execution, execution_data, entity_definition.interface)

#############################
# Terminate Execution State #
Expand Down

0 comments on commit af9829e

Please sign in to comment.