From 2e12f434069555ee8f598a6ee03e16e94dbcad10 Mon Sep 17 00:00:00 2001 From: UmerAhmad <48142234+UmerAhmad@users.noreply.github.com> Date: Wed, 26 Feb 2025 13:21:02 -0800 Subject: [PATCH] Adding missing fields to FlyteTask remote entity (#3093) * adding missing fields to flytetask remote entity Signed-off-by: Umer Ahmad Signed-off-by: Umer Ahmad * Patch fetch task remote unit test Signed-off-by: Umer Ahmad Signed-off-by: Umer Ahmad * Change patch to be global using fixture Signed-off-by: Umer Ahmad * patch spark plugin remote register Signed-off-by: Umer Ahmad --------- Signed-off-by: Umer Ahmad Signed-off-by: Umer Ahmad Signed-off-by: Umer Ahmad Co-authored-by: Umer Ahmad --- flytekit/remote/entities.py | 17 +++++++++ .../tests/test_remote_register.py | 2 + tests/flytekit/unit/remote/test_remote.py | 38 ++++++++++--------- 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index 16c16eedd0..73bc26b360 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -49,7 +49,11 @@ def __init__( custom, container=None, task_type_version: int = 0, + security_context=None, config=None, + k8s_pod=None, + sql=None, + extended_resources=None, should_register: bool = False, ): super(FlyteTask, self).__init__( @@ -61,7 +65,11 @@ def __init__( custom, container=container, task_type_version=task_type_version, + security_context=security_context, config=config, + k8s_pod=k8s_pod, + sql=sql, + extended_resources=extended_resources, ) ) self._should_register = should_register @@ -146,6 +154,10 @@ def k8s_pod(self): def sql(self): return self.template.sql + @property + def extended_resources(self): + return self.template.extended_resources + @property def should_register(self) -> bool: return self._should_register @@ -172,6 +184,11 @@ def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> FlyteTask: custom=base_model.custom, container=base_model.container, task_type_version=base_model.task_type_version, + security_context=base_model.security_context, + config=base_model.config, + k8s_pod=base_model.k8s_pod, + sql=base_model.sql, + extended_resources=base_model.extended_resources, ) # Override the newly generated name if one exists in the base model if not base_model.id.is_empty: diff --git a/plugins/flytekit-spark/tests/test_remote_register.py b/plugins/flytekit-spark/tests/test_remote_register.py index fd44aba4ba..384b8f87fc 100644 --- a/plugins/flytekit-spark/tests/test_remote_register.py +++ b/plugins/flytekit-spark/tests/test_remote_register.py @@ -22,6 +22,8 @@ def my_python_task(a: str) -> int: mock_client = MagicMock() remote._client = mock_client remote._client_initialized = True + remote._client.get_task.return_value.closure.compiled_task.template.sql = None + remote._client.get_task.return_value.closure.compiled_task.template.k8s_pod = None mock_image_config = MagicMock(default_image=MagicMock(full="fake-cr.io/image-name:tag")) remote.register_task( diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 9911cad02f..2e2dcdc22b 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -547,14 +547,21 @@ def wf1(name: str = "union") -> float: flyte_remote.register_script(wf1) -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_local_server(mock_client): +@pytest.fixture() +def mock_flyte_remote_client(): + with patch("flytekit.remote.remote.FlyteRemote.client") as mock_flyte_remote_client: + mock_flyte_remote_client.get_task.return_value.closure.compiled_task.template.sql = None + mock_flyte_remote_client.get_task.return_value.closure.compiled_task.template.k8s_pod = None + yield mock_flyte_remote_client + + +def test_local_server(mock_flyte_remote_client): ctx = FlyteContextManager.current_context() lt = TypeEngine.to_literal_type(typing.Dict[str, int]) lm = TypeEngine.to_literal(ctx, {"hello": 55}, typing.Dict[str, int], lt) lm = lm.map.to_flyte_idl() - mock_client.get_data.return_value = dataproxy_pb2.GetDataResponse(literal_map=lm) + mock_flyte_remote_client.get_data.return_value = dataproxy_pb2.GetDataResponse(literal_map=lm) rr = FlyteRemote( Config.for_sandbox(), @@ -566,8 +573,7 @@ def test_local_server(mock_client): @mock.patch("flytekit.remote.remote.uuid") -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_execution_name(mock_client, mock_uuid): +def test_execution_name(mock_uuid, mock_flyte_remote_client): test_uuid = uuid.UUID("16fd2706-8baf-433b-82eb-8c7fada847da") mock_uuid.uuid4.return_value = test_uuid remote = FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain") @@ -597,7 +603,7 @@ def test_execution_name(mock_client, mock_uuid): entity=ft, inputs={"t": datetime.now(), "v": 0}, ) - mock_client.create_execution.assert_has_calls( + mock_flyte_remote_client.create_execution.assert_has_calls( [ mock.call(ANY, ANY, "execution-test", ANY, ANY), mock.call(ANY, ANY, "execution-test-" + test_uuid.hex[:19], ANY, ANY), @@ -688,9 +694,8 @@ def test_register_wf_script_mode(compress_scripts_mock, upload_file_mock, regist ) -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_fetch_active_launchplan_not_found(mock_client, remote): - mock_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found") +def test_fetch_active_launchplan_not_found(mock_flyte_remote_client, remote): + mock_flyte_remote_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found") assert remote.fetch_active_launchplan(name="basic.list_float_wf.fake_wf") is None @@ -785,8 +790,7 @@ async def eager_wf(a: int) -> int: _get_pickled_target_dict(eager_wf) -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_launchplan_auto_activate(mock_client): +def test_launchplan_auto_activate(mock_flyte_remote_client): @workflow def wf() -> int: return 1 @@ -804,15 +808,14 @@ def wf() -> int: # The first one should not update the launchplan rr.register_launch_plan(lp1, version="1", serialization_settings=ss) - mock_client.update_launch_plan.assert_not_called() + mock_flyte_remote_client.update_launch_plan.assert_not_called() # the second one should rr.register_launch_plan(lp2, version="1", serialization_settings=ss) - mock_client.update_launch_plan.assert_called() + mock_flyte_remote_client.update_launch_plan.assert_called() -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_register_task_with_node_dependency_hints(mock_client): +def test_register_task_with_node_dependency_hints(mock_flyte_remote_client): @task def task0(): return None @@ -858,8 +861,7 @@ def workflow1(): @mock.patch("flytekit.remote.remote.FlyteRemote.fetch_launch_plan") @mock.patch("flytekit.remote.remote.FlyteRemote.raw_register") @mock.patch("flytekit.remote.remote.FlyteRemote._serialize_and_register") -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_register_launch_plan(mock_client, mock_serialize_and_register, mock_raw_register,mock_fetch_launch_plan, mock_get_serializable): +def test_register_launch_plan(mock_serialize_and_register, mock_raw_register,mock_fetch_launch_plan, mock_get_serializable, mock_flyte_remote_client): serialization_settings = SerializationSettings( image_config=ImageConfig.auto_default_image(), version="dummy_version", @@ -883,7 +885,7 @@ def hello_world_wf() -> str: lp = LaunchPlan.get_or_create(workflow=hello_world_wf, name="additional_lp_for_hello_world", default_inputs={}) mock_get_serializable.return_value = MagicMock() - mock_client.get_workflow.return_value = MagicMock() + mock_flyte_remote_client.get_workflow.return_value = MagicMock() mock_remote_lp = MagicMock() mock_fetch_launch_plan.return_value = mock_remote_lp