diff --git a/makefile b/makefile index 329c2df..71675b2 100644 --- a/makefile +++ b/makefile @@ -21,7 +21,7 @@ test: @echo Running Tox tests @tox -e py -NITRIC_VERSION := 1.6.0 +NITRIC_VERSION := 1.14.0 download-local: @rm -r ./proto/nitric diff --git a/nitric/application.py b/nitric/application.py index 38e2e6a..069d9be 100644 --- a/nitric/application.py +++ b/nitric/application.py @@ -42,6 +42,8 @@ class Nitric: "keyvaluestore": {}, "oidcsecuritydefinition": {}, "sql": {}, + "job": {}, + "jobdefinition": {}, } @classmethod diff --git a/nitric/context.py b/nitric/context.py index e696998..b7ba59a 100644 --- a/nitric/context.py +++ b/nitric/context.py @@ -31,6 +31,11 @@ from nitric.proto.topics.v1 import ClientMessage as TopicClientMessage from nitric.proto.topics.v1 import MessageResponse as TopicResponse from nitric.proto.topics.v1 import ServerMessage as TopicServerMessage +from nitric.proto.batch.v1 import ( + ServerMessage as BatchServerMessage, + ClientMessage as BatchClientMessage, + JobResponse as BatchJobResponse, +) from nitric.utils import dict_from_struct Record = Dict[str, Union[str, List[str]]] @@ -363,6 +368,42 @@ async def chained_middleware(ctx: C, nxt: Optional[Middleware[C]] = None) -> C: return composed +class JobRequest: + """Represents a translated Job, from a Job Definition, forwarded from the Nitric Runtime Server.""" + + data: dict[str, Any] + + def __init__(self, data: dict[str, Any]): + """Construct a new JobRequest.""" + self.data = data + + +class JobResponse: + """Represents the response to a trigger from a Job submission as a result of a SubmitJob call.""" + + def __init__(self, success: bool = True): + """Construct a new EventResponse.""" + self.success = success + + +class JobContext: + """Represents the full request/response context for an Event based trigger.""" + + def __init__(self, request: JobRequest, response: Optional[JobResponse] = None): + """Construct a new EventContext.""" + self.req = request + self.res = response if response else JobResponse() + + @staticmethod + def _from_request(msg: BatchServerMessage) -> "JobContext": + """Construct a new EventContext from a Topic trigger from the Nitric Membrane.""" + return JobContext(request=JobRequest(data=dict_from_struct(msg.job_request.data.struct))) + + def to_response(self) -> BatchClientMessage: + """Construct a EventContext for the Nitric Membrane from this context object.""" + return BatchClientMessage(job_response=BatchJobResponse(success=self.res.success)) + + class FunctionServer(ABC): """Represents a worker that should be started at runtime.""" diff --git a/nitric/proto/batch/__init__.py b/nitric/proto/batch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nitric/proto/batch/v1/__init__.py b/nitric/proto/batch/v1/__init__.py new file mode 100644 index 0000000..1f64499 --- /dev/null +++ b/nitric/proto/batch/v1/__init__.py @@ -0,0 +1,211 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# sources: nitric/proto/batch/v1/batch.proto +# plugin: python-betterproto +# This file has been @generated + +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + AsyncIterable, + AsyncIterator, + Dict, + Iterable, + Optional, + Union, +) + +import betterproto +import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf +import grpclib +from betterproto.grpc.grpclib_server import ServiceBase + + +if TYPE_CHECKING: + import grpclib.server + from betterproto.grpc.grpclib_client import MetadataLike + from grpclib.metadata import Deadline + + +@dataclass(eq=False, repr=False) +class ClientMessage(betterproto.Message): + id: str = betterproto.string_field(1) + """globally unique ID of the request/response pair""" + + registration_request: "RegistrationRequest" = betterproto.message_field( + 2, group="content" + ) + """Register a handler for a job""" + + job_response: "JobResponse" = betterproto.message_field(3, group="content") + """Handle a job submission""" + + +@dataclass(eq=False, repr=False) +class JobRequest(betterproto.Message): + job_name: str = betterproto.string_field(1) + data: "JobData" = betterproto.message_field(2) + + +@dataclass(eq=False, repr=False) +class JobData(betterproto.Message): + struct: "betterproto_lib_google_protobuf.Struct" = betterproto.message_field( + 1, group="data" + ) + + +@dataclass(eq=False, repr=False) +class JobResponse(betterproto.Message): + success: bool = betterproto.bool_field(1) + """Mark if the job was successfully processed""" + + +@dataclass(eq=False, repr=False) +class RegistrationRequest(betterproto.Message): + job_name: str = betterproto.string_field(1) + requirements: "JobResourceRequirements" = betterproto.message_field(2) + """Register with default requirements""" + + +@dataclass(eq=False, repr=False) +class RegistrationResponse(betterproto.Message): + pass + + +@dataclass(eq=False, repr=False) +class JobResourceRequirements(betterproto.Message): + cpus: float = betterproto.float_field(1) + """The number of CPUs to allocate for the job""" + + memory: int = betterproto.int64_field(2) + """The amount of memory to allocate for the job""" + + gpus: int = betterproto.int64_field(3) + """The number of GPUs to allocate for the job""" + + +@dataclass(eq=False, repr=False) +class ServerMessage(betterproto.Message): + """ + ServerMessage is the message sent from the nitric server to the service + """ + + id: str = betterproto.string_field(1) + """globally unique ID of the request/response pair""" + + registration_response: "RegistrationResponse" = betterproto.message_field( + 2, group="content" + ) + """ + + """ + + job_request: "JobRequest" = betterproto.message_field(3, group="content") + """Request to a job handler""" + + +@dataclass(eq=False, repr=False) +class JobSubmitRequest(betterproto.Message): + job_name: str = betterproto.string_field(1) + """The name of the job that should handle the data""" + + data: "JobData" = betterproto.message_field(2) + """The data to be processed by the job""" + + +@dataclass(eq=False, repr=False) +class JobSubmitResponse(betterproto.Message): + pass + + +class JobStub(betterproto.ServiceStub): + async def handle_job( + self, + client_message_iterator: Union[ + AsyncIterable["ClientMessage"], Iterable["ClientMessage"] + ], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional["MetadataLike"] = None + ) -> AsyncIterator["ServerMessage"]: + async for response in self._stream_stream( + "/nitric.proto.batch.v1.Job/HandleJob", + client_message_iterator, + ClientMessage, + ServerMessage, + timeout=timeout, + deadline=deadline, + metadata=metadata, + ): + yield response + + +class BatchStub(betterproto.ServiceStub): + async def submit_job( + self, + job_submit_request: "JobSubmitRequest", + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional["MetadataLike"] = None + ) -> "JobSubmitResponse": + return await self._unary_unary( + "/nitric.proto.batch.v1.Batch/SubmitJob", + job_submit_request, + JobSubmitResponse, + timeout=timeout, + deadline=deadline, + metadata=metadata, + ) + + +class JobBase(ServiceBase): + async def handle_job( + self, client_message_iterator: AsyncIterator["ClientMessage"] + ) -> AsyncIterator["ServerMessage"]: + raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) + yield ServerMessage() + + async def __rpc_handle_job( + self, stream: "grpclib.server.Stream[ClientMessage, ServerMessage]" + ) -> None: + request = stream.__aiter__() + await self._call_rpc_handler_server_stream( + self.handle_job, + stream, + request, + ) + + def __mapping__(self) -> Dict[str, grpclib.const.Handler]: + return { + "/nitric.proto.batch.v1.Job/HandleJob": grpclib.const.Handler( + self.__rpc_handle_job, + grpclib.const.Cardinality.STREAM_STREAM, + ClientMessage, + ServerMessage, + ), + } + + +class BatchBase(ServiceBase): + async def submit_job( + self, job_submit_request: "JobSubmitRequest" + ) -> "JobSubmitResponse": + raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) + + async def __rpc_submit_job( + self, stream: "grpclib.server.Stream[JobSubmitRequest, JobSubmitResponse]" + ) -> None: + request = await stream.recv_message() + response = await self.submit_job(request) + await stream.send_message(response) + + def __mapping__(self) -> Dict[str, grpclib.const.Handler]: + return { + "/nitric.proto.batch.v1.Batch/SubmitJob": grpclib.const.Handler( + self.__rpc_submit_job, + grpclib.const.Cardinality.UNARY_UNARY, + JobSubmitRequest, + JobSubmitResponse, + ), + } diff --git a/nitric/proto/deployments/v1/__init__.py b/nitric/proto/deployments/v1/__init__.py index 0b727bf..da2b6f0 100644 --- a/nitric/proto/deployments/v1/__init__.py +++ b/nitric/proto/deployments/v1/__init__.py @@ -17,6 +17,7 @@ import grpclib from betterproto.grpc.grpclib_server import ServiceBase +from ...batch import v1 as __batch_v1__ from ...resources import v1 as __resources_v1__ from ...storage import v1 as __storage_v1__ @@ -201,6 +202,36 @@ def __post_init__(self) -> None: warnings.warn("Service.memory is deprecated", DeprecationWarning) +@dataclass(eq=False, repr=False) +class Job(betterproto.Message): + name: str = betterproto.string_field(1) + """The name of the job to create""" + + requirements: "__batch_v1__.JobResourceRequirements" = betterproto.message_field(2) + """The default resource requirements of the job""" + + +@dataclass(eq=False, repr=False) +class Batch(betterproto.Message): + image: "ImageSource" = betterproto.message_field(1, group="source") + """Image URI for this batch service""" + + type: str = betterproto.string_field(10) + """ + A simple type property describes the requested type of batch that this + should be for this project, a provider can implement how this request is + satisfied in any way + """ + + env: Dict[str, str] = betterproto.map_field( + 11, betterproto.TYPE_STRING, betterproto.TYPE_STRING + ) + """Environment variables for this Batch""" + + jobs: List["Job"] = betterproto.message_field(12) + """Jobs that are defined in this Batch""" + + @dataclass(eq=False, repr=False) class Bucket(betterproto.Message): listeners: List["BucketListener"] = betterproto.message_field(1) @@ -338,6 +369,7 @@ class Resource(betterproto.Message): http: "Http" = betterproto.message_field(19, group="config") queue: "Queue" = betterproto.message_field(20, group="config") sql_database: "SqlDatabase" = betterproto.message_field(21, group="config") + batch: "Batch" = betterproto.message_field(22, group="config") @dataclass(eq=False, repr=False) diff --git a/nitric/proto/resources/v1/__init__.py b/nitric/proto/resources/v1/__init__.py index 75babe0..d03ee6a 100644 --- a/nitric/proto/resources/v1/__init__.py +++ b/nitric/proto/resources/v1/__init__.py @@ -38,6 +38,10 @@ class ResourceType(betterproto.Enum): ApiSecurityDefinition = 12 Queue = 13 SqlDatabase = 14 + Batch = 15 + """Batches represent a collection of jobs""" + + Job = 16 class Action(betterproto.Enum): @@ -66,6 +70,8 @@ class Action(betterproto.Enum): """Queue Permissions: 6XX""" QueueDequeue = 601 + JobSubmit = 700 + """Job Permissions: 7XX""" @dataclass(eq=False, repr=False) @@ -99,6 +105,7 @@ class ResourceDeclareRequest(betterproto.Message): ) queue: "QueueResource" = betterproto.message_field(17, group="config") sql_database: "SqlDatabaseResource" = betterproto.message_field(18, group="config") + job: "JobResource" = betterproto.message_field(19, group="config") @dataclass(eq=False, repr=False) @@ -126,6 +133,11 @@ class SecretResource(betterproto.Message): pass +@dataclass(eq=False, repr=False) +class JobResource(betterproto.Message): + pass + + @dataclass(eq=False, repr=False) class SqlDatabaseMigrations(betterproto.Message): migrations_path: str = betterproto.string_field(1, group="migrations") diff --git a/nitric/resources/__init__.py b/nitric/resources/__init__.py index d70df80..086ef2f 100644 --- a/nitric/resources/__init__.py +++ b/nitric/resources/__init__.py @@ -27,6 +27,7 @@ from nitric.resources.websockets import Websocket, websocket from nitric.resources.queues import Queue, queue from nitric.resources.sql import Sql, sql +from nitric.resources.job import job, Job __all__ = [ "api", @@ -38,6 +39,8 @@ "Bucket", "kv", "KeyValueStoreRef", + "job", + "Job", "oidc_rule", "queue", "Queue", diff --git a/nitric/resources/job.py b/nitric/resources/job.py new file mode 100644 index 0000000..ccaedf8 --- /dev/null +++ b/nitric/resources/job.py @@ -0,0 +1,194 @@ +from nitric.resources.resource import SecureResource +from nitric.application import Nitric +from nitric.proto.resources.v1 import ( + Action, + JobResource, + ResourceDeclareRequest, + ResourceIdentifier, + ResourceType, +) +from nitric.context import JobContext +import logging +import betterproto +from nitric.proto.batch.v1 import ( + BatchStub, + JobSubmitRequest, + JobData, + JobStub, + RegistrationRequest, + ClientMessage, + JobResponse as ProtoJobResponse, + JobResourceRequirements, +) +from nitric.exception import exception_from_grpc_error +from grpclib import GRPCError +from grpclib.client import Channel +from typing import Callable, Any, Optional, Literal, List +from nitric.context import FunctionServer, Handler +from nitric.channel import ChannelManager +from nitric.bidi import AsyncNotifierList +from nitric.utils import struct_from_dict +import grpclib + + +JobPermission = Literal["submit"] +JobHandle = Handler[JobContext] + + +class JobHandler(FunctionServer): + """Function worker for Jobs.""" + + _handler: JobHandle + _registration_request: RegistrationRequest + _responses: AsyncNotifierList[ClientMessage] + + def __init__( + self, + job_name: str, + handler: JobHandle, + cpus: float | None = None, + memory: int | None = None, + gpus: int | None = None, + ): + """Construct a new WebsocketHandler.""" + self._handler = handler + self._responses = AsyncNotifierList() + self._registration_request = RegistrationRequest( + job_name=job_name, + requirements=JobResourceRequirements( + cpus=cpus if cpus is not None else 0, + memory=memory if memory is not None else 0, + gpus=gpus if gpus is not None else 0, + ), + ) + + async def _message_request_iterator(self): + # Register with the server + yield ClientMessage(registration_request=self._registration_request) + # wait for any responses for the server and send them + async for response in self._responses: + yield response + + async def start(self) -> None: + """Register this subscriber and listen for messages.""" + channel = ChannelManager.get_channel() + server = JobStub(channel=channel) + + try: + async for server_msg in server.handle_job(self._message_request_iterator()): + msg_type, _ = betterproto.which_one_of(server_msg, "content") + + if msg_type == "registration_response": + continue + if msg_type == "job_request": + ctx = JobContext._from_request(server_msg) + + response: ClientMessage + try: + resp_ctx = await self._handler(ctx) + if resp_ctx is None: + resp_ctx = ctx + + response = ClientMessage( + id=server_msg.id, + job_response=ProtoJobResponse(success=ctx.res.success), + ) + except Exception as e: # pylint: disable=broad-except + logging.exception("An unhandled error occurred in a job event handler: %s", e) + response = ClientMessage(id=server_msg.id, job_response=ProtoJobResponse(success=False)) + await self._responses.add_item(response) + except grpclib.exceptions.GRPCError as e: + print(f"Stream terminated: {e.message}") + except grpclib.exceptions.StreamTerminatedError: + print("Stream from membrane closed.") + finally: + print("Closing client stream") + channel.close() + + +class JobRef: + """A reference to a deployed job, used to interact with the job at runtime.""" + + _channel: Channel + _stub: BatchStub + name: str + + def __init__(self, name: str) -> None: + """Construct a reference to a deployed Job.""" + self._channel: Channel = ChannelManager.get_channel() + self._stub = BatchStub(channel=self._channel) + self.name = name + + def __del__(self) -> None: + # close the channel when this client is destroyed + if self._channel is not None: + self._channel.close() + + async def submit(self, data: dict[str, Any]) -> None: + """Submit a new execution for this job definition.""" + await self._stub.submit_job( + job_submit_request=JobSubmitRequest(job_name=self.name, data=JobData(struct=struct_from_dict(data))) + ) + + +class Job(SecureResource): + """A Job Definition.""" + + name: str + + def __init__(self, name: str): + """Job definition constructor.""" + super().__init__(name) + self.name = name + + async def _register(self) -> None: + try: + await self._resources_stub.declare( + resource_declare_request=ResourceDeclareRequest( + id=_to_resource_identifier(self), + job=JobResource(), + ) + ) + + except GRPCError as grpc_err: + raise exception_from_grpc_error(grpc_err) from grpc_err + + def _perms_to_actions(self, *args: JobPermission) -> List[Action]: + _permMap: dict[JobPermission, List[Action]] = {"submit": [Action.JobSubmit]} + + return [action for perm in args for action in _permMap[perm]] + + def allow(self, perm: JobPermission, *args: JobPermission) -> JobRef: + """Request the specified permissions to this resource.""" + str_args = [perm] + [str(permission) for permission in args] + self._register_policy(*str_args) + + return JobRef(self.name) + + def _to_resource_id(self) -> ResourceIdentifier: + return ResourceIdentifier(name=self.name, type=ResourceType.Job) + + def __call__( + self, cpus: Optional[float] = None, memory: Optional[int] = None, gpus: Optional[int] = None + ) -> Callable[[JobHandle], None]: + """Define the handler for this job definition.""" + + def decorator(function: JobHandle) -> None: + wrkr = JobHandler(self.name, function, cpus, memory, gpus) + Nitric._register_worker(wrkr) + + return decorator + + +def _to_resource_identifier(b: Job) -> ResourceIdentifier: + return ResourceIdentifier(name=b.name, type=ResourceType.Job) + + +def job(name: str) -> Job: + """ + Create and register a job. + + If a job has already been registered with the same name, the original reference will be reused. + """ + # type ignored because the create call are treated as protected. + return Nitric._create_resource(Job, name) # type: ignore pylint: disable=protected-access