diff --git a/langserve/server.py b/langserve/server.py index 469b73a3..56a9c23f 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -221,8 +221,25 @@ def _scrub_exceptions_in_event(event: CallbackEventDict) -> CallbackEventDict: _APP_SEEN = weakref.WeakSet() +_APP_TO_PATHS = weakref.WeakKeyDictionary() +def _register_path_for_app(app: Union[FastAPI, APIRouter], path: str) -> None: + """Register a path when its added to app. Raise if path already seen.""" + if app in _APP_TO_PATHS: + seen_paths = _APP_TO_PATHS.get(app) + if path in seen_paths: + raise ValueError( + f"A runnable already exists at path: {path}. If adding " + f"multiple runnables make sure they have different paths." + ) + seen_paths.add(path) + else: + _APP_TO_PATHS[app] = {path} + + +# PUBLIC API + # PUBLIC API @@ -279,6 +296,7 @@ def add_routes( "Use `pip install sse_starlette` to install." ) + _register_path_for_app(app, path) well_known_lc_serializer = WellKnownLCSerializer() if hasattr(app, "openapi_tags") and app not in _APP_SEEN: diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index ca682c94..8d05a726 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -1178,3 +1178,11 @@ def test_error_on_bad_path() -> None: with pytest.raises(ValueError): add_routes(app, RunnableLambda(lambda foo: "hello"), path="foo") add_routes(app, RunnableLambda(lambda foo: "hello"), path="/foo") + + +def test_error_on_path_collision() -> None: + """Test error on path collision.""" + app = FastAPI() + add_routes(app, RunnableLambda(lambda foo: "hello"), path="/foo") + with pytest.raises(ValueError): + add_routes(app, RunnableLambda(lambda foo: "hello"), path="/foo")