Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes associated with feedback endpoint #593

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,6 @@ def _add_callbacks(
_SEEN_NAMES = set()


def _is_scoped_feedback_enabled() -> bool:
"""Temporary hard-coded as False. Used only to enable during unit tests."""
return False


class PerKeyFeedbackConfig(TypedDict):
"""Per feedback configuration.

Expand Down Expand Up @@ -650,7 +645,8 @@ def __init__(
# remember to make relevant updates in the unit tests.
self._langsmith_client = (
ls_client.Client()
if tracing_is_enabled() and enable_feedback_endpoint
if tracing_is_enabled()
and (enable_feedback_endpoint or self._token_feedback_enabled)
else None
)

Expand Down Expand Up @@ -1181,6 +1177,7 @@ async def stream_log(
endpoint="stream_log",
server_config=server_config,
)
run_id = config["run_id"]
except BaseException:
# Exceptions will be properly translated by default FastAPI middleware
# to either 422 (on input validation) or 500 internal server errors.
Expand All @@ -1194,8 +1191,25 @@ async def stream_log(
except RequestValidationError:
raise

feedback_key: Optional[str]

if self._token_feedback_enabled:
# Create task to create a presigned feedback token
feedback_key: str = self._token_feedback_config["key_configs"][0]["key"]
feedback_coro = run_in_executor(
None,
self._langsmith_client.create_presigned_feedback_token,
run_id,
feedback_key,
)
task: Optional[asyncio.Task] = asyncio.create_task(feedback_coro)
else:
feedback_key = None
task = None

async def _stream_log() -> AsyncIterator[dict]:
"""Stream the output of the runnable."""
has_sent_metadata = False
try:
async for chunk in self._runnable.astream_log(
input_,
Expand Down Expand Up @@ -1229,6 +1243,18 @@ async def _stream_log() -> AsyncIterator[dict]:
"data": self._serializer.dumps(data).decode("utf-8"),
"event": "data",
}

# Send a metadata event as soon as possible
if not has_sent_metadata and self._enable_feedback_endpoint:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If _token_feedback_enabled is False but _enable_feedback_endpoint is True, task will be None, causing AssertionError("Feedback token task was not created.") always raise. Do you mean to check
_token_feedback_enabled here?
Can be reproduced by using the default example in

langserve/README.md

Lines 399 to 406 in 109445b

add_routes(
app,
chain.with_types(input_type=InputChat),
enable_feedback_endpoint=True,
enable_public_trace_link_endpoint=True,
playground_type="chat",
)
```

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed at #620

if task is None:
raise AssertionError("Feedback token task was not created.")
if not task.done():
continue
feedback_token = task.result()
yield _create_metadata_event(
run_id, feedback_key, feedback_token
)
has_sent_metadata = True
yield {"event": "end"}
except BaseException:
yield {
Expand Down
2 changes: 1 addition & 1 deletion langserve/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class BaseFeedback(BaseModel):
class FeedbackCreateRequestTokenBased(BaseModel):
"""Shared information between create requests of feedback and feedback objects."""

token_or_url: UUID
token_or_url: Union[UUID, str]
"""The associated run ID this feedback is logged for."""

score: Optional[Union[float, int, bool]] = None
Expand Down
84 changes: 71 additions & 13 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2796,19 +2796,14 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
async def get_langsmith_client() -> AsyncIterator[MagicMock]:
"""Get a patched langsmith client."""
with patch("langserve.api_handler.ls_client") as mocked_ls_client_package:
with patch("langserve.api_handler._is_scoped_feedback_enabled") as f:
# Enable scoped feedback for now.
f.return_value = True
with patch(
"langserve.api_handler.tracing_is_enabled"
) as tracing_is_enabled:
tracing_is_enabled.return_value = True
mocked_client = MagicMock(auto_spec=Client)
mocked_ls_client_package.Client.return_value = mocked_client
yield mocked_client


async def test_scoped_feedback() -> None:
with patch("langserve.api_handler.tracing_is_enabled") as tracing_is_enabled:
tracing_is_enabled.return_value = True
mocked_client = MagicMock(auto_spec=Client)
mocked_ls_client_package.Client.return_value = mocked_client
yield mocked_client


async def test_token_feedback_included_in_responses() -> None:
"""Test that information to leave scoped feedback is passed to the client
is present in the server response.
"""
Expand Down Expand Up @@ -2954,6 +2949,35 @@ async def test_scoped_feedback() -> None:
"type": "metadata",
}

# Test astream log
response = await async_client.post(
"/stream_log",
json={"input": "hello"},
)
events = _decode_eventstream(response.text)
for event in events:
if "data" in event and "run_id" in event["data"]:
del event["data"]["run_id"]

# Find the metadata event and pull it out
metadata_event = None
for event in events:
if event["type"] == "metadata":
metadata_event = event

assert metadata_event == {
"data": {
"feedback_tokens": [
{
"expires_at": "2023-01-01T00:00:00",
"key": "foo",
"token_url": "feedback_id",
}
]
},
"type": "metadata",
}


async def test_passing_run_id_from_client() -> None:
"""test that the client can set a run id if server allows it."""
Expand Down Expand Up @@ -3014,3 +3038,37 @@ async def test_passing_bad_runnable_to_add_routes() -> None:
add_routes(FastAPI(), "not a runnable")

assert e.match("Expected a Runnable, got <class 'str'>")


async def test_token_feedback_endpoint() -> None:
"""Tests that the feedback endpoint can accept feedback to langsmith."""
async with get_langsmith_client() as client:
local_app = FastAPI()
add_routes(
local_app,
RunnableLambda(lambda foo: "hello"),
token_feedback_config={
"key_configs": [
{
"key": "silliness",
}
]
},
)

async with get_async_test_client(
local_app, raise_app_exceptions=True
) as async_client:
response = await async_client.post(
"/token_feedback", json={"token_or_url": "some_url", "score": 3}
)
assert response.status_code == 200

call = client.create_feedback_from_token.call_args
assert call.args[0] == "some_url"
assert call.kwargs == {
"comment": None,
"metadata": {"from_langserve": True},
"score": 3,
"value": None,
}
Loading