Skip to content

Commit

Permalink
Merge branch 'flyteorg:master' into flyteremote-interruptible-override
Browse files Browse the repository at this point in the history
  • Loading branch information
redartera authored Nov 21, 2024
2 parents b27ca6a + 2e40e76 commit 89f8813
Show file tree
Hide file tree
Showing 123 changed files with 8,460 additions and 852 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Set Python versions for run
run: |
if [[ ${{ github.event_name }} == "schedule" ]]; then
echo "python_versions=[\"3.8\",\"3.9\",\"3.10\",\"3.11\",\"3.12\"]" >> $GITHUB_ENV
echo "python_versions=[\"3.9\",\"3.10\",\"3.11\",\"3.12\"]" >> $GITHUB_ENV
else
echo "python_versions=[\"3.9\", \"3.12\"]" >> $GITHUB_ENV
fi
Expand Down Expand Up @@ -342,6 +342,7 @@ jobs:
- flytekit-kf-mpi
- flytekit-kf-pytorch
- flytekit-kf-tensorflow
- flytekit-memray
- flytekit-mlflow
- flytekit-mmcloud
- flytekit-modin
Expand All @@ -363,8 +364,6 @@ jobs:
- flytekit-vaex
- flytekit-whylogs
exclude:
- python-version: 3.8
plugin-names: "flytekit-aws-sagemaker"
- python-version: 3.9
plugin-names: "flytekit-aws-sagemaker"
# flytekit-modin depends on ray which does not have a 3.11 wheel yet.
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/pythonpublish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ jobs:
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
Expand Down
1 change: 1 addition & 0 deletions Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ RUN SETUPTOOLS_SCM_PRETEND_VERSION_FOR_FLYTEKIT=$PSEUDO_VERSION \
pandas \
pillow \
plotly \
pyarrow \
pygments \
scikit-learn \
ydata-profiling \
Expand Down
11 changes: 8 additions & 3 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ aiosignal==1.3.1
# via aiohttp
annotated-types==0.7.0
# via pydantic
appnope==0.1.4
# via ipykernel
asn1crypto==1.5.1
# via snowflake-connector-python
asttokens==2.4.1
Expand Down Expand Up @@ -89,6 +87,7 @@ cryptography==43.0.1
# msal
# pyjwt
# pyopenssl
# secretstorage
# snowflake-connector-python
dataclasses-json==0.5.9
# via flytekit
Expand Down Expand Up @@ -217,6 +216,10 @@ jaraco-functools==4.0.1
# via keyring
jedi==0.19.1
# via ipython
jeepney==0.8.0
# via
# keyring
# secretstorage
jmespath==1.0.1
# via botocore
joblib==1.4.2
Expand Down Expand Up @@ -473,6 +476,8 @@ scikit-learn==1.5.0
# via -r dev-requirements.in
scipy==1.13.1
# via scikit-learn
secretstorage==3.3.3
# via keyring
setuptools-scm==8.1.0
# via -r dev-requirements.in
six==1.16.0
Expand All @@ -482,7 +487,7 @@ six==1.16.0
# isodate
# kubernetes
# python-dateutil
snowflake-connector-python==3.12.1
snowflake-connector-python==3.12.3
# via -r dev-requirements.in
sortedcontainers==2.4.0
# via
Expand Down
1 change: 0 additions & 1 deletion docs/source/experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@ Experimental Features
:nosignatures:
:toctree: generated/

~experimental.map_task
~experimental.eager
~experimental.EagerException
120 changes: 98 additions & 22 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
import sys
import tempfile
import textwrap
import time
import traceback
import uuid
import warnings
from sys import exit
from typing import Callable, List, Optional

import click
from flyteidl.core import literals_pb2 as _literals_pb2
from google.protobuf.timestamp_pb2 import Timestamp

