Skip to content

Commit

Permalink
Adding missing fields to FlyteTask remote entity (#3093)
Browse files Browse the repository at this point in the history
* adding missing fields to flytetask remote entity

Signed-off-by: Umer Ahmad <[email protected]>
Signed-off-by: Umer Ahmad <[email protected]>

* Patch fetch task remote unit test

Signed-off-by: Umer Ahmad <[email protected]>
Signed-off-by: Umer Ahmad <[email protected]>

* Change patch to be global using fixture

Signed-off-by: Umer Ahmad <[email protected]>

* patch spark plugin remote register

Signed-off-by: Umer Ahmad <[email protected]>

---------

Signed-off-by: Umer Ahmad <[email protected]>
Signed-off-by: Umer Ahmad <[email protected]>
Signed-off-by: Umer Ahmad <[email protected]>
Co-authored-by: Umer Ahmad <[email protected]>
  • Loading branch information
UmerAhmad and Umer Ahmad authored Feb 26, 2025
1 parent 87fb3c6 commit 2e12f43
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 18 deletions.
17 changes: 17 additions & 0 deletions flytekit/remote/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def __init__(
custom,
container=None,
task_type_version: int = 0,
security_context=None,
config=None,
k8s_pod=None,
sql=None,
extended_resources=None,
should_register: bool = False,
):
super(FlyteTask, self).__init__(
Expand All @@ -61,7 +65,11 @@ def __init__(
custom,
container=container,
task_type_version=task_type_version,
security_context=security_context,
config=config,
k8s_pod=k8s_pod,
sql=sql,
extended_resources=extended_resources,
)
)
self._should_register = should_register
Expand Down Expand Up @@ -146,6 +154,10 @@ def k8s_pod(self):
def sql(self):
return self.template.sql

@property
def extended_resources(self):
return self.template.extended_resources

@property
def should_register(self) -> bool:
return self._should_register
Expand All @@ -172,6 +184,11 @@ def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> FlyteTask:
custom=base_model.custom,
container=base_model.container,
task_type_version=base_model.task_type_version,
security_context=base_model.security_context,
config=base_model.config,
k8s_pod=base_model.k8s_pod,
sql=base_model.sql,
extended_resources=base_model.extended_resources,
)
# Override the newly generated name if one exists in the base model
if not base_model.id.is_empty:
Expand Down
2 changes: 2 additions & 0 deletions plugins/flytekit-spark/tests/test_remote_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def my_python_task(a: str) -> int:
mock_client = MagicMock()
remote._client = mock_client
remote._client_initialized = True
remote._client.get_task.return_value.closure.compiled_task.template.sql = None
remote._client.get_task.return_value.closure.compiled_task.template.k8s_pod = None

mock_image_config = MagicMock(default_image=MagicMock(full="fake-cr.io/image-name:tag"))
remote.register_task(
Expand Down
38 changes: 20 additions & 18 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,14 +547,21 @@ def wf1(name: str = "union") -> float:
flyte_remote.register_script(wf1)


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_local_server(mock_client):
@pytest.fixture()
def mock_flyte_remote_client():
with patch("flytekit.remote.remote.FlyteRemote.client") as mock_flyte_remote_client:
mock_flyte_remote_client.get_task.return_value.closure.compiled_task.template.sql = None
mock_flyte_remote_client.get_task.return_value.closure.compiled_task.template.k8s_pod = None
yield mock_flyte_remote_client


def test_local_server(mock_flyte_remote_client):
ctx = FlyteContextManager.current_context()
lt = TypeEngine.to_literal_type(typing.Dict[str, int])
lm = TypeEngine.to_literal(ctx, {"hello": 55}, typing.Dict[str, int], lt)
lm = lm.map.to_flyte_idl()

mock_client.get_data.return_value = dataproxy_pb2.GetDataResponse(literal_map=lm)
mock_flyte_remote_client.get_data.return_value = dataproxy_pb2.GetDataResponse(literal_map=lm)

rr = FlyteRemote(
Config.for_sandbox(),
Expand All @@ -566,8 +573,7 @@ def test_local_server(mock_client):


@mock.patch("flytekit.remote.remote.uuid")
@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_execution_name(mock_client, mock_uuid):
def test_execution_name(mock_uuid, mock_flyte_remote_client):
test_uuid = uuid.UUID("16fd2706-8baf-433b-82eb-8c7fada847da")
mock_uuid.uuid4.return_value = test_uuid
remote = FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain")
Expand Down Expand Up @@ -597,7 +603,7 @@ def test_execution_name(mock_client, mock_uuid):
entity=ft,
inputs={"t": datetime.now(), "v": 0},
)
mock_client.create_execution.assert_has_calls(
mock_flyte_remote_client.create_execution.assert_has_calls(
[
mock.call(ANY, ANY, "execution-test", ANY, ANY),
mock.call(ANY, ANY, "execution-test-" + test_uuid.hex[:19], ANY, ANY),
Expand Down Expand Up @@ -688,9 +694,8 @@ def test_register_wf_script_mode(compress_scripts_mock, upload_file_mock, regist
)


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_fetch_active_launchplan_not_found(mock_client, remote):
mock_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found")
def test_fetch_active_launchplan_not_found(mock_flyte_remote_client, remote):
mock_flyte_remote_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found")
assert remote.fetch_active_launchplan(name="basic.list_float_wf.fake_wf") is None


Expand Down Expand Up @@ -785,8 +790,7 @@ async def eager_wf(a: int) -> int:
_get_pickled_target_dict(eager_wf)


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_launchplan_auto_activate(mock_client):
def test_launchplan_auto_activate(mock_flyte_remote_client):
@workflow
def wf() -> int:
return 1
Expand All @@ -804,15 +808,14 @@ def wf() -> int:

# The first one should not update the launchplan
rr.register_launch_plan(lp1, version="1", serialization_settings=ss)
mock_client.update_launch_plan.assert_not_called()
mock_flyte_remote_client.update_launch_plan.assert_not_called()

# the second one should
rr.register_launch_plan(lp2, version="1", serialization_settings=ss)
mock_client.update_launch_plan.assert_called()
mock_flyte_remote_client.update_launch_plan.assert_called()


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_register_task_with_node_dependency_hints(mock_client):
def test_register_task_with_node_dependency_hints(mock_flyte_remote_client):
@task
def task0():
return None
Expand Down Expand Up @@ -858,8 +861,7 @@ def workflow1():
@mock.patch("flytekit.remote.remote.FlyteRemote.fetch_launch_plan")
@mock.patch("flytekit.remote.remote.FlyteRemote.raw_register")
@mock.patch("flytekit.remote.remote.FlyteRemote._serialize_and_register")
@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_register_launch_plan(mock_client, mock_serialize_and_register, mock_raw_register,mock_fetch_launch_plan, mock_get_serializable):
def test_register_launch_plan(mock_serialize_and_register, mock_raw_register,mock_fetch_launch_plan, mock_get_serializable, mock_flyte_remote_client):
serialization_settings = SerializationSettings(
image_config=ImageConfig.auto_default_image(),
version="dummy_version",
Expand All @@ -883,7 +885,7 @@ def hello_world_wf() -> str:
lp = LaunchPlan.get_or_create(workflow=hello_world_wf, name="additional_lp_for_hello_world", default_inputs={})

mock_get_serializable.return_value = MagicMock()
mock_client.get_workflow.return_value = MagicMock()
mock_flyte_remote_client.get_workflow.return_value = MagicMock()

mock_remote_lp = MagicMock()
mock_fetch_launch_plan.return_value = mock_remote_lp
Expand Down

0 comments on commit 2e12f43

Please sign in to comment.