Skip to content

Commit

Permalink
Allow user to configure callbacks server side (#317)
Browse files Browse the repository at this point in the history
This PR makes it possible for a user to configure server side callbacks.

A user can specify a per request modifier to add callbacks into the
config, without them being over-written by the event aggregator
callback.
  • Loading branch information
eyurtsev authored Dec 13, 2023
1 parent 3c6cce8 commit 87b54a3
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fastapi import HTTPException
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.tracers.log_stream import RunLogPatch
from langchain.load.serializable import Serializable
from langchain.schema.runnable import Runnable, RunnableConfig
Expand Down Expand Up @@ -371,6 +372,15 @@ def _json_encode_response(model: BaseModel) -> JSONResponse:
return JSONResponse(content=obj)


def _add_callbacks(
config: RunnableConfig, callbacks: Sequence[AsyncCallbackHandler]
) -> None:
"""Add the callback aggregator to the config."""
if "callbacks" not in config:
config["callbacks"] = []
config["callbacks"].extend(callbacks)


class _APIHandler:
"""Implementation of the various API endpoints for a runnable server.
Expand Down Expand Up @@ -567,7 +577,7 @@ async def invoke(
)

event_aggregator = AsyncEventAggregatorCallback()
config["callbacks"] = [event_aggregator]
_add_callbacks(config, [event_aggregator])
output = await self.runnable.ainvoke(input_, config=config)

if self.include_callback_events:
Expand Down Expand Up @@ -661,7 +671,7 @@ async def batch(

final_configs = []
for config_, aggregator in zip(configs_, aggregators):
config_["callbacks"] = [aggregator]
_add_callbacks(config_, [aggregator])
final_configs.append(
_update_config_with_defaults(
self.path, config_, request, endpoint="batch"
Expand Down Expand Up @@ -741,7 +751,7 @@ async def _stream() -> AsyncIterator[dict]:
try:
config_w_callbacks = config.copy()
event_aggregator = AsyncEventAggregatorCallback()
config_w_callbacks["callbacks"] = [event_aggregator]
_add_callbacks(config_w_callbacks, [event_aggregator])
has_sent_metadata = False
async for chunk in self.runnable.astream(
input_,
Expand Down

0 comments on commit 87b54a3

Please sign in to comment.