Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
Signed-off-by: JiangJiaWei1103 <[email protected]>
  • Loading branch information
JiangJiaWei1103 committed Feb 22, 2025
2 parents 46b7b62 + acfaa76 commit 6227801
Show file tree
Hide file tree
Showing 175 changed files with 7,959 additions and 1,511 deletions.
15 changes: 13 additions & 2 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
os:
- ubuntu-24.04-arm
- ubuntu-latest
- windows-latest
- macos-latest
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -78,7 +82,11 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
os:
- ubuntu-24.04-arm
- ubuntu-latest
- windows-latest
- macos-latest
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -303,6 +311,7 @@ jobs:
- flytekit-huggingface
- flytekit-identity-aware-proxy
- flytekit-inference
- flytekit-k8sdataservice
- flytekit-k8s-pod
- flytekit-kf-mpi
- flytekit-kf-pytorch
Expand All @@ -319,10 +328,12 @@ jobs:
# flytekit-onnx-tensorflow
- flytekit-omegaconf
- flytekit-openai
- flytekit-optuna
- flytekit-pandera
- flytekit-papermill
- flytekit-polars
- flytekit-ray
- flytekit-slurm
- flytekit-snowflake
- flytekit-spark
- flytekit-sqlalchemy
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ coverage.xml

# Version file is auto-generated by setuptools_scm
flytekit/_version.py
testing
10 changes: 9 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.9
rev: v0.8.3
hooks:
# Run the linter.
- id: ruff
Expand All @@ -28,3 +28,11 @@ repos:
- id: codespell
additional_dependencies:
- tomli
- repo: https://github.com/jsh9/pydoclint
rev: 0.6.0
hooks:
- id: pydoclint
args:
- --style=google
- --exclude='.git|tests/flytekit/*|tests/'
- --baseline=pydoclint-errors-baseline.txt
1 change: 1 addition & 0 deletions Dockerfile.agent
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ RUN apt-get update && apt-get install build-essential -y \
RUN uv pip install --system --no-cache-dir -U flytekit==$VERSION \
flytekitplugins-airflow==$VERSION \
flytekitplugins-bigquery==$VERSION \
flytekitplugins-k8sdataservice==$VERSION \
flytekitplugins-openai==$VERSION \
flytekitplugins-snowflake==$VERSION \
flytekitplugins-awssagemaker==$VERSION \
Expand Down
4 changes: 2 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ filelock==3.14.0
# via
# snowflake-connector-python
# virtualenv
flyteidl==1.14.1
flyteidl==1.14.3
# via flytekit
frozenlist==1.4.1
# via
Expand Down Expand Up @@ -491,7 +491,7 @@ six==1.16.0
# isodate
# kubernetes
# python-dateutil
snowflake-connector-python==3.12.3
snowflake-connector-python==3.13.1
# via -r dev-requirements.in
sortedcontainers==2.4.0
# via
Expand Down
12 changes: 12 additions & 0 deletions docs/source/plugins/k8sstatefuldataservice.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. k8sstatefuldataservice:
###################################################
Kubernetes StatefulSet Data Service API reference
###################################################

.. tags:: Integration, DeepLearning, MachineLearning, Kubernetes, GNN

