From efe280d9bfc59897275948cb28d2c3182be1aeb0 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 3 Apr 2024 16:01:38 -0400 Subject: [PATCH] x --- langserve/api_handler.py | 38 ++++++++++-- langserve/schema.py | 2 +- tests/unit_tests/test_server_client.py | 84 ++++++++++++++++++++++---- 3 files changed, 104 insertions(+), 20 deletions(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index e457b03b..d72f09cb 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -477,11 +477,6 @@ def _add_callbacks( _SEEN_NAMES = set() -def _is_scoped_feedback_enabled() -> bool: - """Temporary hard-coded as False. Used only to enable during unit tests.""" - return False - - class PerKeyFeedbackConfig(TypedDict): """Per feedback configuration. @@ -650,7 +645,8 @@ def __init__( # remember to make relevant updates in the unit tests. self._langsmith_client = ( ls_client.Client() - if tracing_is_enabled() and enable_feedback_endpoint + if tracing_is_enabled() + and (enable_feedback_endpoint or self._token_feedback_enabled) else None ) @@ -1181,6 +1177,7 @@ async def stream_log( endpoint="stream_log", server_config=server_config, ) + run_id = config["run_id"] except BaseException: # Exceptions will be properly translated by default FastAPI middleware # to either 422 (on input validation) or 500 internal server errors. @@ -1194,8 +1191,25 @@ async def stream_log( except RequestValidationError: raise + feedback_key: Optional[str] + + if self._token_feedback_enabled: + # Create task to create a presigned feedback token + feedback_key: str = self._token_feedback_config["key_configs"][0]["key"] + feedback_coro = run_in_executor( + None, + self._langsmith_client.create_presigned_feedback_token, + run_id, + feedback_key, + ) + task: Optional[asyncio.Task] = asyncio.create_task(feedback_coro) + else: + feedback_key = None + task = None + async def _stream_log() -> AsyncIterator[dict]: """Stream the output of the runnable.""" + has_sent_metadata = False try: async for chunk in self._runnable.astream_log( input_, @@ -1229,6 +1243,18 @@ async def _stream_log() -> AsyncIterator[dict]: "data": self._serializer.dumps(data).decode("utf-8"), "event": "data", } + + # Send a metadata event as soon as possible + if not has_sent_metadata and self._enable_feedback_endpoint: + if task is None: + raise AssertionError("Feedback token task was not created.") + if not task.done(): + continue + feedback_token = task.result() + yield _create_metadata_event( + run_id, feedback_key, feedback_token + ) + has_sent_metadata = True yield {"event": "end"} except BaseException: yield { diff --git a/langserve/schema.py b/langserve/schema.py index 6f729b7c..26392993 100644 --- a/langserve/schema.py +++ b/langserve/schema.py @@ -116,7 +116,7 @@ class BaseFeedback(BaseModel): class FeedbackCreateRequestTokenBased(BaseModel): """Shared information between create requests of feedback and feedback objects.""" - token_or_url: UUID + token_or_url: Union[UUID, str] """The associated run ID this feedback is logged for.""" score: Optional[Union[float, int, bool]] = None diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index aad1fc44..927346bb 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -2796,19 +2796,14 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory: async def get_langsmith_client() -> AsyncIterator[MagicMock]: """Get a patched langsmith client.""" with patch("langserve.api_handler.ls_client") as mocked_ls_client_package: - with patch("langserve.api_handler._is_scoped_feedback_enabled") as f: - # Enable scoped feedback for now. - f.return_value = True - with patch( - "langserve.api_handler.tracing_is_enabled" - ) as tracing_is_enabled: - tracing_is_enabled.return_value = True - mocked_client = MagicMock(auto_spec=Client) - mocked_ls_client_package.Client.return_value = mocked_client - yield mocked_client - - -async def test_scoped_feedback() -> None: + with patch("langserve.api_handler.tracing_is_enabled") as tracing_is_enabled: + tracing_is_enabled.return_value = True + mocked_client = MagicMock(auto_spec=Client) + mocked_ls_client_package.Client.return_value = mocked_client + yield mocked_client + + +async def test_token_feedback_included_in_responses() -> None: """Test that information to leave scoped feedback is passed to the client is present in the server response. """ @@ -2954,6 +2949,35 @@ async def test_scoped_feedback() -> None: "type": "metadata", } + # Test astream log + response = await async_client.post( + "/stream_log", + json={"input": "hello"}, + ) + events = _decode_eventstream(response.text) + for event in events: + if "data" in event and "run_id" in event["data"]: + del event["data"]["run_id"] + + # Find the metadata event and pull it out + metadata_event = None + for event in events: + if event["type"] == "metadata": + metadata_event = event + + assert metadata_event == { + "data": { + "feedback_tokens": [ + { + "expires_at": "2023-01-01T00:00:00", + "key": "foo", + "token_url": "feedback_id", + } + ] + }, + "type": "metadata", + } + async def test_passing_run_id_from_client() -> None: """test that the client can set a run id if server allows it.""" @@ -3014,3 +3038,37 @@ async def test_passing_bad_runnable_to_add_routes() -> None: add_routes(FastAPI(), "not a runnable") assert e.match("Expected a Runnable, got ") + + +async def test_token_feedback_endpoint() -> None: + """Tests that the feedback endpoint can accept feedback to langsmith.""" + async with get_langsmith_client() as client: + local_app = FastAPI() + add_routes( + local_app, + RunnableLambda(lambda foo: "hello"), + token_feedback_config={ + "key_configs": [ + { + "key": "silliness", + } + ] + }, + ) + + async with get_async_test_client( + local_app, raise_app_exceptions=True + ) as async_client: + response = await async_client.post( + "/token_feedback", json={"token_or_url": "some_url", "score": 3} + ) + assert response.status_code == 200 + + call = client.create_feedback_from_token.call_args + assert call.args[0] == "some_url" + assert call.kwargs == { + "comment": None, + "metadata": {"from_langserve": True}, + "score": 3, + "value": None, + }