Skip to content

Commit

Permalink
Fixes associated with feedback endpoint (#593)
Browse files Browse the repository at this point in the history
- Fixes langsmith client not being enabled if token feedback was
specified
- Add unit test for token feedback endpoint
- Fix schema for token feedback to accept a str in addition to a UUID
- Add metadata even to astream log with token feedback information
  • Loading branch information
eyurtsev authored Apr 3, 2024
1 parent 1765f63 commit 8439937
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 20 deletions.
38 changes: 32 additions & 6 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,6 @@ def _add_callbacks(
_SEEN_NAMES = set()


def _is_scoped_feedback_enabled() -> bool:
"""Temporary hard-coded as False. Used only to enable during unit tests."""
return False


class PerKeyFeedbackConfig(TypedDict):
"""Per feedback configuration.
Expand Down Expand Up @@ -650,7 +645,8 @@ def __init__(
# remember to make relevant updates in the unit tests.
self._langsmith_client = (
ls_client.Client()
if tracing_is_enabled() and enable_feedback_endpoint
if tracing_is_enabled()
and (enable_feedback_endpoint or self._token_feedback_enabled)
else None
)

Expand Down Expand Up @@ -1181,6 +1177,7 @@ async def stream_log(
endpoint="stream_log",
server_config=server_config,
)
run_id = config["run_id"]
except BaseException:
# Exceptions will be properly translated by default FastAPI middleware
# to either 422 (on input validation) or 500 internal server errors.
Expand All @@ -1194,8 +1191,25 @@ async def stream_log(
except RequestValidationError:
raise

feedback_key: Optional[str]

if self._token_feedback_enabled:
# Create task to create a presigned feedback token
feedback_key: str = self._token_feedback_config["key_configs"][0]["key"]
feedback_coro = run_in_executor(
None,
self._langsmith_client.create_presigned_feedback_token,
run_id,
feedback_key,
)
task: Optional[asyncio.Task] = asyncio.create_task(feedback_coro)
else:
feedback_key = None
task = None

async def _stream_log() -> AsyncIterator[dict]:
"""Stream the output of the runnable."""
has_sent_metadata = False
try:
async for chunk in self._runnable.astream_log(
input_,
Expand Down Expand Up @@ -1229,6 +1243,18 @@ async def _stream_log() -> AsyncIterator[dict]:
"data": self._serializer.dumps(data).decode("utf-8"),
"event": "data",
}

# Send a metadata event as soon as possible
if not has_sent_metadata and self._enable_feedback_endpoint:
if task is None:
raise AssertionError("Feedback token task was not created.")
if not task.done():
continue
feedback_token = task.result()
yield _create_metadata_event(
run_id, feedback_key, feedback_token
)
has_sent_metadata = True
yield {"event": "end"}
except BaseException:
yield {
Expand Down
2 changes: 1 addition & 1 deletion langserve/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class BaseFeedback(BaseModel):
class FeedbackCreateRequestTokenBased(BaseModel):
"""Shared information between create requests of feedback and feedback objects."""

token_or_url: UUID
token_or_url: Union[UUID, str]
"""The associated run ID this feedback is logged for."""

score: Optional[Union[float, int, bool]] = None
Expand Down
84 changes: 71 additions & 13 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2796,19 +2796,14 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
async def get_langsmith_client() -> AsyncIterator[MagicMock]:
"""Get a patched langsmith client."""
with patch("langserve.api_handler.ls_client") as mocked_ls_client_package:
with patch("langserve.api_handler._is_scoped_feedback_enabled") as f:
# Enable scoped feedback for now.
f.return_value = True
with patch(
"langserve.api_handler.tracing_is_enabled"
) as tracing_is_enabled:
tracing_is_enabled.return_value = True
mocked_client = MagicMock(auto_spec=Client)
mocked_ls_client_package.Client.return_value = mocked_client
yield mocked_client


async def test_scoped_feedback() -> None:
with patch("langserve.api_handler.tracing_is_enabled") as tracing_is_enabled:
tracing_is_enabled.return_value = True
mocked_client = MagicMock(auto_spec=Client)
mocked_ls_client_package.Client.return_value = mocked_client
yield mocked_client


async def test_token_feedback_included_in_responses() -> None:
"""Test that information to leave scoped feedback is passed to the client
is present in the server response.
"""
Expand Down Expand Up @@ -2954,6 +2949,35 @@ async def test_scoped_feedback() -> None:
"type": "metadata",
}

# Test astream log
response = await async_client.post(
"/stream_log",
json={"input": "hello"},
)
events = _decode_eventstream(response.text)
for event in events:
if "data" in event and "run_id" in event["data"]:
del event["data"]["run_id"]

# Find the metadata event and pull it out
metadata_event = None
for event in events:
if event["type"] == "metadata":
metadata_event = event

assert metadata_event == {
"data": {
"feedback_tokens": [
{
"expires_at": "2023-01-01T00:00:00",
"key": "foo",
"token_url": "feedback_id",
}
]
},
"type": "metadata",
}


async def test_passing_run_id_from_client() -> None:
"""test that the client can set a run id if server allows it."""
Expand Down Expand Up @@ -3014,3 +3038,37 @@ async def test_passing_bad_runnable_to_add_routes() -> None:
add_routes(FastAPI(), "not a runnable")

assert e.match("Expected a Runnable, got <class 'str'>")


async def test_token_feedback_endpoint() -> None:
"""Tests that the feedback endpoint can accept feedback to langsmith."""
async with get_langsmith_client() as client:
local_app = FastAPI()
add_routes(
local_app,
RunnableLambda(lambda foo: "hello"),
token_feedback_config={
"key_configs": [
{
"key": "silliness",
}
]
},
)

async with get_async_test_client(
local_app, raise_app_exceptions=True
) as async_client:
response = await async_client.post(
"/token_feedback", json={"token_or_url": "some_url", "score": 3}
)
assert response.status_code == 200

call = client.create_feedback_from_token.call_args
assert call.args[0] == "some_url"
assert call.kwargs == {
"comment": None,
"metadata": {"from_langserve": True},
"score": 3,
"value": None,
}

0 comments on commit 8439937

Please sign in to comment.