.. automodule:: flytekitplugins.k8sdataservice
:no-members:
:no-inherited-members:
:no-special-members:
5 changes: 5 additions & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@
:toctree: generated/
HashMethod
Cache
CachePolicy
VersionParameters
Artifacts
=========
Expand Down Expand Up @@ -223,11 +226,13 @@
from flytekit.core.artifact import Artifact
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.base_task import SecurityContext, TaskMetadata, kwtypes
from flytekit.core.cache import Cache, CachePolicy, VersionParameters
from flytekit.core.checkpointer import Checkpoint
from flytekit.core.condition import conditional
from flytekit.core.container_task import ContainerTask
from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.environment import Environment
from flytekit.core.gate import approve, sleep, wait_for_input
from flytekit.core.hash import HashMethod
from flytekit.core.launch_plan import LaunchPlan, reference_launch_plan
Expand Down
4 changes: 2 additions & 2 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _dispatch_execute(
exc_str = get_traceback_str(e)
output_file_dict[error_file_name] = _error_models.ErrorDocument(
_error_models.ContainerError(
code="USER",
code=e.error_code,
message=exc_str,
kind=kind,
origin=_execution_models.ExecutionError.ErrorKind.USER,
Expand Down Expand Up @@ -324,7 +324,7 @@ def _dispatch_execute(
logger.info(f"Engine folder written successfully to the output prefix {output_prefix}")

if task_def is not None and not getattr(task_def, "disable_deck", True):
_output_deck(task_def.name.split(".")[-1], ctx.user_space_params)
_output_deck(task_name=task_def.name.split(".")[-1], new_user_params=ctx.user_space_params)

logger.debug("Finished _dispatch_execute")

Expand Down
14 changes: 10 additions & 4 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,12 @@ def upgrade_channel_to_proxy_authenticated(cfg: PlatformConfig, in_channel: grpc
:param in_channel: grpc.Channel Precreated channel
:return: grpc.Channel. New composite channel
"""

def authenticator_factory():
return get_proxy_authenticator(cfg)

if cfg.proxy_command:
proxy_authenticator = get_proxy_authenticator(cfg)
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(proxy_authenticator))
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator_factory))
else:
return in_channel

Expand All @@ -137,8 +140,11 @@ def upgrade_channel_to_authenticated(cfg: PlatformConfig, in_channel: grpc.Chann
:param in_channel: grpc.Channel Precreated channel
:return: grpc.Channel. New composite channel
"""
authenticator = get_authenticator(cfg, RemoteClientConfigStore(in_channel))
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator))

def authenticator_factory():
return get_authenticator(cfg, RemoteClientConfigStore(in_channel))

return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator_factory))


def get_authenticated_channel(cfg: PlatformConfig) -> grpc.Channel:
Expand Down
16 changes: 16 additions & 0 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from flytekit.clients.raw import RawSynchronousFlyteClient as _RawSynchronousFlyteClient
from flytekit.models import common as _common
from flytekit.models import domain as _domain
from flytekit.models import execution as _execution
from flytekit.models import filters as _filters
from flytekit.models import launch_plan as _launch_plan
Expand Down Expand Up @@ -896,6 +897,21 @@ def list_projects_paginated(self, limit=100, token=None, filters=None, sort_by=N
str(projects.token),
)

####################################################################################################################
#
# Domain Endpoints
#
####################################################################################################################

def get_domains(self):
"""
This returns a list of domains.
:rtype: list[flytekit.models.Domain]
"""
domains = super(SynchronousFlyteClient, self).get_domains()
return [_domain.Domain.from_flyte_idl(domain) for domain in domains.domains]

####################################################################################################################
#
# Matching Attributes Endpoints
Expand Down
17 changes: 12 additions & 5 deletions flytekit/clients/grpc_utils/auth_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,22 @@ class AuthUnaryInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamCli
is needed.
"""

def __init__(self, authenticator: Authenticator):
self._authenticator = authenticator
def __init__(self, get_authenticator: typing.Callable[[], Authenticator]):
self._get_authenticator = get_authenticator
self._authenticator = None

@property
def authenticator(self) -> Authenticator:
if self._authenticator is None:
self._authenticator = self._get_authenticator()
return self._authenticator

def _call_details_with_auth_metadata(self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails:
"""
Returns new ClientCallDetails with metadata added.
"""
metadata = client_call_details.metadata
auth_metadata = self._authenticator.fetch_grpc_call_auth_metadata()
auth_metadata = self.authenticator.fetch_grpc_call_auth_metadata()
if auth_metadata:
metadata = []
if client_call_details.metadata:
Expand Down Expand Up @@ -64,7 +71,7 @@ def intercept_unary_unary(
if not hasattr(e, "code"):
raise e
if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
self._authenticator.refresh_credentials()
self.authenticator.refresh_credentials()
updated_call_details = self._call_details_with_auth_metadata(client_call_details)
return continuation(updated_call_details, request)
return fut
Expand All @@ -76,7 +83,7 @@ def intercept_unary_stream(self, continuation, client_call_details, request):
updated_call_details = self._call_details_with_auth_metadata(client_call_details)
c: grpc.Call = continuation(updated_call_details, request)
if c.code() == grpc.StatusCode.UNAUTHENTICATED:
self._authenticator.refresh_credentials()
self.authenticator.refresh_credentials()
updated_call_details = self._call_details_with_auth_metadata(client_call_details)
return continuation(updated_call_details, request)
return c
17 changes: 16 additions & 1 deletion flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing

import grpc
from flyteidl.admin.project_pb2 import ProjectListRequest
from flyteidl.admin.project_pb2 import GetDomainRequest, ProjectListRequest
from flyteidl.admin.signal_pb2 import SignalList, SignalListRequest, SignalSetRequest, SignalSetResponse
from flyteidl.service import admin_pb2_grpc as _admin_service
from flyteidl.service import dataproxy_pb2 as _dataproxy_pb2
Expand Down Expand Up @@ -520,6 +520,21 @@ def update_project(self, project):
"""
return self._stub.UpdateProject(project, metadata=self._metadata)

####################################################################################################################
#
# Domain Endpoints
#
####################################################################################################################

def get_domains(self):
"""
This will return a list of domains registered with the Flyte Admin Service
:param flyteidl.admin.project_pb2.GetDomainRequest get_domain_request:
:rtype: flyteidl.admin.project_pb2.GetDomainsResponse
"""
get_domain_request = GetDomainRequest()
return self._stub.GetDomains(get_domain_request, metadata=self._metadata)

####################################################################################################################
#
# Matching Attributes Endpoints
Expand Down
4 changes: 1 addition & 3 deletions flytekit/clis/sdk_in_container/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,4 @@ def init(template, project_name):
processed_contents = project_template_regex.sub(project_name_bytes, zip_contents)
dest_file.write(processed_contents)

click.echo(
f"Visit the {project_name} directory and follow the next steps in the Getting started guide (https://docs.flyte.org/en/latest/user_guide/getting_started_with_workflow_development/index.html) to proceed."
)
click.echo(f"Project initialized in directory {project_name}.")
21 changes: 19 additions & 2 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,9 @@ def run_remote(
if run_level_params.wait_execution:
msg += " Waiting to complete..."
p = Progress(TimeElapsedColumn(), TextColumn(msg), transient=True)
t = p.add_task("exec")
t = p.add_task("exec", visible=False)
with p:
p.update(t, visible=True)
p.start_task(t)
execution = remote.execute(
entity,
Expand Down Expand Up @@ -1051,9 +1052,25 @@ def _create_command(
r = run_level_params.remote_instance()
flyte_ctx = r.context

final_inputs_with_defaults = loaded_entity.python_interface.inputs_with_defaults
if isinstance(loaded_entity, LaunchPlan):
# For LaunchPlans it is essential to handle fixed inputs and default inputs in a special way
# Fixed inputs are inputs that are always passed to the launch plan and cannot be overridden
# Default inputs are inputs that are optional and have a default value
# The final inputs to the launch plan are a combination of the fixed inputs and the default inputs
all_inputs = loaded_entity.python_interface.inputs_with_defaults
default_inputs = loaded_entity.saved_inputs
pmap = loaded_entity.parameters
final_inputs_with_defaults = {}
for name, _ in pmap.parameters.items():
_type, v = all_inputs[name]
if name in default_inputs:
v = default_inputs[name]
final_inputs_with_defaults[name] = _type, v

# Add options for each of the workflow inputs
params = []
for input_name, input_type_val in loaded_entity.python_interface.inputs_with_defaults.items():
for input_name, input_type_val in final_inputs_with_defaults.items():
literal_var = loaded_entity.interface.inputs.get(input_name)
python_type, default_val = input_type_val
required = type(None) not in get_args(python_type) and default_val is None
Expand Down
19 changes: 13 additions & 6 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def serve(ctx: click.Context):
type=int,
help="Grpc port for the agent service",
)
@click.option(
"--prometheus_port",
default="9090",
is_flag=False,
type=int,
help="Prometheus port for the agent service",
)
@click.option(
"--worker",
default="10",
Expand All @@ -45,20 +52,20 @@ def serve(ctx: click.Context):
"for testing.",
)
@click.pass_context
def agent(_: click.Context, port, worker, timeout):
def agent(_: click.Context, port, prometheus_port, worker, timeout):
"""
Start a grpc server for the agent service.
"""
import asyncio

asyncio.run(_start_grpc_server(port, worker, timeout))
asyncio.run(_start_grpc_server(port, prometheus_port, worker, timeout))


async def _start_grpc_server(port: int, worker: int, timeout: int):
async def _start_grpc_server(port: int, prometheus_port: int, worker: int, timeout: int):
from flytekit.extend.backend.agent_service import AgentMetadataService, AsyncAgentService, SyncAgentService

click.secho("🚀 Starting the agent service...")
_start_http_server()
_start_http_server(prometheus_port)
print_agents_metadata()

server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=worker))
Expand All @@ -73,12 +80,12 @@ async def _start_grpc_server(port: int, worker: int, timeout: int):
await server.wait_for_termination(timeout)


def _start_http_server():
def _start_http_server(prometheus_port: int):
try:
from prometheus_client import start_http_server

click.secho("Starting up the server to expose the prometheus metrics...")
start_http_server(9090)
start_http_server(prometheus_port)
except ImportError as e:
click.secho(f"Failed to start the prometheus server with error {e}", fg="red")

Expand Down
Loading

0 comments on commit 6227801

Please sign in to comment.