from flytekit.configuration import (
SERIALIZED_CONTEXT_ENV_VAR,
Expand All @@ -40,6 +43,7 @@
from flytekit.core.promise import VoidPromise
from flytekit.core.utils import str2bool
from flytekit.deck.deck import _output_deck
from flytekit.exceptions.base import FlyteException
from flytekit.exceptions.system import FlyteNonRecoverableSystemException
from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException
from flytekit.interfaces.stats.taggable import get_stats as _get_stats
Expand Down Expand Up @@ -74,6 +78,41 @@ def _compute_array_job_index():
return offset


def _build_error_file_name() -> str:
"""Get name of error file uploaded to the raw output prefix bucket.
For distributed tasks, all workers upload error files which must not overwrite each other, leading to a race condition.
A uuid is included to prevent this.
Returns
-------
str
Name of the error file.
"""
dist_error_strategy = get_one_of("FLYTE_INTERNAL_DIST_ERROR_STRATEGY", "_F_DES")
if not dist_error_strategy:
return _constants.ERROR_FILE_NAME
error_file_name_base, error_file_name_extension = os.path.splitext(_constants.ERROR_FILE_NAME)
error_file_name_base += f"-{uuid.uuid4().hex}"
return f"{error_file_name_base}{error_file_name_extension}"


def _get_worker_name() -> str:
"""Get the name of the worker
For distributed tasks, the backend plugin can set a worker name to be used for error reporting.
Returns
-------
str
Name of the worker
"""
dist_error_strategy = get_one_of("FLYTE_INTERNAL_DIST_ERROR_STRATEGY", "_F_DES")
if not dist_error_strategy:
return ""
return get_one_of("FLYTE_INTERNAL_WORKER_NAME", "_F_WN")


def _get_working_loop():
"""Returns a running event loop."""
try:
Expand Down Expand Up @@ -106,6 +145,9 @@ def _dispatch_execute(
b: OR if IgnoreOutputs is raised, then ignore uploading outputs
c: OR if an unhandled exception is retrieved - record it as an errors.pb
"""
error_file_name = _build_error_file_name()
worker_name = _get_worker_name()

output_file_dict = {}

task_def = None
Expand Down Expand Up @@ -142,12 +184,14 @@ def _dispatch_execute(
output_file_dict = {_constants.FUTURES_FILE_NAME: outputs}
else:
logger.error(f"SystemError: received unknown outputs from task {outputs}")
output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
output_file_dict[error_file_name] = _error_models.ErrorDocument(
_error_models.ContainerError(
"UNKNOWN_OUTPUT",
f"Type of output received not handled {type(outputs)} outputs: {outputs}",
_error_models.ContainerError.Kind.RECOVERABLE,
_execution_models.ExecutionError.ErrorKind.SYSTEM,
code="UNKNOWN_OUTPUT",
message=f"Type of output received not handled {type(outputs)} outputs: {outputs}",
kind=_error_models.ContainerError.Kind.RECOVERABLE,
origin=_execution_models.ExecutionError.ErrorKind.SYSTEM,
timestamp=get_container_error_timestamp(),
worker=worker_name,
)
)

Expand All @@ -165,12 +209,14 @@ def _dispatch_execute(
kind = _error_models.ContainerError.Kind.NON_RECOVERABLE

exc_str = get_traceback_str(e)
output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
output_file_dict[error_file_name] = _error_models.ErrorDocument(
_error_models.ContainerError(
"USER",
exc_str,
kind,
_execution_models.ExecutionError.ErrorKind.USER,
code="USER",
message=exc_str,
kind=kind,
origin=_execution_models.ExecutionError.ErrorKind.USER,
timestamp=get_container_error_timestamp(e.value),
worker=worker_name,
)
)
if task_def is not None:
Expand All @@ -183,12 +229,14 @@ def _dispatch_execute(

except FlyteNonRecoverableSystemException as e:
exc_str = get_traceback_str(e.value)
output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
output_file_dict[error_file_name] = _error_models.ErrorDocument(
_error_models.ContainerError(
"SYSTEM",
exc_str,
_error_models.ContainerError.Kind.NON_RECOVERABLE,
_execution_models.ExecutionError.ErrorKind.SYSTEM,
code="SYSTEM",
message=exc_str,
kind=_error_models.ContainerError.Kind.NON_RECOVERABLE,
origin=_execution_models.ExecutionError.ErrorKind.SYSTEM,
timestamp=get_container_error_timestamp(e.value),
worker=worker_name,
)
)

Expand All @@ -199,12 +247,14 @@ def _dispatch_execute(
# All other errors are captured here, and are considered system errors
except Exception as e:
exc_str = get_traceback_str(e)
output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument(
output_file_dict[error_file_name] = _error_models.ErrorDocument(
_error_models.ContainerError(
"SYSTEM",
exc_str,
_error_models.ContainerError.Kind.RECOVERABLE,
_execution_models.ExecutionError.ErrorKind.SYSTEM,
code="SYSTEM",
message=exc_str,
kind=_error_models.ContainerError.Kind.RECOVERABLE,
origin=_execution_models.ExecutionError.ErrorKind.SYSTEM,
timestamp=get_container_error_timestamp(e),
worker=worker_name,
)
)

Expand All @@ -223,7 +273,7 @@ def _dispatch_execute(

logger.debug("Finished _dispatch_execute")

if str2bool(os.getenv(FLYTE_FAIL_ON_ERROR)) and _constants.ERROR_FILE_NAME in output_file_dict:
if str2bool(os.getenv(FLYTE_FAIL_ON_ERROR)) and error_file_name in output_file_dict:
"""
If the environment variable FLYTE_FAIL_ON_ERROR is set to true, the task execution will fail if an error file is
generated. This environment variable is set to true by the plugin author if they want the task to fail on error.
Expand Down Expand Up @@ -255,6 +305,32 @@ def get_traceback_str(e: Exception) -> str:
return format_str.format(exception_str=exception_str, message_str=message_str)


def get_container_error_timestamp(e: Optional[Exception] = None) -> Timestamp:
"""Get timestamp for ContainerError.
If a flyte exception is passed, use its timestamp, otherwise, use the current time.
Parameters
----------
e : Exception, optional
Exception that has occurred.
Returns
-------
Timestamp
Timestamp to be reported in ContainerError
"""
timestamp = None
if isinstance(e, FlyteException):
timestamp = e.timestamp
if timestamp is None:
timestamp = time.time()
timstamp_secs = int(timestamp)
timestamp_fsecs = timestamp - timstamp_secs
timestamp_nanos = int(timestamp_fsecs * 1_000_000_000)
return Timestamp(seconds=timstamp_secs, nanos=timestamp_nanos)


def get_one_of(*args) -> str:
"""
Helper function to iterate through a series of different environment variables. This function exists because for
Expand Down Expand Up @@ -488,7 +564,7 @@ def _execute_map_task(
raise ValueError(f"Resolver args cannot be <1, got {resolver_args}")

with setup_execution(
raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir
raw_output_data_prefix, output_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir
) as ctx:
working_dir = os.getcwd()
if all(os.path.realpath(path) != working_dir for path in sys.path):
Expand Down
14 changes: 14 additions & 0 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from flyteidl.admin import project_pb2 as _project_pb2
from flyteidl.admin import task_execution_pb2 as _task_execution_pb2
from flyteidl.admin import task_pb2 as _task_pb2
from flyteidl.admin import version_pb2 as _version_pb2
from flyteidl.admin import workflow_attributes_pb2 as _workflow_attributes_pb2
from flyteidl.admin import workflow_pb2 as _workflow_pb2
from flyteidl.core import identifier_pb2 as _identifier_pb2
Expand Down Expand Up @@ -1087,3 +1088,16 @@ def get_download_artifact_signed_url(
expires_in=expires_in_pb,
)
)

def get_control_plane_version(self) -> str:
"""
Retrieve the Control Plane version from Flyteadmin.
This method calls Flyteadmin's GetVersion API to obtain the current version information of the control plane.
The retrieved version can be used to enable or disable specific features based on the Flyteadmin version.
Returns:
str: The version string of the control plane.
"""
version_response = self._stub.GetVersion(_version_pb2.GetVersionRequest(), metadata=self._metadata)
return version_response.control_plane_version.Version
Loading

0 comments on commit 89f8813

Please sign in to comment.