diff --git a/.github/workflows/build_image.yml b/.github/workflows/build_image.yml index e3b894a4e2..f4bf85f18b 100644 --- a/.github/workflows/build_image.yml +++ b/.github/workflows/build_image.yml @@ -1,4 +1,4 @@ -name: Publish Python Package +name: Publish Official flytekit Images on: workflow_dispatch: @@ -8,18 +8,18 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 with: fetch-depth: "0" - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx id: buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to GitHub Container Registry - uses: docker/login-action@v1 + uses: docker/login-action@v3 with: registry: ghcr.io username: "${{ secrets.FLYTE_BOT_USERNAME }}" @@ -54,12 +54,12 @@ jobs: with: fetch-depth: "0" - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx id: buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to GitHub Container Registry - uses: docker/login-action@v1 + uses: docker/login-action@v3 with: registry: ghcr.io username: "${{ secrets.FLYTE_BOT_USERNAME }}" diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 303362336f..912a0a01c6 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -54,7 +54,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Cache pip - uses: actions/cache@v3 + uses: actions/cache@v4 with: # This path is specific to Ubuntu path: ~/.cache/pip @@ -97,7 +97,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Cache pip - uses: actions/cache@v3 + uses: actions/cache@v4 with: # This path is specific to Ubuntu path: ~/.cache/pip @@ -149,7 +149,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Cache pip - uses: actions/cache@v3 + uses: actions/cache@v4 with: # This path is specific to Ubuntu path: ~/.cache/pip @@ -186,7 +186,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Cache pip - uses: actions/cache@v3 + uses: actions/cache@v4 with: # This path is specific to Ubuntu path: ~/.cache/pip @@ -230,7 +230,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Cache pip - uses: actions/cache@v3 + uses: actions/cache@v4 with: # This path is specific to Ubuntu path: ~/.cache/pip @@ -274,6 +274,10 @@ jobs: AWS_SECRET_ACCESS_KEY: miniostorage run: | make ${{ matrix.makefile-cmd }} + - name: Setup tmate session + if: ${{ failure() }} + uses: mxschmitt/action-tmate@v3 + timeout-minutes: 60 - name: Codecov uses: codecov/codecov-action@v3.1.0 with: @@ -394,7 +398,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Cache pip - uses: actions/cache@v3 + uses: actions/cache@v4 with: # This path is specific to Ubuntu path: ~/.cache/pip @@ -439,7 +443,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: 3.12 - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('**/dev-requirements.in') }} diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index 4aceee472e..ae8adf48e0 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -1,4 +1,4 @@ -name: Publish Python Package +name: Publish Python Packages and Official Images on: release: @@ -88,13 +88,13 @@ jobs: with: fetch-depth: "0" - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx id: buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to GitHub Container Registry if: ${{ github.event_name == 'release' }} - uses: docker/login-action@v1 + uses: docker/login-action@v3 with: registry: ghcr.io username: "${{ secrets.FLYTE_BOT_USERNAME }}" @@ -178,13 +178,13 @@ jobs: with: fetch-depth: "0" - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx id: buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to GitHub Container Registry if: ${{ github.event_name == 'release' }} - uses: docker/login-action@v1 + uses: docker/login-action@v3 with: registry: ghcr.io username: "${{ secrets.FLYTE_BOT_USERNAME }}" @@ -247,13 +247,13 @@ jobs: with: fetch-depth: "0" - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx id: buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to GitHub Container Registry if: ${{ github.event_name == 'release' }} - uses: docker/login-action@v1 + uses: docker/login-action@v3 with: registry: ghcr.io username: "${{ secrets.FLYTE_BOT_USERNAME }}" @@ -289,13 +289,13 @@ jobs: with: fetch-depth: "0" - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx id: buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to GitHub Container Registry if: ${{ github.event_name == 'release' }} - uses: docker/login-action@v1 + uses: docker/login-action@v3 with: registry: ghcr.io username: "${{ secrets.FLYTE_BOT_USERNAME }}" diff --git a/.gitignore b/.gitignore index ac4cf37b06..0db8768ef2 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ coverage.xml # Version file is auto-generated by setuptools_scm flytekit/_version.py +testing diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 5afb9d4bfb..74a50aff7f 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -178,6 +178,9 @@ :toctree: generated/ HashMethod + Cache + CachePolicy + VersionParameters Artifacts ========= @@ -223,6 +226,7 @@ from flytekit.core.artifact import Artifact from flytekit.core.base_sql_task import SQLTask from flytekit.core.base_task import SecurityContext, TaskMetadata, kwtypes +from flytekit.core.cache import Cache, CachePolicy, VersionParameters from flytekit.core.checkpointer import Checkpoint from flytekit.core.condition import conditional from flytekit.core.container_task import ContainerTask diff --git a/flytekit/clis/sdk_in_container/init.py b/flytekit/clis/sdk_in_container/init.py index 4ea470d0a8..5bfa5d3d4e 100644 --- a/flytekit/clis/sdk_in_container/init.py +++ b/flytekit/clis/sdk_in_container/init.py @@ -57,6 +57,4 @@ def init(template, project_name): processed_contents = project_template_regex.sub(project_name_bytes, zip_contents) dest_file.write(processed_contents) - click.echo( - f"Visit the {project_name} directory and follow the next steps in the Getting started guide (https://docs.flyte.org/en/latest/user_guide/getting_started_with_workflow_development/index.html) to proceed." - ) + click.echo(f"Project initialized in directory {project_name}.") diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 44141a5cc1..7a08ef31af 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -1052,9 +1052,25 @@ def _create_command( r = run_level_params.remote_instance() flyte_ctx = r.context + final_inputs_with_defaults = loaded_entity.python_interface.inputs_with_defaults + if isinstance(loaded_entity, LaunchPlan): + # For LaunchPlans it is essential to handle fixed inputs and default inputs in a special way + # Fixed inputs are inputs that are always passed to the launch plan and cannot be overridden + # Default inputs are inputs that are optional and have a default value + # The final inputs to the launch plan are a combination of the fixed inputs and the default inputs + all_inputs = loaded_entity.python_interface.inputs_with_defaults + default_inputs = loaded_entity.saved_inputs + pmap = loaded_entity.parameters + final_inputs_with_defaults = {} + for name, _ in pmap.parameters.items(): + _type, v = all_inputs[name] + if name in default_inputs: + v = default_inputs[name] + final_inputs_with_defaults[name] = _type, v + # Add options for each of the workflow inputs params = [] - for input_name, input_type_val in loaded_entity.python_interface.inputs_with_defaults.items(): + for input_name, input_type_val in final_inputs_with_defaults.items(): literal_var = loaded_entity.interface.inputs.get(input_name) python_type, default_val = input_type_val required = type(None) not in get_args(python_type) and default_val is None diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index 0c0e36c280..25622f4da7 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -19,11 +19,12 @@ """ import os -from typing import Optional, Protocol, runtime_checkable +from typing import List, Optional, Protocol, runtime_checkable from click import Group from importlib_metadata import entry_points +from flytekit import CachePolicy from flytekit.configuration import Config, get_config_file from flytekit.loggers import logger from flytekit.remote import FlyteRemote @@ -53,6 +54,10 @@ def get_default_image() -> Optional[str]: def get_auth_success_html(endpoint: str) -> Optional[str]: """Get default success html for auth. Return None to use flytekit's default success html.""" + @staticmethod + def get_default_cache_policies() -> List[CachePolicy]: + """Get default cache policies for tasks.""" + class FlytekitPlugin: @staticmethod @@ -103,6 +108,11 @@ def get_auth_success_html(endpoint: str) -> Optional[str]: """Get default success html. Return None to use flytekit's default success html.""" return None + @staticmethod + def get_default_cache_policies() -> List[CachePolicy]: + """Get default cache policies for tasks.""" + return [] + def _get_plugin_from_entrypoint(): """Get plugin from entrypoint.""" diff --git a/flytekit/core/cache.py b/flytekit/core/cache.py new file mode 100644 index 0000000000..a74a386b18 --- /dev/null +++ b/flytekit/core/cache.py @@ -0,0 +1,97 @@ +import hashlib +from dataclasses import dataclass +from typing import Callable, Generic, List, Optional, Protocol, Tuple, Union, runtime_checkable + +from typing_extensions import ParamSpec, TypeVar + +from flytekit.core.pod_template import PodTemplate +from flytekit.image_spec.image_spec import ImageSpec + +P = ParamSpec("P") +FuncOut = TypeVar("FuncOut") + + +@dataclass +class VersionParameters(Generic[P, FuncOut]): + """ + Parameters used for version hash generation. + + param func: The function to generate a version for. This is an optional parameter and can be any callable + that matches the specified parameter and return types. + :type func: Optional[Callable[P, FuncOut]] + :param container_image: The container image to generate a version for. This can be a string representing the + image name or an ImageSpec object. + :type container_image: Optional[Union[str, ImageSpec]] + """ + + func: Callable[P, FuncOut] + container_image: Optional[Union[str, ImageSpec]] = None + pod_template: Optional[PodTemplate] = None + pod_template_name: Optional[str] = None + + +@runtime_checkable +class CachePolicy(Protocol): + def get_version(self, salt: str, params: VersionParameters) -> str: ... + + +@dataclass +class Cache: + """ + Cache configuration for a task. + + :param version: The version of the task. If not provided, the version will be generated based on the cache policies. + :type version: Optional[str] + :param serialize: Boolean that indicates if identical (ie. same inputs) instances of this task should be executed in + serial when caching is enabled. This means that given multiple concurrent executions over identical inputs, + only a single instance executes and the rest wait to reuse the cached results. + :type serialize: bool + :param ignored_inputs: A tuple of input names to ignore when generating the version hash. + :type ignored_inputs: Union[Tuple[str, ...], str] + :param salt: A salt used in the hash generation. + :type salt: str + :param policies: A list of cache policies to generate the version hash. + :type policies: Optional[Union[List[CachePolicy], CachePolicy]] + """ + + version: Optional[str] = None + serialize: bool = False + ignored_inputs: Union[Tuple[str, ...], str] = () + salt: str = "" + policies: Optional[Union[List[CachePolicy], CachePolicy]] = None + + def __post_init__(self): + if isinstance(self.ignored_inputs, str): + self._ignored_inputs = (self.ignored_inputs,) + else: + self._ignored_inputs = self.ignored_inputs + + # Normalize policies so that self._policies is always a list + if self.policies is None: + from flytekit.configuration.plugin import get_plugin + + self._policies = get_plugin().get_default_cache_policies() + elif isinstance(self.policies, CachePolicy): + self._policies = [self.policies] + + if self.version is None and not self._policies: + raise ValueError("If version is not defined then at least one cache policy needs to be set") + + def get_ignored_inputs(self) -> Tuple[str, ...]: + return self._ignored_inputs + + def get_version(self, params: VersionParameters) -> str: + if self.version is not None: + return self.version + + task_hash = "" + for policy in self._policies: + try: + task_hash += policy.get_version(self.salt, params) + except Exception as e: + raise ValueError( + f"Failed to generate version for cache policy {policy}. Please consider setting the version in the Cache definition, e.g. Cache(version='v1.2.3')" + ) from e + + hash_obj = hashlib.sha256(task_hash.encode()) + return hash_obj.hexdigest() diff --git a/flytekit/core/constants.py b/flytekit/core/constants.py index 903e5d5ced..a80ed0f9e4 100644 --- a/flytekit/core/constants.py +++ b/flytekit/core/constants.py @@ -38,3 +38,7 @@ CACHE_KEY_METADATA = "cache-key-metadata" SERIALIZATION_FORMAT = "serialization-format" + +# Shared memory mount name and path +SHARED_MEMORY_MOUNT_NAME = "flyte-shared-memory" +SHARED_MEMORY_MOUNT_PATH = "/dev/shm" diff --git a/flytekit/core/node.py b/flytekit/core/node.py index f579d391ad..b192a6223e 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -3,11 +3,12 @@ import datetime import typing from typing import Any, Dict, List, Optional, Union +from typing import Literal as L from flyteidl.core import tasks_pb2 from flytekit.core.pod_template import PodTemplate -from flytekit.core.resources import Resources, convert_resources_to_resource_model +from flytekit.core.resources import Resources, construct_extended_resources, convert_resources_to_resource_model from flytekit.core.utils import _dnsify from flytekit.extras.accelerators import BaseAccelerator from flytekit.loggers import logger @@ -193,6 +194,7 @@ def with_overrides( cache: Optional[bool] = None, cache_version: Optional[str] = None, cache_serialize: Optional[bool] = None, + shared_memory: Optional[Union[L[True], str]] = None, pod_template: Optional[PodTemplate] = None, *args, **kwargs, @@ -240,7 +242,11 @@ def with_overrides( if accelerator is not None: assert_not_promise(accelerator, "accelerator") - self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl()) + + if shared_memory is not None: + assert_not_promise(shared_memory, "shared_memory") + + self._extended_resources = construct_extended_resources(accelerator=accelerator, shared_memory=shared_memory) self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 74a2f87b62..8f51472e8a 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1482,7 +1482,7 @@ def flyte_entity_call_handler( # call the blocking version of the async call handler # This is a recursive call, the async handler also calls this function, so this conditional must match # the one in the async function perfectly, otherwise you'll get infinite recursion. - loop_manager.run_sync(async_flyte_entity_call_handler, entity, **kwargs) + return loop_manager.run_sync(async_flyte_entity_call_handler, entity, **kwargs) if ctx.execution_state and ( ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index dfbd678fb6..0584d88168 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -5,6 +5,7 @@ from abc import ABC from dataclasses import dataclass from typing import Callable, Dict, List, Optional, TypeVar, Union +from typing import Literal as L from flyteidl.core import tasks_pb2 @@ -13,7 +14,7 @@ from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin from flytekit.core.context_manager import FlyteContextManager from flytekit.core.pod_template import PodTemplate -from flytekit.core.resources import Resources, ResourceSpec +from flytekit.core.resources import Resources, ResourceSpec, construct_extended_resources from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance, extract_task_module from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit @@ -51,6 +52,7 @@ def __init__( pod_template: Optional[PodTemplate] = None, pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, + shared_memory: Optional[Union[L[True], str]] = None, **kwargs, ): """ @@ -78,6 +80,8 @@ def __init__( :param pod_template: Custom PodTemplate for this task. :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. :param accelerator: The accelerator to use for this task. + :param shared_memory: If True, then shared memory will be attached to the container where the size is equal + to the allocated memory. If str, then the shared memory is set to that size. """ sec_ctx = None if secret_requests: @@ -128,6 +132,7 @@ def __init__( self.pod_template = pod_template self.accelerator = accelerator + self.shared_memory = shared_memory @property def task_resolver(self) -> TaskResolverMixin: @@ -250,10 +255,7 @@ def get_extended_resources(self, settings: SerializationSettings) -> Optional[ta """ Returns the extended resources to allocate to the task on hosted Flyte. """ - if self.accelerator is None: - return None - - return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl()) + return construct_extended_resources(accelerator=self.accelerator, shared_memory=self.shared_memory) class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 419da6e719..4544c435b7 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -607,3 +607,32 @@ async def run_with_backend(self, **kwargs): # now have to fail this eager task, because we don't want it to show up as succeeded. raise FlyteNonRecoverableSystemException(base_error) return result + + def run(self, remote: "FlyteRemote", ss: SerializationSettings, **kwargs): # type: ignore[name-defined] + """ + This is a helper function to help run eager parent tasks locally, pointing to a remote cluster. This is used + only for local testing for now. + """ + ctx = FlyteContextManager.current_context() + # tag is the current execution id + # root tag is read from the environment variable if it exists, if not, it's the current execution id + if not ctx.user_space_params or not ctx.user_space_params.execution_id: + raise AssertionError("User facing context and execution ID should be present when not running locally") + tag = ctx.user_space_params.execution_id.name + root_tag = os.environ.get(EAGER_ROOT_ENV_NAME, tag) + + # Prefix is a combination of the name of this eager workflow, and the current execution id. + prefix = self.name.split(".")[-1][:8] + prefix = f"e-{prefix}-{tag[:5]}" + prefix = _dnsify(prefix) + # Note: The construction of this object is in this function because this function should be on the + # main thread of pyflyte-execute. It needs to be on the main thread because signal handlers can only + # be installed on the main thread. + c = Controller(remote=remote, ss=ss, tag=tag, root_tag=root_tag, exec_prefix=prefix) + handler = c.get_signal_handler() + signal.signal(signal.SIGINT, handler) + signal.signal(signal.SIGTERM, handler) + builder = ctx.with_worker_queue(c) + + with FlyteContextManager.with_context(builder): + return loop_manager.run_sync(self.async_execute, self, **kwargs) diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index f64b7d23dc..c911bdb161 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -1,9 +1,15 @@ from dataclasses import dataclass, fields -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union +from typing import Literal as L -from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements +from flyteidl.core import tasks_pb2 + +if TYPE_CHECKING: + from kubernetes.client import V1PodSpec from mashumaro.mixins.json import DataClassJSONMixin +from flytekit.core.constants import SHARED_MEMORY_MOUNT_NAME, SHARED_MEMORY_MOUNT_PATH +from flytekit.extras.accelerators import BaseAccelerator from flytekit.models import task as task_models @@ -102,12 +108,43 @@ def convert_resources_to_resource_model( return task_models.Resources(requests=request_entries, limits=limit_entries) +def construct_extended_resources( + *, + accelerator: Optional[BaseAccelerator] = None, + shared_memory: Optional[Union[L[True], str]] = None, +) -> Optional[tasks_pb2.ExtendedResources]: + """Convert public extended resources to idl. + + :param accelerator: The accelerator to use for this task. + :param shared_memory: If True, then shared memory will be attached to the container where the size is equal + to the allocated memory. If str, then the shared memory is set to that size. + """ + kwargs = {} + if accelerator is not None: + kwargs["gpu_accelerator"] = accelerator.to_flyte_idl() + if isinstance(shared_memory, str) or shared_memory is True: + if shared_memory is True: + shared_memory = None + kwargs["shared_memory"] = tasks_pb2.SharedMemory( + mount_name=SHARED_MEMORY_MOUNT_NAME, + mount_path=SHARED_MEMORY_MOUNT_PATH, + size_limit=shared_memory, + ) + + if not kwargs: + return None + + return tasks_pb2.ExtendedResources(**kwargs) + + def pod_spec_from_resources( primary_container_name: Optional[str] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, k8s_gpu_resource_key: str = "nvidia.com/gpu", -) -> V1PodSpec: +) -> "V1PodSpec": + from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements + def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str): if resources is None: return None diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 6451e742c5..f39a133877 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -5,12 +5,14 @@ import os from functools import partial, update_wrapper from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload +from typing import Literal as L from typing_extensions import ParamSpec # type: ignore from flytekit.core import launch_plan as _annotated_launchplan from flytekit.core import workflow as _annotated_workflow from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin +from flytekit.core.cache import Cache, VersionParameters from flytekit.core.interface import Interface, output_name_generator, transform_function_to_interface from flytekit.core.pod_template import PodTemplate from flytekit.core.python_function_task import AsyncPythonFunctionTask, EagerAsyncPythonFunctionTask, PythonFunctionTask @@ -96,10 +98,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction def task( _task_function: None = ..., task_config: Optional[T] = ..., - cache: bool = ..., - cache_serialize: bool = ..., - cache_version: str = ..., - cache_ignore_input_vars: Tuple[str, ...] = ..., + cache: Union[bool, Cache] = ..., retries: int = ..., interruptible: Optional[bool] = ..., deprecated: str = ..., @@ -128,6 +127,8 @@ def task( pod_template_name: Optional[str] = ..., accelerator: Optional[BaseAccelerator] = ..., pickle_untyped: bool = ..., + shared_memory: Optional[Union[L[True], str]] = None, + **kwargs, ) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ... @@ -135,10 +136,7 @@ def task( def task( _task_function: Callable[P, FuncOut], task_config: Optional[T] = ..., - cache: bool = ..., - cache_serialize: bool = ..., - cache_version: str = ..., - cache_ignore_input_vars: Tuple[str, ...] = ..., + cache: Union[bool, Cache] = ..., retries: int = ..., interruptible: Optional[bool] = ..., deprecated: str = ..., @@ -167,16 +165,15 @@ def task( pod_template_name: Optional[str] = ..., accelerator: Optional[BaseAccelerator] = ..., pickle_untyped: bool = ..., + shared_memory: Optional[Union[L[True], str]] = ..., + **kwargs, ) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ... def task( _task_function: Optional[Callable[P, FuncOut]] = None, task_config: Optional[T] = None, - cache: bool = False, - cache_serialize: bool = False, - cache_version: str = "", - cache_ignore_input_vars: Tuple[str, ...] = (), + cache: Union[bool, Cache] = False, retries: int = 0, interruptible: Optional[bool] = None, deprecated: str = "", @@ -211,6 +208,8 @@ def task( pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, pickle_untyped: bool = False, + shared_memory: Optional[Union[L[True], str]] = None, + **kwargs, ) -> Union[ Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]], @@ -248,15 +247,17 @@ def my_task(x: int, y: typing.Dict[str, str]) -> str: :param _task_function: This argument is implicitly passed and represents the decorated function :param task_config: This argument provides configuration for a specific task types. Please refer to the plugins documentation for the right object to use. - :param cache: Boolean that indicates if caching should be enabled - :param cache_serialize: Boolean that indicates if identical (ie. same inputs) instances of this task should be - executed in serial when caching is enabled. This means that given multiple concurrent executions over - identical inputs, only a single instance executes and the rest wait to reuse the cached results. This - parameter does nothing without also setting the cache parameter. - :param cache_version: Cache version to use. Changes to the task signature will automatically trigger a cache miss, - but you can always manually update this field as well to force a cache miss. You should also manually bump - this version if the function body/business logic has changed, but the signature hasn't. - :param cache_ignore_input_vars: Input variables that should not be included when calculating hash for cache. + :param cache: Boolean or Cache that indicates how caching is configured. + :deprecated param cache_serialize: (deprecated - please use Cache) Boolean that indicates if identical (ie. same inputs) + instances of this task should be executed in serial when caching is enabled. This means that given multiple + concurrent executions over identical inputs, only a single instance executes and the rest wait to reuse the + cached results. This parameter does nothing without also setting the cache parameter. + :deprecated param cache_version: (deprecated - please use Cache) Cache version to use. Changes to the task signature will + automatically trigger a cache miss, but you can always manually update this field as well to force a cache + miss. You should also manually bump this version if the function body/business logic has changed, but the + signature hasn't. + :deprecated param cache_ignore_input_vars: (deprecated - please use Cache) Input variables that should not be included when + calculating hash for cache. :param retries: Number of times to retry this task during a workflow execution. :param interruptible: [Optional] Boolean that indicates that this task can be interrupted and/or scheduled on nodes with lower QoS guarantees. This will directly reduce the `$`/`execution cost` associated, @@ -341,9 +342,52 @@ def launch_dynamically(): :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. :param accelerator: The accelerator to use for this task. :param pickle_untyped: Boolean that indicates if the task allows unspecified data types. + :param shared_memory: If True, then shared memory will be attached to the container where the size is equal + to the allocated memory. If int, then the shared memory is set to that size. """ + # Maintain backwards compatibility with the old cache parameters, while cleaning up the task function definition. + cache_serialize = kwargs.get("cache_serialize") + cache_version = kwargs.get("cache_version") + cache_ignore_input_vars = kwargs.get("cache_ignore_input_vars") + + def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionTask[T]: + nonlocal cache, cache_serialize, cache_version, cache_ignore_input_vars + + # If the cache is of type bool but cache_version is not set, then assume that we want to use the + # default cache policies in Cache + if isinstance(cache, bool) and cache is True and cache_version is None: + cache = Cache( + serialize=cache_serialize if cache_serialize is not None else False, + ignored_inputs=cache_ignore_input_vars if cache_ignore_input_vars is not None else tuple(), + ) + + if isinstance(cache, Cache): + # Validate that none of the deprecated cache-related parameters are set. + if cache_serialize is not None or cache_version is not None or cache_ignore_input_vars is not None: + raise ValueError( + "cache_serialize, cache_version, and cache_ignore_input_vars are deprecated. Please use Cache object" + ) + cache_version = cache.get_version( + VersionParameters( + func=fn, + container_image=container_image, + pod_template=pod_template, + pod_template_name=pod_template_name, + ) + ) + cache_serialize = cache.serialize + cache_ignore_input_vars = cache.get_ignored_inputs() + cache = True + + # Set default values to each of the cache-related variables. Notice how this only applies if the values are not + # set explicitly, which only happens if they are not set at all in the invocation of the task. + if cache_serialize is None: + cache_serialize = False + if cache_version is None: + cache_version = "" + if cache_ignore_input_vars is None: + cache_ignore_input_vars = tuple() - def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: _metadata = TaskMetadata( cache=cache, cache_serialize=cache_serialize, @@ -390,6 +434,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]: pod_template_name=pod_template_name, accelerator=accelerator, pickle_untyped=pickle_untyped, + shared_memory=shared_memory, ) update_wrapper(task_instance, decorated_fn) return task_instance diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 759a0b06ca..bcd34cf6f3 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -40,6 +40,7 @@ from flytekit.core.context_manager import FlyteContext from flytekit.core.hash import HashMethod from flytekit.core.type_helpers import load_type_from_tag +from flytekit.core.type_match_checking import literal_types_match from flytekit.core.utils import load_proto_from_file, str2bool, timeit from flytekit.exceptions import user as user_exceptions from flytekit.interaction.string_literals import literal_map_string_repr @@ -2019,25 +2020,43 @@ def __init__(self): @staticmethod def extract_types(t: Optional[Type[dict]]) -> typing.Tuple: + if t is None: + return None, None + + # Get the origin and type arguments. _origin = get_origin(t) _args = get_args(t) - if _origin is not None: - if _origin is Annotated and _args: - # _args holds the type arguments to the dictionary, in other words: - # >>> get_args(Annotated[dict[int, str], FlyteAnnotation("abc")]) - # (dict[int, str], ) - for x in _args[1:]: - if isinstance(x, FlyteAnnotation): - raise ValueError( - f"Flytekit does not currently have support for FlyteAnnotations applied to dicts. {t} cannot be parsed." - ) - if _origin is dict and _args is not None: + + # If not annotated or dict, return None, None. + if _origin is None: + return None, None + + # If this is something like Annotated[dict[int, str], FlyteAnnotation("abc")], + # we need to check if there's a FlyteAnnotation in the metadata. + if _origin is Annotated: + # This case should never happen since Python's typing system requires at least two arguments + # for Annotated[...] - a type and an annotation. Including this check for completeness. + if not _args: + return None, None + + first_arg = _args[0] + # Check the rest of the metadata (after the dict type itself). + for x in _args[1:]: + if isinstance(x, FlyteAnnotation): + raise ValueError( + f"Flytekit does not currently have support for FlyteAnnotations applied to dicts. {t} cannot be parsed." + ) + # Recursively process the first argument if it's Annotated (or dict). + return DictTransformer.extract_types(first_arg) + + # If the origin is dict, return the type arguments if they exist. + if _origin is dict: + # _args can be (). + if _args is not None: return _args # type: ignore - elif _origin is Annotated: - return DictTransformer.extract_types(_args[0]) - else: - raise ValueError(f"Trying to extract dictionary type information from a non-dict type {t}") - return None, None + + # Otherwise, we do not support this type in extract_types. + raise ValueError(f"Trying to extract dictionary type information from a non-dict type {t}") @staticmethod async def dict_to_generic_literal( @@ -2415,6 +2434,27 @@ def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing. return cls(**constructor_inputs) +def strict_type_hint_matching(input_val: typing.Any, target_literal_type: LiteralType) -> typing.Type: + """ + Try to be smarter about guessing the type of the input (and hence the transformer). + If the literal type from the transformer for type(v), matches the literal type of the interface, then we + can use type(). Otherwise, fall back to guess python type from the literal type. + Raises ValueError, like in case of [1,2,3] type() will just give `list`, which won't work. + Raises ValueError also if the transformer found for the raw type doesn't have a literal type match. + """ + native_type = type(input_val) + transformer: TypeTransformer = TypeEngine.get_transformer(native_type) + inferred_literal_type = transformer.get_literal_type(native_type) + # note: if no good match, transformer will be the pickle transformer, but type will not match unless it's the + # pickle type so will fall back to normal guessing + if literal_types_match(inferred_literal_type, target_literal_type): + return type(input_val) + + raise ValueError( + f"Transformer for {native_type} returned literal type {inferred_literal_type} which doesn't match {target_literal_type}" + ) + + def _check_and_covert_float(lv: Literal) -> float: if lv.scalar.primitive.float_value is not None: return lv.scalar.primitive.float_value diff --git a/flytekit/core/type_match_checking.py b/flytekit/core/type_match_checking.py new file mode 100644 index 0000000000..2292a35a8a --- /dev/null +++ b/flytekit/core/type_match_checking.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from flytekit.models.core.types import EnumType +from flytekit.models.types import LiteralType, UnionType + + +def literal_types_match(downstream: LiteralType, upstream: LiteralType) -> bool: + """ + Returns if two LiteralTypes are the same. + Takes into account arbitrary ordering of enums and unions, otherwise just an equivalence check. + """ + + # If the types are exactly the same, return True + if downstream == upstream: + return True + + if downstream.collection_type: + if not upstream.collection_type: + return False + return literal_types_match(downstream.collection_type, upstream.collection_type) + + if downstream.map_value_type: + if not upstream.map_value_type: + return False + return literal_types_match(downstream.map_value_type, upstream.map_value_type) + + # Handle enum types + if downstream.enum_type and upstream.enum_type: + return _enum_types_match(downstream.enum_type, upstream.enum_type) + + # Handle union types + if downstream.union_type and upstream.union_type: + return _union_types_match(downstream.union_type, upstream.union_type) + + # If none of the above conditions are met, the types are not castable + return False + + +def _enum_types_match(downstream: EnumType, upstream: EnumType) -> bool: + return set(upstream.values) == set(downstream.values) + + +def _union_types_match(downstream: UnionType, upstream: UnionType) -> bool: + if len(downstream.variants) != len(upstream.variants): + return False + + down_sorted = sorted(downstream.variants, key=lambda x: str(x)) + up_sorted = sorted(upstream.variants, key=lambda x: str(x)) + + for downstream_variant, upstream_variant in zip(down_sorted, up_sorted): + if not literal_types_match(downstream_variant, upstream_variant): + return False + + return True diff --git a/flytekit/core/worker_queue.py b/flytekit/core/worker_queue.py index 6b95a748b4..9c996a610a 100644 --- a/flytekit/core/worker_queue.py +++ b/flytekit/core/worker_queue.py @@ -23,6 +23,7 @@ from flytekit.loggers import developer_logger, logger from flytekit.models.common import Labels from flytekit.models.core.execution import WorkflowExecutionPhase +from flytekit.utils.rate_limiter import RateLimiter if typing.TYPE_CHECKING: from flytekit.remote.remote_callable import RemoteEntity @@ -185,6 +186,7 @@ def __init__(self, remote: FlyteRemote, ss: SerializationSettings, tag: str, roo ) self.__runner_thread.start() atexit.register(self._close, stopping_condition=self.stopping_condition, runner=self.__runner_thread) + self.rate_limiter = RateLimiter(rpm=60) # Executions should be tracked in the following way: # a) you should be able to list by label, all executions generated by the current eager task, @@ -219,21 +221,27 @@ def reconcile_one(self, update: Update): try: item = update.work_item if item.wf_exec is None: - logger.warning(f"reconcile should launch for {id(item)} entity name: {item.entity.name}") + logger.info(f"reconcile should launch for {id(item)} entity name: {item.entity.name}") wf_exec = self.launch_execution(update.work_item, update.idx) update.wf_exec = wf_exec + # Set this to running even if the launched execution was a re-run and already succeeded. + # This forces the bottom half of the conditional to run, properly fetching all outputs. update.status = ItemStatus.RUNNING else: - if not item.wf_exec.is_done: - update.status = ItemStatus.RUNNING - # Technically a mutating operation, but let's pretend it's not - update.wf_exec = self.remote.sync_execution(item.wf_exec) - if update.wf_exec.closure.phase == WorkflowExecutionPhase.SUCCEEDED: - update.status = ItemStatus.SUCCESS - elif update.wf_exec.closure.phase == WorkflowExecutionPhase.FAILED: - update.status = ItemStatus.FAILED + # Technically a mutating operation, but let's pretend it's not + update.wf_exec = self.remote.sync_execution(item.wf_exec) + + # Fill in update status + if update.wf_exec.closure.phase == WorkflowExecutionPhase.SUCCEEDED: + update.status = ItemStatus.SUCCESS + elif update.wf_exec.closure.phase == WorkflowExecutionPhase.FAILED: + update.status = ItemStatus.FAILED + elif update.wf_exec.closure.phase == WorkflowExecutionPhase.ABORTED: + update.status = ItemStatus.FAILED + elif update.wf_exec.closure.phase == WorkflowExecutionPhase.TIMED_OUT: + update.status = ItemStatus.FAILED else: - developer_logger.debug(f"Execution {item.wf_exec.id.name} is done, item is {item.status}") + update.status = ItemStatus.RUNNING except Exception as e: logger.error( @@ -261,10 +269,12 @@ def _apply_updates(self, update_items: typing.Dict[uuid.UUID, Update]) -> None: if item.uuid in update_items: update = update_items[typing.cast(uuid.UUID, item.uuid)] item.wf_exec = update.wf_exec - assert update.status is not None + if update.status is None: + raise AssertionError(f"update's status missing for {item.entity.name}") item.status = update.status if update.status == ItemStatus.SUCCESS: - assert update.wf_exec is not None + if update.wf_exec is None: + raise AssertionError(f"update's wf_exec missing for {item.entity.name}") item.result = update.wf_exec.outputs.as_python_native(item.python_interface) elif update.status == ItemStatus.FAILED: # If update object already has an error, then use that, otherwise look for one in the @@ -274,7 +284,10 @@ def _apply_updates(self, update_items: typing.Dict[uuid.UUID, Update]) -> None: else: from flytekit.exceptions.eager import EagerException - assert update.wf_exec is not None + if update.wf_exec is None: + raise AssertionError( + f"update's wf_exec missing in error case for {item.entity.name}" + ) exc = EagerException( f"Error executing {update.work_item.entity.name} with error:" @@ -344,7 +357,7 @@ def get_execution_name(self, entity: RunnableEntity, idx: int, input_kwargs: dic def launch_execution(self, wi: WorkItem, idx: int) -> FlyteWorkflowExecution: """This function launches executions.""" - logger.warning(f"Launching execution for {wi.entity.name} {idx=} with {wi.input_kwargs}") + logger.info(f"Launching execution for {wi.entity.name} {idx=} with {wi.input_kwargs}") if wi.result is None and wi.error is None: l = self.get_labels() e = self.get_env() @@ -359,6 +372,7 @@ def launch_execution(self, wi: WorkItem, idx: int) -> FlyteWorkflowExecution: assert self.ss.version version = self.ss.version + self.rate_limiter.sync_acquire() # todo: if the execution already exists, remote.execute will return that execution. in the future # we can add input checking to make sure the inputs are indeed a match. wf_exec = self.remote.execute( @@ -394,7 +408,8 @@ async def add(self, entity: RunnableEntity, input_kwargs: dict[str, typing.Any]) if i.status == ItemStatus.SUCCESS: return i.result elif i.status == ItemStatus.FAILED: - assert i.error is not None + if i.error is None: + raise AssertionError(f"Error should not be None if status is failed for {entity.name}") raise i.error else: await asyncio.sleep(2) # Small delay to avoid busy-waiting diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 0567e701d5..aa213502bb 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -296,7 +296,10 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis # Get default arguments and override with kwargs passed in input_kwargs = self.python_interface.default_inputs_as_kwargs input_kwargs.update(kwargs) - self.compile() + ctx = FlyteContext.current_context() + # todo: remove this conditional once context manager is thread safe + if not (ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.EAGER_EXECUTION): + self.compile() try: return flyte_entity_call_handler(self, *args, **input_kwargs) except Exception as exc: diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 8cc2cc21cf..f415367ebc 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -526,12 +526,12 @@ def convert( # If the input matches the default value in the launch plan, serialization can be skipped. if param and value == param.default: return None - lit = TypeEngine.to_literal(self._flyte_ctx, value, self._python_type, self._literal_type) + # If this is used for remote execution then we need to convert it back to a python native type if not self._is_remote: - # If this is used for remote execution then we need to convert it back to a python native type - # for FlyteRemote to use it. This maybe a double conversion penalty! - return TypeEngine.to_python_value(self._flyte_ctx, lit, self._python_type) + return value + + lit = TypeEngine.to_literal(self._flyte_ctx, value, self._python_type, self._literal_type) return lit except click.BadParameter: raise diff --git a/flytekit/interaction/string_literals.py b/flytekit/interaction/string_literals.py index 6f70488981..3e3ef82822 100644 --- a/flytekit/interaction/string_literals.py +++ b/flytekit/interaction/string_literals.py @@ -1,6 +1,8 @@ import base64 +import json import typing +import msgpack from google.protobuf.json_format import MessageToDict from flytekit.models.literals import Literal, LiteralMap, Primitive, Scalar @@ -42,6 +44,8 @@ def scalar_to_string(scalar: Scalar) -> typing.Any: if scalar.blob: return scalar.blob.uri if scalar.binary: + if scalar.binary.tag == "msgpack": + return json.dumps(msgpack.unpackb(scalar.binary.value)) return base64.b64encode(scalar.binary.value) if scalar.generic: return MessageToDict(scalar.generic) diff --git a/flytekit/loggers.py b/flytekit/loggers.py index 9a9281cad3..3b5613347e 100644 --- a/flytekit/loggers.py +++ b/flytekit/loggers.py @@ -1,8 +1,13 @@ +import importlib.util import logging import os import typing -from pythonjsonlogger import jsonlogger +if importlib.util.find_spec("pythonjsonlogger.json"): + # Module was renamed: https://github.com/nhairs/python-json-logger/releases/tag/v3.1.0 + from pythonjsonlogger import json as jsonlogger +else: + from pythonjsonlogger import jsonlogger from .tools import interactive diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 7eef68cdbe..4b9fd8d856 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -642,22 +642,21 @@ def pod_template(self) -> typing.Optional[PodTemplate]: return self._pod_template def to_flyte_idl(self): + pod_template_override = None + if self.pod_template is not None: + pod_template_override = tasks_pb2.K8sPod( + metadata=K8sObjectMetadata( + labels=self.pod_template.labels, + annotations=self.pod_template.annotations, + ).to_flyte_idl(), + pod_spec=json_format.Parse(json.dumps(self.pod_template.pod_spec), struct_pb2.Struct()), + primary_container_name=self.pod_template.primary_container_name, + ) return _core_workflow.TaskNodeOverrides( resources=self.resources.to_flyte_idl() if self.resources is not None else None, extended_resources=self.extended_resources, container_image=self.container_image, - pod_template=tasks_pb2.K8sPod( - metadata=K8sObjectMetadata( - labels=self.pod_template.labels if self.pod_template else None, - annotations=self.pod_template.annotations if self.pod_template else None, - ).to_flyte_idl() - if self.pod_template is not None - else None, - pod_spec=json_format.Parse(json.dumps(self.pod_template.pod_spec), struct_pb2.Struct()) - if self.pod_template - else None, - primary_container_name=self.pod_template.primary_container_name if self.pod_template else None, - ), + pod_template=pod_template_override, ) @classmethod @@ -665,9 +664,20 @@ def from_flyte_idl(cls, pb2_object): resources = Resources.from_flyte_idl(pb2_object.resources) extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None container_image = pb2_object.container_image if len(pb2_object.container_image) > 0 else None + pod_template = pb2_object.pod_template if pb2_object.HasField("pod_template") else None if bool(resources.requests) or bool(resources.limits): - return cls(resources=resources, extended_resources=extended_resources, container_image=container_image) - return cls(resources=None, extended_resources=extended_resources, container_image=container_image) + return cls( + resources=resources, + extended_resources=extended_resources, + container_image=container_image, + pod_template=pod_template, + ) + return cls( + resources=None, + extended_resources=extended_resources, + container_image=container_image, + pod_template=pod_template, + ) class TaskNode(_common.FlyteIdlEntity): diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 7e4ff02645..0019e4d79b 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -10,6 +10,7 @@ import flyteidl.admin.execution_pb2 as _execution_pb2 import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 +from google.protobuf.wrappers_pb2 import BoolValue import flytekit from flytekit.models import common as _common_models @@ -179,6 +180,7 @@ def __init__( max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, overwrite_cache: Optional[bool] = None, + interruptible: Optional[bool] = None, envs: Optional[_common_models.Envs] = None, tags: Optional[typing.List[str]] = None, cluster_assignment: Optional[ClusterAssignment] = None, @@ -198,6 +200,7 @@ def __init__( parallelism/concurrency of MapTasks is independent from this. :param security_context: Optional security context to use for this execution. :param overwrite_cache: Optional flag to overwrite the cache for this execution. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: flytekit.models.common.Envs environment variables to set for this execution. :param tags: Optional list of tags to apply to the execution. :param execution_cluster_label: Optional execution cluster label to use for this execution. @@ -213,6 +216,7 @@ def __init__( self._max_parallelism = max_parallelism self._security_context = security_context self._overwrite_cache = overwrite_cache + self._interruptible = interruptible self._envs = envs self._tags = tags self._cluster_assignment = cluster_assignment @@ -287,6 +291,10 @@ def security_context(self) -> typing.Optional[security.SecurityContext]: def overwrite_cache(self) -> Optional[bool]: return self._overwrite_cache + @property + def interruptible(self) -> Optional[bool]: + return self._interruptible + @property def envs(self) -> Optional[_common_models.Envs]: return self._envs @@ -321,6 +329,7 @@ def to_flyte_idl(self): max_parallelism=self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, + interruptible=BoolValue(value=self.interruptible) if self.interruptible is not None else None, envs=self.envs.to_flyte_idl() if self.envs else None, tags=self.tags, cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, @@ -351,6 +360,7 @@ def from_flyte_idl(cls, p): if p.security_context else None, overwrite_cache=p.overwrite_cache, + interruptible=p.interruptible.value if p.HasField("interruptible") else None, envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None, tags=p.tags, cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 52d3b92187..88430aa28a 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -9,7 +9,6 @@ from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct from google.protobuf.wrappers_pb2 import BoolValue -from kubernetes.client import ApiClient from flytekit.models import common as _common from flytekit.models import interface as _interface @@ -1077,6 +1076,8 @@ def to_pod_template(self) -> "PodTemplate": @classmethod def from_pod_template(cls, pod_template: "PodTemplate") -> "K8sPod": + from kubernetes.client import ApiClient + return cls( metadata=K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations), pod_spec=ApiClient().sanitize_for_serialization(pod_template.pod_spec), diff --git a/flytekit/remote/data.py b/flytekit/remote/data.py index 84fcff1420..40175f6446 100644 --- a/flytekit/remote/data.py +++ b/flytekit/remote/data.py @@ -1,7 +1,9 @@ +import json import os import pathlib import typing +import msgpack from google.protobuf.json_format import MessageToJson from rich import print @@ -39,6 +41,9 @@ def download_literal( elif data.scalar.generic is not None: with open(download_to / f"{var}.json", "w") as f: f.write(MessageToJson(data.scalar.generic)) + elif data.scalar.binary is not None and data.scalar.binary.tag == "msgpack": + with open(download_to / f"{var}.json", "w") as f: + json.dump(msgpack.unpackb(data.scalar.binary.value), f) else: print( f"[dim]Skipping {var} val {literal_string_repr(data)} as it is not a blob, structured dataset," diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index e9ec1d0d0a..ef8b28d866 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -58,7 +58,7 @@ from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec from flytekit.core.task import ReferenceTask from flytekit.core.tracker import extract_task_module -from flytekit.core.type_engine import LiteralsResolver, TypeEngine +from flytekit.core.type_engine import LiteralsResolver, TypeEngine, strict_type_hint_matching from flytekit.core.workflow import PythonFunctionWorkflow, ReferenceWorkflow, WorkflowBase, WorkflowFailurePolicy from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.user import ( @@ -125,6 +125,8 @@ MOST_RECENT_FIRST = admin_common_models.Sort("created_at", admin_common_models.Sort.Direction.DESCENDING) +LATEST_VERSION_STR = "latest" + class RegistrationSkipped(Exception): """ @@ -163,12 +165,14 @@ def _get_entity_identifier( name: str, version: typing.Optional[str] = None, ): + if version is None or version == LATEST_VERSION_STR: + version = _get_latest_version(list_entities_method, project, domain, name) return Identifier( resource_type, project, domain, name, - version if version is not None else _get_latest_version(list_entities_method, project, domain, name), + version, ) @@ -804,6 +808,21 @@ def list_tasks_by_version( ##################### # Register Entities # ##################### + def _resolve_version( + self, version: typing.Optional[str], entity: typing.Any, ss: SerializationSettings + ) -> typing.Tuple[str, typing.Optional[PickledEntity]]: + if version is None and self.interactive_mode_enabled: + md5_bytes, pickled_target_dict = _get_pickled_target_dict(entity) + return self._version_from_hash( + md5_bytes, ss, entity.python_interface.default_inputs_as_kwargs, *self._get_image_names(entity) + ), pickled_target_dict + elif version is not None: + return version, None + elif ss.version is not None: + return ss.version, None + raise ValueError( + "Version must be provided when not in interactive mode. If you want to use latest version pass 'latest'" + ) def _resolve_identifier(self, t: int, name: str, version: str, ss: SerializationSettings) -> Identifier: ident = Identifier( @@ -1051,15 +1070,14 @@ def register_workflow( :return: """ if serialization_settings is None: - _, _, _, module_file = extract_task_module(entity) - project_root = _find_project_root(module_file) serialization_settings = SerializationSettings( image_config=ImageConfig.auto_default_image(), - source_root=project_root, project=self.default_project, domain=self.default_domain, ) + version, _ = self._resolve_version(version, entity, serialization_settings) + ident = run_sync( self._serialize_and_register, entity, serialization_settings, version, options, default_launch_plan ) @@ -1382,7 +1400,7 @@ def _wf_exists( def register_launch_plan( self, entity: LaunchPlan, - version: str, + version: typing.Optional[str] = None, project: typing.Optional[str] = None, domain: typing.Optional[str] = None, options: typing.Optional[Options] = None, @@ -1401,16 +1419,15 @@ def register_launch_plan( :param options: """ if serialization_settings is None: - _, _, _, module_file = extract_task_module(entity.workflow) - project_root = _find_project_root(module_file) serialization_settings = SerializationSettings( image_config=ImageConfig.auto_default_image(), - source_root=project_root, project=project or self.default_project, domain=domain or self.default_domain, version=version, ) + version, _ = self._resolve_version(version, entity, serialization_settings) + if self._wf_exists( name=entity.workflow.name, version=version, @@ -1457,6 +1474,7 @@ def _execute( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1475,6 +1493,7 @@ def _execute( :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten for a single execution. If enabled, all calculations are performed even if cached results would be available, overwriting the stored data once execution finishes successfully. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -1514,9 +1533,12 @@ def _execute( else: if k not in type_hints: try: - type_hints[k] = TypeEngine.guess_python_type(input_flyte_type_map[k].type) + type_hints[k] = strict_type_hint_matching(v, input_flyte_type_map[k].type) except ValueError: - logger.debug(f"Could not guess type for {input_flyte_type_map[k].type}, skipping...") + developer_logger.debug( + f"Could not guess type for {input_flyte_type_map[k].type}, skipping..." + ) + type_hints[k] = TypeEngine.guess_python_type(input_flyte_type_map[k].type) variable = entity.interface.inputs.get(k) hint = type_hints[k] self.file_access._get_upload_signed_url_fn = functools.partial( @@ -1545,6 +1567,7 @@ def _execute( 0, ), overwrite_cache=overwrite_cache, + interruptible=interruptible, notifications=notifications, disable_all=options.disable_notifications, labels=options.labels, @@ -1623,6 +1646,7 @@ def execute( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1663,6 +1687,7 @@ def execute( :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten for a single execution. If enabled, all calculations are performed even if cached results would be available, overwriting the stored data once execution finishes successfully. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to be set for the execution. :param tags: Tags to be set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -1687,6 +1712,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1704,6 +1730,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1719,6 +1746,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1734,6 +1762,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1749,6 +1778,7 @@ def execute( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1767,6 +1797,7 @@ def execute( image_config=image_config, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1787,6 +1818,7 @@ def execute( options=options, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1805,6 +1837,7 @@ def execute( options=options, wait=wait, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1827,6 +1860,7 @@ def execute_remote_task_lp( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1847,6 +1881,7 @@ def execute_remote_task_lp( options=options, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1865,6 +1900,7 @@ def execute_remote_wf( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1886,6 +1922,7 @@ def execute_remote_wf( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1904,6 +1941,7 @@ def execute_reference_task( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1935,6 +1973,7 @@ def execute_reference_task( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -1951,6 +1990,7 @@ def execute_reference_workflow( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -1996,6 +2036,7 @@ def execute_reference_workflow( options=options, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -2012,6 +2053,7 @@ def execute_reference_launch_plan( wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -2043,6 +2085,7 @@ def execute_reference_launch_plan( wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -2059,12 +2102,13 @@ def execute_local_task( project: str = None, domain: str = None, name: str = None, - version: str = None, + version: str = "latest", execution_name: typing.Optional[str] = None, execution_name_prefix: typing.Optional[str] = None, image_config: typing.Optional[ImageConfig] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -2079,11 +2123,13 @@ def execute_local_task( :param project: The execution project, will default to the Remote's default project. :param domain: The execution domain, will default to the Remote's default domain. :param name: specific name of the task to run. - :param version: specific version of the task to run. + :param version: specific version of the task to run, default is a special string ``latest``, which implies latest version by time :param execution_name: If provided, will use this name for the execution. + :param execution_name_prefix: If provided, will use this prefix for the execution name. :param image_config: If provided, will use this image config in the pod. :param wait: If True, will wait for the execution to complete before returning. :param overwrite_cache: If True, will overwrite the cache. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -2098,12 +2144,7 @@ def execute_local_task( domain=domain or self._default_domain, version=version, ) - pickled_target_dict = None - if version is None and self.interactive_mode_enabled: - md5_bytes, pickled_target_dict = _get_pickled_target_dict(entity) - version = self._version_from_hash( - md5_bytes, ss, entity.python_interface.default_inputs_as_kwargs, *self._get_image_names(entity) - ) + version, pickled_target_dict = self._resolve_version(version, entity, ss) resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) resolved_identifiers_dict = asdict(resolved_identifiers) @@ -2118,6 +2159,7 @@ def execute_local_task( # object (look into the function, the passed in ss is basically ignored). How should it be piped in? # https://github.com/flyteorg/flyte/issues/6070 flyte_task: FlyteTask = self.register_task(entity, ss, version) + return self.execute( flyte_task, inputs, @@ -2128,6 +2170,7 @@ def execute_local_task( wait=wait, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, options=options, envs=envs, tags=tags, @@ -2149,6 +2192,7 @@ def execute_local_workflow( options: typing.Optional[Options] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, @@ -2156,22 +2200,24 @@ def execute_local_workflow( ) -> FlyteWorkflowExecution: """ Execute an @workflow decorated function. - :param entity: - :param inputs: - :param project: - :param domain: - :param name: - :param version: - :param execution_name: - :param image_config: - :param options: - :param wait: - :param overwrite_cache: - :param envs: - :param tags: - :param cluster_pool: - :param execution_cluster_label: - :return: + + :param entity: The workflow to execute + :param inputs: Input dictionary + :param project: Project to execute in + :param domain: Domain to execute in + :param name: Optional name override for the workflow + :param version: Optional version for the workflow + :param execution_name: Optional name for the execution + :param image_config: Optional image config override + :param options: Optional Options object + :param wait: Whether to wait for execution completion + :param overwrite_cache: If True, will overwrite the cache + :param interruptible: Optional flag to override the default interruptible flag of the executed entity + :param envs: Environment variables to set for the execution + :param tags: Tags to set for the execution + :param cluster_pool: Specify cluster pool on which newly created execution should be placed + :param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed + :return: FlyteWorkflowExecution object """ if not image_config: image_config = ImageConfig.auto_default_image() @@ -2227,6 +2273,7 @@ def execute_local_workflow( options=options, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -2246,12 +2293,14 @@ def execute_local_launch_plan( options: typing.Optional[Options] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, + interruptible: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, cluster_pool: typing.Optional[str] = None, execution_cluster_label: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ + Execute a locally defined `LaunchPlan`. :param entity: The locally defined launch plan object :param inputs: Inputs to be passed into the execution as a dict with Python native values. @@ -2263,6 +2312,7 @@ def execute_local_launch_plan( :param options: Options to be passed into the execution. :param wait: If True, will wait for the execution to complete before returning. :param overwrite_cache: If True, will overwrite the cache. + :param interruptible: Optional flag to override the default interruptible flag of the executed entity. :param envs: Environment variables to be passed into the execution. :param tags: Tags to be passed into the execution. :param cluster_pool: Specify cluster pool on which newly created execution should be placed. @@ -2277,6 +2327,7 @@ def execute_local_launch_plan( flyte_launchplan: FlyteLaunchPlan = self.fetch_launch_plan(**resolved_identifiers_dict) flyte_launchplan.python_interface = entity.python_interface except FlyteEntityNotExistException: + logger.info("Registering launch plan because it wasn't found in Flyte Admin.") flyte_launchplan: FlyteLaunchPlan = self.register_launch_plan( entity, version=version, @@ -2294,6 +2345,7 @@ def execute_local_launch_plan( wait=wait, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, + interruptible=interruptible, envs=envs, tags=tags, cluster_pool=cluster_pool, @@ -2353,7 +2405,7 @@ def sync( :param execution: :param entity_definition: :param sync_nodes: By default sync will fetch data on all underlying node executions (recursively, - so subworkflows will also get picked up). Set this to False in order to prevent that (which + so subworkflows and launch plans will also get picked up). Set this to False in order to prevent that (which will make this call faster). :return: Returns the same execution object, but with additional information pulled in. """ @@ -2490,7 +2542,7 @@ def sync_node_execution( launched_exec = self.fetch_execution( project=launched_exec_id.project, domain=launched_exec_id.domain, name=launched_exec_id.name ) - self.sync_execution(launched_exec) + self.sync_execution(launched_exec, sync_nodes=True) if launched_exec.is_done: # The synced underlying execution should've had these populated. execution._inputs = launched_exec.inputs @@ -2500,7 +2552,7 @@ def sync_node_execution( return execution # If a node ran a static subworkflow or a dynamic subworkflow then the parent flag will be set. - if execution.metadata.is_parent_node or execution.metadata.is_array: + if execution.metadata.is_parent_node: # We'll need to query child node executions regardless since this is a parent node child_node_executions = iterate_node_executions( self.client, @@ -2554,23 +2606,32 @@ def sync_node_execution( "not have inputs and outputs filled in" ) return execution - elif execution._node.array_node is not None: - # if there's a task node underneath the array node, let's fetch the interface for it - if execution._node.array_node.node.task_node is not None: - tid = execution._node.array_node.node.task_node.reference_id - t = self.fetch_task(tid.project, tid.domain, tid.name, tid.version) - if t.interface: - execution._interface = t.interface - else: - logger.error(f"Fetched map task does not have an interface, skipping i/o {t}") - return execution - else: - logger.error(f"Array node not over task, skipping i/o {t}") - return execution else: logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}") raise ValueError(f"Node execution undeterminable, entity has type {type(execution._node)}") + # Handle the case for array nodes + elif execution.metadata.is_array: + if execution._node.array_node is None: + logger.error("Array node not found") + return execution + # if there's a task node underneath the array node, let's fetch the interface for it + if execution._node.array_node.node.task_node is not None: + tid = execution._node.array_node.node.task_node.reference_id + t = self.fetch_task(tid.project, tid.domain, tid.name, tid.version) + execution._task_executions = [ + self.sync_task_execution(FlyteTaskExecution.promote_from_model(task_execution), t) + for task_execution in iterate_task_executions(self.client, execution.id) + ] + if t.interface: + execution._interface = t.interface + else: + logger.error(f"Fetched map task does not have an interface, skipping i/o {t}") + return execution + else: + logger.error("Array node not over task, skipping i/o") + return execution + # Handle the case for gate nodes elif execution._node.gate_node is not None: logger.info("Skipping gate node execution for now - gate nodes don't have inputs and outputs filled in") diff --git a/flytekit/utils/rate_limiter.py b/flytekit/utils/rate_limiter.py new file mode 100644 index 0000000000..f18c751cd0 --- /dev/null +++ b/flytekit/utils/rate_limiter.py @@ -0,0 +1,54 @@ +import asyncio +from collections import deque +from datetime import datetime, timedelta + +from flytekit.loggers import developer_logger +from flytekit.utils.asyn import run_sync + + +class RateLimiter: + """Rate limiter that allows up to a certain number of requests per minute.""" + + def __init__(self, rpm: int): + if not isinstance(rpm, int) or rpm <= 0 or rpm > 100: + raise ValueError("Rate must be a positive integer between 1 and 100") + self.rpm = rpm + self.queue = deque() + self.sem = asyncio.Semaphore(rpm) + self.delay = timedelta(seconds=60) # always 60 seconds since this we're using a per-minute rate limiter + + def sync_acquire(self): + run_sync(self.acquire) + + async def acquire(self): + async with self.sem: + now = datetime.now() + # Start by clearing out old data + while self.queue and (now - self.queue[0]) > self.delay: + self.queue.popleft() + + # Now that the queue only has valid entries, we'll need to wait if the queue is full. + if len(self.queue) >= self.rpm: + # Compute necessary delay and sleep that amount + # First pop one off, so another coroutine won't try to base its wait time off the same timestamp. But + # if you pop it off, the next time this code runs it'll think there's enough spots... so add the + # expected time back onto the queue before awaiting. Once you await, you lose the 'thread' and other + # coroutines can run. + # Basically the invariant is: this block of code leaves the number of items in the queue unchanged: + # it'll pop off a timestamp but immediately add one back. + # Because of the semaphore, we don't have to worry about the one we add to the end being referenced + # because there will never be more than RPM-1 other coroutines running at the same time. + earliest = self.queue.popleft() + delay: timedelta = (earliest + self.delay) - now + if delay.total_seconds() > 0: + next_time = earliest + self.delay + self.queue.append(next_time) + developer_logger.debug( + f"Capacity reached - removed time {earliest} and added back {next_time}, sleeping for {delay.total_seconds()}" + ) + await asyncio.sleep(delay.total_seconds()) + else: + developer_logger.debug(f"No more need to wait, {earliest=} vs {now=}") + self.queue.append(now) + else: + self.queue.append(now) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py index d254ec5960..3dd8e16d84 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -122,20 +122,22 @@ async def do( ) ) with context_manager.FlyteContextManager.with_context(builder) as new_ctx: - outputs = { - "result": TypeEngine.to_literal( - new_ctx, - truncated_result if truncated_result else result, - Annotated[dict, kwtypes(allow_pickle=True)], - TypeEngine.to_literal_type(dict), - ), - "idempotence_token": TypeEngine.to_literal( - new_ctx, - idempotence_token, - str, - TypeEngine.to_literal_type(str), - ), - } + outputs = LiteralMap( + { + "result": TypeEngine.to_literal( + new_ctx, + truncated_result if truncated_result else result, + Annotated[dict, kwtypes(allow_pickle=True)], + TypeEngine.to_literal_type(dict), + ), + "idempotence_token": TypeEngine.to_literal( + new_ctx, + idempotence_token, + str, + TypeEngine.to_literal_type(str), + ), + } + ) return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index 3046b07dd9..f57194dc7c 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -2,6 +2,7 @@ from unittest import mock import msgpack import base64 +import json import pytest from flyteidl.core.execution_pb2 import TaskExecution @@ -67,7 +68,7 @@ @mock.patch( "flytekitplugins.awssagemaker_inference.boto3_agent.Boto3AgentMixin._call", ) -async def test_agent(mock_boto_call, mock_return_value): +async def test_agent(mock_boto_call, mock_return_value, request): mock_boto_call.return_value = mock_return_value[0] agent = AgentRegistry.get_agent("boto") @@ -158,12 +159,17 @@ async def test_agent(mock_boto_call, mock_return_value): assert resource.phase == TaskExecution.SUCCEEDED + if request.node.callspec.indices["mock_return_value"] in (0, 1): + assert isinstance(resource.outputs, literals.LiteralMap) + if mock_return_value[0][0]: outputs = literal_map_string_repr(resource.outputs) if "pickle_check" in mock_return_value[0][0]: assert "pickle_file" in outputs["result"] else: - outputs["result"] = msgpack.loads(base64.b64decode(outputs["result"])) + raw_result = outputs["result"] + parsed_result = json.loads(raw_result) + outputs["result"] = parsed_result assert ( outputs["result"]["EndpointConfigArn"] == "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" diff --git a/plugins/flytekit-onnx-pytorch/dev-requirements.txt b/plugins/flytekit-onnx-pytorch/dev-requirements.txt index 36b3da649c..dd7b35c48c 100644 --- a/plugins/flytekit-onnx-pytorch/dev-requirements.txt +++ b/plugins/flytekit-onnx-pytorch/dev-requirements.txt @@ -4,7 +4,7 @@ # # pip-compile dev-requirements.in # -certifi==2023.7.22 +certifi==2024.7.4 # via requests charset-normalizer==3.3.2 # via requests diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py index 8daf236828..cac9e17b88 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/batch/agent.py @@ -105,7 +105,7 @@ async def get( result = retrieved_result.to_dict() ctx = FlyteContextManager.current_context() - outputs = {"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))} + outputs = LiteralMap({"result": TypeEngine.to_literal(ctx, result, Dict, TypeEngine.to_literal_type(Dict))}) return Resource(phase=flyte_phase, outputs=outputs, message=message) diff --git a/plugins/flytekit-openai/tests/openai_batch/test_agent.py b/plugins/flytekit-openai/tests/openai_batch/test_agent.py index 476ca5c8ba..3cf2b462eb 100644 --- a/plugins/flytekit-openai/tests/openai_batch/test_agent.py +++ b/plugins/flytekit-openai/tests/openai_batch/test_agent.py @@ -1,8 +1,8 @@ +import json from datetime import timedelta from unittest import mock from unittest.mock import AsyncMock -import msgpack -import base64 + import pytest from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.openai.batch.agent import BatchEndpointMetadata @@ -156,11 +156,13 @@ async def test_openai_batch_agent(mock_retrieve, mock_create, mock_context): mock_retrieve.return_value = batch_retrieve_result resource = await agent.get(metadata) assert resource.phase == TaskExecution.SUCCEEDED + assert isinstance(resource.outputs, literals.LiteralMap) outputs = literal_map_string_repr(resource.outputs) - result = outputs["result"] - assert msgpack.loads(base64.b64decode(result)) == batch_retrieve_result.to_dict() + raw_result = outputs["result"] + parsed_result = json.loads(raw_result) + assert parsed_result == batch_retrieve_result.to_dict() # Status: Failed mock_retrieve.return_value = batch_retrieve_result_failure diff --git a/plugins/flytekit-pandera/flytekitplugins/pandera/pandas_transformer.py b/plugins/flytekit-pandera/flytekitplugins/pandera/pandas_transformer.py index b69d676189..326611c584 100644 --- a/plugins/flytekit-pandera/flytekitplugins/pandera/pandas_transformer.py +++ b/plugins/flytekit-pandera/flytekitplugins/pandera/pandas_transformer.py @@ -91,7 +91,6 @@ def to_literal( html = renderer.to_html(python_val, schema, exc) val = python_val if config.on_error == "raise": - # render the deck before raising the error raise exc elif config.on_error == "warn": logger.warning(str(exc)) @@ -100,7 +99,7 @@ def to_literal( else: html = renderer.to_html(val, schema) finally: - Deck(renderer._title, html) + Deck(renderer._title, html).publish() lv = self._sd_transformer.to_literal(ctx, val, pandas.DataFrame, expected) @@ -138,7 +137,7 @@ def to_python_value( else: html = renderer.to_html(val, schema) finally: - Deck(renderer._title, html) + Deck(renderer._title, html).publish() return val diff --git a/plugins/flytekit-pandera/setup.py b/plugins/flytekit-pandera/setup.py index c4c0b8161f..1af201f199 100644 --- a/plugins/flytekit-pandera/setup.py +++ b/plugins/flytekit-pandera/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pandera>=0.7.1", "pandas", "great_tables"] +plugin_requires = ["flytekit>=1.15.0b2,<2.0.0", "pandera>=0.7.1", "pandas", "great_tables"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-ray/flytekitplugins/ray/task.py b/plugins/flytekit-ray/flytekitplugins/ray/task.py index c87c86276d..3f0dacf9f1 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/task.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/task.py @@ -92,7 +92,7 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: working_dir = os.getcwd() init_params["runtime_env"] = { "working_dir": working_dir, - "excludes": ["script_mode.tar.gz", "fast*.tar.gz"], + "excludes": ["script_mode.tar.gz", "fast*.tar.gz", ".python_history"], } ray.init(**init_params) @@ -104,16 +104,23 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] runtime_env = base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode() if cfg.runtime_env else None runtime_env_yaml = yaml.dump(cfg.runtime_env) if cfg.runtime_env else None - if cfg.head_node_config.requests or cfg.head_node_config.limits: - head_pod_template = PodTemplate( - pod_spec=pod_spec_from_resources( - primary_container_name=_RAY_HEAD_CONTAINER_NAME, - requests=cfg.head_node_config.requests, - limits=cfg.head_node_config.limits, + head_group_spec = None + if cfg.head_node_config: + if cfg.head_node_config.requests or cfg.head_node_config.limits: + head_pod_template = PodTemplate( + pod_spec=pod_spec_from_resources( + primary_container_name=_RAY_HEAD_CONTAINER_NAME, + requests=cfg.head_node_config.requests, + limits=cfg.head_node_config.limits, + ) ) + else: + head_pod_template = cfg.head_node_config.pod_template + + head_group_spec = HeadGroupSpec( + cfg.head_node_config.ray_start_params, + K8sPod.from_pod_template(head_pod_template) if head_pod_template else None, ) - else: - head_pod_template = cfg.head_node_config.pod_template worker_group_spec: typing.List[WorkerGroupSpec] = [] for c in cfg.worker_node_config: @@ -134,14 +141,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] ray_job = RayJob( ray_cluster=RayCluster( - head_group_spec=( - HeadGroupSpec( - cfg.head_node_config.ray_start_params, - K8sPod.from_pod_template(head_pod_template) if head_pod_template else None, - ) - if cfg.head_node_config - else None - ), + head_group_spec=head_group_spec, worker_group_spec=worker_group_spec, enable_autoscaling=(cfg.enable_autoscaling if cfg.enable_autoscaling else False), ), diff --git a/plugins/flytekit-ray/tests/test_ray.py b/plugins/flytekit-ray/tests/test_ray.py index 8fd8d432a9..c9b00a6dad 100644 --- a/plugins/flytekit-ray/tests/test_ray.py +++ b/plugins/flytekit-ray/tests/test_ray.py @@ -43,6 +43,15 @@ ttl_seconds_after_finished=20, ) +default_img = Image(name="default", fqn="test", tag="tag") +settings = SerializationSettings( + project="proj", + domain="dom", + version="123", + image_config=ImageConfig(default_image=default_img, images=[default_img]), + env={}, +) + def test_ray_task(): @task(task_config=config) @@ -56,14 +65,6 @@ def t1(a: int) -> str: assert t1.task_type == "ray" assert isinstance(t1, PythonFunctionTask) - default_img = Image(name="default", fqn="test", tag="tag") - settings = SerializationSettings( - project="proj", - domain="dom", - version="123", - image_config=ImageConfig(default_image=default_img, images=[default_img]), - env={}, - ) head_pod_template = PodTemplate( pod_spec=pod_spec_from_resources( primary_container_name="ray-head", @@ -117,3 +118,52 @@ def t1(a: int) -> str: assert t1(a=3) == "5" assert ray.is_initialized() + +existing_cluster_config = RayJobConfig( + worker_node_config=[], + runtime_env={"pip": ["numpy"]}, + address="localhost:8265", +) + +def test_ray_task_existing_cluster(): + @task(task_config=existing_cluster_config) + def t1(a: int) -> str: + assert ray.is_initialized() + inc = a + 2 + return str(inc) + + assert t1.task_config is not None + assert t1.task_config == existing_cluster_config + assert t1.task_type == "ray" + assert isinstance(t1, PythonFunctionTask) + + ray_job_pb = RayJob( + ray_cluster=RayCluster(worker_group_spec=[]), + runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(), + runtime_env_yaml=yaml.dump({"pip": ["numpy"]}), + ).to_flyte_idl() + + assert t1.get_custom(settings) == MessageToDict(ray_job_pb) + + assert t1.get_command(settings) == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.test_ray", + "task-name", + "t1", + ] + + # cannot execute this as it will try to hit a non-existent cluster diff --git a/pyproject.toml b/pyproject.toml index 741ff78a17..55184c1c36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.14.3", + "flyteidl>=1.15.0", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", @@ -33,7 +33,6 @@ dependencies = [ "jsonlines", "jsonpickle", "keyring>=18.0.1", - "kubernetes>=12.0.1", "markdown-it-py", "marshmallow-enum", "marshmallow-jsonschema>=0.12.0", diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 61dc480a81..74015673bb 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -162,6 +162,7 @@ def test_fetch_execute_launch_plan(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) flyte_launch_plan = remote.fetch_launch_plan(name="basic.hello_world.my_wf", version=VERSION) execution = remote.execute(flyte_launch_plan, inputs={}, wait=True) + print("Execution Error:", execution.error) assert execution.outputs["o0"] == "hello world" @@ -246,6 +247,55 @@ def test_monitor_workflow_execution(register): assert execution.outputs["o0"] == "hello world" +def test_sync_execution_sync_nodes_get_all_executions(register): + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + flyte_launch_plan = remote.fetch_launch_plan(name="basic.deep_child_workflow.parent_wf", version=VERSION) + execution = remote.execute( + flyte_launch_plan, + inputs={"a": 3}, + ) + + poll_interval = datetime.timedelta(seconds=1) + time_to_give_up = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=600) + + execution = remote.sync_execution(execution, sync_nodes=True) + while datetime.datetime.now(datetime.timezone.utc) < time_to_give_up: + if execution.is_done: + break + + with pytest.raises( + FlyteAssertion, match="Please wait until the execution has completed before requesting the outputs.", + ): + execution.outputs + + time.sleep(poll_interval.total_seconds()) + execution = remote.sync_execution(execution, sync_nodes=True) + + if execution.node_executions: + assert execution.node_executions["start-node"].closure.phase == 3 # SUCCEEDED + + for key in execution.node_executions: + assert execution.node_executions[key].closure.phase == 3 + + # check node execution getting correct number of nested workflows and executions + assert len(execution.node_executions) == 5 + execution_n0 = execution.node_executions["n0"] + execution_n1 = execution.node_executions["n1"] + assert len(execution_n1.workflow_executions[0].node_executions) == 4 + execution_n1_n0 = execution_n1.workflow_executions[0].node_executions["n0"] + assert len(execution_n1_n0.workflow_executions[0].node_executions) == 3 + execution_n1_n0_n0 = execution_n1_n0.workflow_executions[0].node_executions["n0"] + + # check inputs and outputs each node execution + assert execution_n0.inputs == {"a": 3} + assert execution_n0.outputs["o0"] == 6 + assert execution_n1.inputs == {"a": 6} + assert execution_n1_n0.inputs == {"a": 6} + assert execution_n1_n0_n0.inputs == {"a": 6} + assert execution_n1_n0_n0.outputs["o0"] == 12 + + + def test_fetch_execute_launch_plan_with_subworkflows(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) @@ -613,6 +663,26 @@ def test_execute_workflow_with_maptask(register): wait=True, ) assert execution.outputs["o0"] == [4, 5, 6] + assert len(execution.node_executions["n0"].task_executions) == 1 + +def test_executes_nested_workflow_dictating_interruptible(register): + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + flyte_launch_plan = remote.fetch_launch_plan(name="basic.child_workflow.parent_wf", version=VERSION) + # The values we want to test for + interruptible_values = [True, False, None] + executions = [] + for creation_interruptible in interruptible_values: + execution = remote.execute(flyte_launch_plan, inputs={"a": 10}, wait=False, interruptible=creation_interruptible) + executions.append(execution) + # Wait for all executions to complete + for execution, expected_interruptible in zip(executions, interruptible_values): + execution = remote.wait(execution, timeout=300) + # Check that the parent workflow is interruptible as expected + assert execution.spec.interruptible == expected_interruptible + # Check that the child workflow is interruptible as expected + subwf_execution_id = execution.node_executions["n1"].closure.workflow_node_metadata.execution_id.name + subwf_execution = remote.fetch_execution(project=PROJECT, domain=DOMAIN, name=subwf_execution_id) + assert subwf_execution.spec.interruptible == expected_interruptible @pytest.mark.lftransfers @@ -896,6 +966,7 @@ def test_attr_access_sd(): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) execution = remote.fetch_execution(name=execution_id) execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=15)) + assert execution.error is None, f"Execution failed with error: {execution.error}" assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}" # Delete the remote file to free the space @@ -1025,3 +1096,31 @@ def test_check_secret(kubectl_secret, task): f"Execution failed with phase: {execution.closure.phase}" ) assert execution.outputs['o0'] == kubectl_secret + + +def test_execute_workflow_with_dataclass(): + """Test remote execution of a workflow with dataclass input.""" + from tests.flytekit.integration.remote.workflows.basic.dataclass_wf import wf, MyConfig + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True) + + config = MyConfig(op_list=["a", "b", "c"]) + out = remote.execute( + wf, + inputs={"config": config}, + wait=True, + version=VERSION, + image_config=ImageConfig.from_images(IMAGE), + ) + assert out.outputs["o0"] == "a,b,c" + + # Test with None value + config = MyConfig(op_list=None) + out = remote.execute( + wf, + inputs={"config": config}, + wait=True, + version=VERSION + "_none", + image_config=ImageConfig.from_images(IMAGE), + ) + assert out.outputs["o0"] == "" diff --git a/tests/flytekit/integration/remote/workflows/basic/dataclass_wf.py b/tests/flytekit/integration/remote/workflows/basic/dataclass_wf.py new file mode 100644 index 0000000000..c17a8cb48a --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/dataclass_wf.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import Optional + +from flytekit import task, workflow +from mashumaro.mixins.json import DataClassJSONMixin + + +@dataclass +class MyConfig(DataClassJSONMixin): + op_list: Optional[list[str]] = None + + +@task +def t1(config: MyConfig) -> str: + if config.op_list: + return ",".join(config.op_list) + return "" + + +@workflow +def wf(config: MyConfig = MyConfig()) -> str: + return t1(config=config) + + +if __name__ == "__main__": + wf() diff --git a/tests/flytekit/integration/remote/workflows/basic/deep_child_workflow.py b/tests/flytekit/integration/remote/workflows/basic/deep_child_workflow.py new file mode 100644 index 0000000000..59274ffec0 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/deep_child_workflow.py @@ -0,0 +1,42 @@ +from flytekit import LaunchPlan, task, workflow +from flytekit.models.common import Labels + + +@task +def double(a: int) -> int: + return a * 2 + + +@task +def add(a: int, b: int) -> int: + return a + b + + +@workflow +def my_deep_childwf(a: int = 42) -> int: + b = double(a=a) + return b + +deep_child_lp = LaunchPlan.get_or_create(my_deep_childwf, name="my_fixed_deep_child_lp", labels=Labels({"l1": "v1"})) + + +@workflow +def my_childwf(a: int = 42) -> int: + b = deep_child_lp(a=a) + c = double(a=b) + return c + + +shallow_child_lp = LaunchPlan.get_or_create(my_childwf, name="my_shallow_fixed_child_lp", labels=Labels({"l1": "v1"})) + + +@workflow +def parent_wf(a: int) -> int: + x = double(a=a) + y = shallow_child_lp(a=x) + z = add(a=x, b=y) + return z + + +if __name__ == "__main__": + print(f"Running parent_wf(a=3) {parent_wf(a=3)}") diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 8a1709d668..8388bb77c6 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -2,6 +2,7 @@ from datetime import datetime import os import re +import sys import textwrap import time import typing @@ -519,6 +520,7 @@ def test_get_traceback_str(): assert expected_error_re.match(traceback_str) is not None +@pytest.mark.skipif(sys.platform.startswith("win"), reason="granularity of timestamp is not reliable") def test_get_container_error_timestamp(monkeypatch) -> None: # Set the timezone to UTC monkeypatch.setenv("TZ", "UTC") diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 848dbbf6e1..e4ab4145dc 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -16,7 +16,7 @@ from flytekit.clis.sdk_in_container.run import ( RunLevelParams, get_entities_in_file, - run_command, + run_command, WorkflowCommand, ) from flytekit.configuration import Config, Image, ImageConfig from flytekit.core.task import task @@ -28,8 +28,7 @@ from flytekit.remote import FlyteRemote from typing import Iterator, List from flytekit.types.iterator import JSON -from flytekit import workflow - +from flytekit import workflow, LaunchPlan pytest.importorskip("pandas") @@ -205,7 +204,7 @@ def test_pyflyte_run_cli(workflow_file): "--s", json.dumps({"x": {"i": 1, "a": ["h", "e"]}}), "--t", - json.dumps({"i": [{"i":1,"a":["h","e"]}]}), + json.dumps({"i": [{"i": 1, "a": ["h", "e"]}]}), ], catch_exceptions=False, ) @@ -293,7 +292,8 @@ def test_all_types_with_yaml_input(): result = runner.invoke( pyflyte.main, - ["run", os.path.join(DIR_NAME, "workflow.py"), "my_wf", "--inputs-file", os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.yaml")], + ["run", os.path.join(DIR_NAME, "workflow.py"), "my_wf", "--inputs-file", + os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.yaml")], catch_exceptions=False, ) assert result.exit_code == 0, result.stdout @@ -301,7 +301,7 @@ def test_all_types_with_yaml_input(): def test_all_types_with_pipe_input(monkeypatch): runner = CliRunner() - input= str(json.load(open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"),"r"))) + input = str(json.load(open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"), "r"))) monkeypatch.setattr("sys.stdin", io.StringIO(input)) result = runner.invoke( pyflyte.main, @@ -321,18 +321,18 @@ def test_all_types_with_pipe_input(monkeypatch): "pipe_input, option_input", [ ( - str( - json.load( - open( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "my_wf_input.json", - ), - "r", + str( + json.load( + open( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "my_wf_input.json", + ), + "r", + ) ) - ) - ), - "GREEN", + ), + "GREEN", ) ], ) @@ -579,11 +579,11 @@ def test_list_default_arguments(wf_path): reason="Github macos-latest image does not have docker installed as per https://github.com/orgs/community/discussions/25777", ) def test_pyflyte_run_run( - mock_image, - image_string, - leaf_configuration_file_name, - final_image_config, - mock_image_spec_builder, + mock_image, + image_string, + leaf_configuration_file_name, + final_image_config, + mock_image_spec_builder, ): mock_image.return_value = "cr.flyte.org/flyteorg/flytekit:py3.9-latest" ImageBuildEngine.register("test", mock_image_spec_builder) @@ -597,10 +597,10 @@ def tk(): ... image_config = ImageConfig.validate_image(None, "", image_tuple) pp = ( - pathlib.Path(__file__).parent.parent.parent - / "configuration" - / "configs" - / leaf_configuration_file_name + pathlib.Path(__file__).parent.parent.parent + / "configuration" + / "configs" + / leaf_configuration_file_name ) obj = RunLevelParams( @@ -641,7 +641,7 @@ def jsons(): @mock.patch("flytekit.configuration.default_images.DefaultImages.default_image") def test_pyflyte_run_with_iterator_json_type( - mock_image, mock_image_spec_builder, caplog + mock_image, mock_image_spec_builder, caplog ): mock_image.return_value = "cr.flyte.org/flyteorg/flytekit:py3.9-latest" ImageBuildEngine.register( @@ -679,10 +679,10 @@ def tk_simple_iterator(x: Iterator[int] = iter([1, 2, 3])) -> Iterator[int]: image_config = ImageConfig.validate_image(None, "", image_tuple) pp = ( - pathlib.Path(__file__).parent.parent.parent - / "configuration" - / "configs" - / "no_images.yaml" + pathlib.Path(__file__).parent.parent.parent + / "configuration" + / "configs" + / "no_images.yaml" ) obj = RunLevelParams( @@ -796,9 +796,9 @@ def test_pyflyte_run_with_none(a_val, workflow_file): [ (["--env", "MY_ENV_VAR=hello"], '["MY_ENV_VAR"]', "hello"), ( - ["--env", "MY_ENV_VAR=hello", "--env", "ABC=42"], - '["MY_ENV_VAR","ABC"]', - "hello,42", + ["--env", "MY_ENV_VAR=hello", "--env", "ABC=42"], + '["MY_ENV_VAR","ABC"]', + "hello,42", ), ], ) @@ -813,16 +813,16 @@ def test_pyflyte_run_with_none(a_val, workflow_file): def test_envvar_local_execution(envs, envs_argument, expected_output, workflow_file): runner = CliRunner() args = ( - [ - "run", - ] - + envs - + [ - workflow_file, - "wf_with_env_vars", - "--env_vars", - ] - + [envs_argument] + [ + "run", + ] + + envs + + [ + workflow_file, + "wf_with_env_vars", + "--env_vars", + ] + + [envs_argument] ) result = runner.invoke( pyflyte.main, diff --git a/tests/flytekit/unit/cli/pyflyte/test_run_lps.py b/tests/flytekit/unit/cli/pyflyte/test_run_lps.py new file mode 100644 index 0000000000..30211a5534 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_run_lps.py @@ -0,0 +1,72 @@ +import click + +from flytekit import task, workflow, LaunchPlan +from flytekit.clis.sdk_in_container.run import WorkflowCommand, RunLevelParams + +import mock +import pytest + + +@task +def two_inputs(x: int, y: str) -> str: + return f"{x},{y}" + +@workflow +def two_inputs_wf(x: int, y: str) -> str: + return two_inputs(x, y) + +lp_fixed_y_default_x = LaunchPlan.get_or_create( + workflow=two_inputs_wf, + name="fixed-default-inputs", + fixed_inputs={"y": "hello"}, + default_inputs={"x": 1} +) + +lp_fixed_y = LaunchPlan.get_or_create( + workflow=two_inputs_wf, + name="fixed-y", + fixed_inputs={"y": "hello"}, +) + +lp_fixed_x = LaunchPlan.get_or_create( + workflow=two_inputs_wf, + name="fixed-x", + fixed_inputs={"x": 1}, +) + +lp_fixed_all = LaunchPlan.get_or_create( + workflow=two_inputs_wf, + name="fixed-all", + fixed_inputs={"x": 1, "y": "test"}, +) + +lp_default_x = LaunchPlan.get_or_create( + name="default-inputs", + workflow=two_inputs_wf, + default_inputs={"x": 1} +) + +lp_simple = LaunchPlan.get_or_create( + workflow=two_inputs_wf, + name="no-fixed-default", +) + +@pytest.mark.parametrize("lp_execs", [ + (lp_fixed_y_default_x, {"x": 1}), + (lp_fixed_y, {"x": None}), + (lp_fixed_x, {"y": None}), + (lp_fixed_all, {}), + (lp_default_x, {"y": None, "x": 1}), + (lp_simple, {"x": None, "y": None}), +]) +def test_workflowcommand_create_command(lp_execs): + cmd = WorkflowCommand("testfile.py") + rp = RunLevelParams() + ctx = click.Context(cmd, obj=rp) + lp, exp_opts = lp_execs + opts = cmd._create_command(ctx, "test_entity", rp, lp, "launch plan").params + for o in opts: + if "input" in o.name: + continue + assert o.name in exp_opts + assert o.default == exp_opts[o.name] diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index ed1fc7fdd0..b911678a9a 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -443,6 +443,26 @@ def wf(x: typing.List[int]): assert task_spec.template.extended_resources.gpu_accelerator.device == "test_gpu" +def test_serialization_extended_resources_shared_memory(serialization_settings): + @task( + shared_memory="2Gi" + ) + def t1(a: int) -> int: + return a + 1 + + arraynode_maptask = map_task(t1) + + @workflow + def wf(x: typing.List[int]): + return arraynode_maptask(a=x) + + od = OrderedDict() + get_serializable(od, serialization_settings, wf) + task_spec = od[arraynode_maptask] + + assert task_spec.template.extended_resources.shared_memory.size_limit == "2Gi" + + def test_supported_node_type(): @task def test_task(): diff --git a/tests/flytekit/unit/core/test_cache.py b/tests/flytekit/unit/core/test_cache.py new file mode 100644 index 0000000000..3c0fae9dfa --- /dev/null +++ b/tests/flytekit/unit/core/test_cache.py @@ -0,0 +1,183 @@ +from typing import OrderedDict +from mock import mock +import pytest +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.cache import Cache, CachePolicy, VersionParameters +from flytekit.core.task import task +from flytekit.tools.translator import get_serializable_task + + +class SaltCachePolicy(CachePolicy): + def get_version(self, salt: str, params: VersionParameters) -> str: + return salt + + +class ExceptionCachePolicy(CachePolicy): + def get_version(self, salt: str, params: VersionParameters) -> str: + raise Exception("This is an exception") + + +@pytest.fixture +def default_serialization_settings(): + default_image = Image(name="default", fqn="full/name", tag="some-tag") + default_image_config = ImageConfig(default_image=default_image) + return SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + +def test_task_arguments(default_serialization_settings): + @task(cache=Cache(policies=SaltCachePolicy())) + def t1(a: int) -> int: + return a + + serialized_t1 = get_serializable_task(OrderedDict(), default_serialization_settings, t1) + assert serialized_t1.template.metadata.discoverable == True + assert serialized_t1.template.metadata.discovery_version == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + + @task(cache=Cache(version="a-version")) + def t2(a: int) -> int: + return a + + serialized_t2 = get_serializable_task(OrderedDict(), default_serialization_settings, t2) + assert serialized_t2.template.metadata.discoverable == True + assert serialized_t2.template.metadata.discovery_version == "a-version" + + @task(cache=Cache(version="a-version", serialize=True)) + def t3(a: int) -> int: + return a + + serialized_t3 = get_serializable_task(OrderedDict(), default_serialization_settings, t3) + assert serialized_t3.template.metadata.discoverable == True + assert serialized_t3.template.metadata.discovery_version == "a-version" + assert serialized_t3.template.metadata.cache_serializable == True + + @task(cache=Cache(version="a-version", ignored_inputs=("a",))) + def t4(a: int) -> int: + return a + + serialized_t4 = get_serializable_task(OrderedDict(), default_serialization_settings, t4) + assert serialized_t4.template.metadata.discoverable == True + assert serialized_t4.template.metadata.discovery_version == "a-version" + assert serialized_t4.template.metadata.cache_ignore_input_vars == ("a",) + + @task(cache=Cache(version="version-overrides-policies", policies=SaltCachePolicy())) + def t5(a: int) -> int: + return a + + serialized_t5 = get_serializable_task(OrderedDict(), default_serialization_settings, t5) + assert serialized_t5.template.metadata.discoverable == True + assert serialized_t5.template.metadata.discovery_version == "version-overrides-policies" + + +def test_task_arguments_deprecated(default_serialization_settings): + with pytest.raises(ValueError, match="cache_serialize, cache_version, and cache_ignore_input_vars are deprecated. Please use Cache object"): + @task(cache=Cache(version="a-version"), cache_version="a-conflicting-version") + def t1_fails(a: int) -> int: + return a + + # A more realistic example where someone might set the version in the Cache object but sets the cache_serialize + # using the deprecated cache_serialize argument + with pytest.raises(ValueError, match="cache_serialize, cache_version, and cache_ignore_input_vars are deprecated. Please use Cache object"): + @task(cache=Cache(version="a-version"), cache_serialize=True) + def t2_fails(a: int) -> int: + return a + + with pytest.raises(ValueError, match="cache_serialize, cache_version, and cache_ignore_input_vars are deprecated. Please use Cache object"): + @task(cache=Cache(version="a-version"), cache_ignore_input_vars=("a",)) + def t3_fails(a: int) -> int: + return a + + with pytest.raises(ValueError, match="cache_serialize, cache_version, and cache_ignore_input_vars are deprecated. Please use Cache object"): + @task(cache=Cache(version="a-version"), cache_serialize=True, cache_version="b-version") + def t5_fails(a: int) -> int: + return a + + with pytest.raises(ValueError, match="cache_serialize, cache_version, and cache_ignore_input_vars are deprecated. Please use Cache object"): + @task(cache=Cache(version="a-version"), cache_serialize=True, cache_ignore_input_vars=("a",)) + def t6_fails(a: int) -> int: + return a + + with pytest.raises(ValueError, match="cache_serialize, cache_version, and cache_ignore_input_vars are deprecated. Please use Cache object"): + @task(cache=Cache(version="a-version"), cache_serialize=True, cache_version="b-version", cache_ignore_input_vars=("a",)) + def t7_fails(a: int) -> int: + return a + + +def test_basic_salt_cache_policy(default_serialization_settings): + @task + def t_notcached(a: int) -> int: + return a + + serialized_t_notcached = get_serializable_task(OrderedDict(), default_serialization_settings, t_notcached) + assert serialized_t_notcached.template.metadata.discoverable == False + + @task(cache=Cache(version="a-version")) + def t_cached_explicit_version(a: int) -> int: + return a + + serialized_t_cached_explicit_version = get_serializable_task(OrderedDict(), default_serialization_settings, t_cached_explicit_version) + assert serialized_t_cached_explicit_version.template.metadata.discoverable == True + assert serialized_t_cached_explicit_version.template.metadata.discovery_version == "a-version" + + @task(cache=Cache(salt="a-sprinkle-of-salt", policies=SaltCachePolicy())) + def t_cached(a: int) -> int: + return a + 1 + + serialized_t_cached = get_serializable_task(OrderedDict(), default_serialization_settings, t_cached) + assert serialized_t_cached.template.metadata.discoverable == True + assert serialized_t_cached.template.metadata.discovery_version == "348b4b8c52d8868e0c202ce4d26d59906c13716197b611a0a7a215074159df79" + + +@mock.patch("flytekit.configuration.plugin.FlytekitPlugin.get_default_cache_policies") +def test_set_default_policies(mock_get_default_cache_policies, default_serialization_settings): + # Enable SaltCachePolicy as the default cache policy + mock_get_default_cache_policies.return_value = [SaltCachePolicy()] + + @task(cache=True) + def t1(a: int) -> int: + return a + + serialized_t1 = get_serializable_task(OrderedDict(), default_serialization_settings, t1) + assert serialized_t1.template.metadata.discoverable == True + assert serialized_t1.template.metadata.discovery_version == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + + @task(cache=Cache()) + def t2(a: int) -> int: + return a + + serialized_t2 = get_serializable_task(OrderedDict(), default_serialization_settings, t2) + assert serialized_t2.template.metadata.discoverable == True + assert serialized_t2.template.metadata.discovery_version == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + + # Confirm that the default versions match + assert serialized_t1.template.metadata.discovery_version == serialized_t2.template.metadata.discovery_version + + # Reset the default policies + mock_get_default_cache_policies.return_value = [] + + with pytest.raises(ValueError, match="If version is not defined then at least one cache policy needs to be set"): + @task(cache=True) + def t3_fails(a: int) -> int: + return a + + with pytest.raises(ValueError, match="If version is not defined then at least one cache policy needs to be set"): + @task(cache=Cache()) + def t4_fails(a: int) -> int: + return a + 1 + + @task(cache=Cache(version="a-version")) + def t_cached_explicit_version(a: int) -> int: + return a + + serialized_t_cached_explicit_version = get_serializable_task(OrderedDict(), default_serialization_settings, t_cached_explicit_version) + assert serialized_t_cached_explicit_version.template.metadata.discoverable == True + assert serialized_t_cached_explicit_version.template.metadata.discovery_version == "a-version" + + +def test_cache_policy_exception(default_serialization_settings): + # Set the address of the ExceptionCachePolicy in the error message so that the test is robust to changes in the + # address of the ExceptionCachePolicy class + with pytest.raises(ValueError, match="Failed to generate version for cache policy"): + @task(cache=Cache(policies=ExceptionCachePolicy())) + def t_cached(a: int) -> int: + return a diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 5b911c0052..29fa758801 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -531,6 +531,29 @@ def my_wf() -> str: assert not accelerator.HasField("unpartitioned") +def test_override_shared_memory(): + @task(shared_memory=True) + def bar() -> str: + return "hello" + + @workflow + def my_wf() -> str: + return bar().with_overrides(shared_memory="128Mi") + + serialization_settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].task_node.overrides is not None + assert wf_spec.template.nodes[0].task_node.overrides.extended_resources is not None + shared_memory = wf_spec.template.nodes[0].task_node.overrides.extended_resources.shared_memory + + def test_cache_override_values(): @task def t1(a: str) -> str: diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py index 115605b055..d5f8b9be02 100644 --- a/tests/flytekit/unit/core/test_resources.py +++ b/tests/flytekit/unit/core/test_resources.py @@ -8,7 +8,9 @@ from flytekit.core.resources import ( pod_spec_from_resources, convert_resources_to_resource_model, + construct_extended_resources, ) +from flytekit.extras.accelerators import T4 _ResourceName = _task_models.Resources.ResourceName @@ -155,3 +157,18 @@ def test_pod_spec_from_resources_requests_set(): ) pod_spec = pod_spec_from_resources(primary_container_name=primary_container_name, requests=requests, limits=limits) assert expected_pod_spec == pod_spec + + +@pytest.mark.parametrize("shared_memory", [None, False]) +def test_construct_extended_resources_shared_memory_none(shared_memory): + resources = construct_extended_resources(shared_memory=shared_memory) + assert resources is None + + +@pytest.mark.parametrize("shared_memory, expected_size_limit", [ + ("2Gi", "2Gi"), + (True, ""), +]) +def test_construct_extended_resources_shared_memory(shared_memory, expected_size_limit): + resources = construct_extended_resources(shared_memory=shared_memory) + assert resources.shared_memory.size_limit == expected_size_limit diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index b9efc83bae..0dfb0ddbb3 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -9,7 +9,7 @@ from dataclasses import asdict, dataclass, field from datetime import timedelta from enum import Enum, auto -from typing import List, Optional, Type +from typing import List, Optional, Type, Dict import mock import msgpack @@ -34,6 +34,10 @@ from flytekit.core.data_persistence import flyte_tmp_dir from flytekit.core.hash import HashMethod from flytekit.core.type_engine import ( + IntTransformer, + FloatTransformer, + BoolTransformer, + StrTransformer, DataclassTransformer, DictTransformer, EnumTransformer, @@ -48,9 +52,9 @@ convert_mashumaro_json_schema_to_python_class, dataclass_from_dict, get_underlying_type, - is_annotated, IntTransformer, + is_annotated, + strict_type_hint_matching, ) -from flytekit.core.type_engine import * from flytekit.exceptions import user as user_exceptions from flytekit.models import types as model_types from flytekit.models.annotation import TypeAnnotation @@ -2864,21 +2868,22 @@ def test_get_underlying_type(t, expected): @pytest.mark.parametrize( - "t,expected", + "t,expected,allow_pickle", [ - (None, (None, None)), - (typing.Dict, ()), - (typing.Dict[str, str], (str, str)), + (None, (None, None), False), + (typing.Dict, (), False), + (typing.Dict[str, str], (str, str), False), ( - Annotated[typing.Dict[str, str], kwtypes(allow_pickle=True)], - (str, str), + Annotated[typing.Dict[str, str], kwtypes(allow_pickle=True)], + (str, str), + True, ), - (typing.Dict[Annotated[str, "a-tag"], int], (Annotated[str, "a-tag"], int)), + (typing.Dict[Annotated[str, "a-tag"], int], (Annotated[str, "a-tag"], int), False), ], ) -def test_dict_get(t, expected): +def test_dict_get(t, expected, allow_pickle): assert DictTransformer.extract_types(t) == expected - + assert DictTransformer.is_pickle(t) == allow_pickle def test_DataclassTransformer_get_literal_type(): @dataclass @@ -3777,3 +3782,72 @@ class RegularDC: assert TypeEngine.get_transformer(RegularDC) == TypeEngine._DATACLASS_TRANSFORMER del TypeEngine._REGISTRY[ParentDC] + + +def test_strict_type_matching(): + # should correctly return the more specific transformer + class MyInt: + def __init__(self, x: int): + self.val = x + + def __eq__(self, other): + if not isinstance(other, MyInt): + return False + return other.val == self.val + + lt = LiteralType(simple=SimpleType.INTEGER) + TypeEngine.register( + SimpleTransformer( + "MyInt", + MyInt, + lt, + lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x.val))), + lambda x: MyInt(x.scalar.primitive.integer), + ) + ) + + pt_guess = IntTransformer.guess_python_type(lt) + assert pt_guess is int + pt_better_guess = strict_type_hint_matching(MyInt(3), lt) + assert pt_better_guess is MyInt + + del TypeEngine._REGISTRY[MyInt] + + +def test_strict_type_matching_error(): + xs: typing.List[float] = [0.1, 0.2, 0.3, 0.4, -99999.7] + lt = TypeEngine.to_literal_type(typing.List[float]) + with pytest.raises(ValueError): + strict_type_hint_matching(xs, lt) + + +@pytest.mark.asyncio +async def test_dict_transformer_annotated_type(): + ctx = FlyteContext.current_context() + + # Test case 1: Regular Dict type + regular_dict = {"a": 1, "b": 2} + regular_dict_type = Dict[str, int] + expected_type = TypeEngine.to_literal_type(regular_dict_type) + + # This should work fine + literal1 = await TypeEngine.async_to_literal(ctx, regular_dict, regular_dict_type, expected_type) + assert literal1.map.literals["a"].scalar.primitive.integer == 1 + assert literal1.map.literals["b"].scalar.primitive.integer == 2 + + # Test case 2: Annotated Dict type + annotated_dict = {"x": 10, "y": 20} + annotated_dict_type = Annotated[Dict[str, int], "some_metadata"] + expected_type = TypeEngine.to_literal_type(annotated_dict_type) + + literal2 = await TypeEngine.async_to_literal(ctx, annotated_dict, annotated_dict_type, expected_type) + assert literal2.map.literals["x"].scalar.primitive.integer == 10 + assert literal2.map.literals["y"].scalar.primitive.integer == 20 + + # Test case 3: Nested Annotated Dict type + nested_dict = {"outer": {"inner": 42}} + nested_dict_type = Dict[str, Annotated[Dict[str, int], "inner_metadata"]] + expected_type = TypeEngine.to_literal_type(nested_dict_type) + + literal3 = await TypeEngine.async_to_literal(ctx, nested_dict, nested_dict_type, expected_type) + assert literal3.map.literals["outer"].map.literals["inner"].scalar.primitive.integer == 42 diff --git a/tests/flytekit/unit/core/test_type_match_checking.py b/tests/flytekit/unit/core/test_type_match_checking.py new file mode 100644 index 0000000000..a88bcf6a93 --- /dev/null +++ b/tests/flytekit/unit/core/test_type_match_checking.py @@ -0,0 +1,87 @@ +from flytekit.models.core.types import BlobType, EnumType +from flytekit.models.types import LiteralType, StructuredDatasetType, UnionType, SimpleType +from flytekit.core.type_match_checking import literal_types_match + + +def test_exact_match(): + lt = LiteralType(simple=SimpleType.STRING) + assert literal_types_match(lt, lt) is True + + lt2 = LiteralType(simple=SimpleType.FLOAT) + assert literal_types_match(lt, lt2) is False + + +def test_collection_type_match(): + lt1 = LiteralType(collection_type=LiteralType(SimpleType.STRING)) + lt2 = LiteralType(collection_type=LiteralType(SimpleType.STRING)) + assert literal_types_match(lt1, lt2) is True + + +def test_collection_type_mismatch(): + lt1 = LiteralType(collection_type=LiteralType(SimpleType.STRING)) + lt2 = LiteralType(collection_type=LiteralType(SimpleType.INTEGER)) + assert literal_types_match(lt1, lt2) is False + + +def test_blob_type_match(): + blob1 = LiteralType(blob=BlobType(format="csv", dimensionality=1)) + blob2 = LiteralType(blob=BlobType(format="csv", dimensionality=1)) + assert literal_types_match(blob1, blob2) is True + + from flytekit.types.pickle.pickle import FlytePickleTransformer + blob1 = LiteralType(blob=BlobType(format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=1)) + blob2 = LiteralType(blob=BlobType(format="", dimensionality=1)) + assert literal_types_match(blob1, blob2) is False + + +def test_blob_type_mismatch(): + blob1 = LiteralType(blob=BlobType(format="csv", dimensionality=1)) + blob2 = LiteralType(blob=BlobType(format="json", dimensionality=1)) + assert literal_types_match(blob1, blob2) is False + + +def test_enum_type_match(): + enum1 = LiteralType(enum_type=EnumType(values=["A", "B"])) + enum2 = LiteralType(enum_type=EnumType(values=["B", "A"])) + assert literal_types_match(enum1, enum2) is True + + +def test_enum_type_mismatch(): + enum1 = LiteralType(enum_type=EnumType(values=["A", "B"])) + enum2 = LiteralType(enum_type=EnumType(values=["A", "C"])) + assert literal_types_match(enum1, enum2) is False + + +def test_structured_dataset_match(): + col1 = StructuredDatasetType.DatasetColumn(name="col1", literal_type=LiteralType(simple=SimpleType.STRING)) + col2 = StructuredDatasetType.DatasetColumn(name="col2", literal_type=LiteralType(simple=SimpleType.STRUCT)) + + dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[])) + dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[])) + assert literal_types_match(dataset1, dataset2) is True + + dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2])) + dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[])) + assert literal_types_match(dataset1, dataset2) is False + + dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2])) + dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2])) + assert literal_types_match(dataset1, dataset2) is True + + +def test_structured_dataset_mismatch(): + dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[])) + dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="csv", columns=[])) + assert literal_types_match(dataset1, dataset2) is False + + +def test_union_type_match(): + union1 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.INTEGER)])) + union2 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.INTEGER), LiteralType(SimpleType.STRING)])) + assert literal_types_match(union1, union2) is True + + +def test_union_type_mismatch(): + union1 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.INTEGER)])) + union2 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.BOOLEAN)])) + assert literal_types_match(union1, union2) is False diff --git a/tests/flytekit/unit/core/test_worker_queue.py b/tests/flytekit/unit/core/test_worker_queue.py index 0a934fc20a..882054693a 100644 --- a/tests/flytekit/unit/core/test_worker_queue.py +++ b/tests/flytekit/unit/core/test_worker_queue.py @@ -4,10 +4,11 @@ from flytekit.core.task import task from flytekit.remote.remote import FlyteRemote from flytekit.core.worker_queue import Controller, WorkItem, ItemStatus, Update -from flytekit.configuration import ImageConfig, LocalConfig, SerializationSettings +from flytekit.configuration import ImageConfig, LocalConfig, SerializationSettings, Image from flytekit.utils.asyn import loop_manager from flytekit.models.execution import ExecutionSpec, ExecutionClosure, ExecutionMetadata, NotificationList, Execution, AbortMetadata from flytekit.models.core import identifier +from flytekit.remote.executions import FlyteWorkflowExecution from flytekit.models import common as common_models from flytekit.models.core import execution from flytekit.exceptions.eager import EagerException @@ -249,3 +250,46 @@ def t1() -> str: wi2 = WorkItem(entity=t1, wf_exec=fwex, input_kwargs={}) wi2.uuid = wi1.uuid assert wi1 == wi2 + + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +@pytest.mark.parametrize("phase,expected_update_status", [ + (execution.WorkflowExecutionPhase.SUCCEEDED, ItemStatus.SUCCESS), + (execution.WorkflowExecutionPhase.FAILED, ItemStatus.FAILED), + (execution.WorkflowExecutionPhase.ABORTED, ItemStatus.FAILED), + (execution.WorkflowExecutionPhase.TIMED_OUT, ItemStatus.FAILED), +]) +def test_reconcile(phase, expected_update_status): + mock_remote = mock.MagicMock() + wf_exec = FlyteWorkflowExecution( + id=identifier.WorkflowExecutionIdentifier("project", "domain", "exec-name"), + spec=ExecutionSpec( + identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version"), + ExecutionMetadata(ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1), + ), + closure=ExecutionClosure( + phase=phase, + started_at=datetime.datetime(year=2024, month=1, day=2, tzinfo=datetime.timezone.utc), + duration=datetime.timedelta(seconds=10), + ) + ) + mock_remote.sync_execution.return_value = wf_exec + + @task + def t1(): + ... + + wi = WorkItem(entity=t1, wf_exec=wf_exec, input_kwargs={}) + c = Controller(mock_remote, serialization_settings, tag="exec-id", root_tag="exec-id", exec_prefix="e-unit-test") + u = Update(wi, 0) + c.reconcile_one(u) + assert u.status == expected_update_status diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index 2ab0d8721f..33e0bbf78f 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -6,6 +6,7 @@ import typing from datetime import datetime, timedelta from enum import Enum +from typing import Optional import click import mock @@ -367,9 +368,9 @@ def test_dataclass_with_default_none(): @dataclass class Datum: x: int - y: str = None - z: typing.Dict[int, str] = None - w: typing.List[int] = None + y: Optional[str] = None + z: Optional[typing.Dict[int, str]] = None + w: Optional[typing.List[int]] = None t = JsonParamType(Datum) value = '{ "x": 1 }' diff --git a/tests/flytekit/unit/interaction/test_string_literals.py b/tests/flytekit/unit/interaction/test_string_literals.py index 06666feae8..1be8a154b6 100644 --- a/tests/flytekit/unit/interaction/test_string_literals.py +++ b/tests/flytekit/unit/interaction/test_string_literals.py @@ -86,6 +86,9 @@ def test_scalar_to_string(): ) assert scalar_to_string(scalar) == 1 + scalar = Scalar(binary=Binary(b'\x82\xa7compact\xc3\xa6schema\x00', "msgpack")) + assert scalar_to_string(scalar) == '{"compact": true, "schema": 0}' + def test_literal_string_repr(): lit = Literal(scalar=Scalar(primitive=Primitive(integer=1))) diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index cd36381ea0..24b6704528 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -2,6 +2,7 @@ from flyteidl.core import tasks_pb2 +from flytekit import PodTemplate from flytekit.extras.accelerators import T4 from flytekit.models import interface as _interface from flytekit.models import literals as _literals @@ -315,6 +316,7 @@ def test_task_node_overrides(): def test_task_node_with_overrides(): + # without pod template task_node = _workflow.TaskNode( reference_id=_generic_id, overrides=_workflow.TaskNodeOverrides( @@ -332,3 +334,35 @@ def test_task_node_with_overrides(): obj = _workflow.TaskNode.from_flyte_idl(task_node.to_flyte_idl()) assert task_node == obj + assert obj.overrides.pod_template is None + + # with pod template + task_node = _workflow.TaskNode( + reference_id=_generic_id, + overrides=_workflow.TaskNodeOverrides( + Resources( + requests=[Resources.ResourceEntry(Resources.ResourceName.CPU, "1")], + limits=[Resources.ResourceEntry(Resources.ResourceName.CPU, "2")], + ), + tasks_pb2.ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), + "", + PodTemplate( + primary_container_name="primary1", + labels={"lKeyA": "lValA", "lKeyB": "lValB"}, + annotations={"aKeyA": "aValA", "aKeyB": "aValB"}, + pod_spec={ + 'containers': [ + { + 'name': 'primary1', + 'image': "random:image", + 'env': [ + {'name': 'eKeyC', 'value': 'eValC'}, + ], + } + ]}, + ), + ), + ) + + obj = _workflow.TaskNode.from_flyte_idl(task_node.to_flyte_idl()) + assert obj.overrides.pod_template is not None diff --git a/tests/flytekit/unit/models/test_execution.py b/tests/flytekit/unit/models/test_execution.py index fec2b5cfbb..8e1dfa749a 100644 --- a/tests/flytekit/unit/models/test_execution.py +++ b/tests/flytekit/unit/models/test_execution.py @@ -166,6 +166,7 @@ def test_execution_spec(literal_value_pair): ), raw_output_data_config=_common_models.RawOutputDataConfig(output_location_prefix="raw_output"), max_parallelism=100, + interruptible=True ) assert obj.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN assert obj.launch_plan.domain == "domain" @@ -183,6 +184,7 @@ def test_execution_spec(literal_value_pair): ] assert obj.disable_all is None assert obj.max_parallelism == 100 + assert obj.interruptible == True assert obj.raw_output_data_config.output_location_prefix == "raw_output" obj2 = _execution.ExecutionSpec.from_flyte_idl(obj.to_flyte_idl()) @@ -203,6 +205,7 @@ def test_execution_spec(literal_value_pair): ] assert obj2.disable_all is None assert obj2.max_parallelism == 100 + assert obj2.interruptible == True assert obj2.raw_output_data_config.output_location_prefix == "raw_output" obj = _execution.ExecutionSpec( @@ -220,6 +223,7 @@ def test_execution_spec(literal_value_pair): assert obj.metadata.principal == "tester" assert obj.notifications is None assert obj.disable_all is True + assert obj.interruptible is None obj2 = _execution.ExecutionSpec.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 @@ -233,6 +237,7 @@ def test_execution_spec(literal_value_pair): assert obj2.metadata.principal == "tester" assert obj2.notifications is None assert obj2.disable_all is True + assert obj2.interruptible is None def test_workflow_execution_data_response(): diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index b9685736b7..048cfb1db9 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -5,6 +5,7 @@ from flyteidl.core.tasks_pb2 import ExtendedResources, TaskMetadata from google.protobuf import text_format +from flytekit.core.resources import construct_extended_resources import flytekit.models.interface as interface_models import flytekit.models.literals as literal_models from flytekit import Description, Documentation, SourceCode @@ -110,7 +111,7 @@ def test_task_template(in_tuple): {"d": "e"}, ), config={"a": "b"}, - extended_resources=ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), + extended_resources=construct_extended_resources(accelerator=T4, shared_memory="2Gi"), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" @@ -130,6 +131,7 @@ def test_task_template(in_tuple): assert obj.extended_resources.gpu_accelerator.device == "nvidia-tesla-t4" assert not obj.extended_resources.gpu_accelerator.HasField("unpartitioned") assert not obj.extended_resources.gpu_accelerator.HasField("partition_size") + assert obj.extended_resources.shared_memory.size_limit == "2Gi" def test_task_spec(): diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py index ee535697a3..770d757eba 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py @@ -19,7 +19,7 @@ from flytekit.core.workflow import workflow from flytekit.lazy_import.lazy_module import is_imported from flytekit.models import literals -from flytekit.models.literals import StructuredDatasetMetadata +from flytekit.models.literals import StructuredDatasetMetadata, Literal from flytekit.models.types import LiteralType, SchemaType, SimpleType, StructuredDatasetType from flytekit.tools.translator import get_serializable from flytekit.types.structured.structured_dataset import ( @@ -713,3 +713,41 @@ def mock_resolve_remote_path(flyte_uri: str) -> typing.Optional[str]: lit = sdte.encode(ctx, sd, df_type=pd.DataFrame, protocol="bq", format="parquet", structured_literal_type=lt) assert lit.scalar.structured_dataset.uri == "bq://blah/blah/blah" + +def test_structured_dataset_pickleable(): + import pickle + + upstream_output = Literal( + scalar=literals.Scalar( + structured_dataset=StructuredDataset( + dataframe=pd.DataFrame({"a": [1, 2], "b": [3, 4]}), + uri="bq://test_uri", + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType( + columns=[ + StructuredDatasetType.DatasetColumn( + name="a", + literal_type=LiteralType(simple=SimpleType.INTEGER) + ), + StructuredDatasetType.DatasetColumn( + name="b", + literal_type=LiteralType(simple=SimpleType.INTEGER) + ) + ], + format="parquet" + ) + ) + ) + ) + ) + + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), + upstream_output, + StructuredDataset + ) + + pickled_input = pickle.dumps(downstream_input) + unpickled_input = pickle.loads(pickled_input) + + assert downstream_input == unpickled_input diff --git a/tests/flytekit/unit/utils/test_rate_limiter.py b/tests/flytekit/unit/utils/test_rate_limiter.py new file mode 100644 index 0000000000..23bece6415 --- /dev/null +++ b/tests/flytekit/unit/utils/test_rate_limiter.py @@ -0,0 +1,55 @@ +import pytest +import sys +import timeit +import asyncio + +from datetime import timedelta +from flytekit.utils.rate_limiter import RateLimiter + + +async def launch_requests(rate_limiter: RateLimiter, total: int): + tasks = [asyncio.create_task(rate_limiter.acquire()) for _ in range(total)] + await asyncio.gather(*tasks) + + +async def helper_for_async(rpm: int, total: int): + rate_limiter = RateLimiter(rpm=rpm) + rate_limiter.delay = timedelta(seconds=1) + await launch_requests(rate_limiter, total) + + +def runner_for_async(rpm: int, total: int): + loop = asyncio.get_event_loop() + return loop.run_until_complete(helper_for_async(rpm, total)) + + +@pytest.mark.asyncio +def test_rate_limiter(): + elapsed_time = timeit.timeit(lambda: runner_for_async(2, 2), number=1) + elapsed_time_more = timeit.timeit(lambda: runner_for_async(2, 6), number=1) + assert elapsed_time < 0.25 + assert round(elapsed_time_more) == 2 + + +async def sync_wrapper(rate_limiter: RateLimiter): + rate_limiter.sync_acquire() + + +async def helper_for_sync(rpm: int, total: int): + rate_limiter = RateLimiter(rpm=rpm) + rate_limiter.delay = timedelta(seconds=1) + tasks = [asyncio.create_task(sync_wrapper(rate_limiter)) for _ in range(total)] + await asyncio.gather(*tasks) + + +def runner_for_sync(rpm: int, total: int): + loop = asyncio.get_event_loop() + return loop.run_until_complete(helper_for_sync(rpm, total)) + + +@pytest.mark.asyncio +def test_rate_limiter_s(): + elapsed_time = timeit.timeit(lambda: runner_for_sync(2, 2), number=1) + elapsed_time_more = timeit.timeit(lambda: runner_for_sync(2, 6), number=1) + assert elapsed_time < 0.25 + assert round(elapsed_time_more) == 2