From bfad05a89487f213e51463ea041ac3d89559b3e9 Mon Sep 17 00:00:00 2001 From: jakerachleff Date: Mon, 20 Nov 2023 14:00:07 -0800 Subject: [PATCH] Revert "add default configs for all served runnables (#209)" This reverts commit b60b52daf6e9782d91e7c4e7576bd55465ff837a. --- langserve/server.py | 113 +++++++------------------ tests/unit_tests/test_server_client.py | 7 +- 2 files changed, 31 insertions(+), 89 deletions(-) diff --git a/langserve/server.py b/langserve/server.py index a80d6d50..7d9a4407 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -7,7 +7,6 @@ """ import contextlib import json -import os import re import weakref from inspect import isclass @@ -133,56 +132,6 @@ def _unpack_request_config( ) -def _update_config_with_defaults( - path: str, - incomingConfig: RunnableConfig, - request: Request, - *, - endpoint: Optional[str] = None, -) -> RunnableConfig: - """Set up some baseline configuration for the underlying runnable.""" - - # Currently all defaults are non-overridable - overridable_default_config = RunnableConfig() - - metadata = { - "__useragent": request.headers.get("user-agent"), - "__langserve_version": __version__, - } - - if endpoint: - metadata["__langserve_endpoint"] = endpoint - - is_hosted = os.environ.get("HOSTED_LANGSERVE_ENABLED", "false").lower() == "true" - if is_hosted: - hosted_metadata = { - "__langserve_hosted_git_commit_sha": os.environ.get( - "HOSTED_LANGSERVE_GIT_COMMIT", "" - ), - "__langserve_hosted_repo_subdirectory_path": os.environ.get( - "HOSTED_LANGSERVE_GIT_REPO_PATH", "" - ), - "__langserve_hosted_repo_url": os.environ.get( - "HOSTED_LANGSERVE_GIT_REPO", "" - ), - } - metadata.update(hosted_metadata) - - non_overridable_default_config = RunnableConfig( - run_name=path, - metadata=metadata, - ) - - # merge_configs is last-writer-wins, so we specifically pass in the - # overridable configs first, then the user provided configs, then - # finally the non-overridable configs - return merge_configs( - overridable_default_config, - incomingConfig, - non_overridable_default_config, - ) - - def _unpack_input(validated_model: BaseModel) -> Any: """Unpack the decoded input from the validated model.""" if hasattr(validated_model, "__root__"): @@ -283,6 +232,24 @@ def _add_namespace_to_model(namespace: str, model: Type[BaseModel]) -> Type[Base return model_with_unique_name +def _add_tracing_info_to_metadata(config: Dict[str, Any], request: Request) -> None: + """Add information useful for tracing and debugging purposes. + + Args: + config: The config to expand with tracing information. + request: The request to use for expanding the metadata. + """ + + metadata = config["metadata"] if "metadata" in config else {} + + info = { + "__useragent": request.headers.get("user-agent"), + "__langserve_version": __version__, + } + metadata.update(info) + config["metadata"] = metadata + + def _scrub_exceptions_in_event(event: CallbackEventDict) -> CallbackEventDict: """Scrub exceptions and change to a serializable format.""" type_ = event["type"] @@ -500,8 +467,7 @@ def add_routes( This parameter may get deprecated! config_keys: list of config keys that will be accepted, by default will accept `configurable` key in the config. Will only be used - if the runnable is configurable. Cannot configure run_name, - which is set by default to the path of the API. + if the runnable is configurable. include_callback_events: Whether to include callback events in the response. If true, the client will be able to show trace information including events that occurred on the server side. @@ -537,11 +503,6 @@ def add_routes( f"If specifying path please start it with a `/`" ) - if "run_name" in config_keys: - raise ValueError( - "Cannot configure run_name. Please remove it from config_keys." - ) - namespace = path or "" model_namespace = _replace_non_alphanumeric_with_underscores(path.strip("/")) @@ -653,7 +614,7 @@ def _route_name_with_config(name: str) -> str: BatchResponse = create_batch_response_model(model_namespace, output_type_) async def _get_config_and_input( - request: Request, config_hash: str, *, endpoint: Optional[str] = None + request: Request, config_hash: str ) -> Tuple[RunnableConfig, Any]: """Extract the config and input from the request, validating the request.""" try: @@ -664,7 +625,7 @@ async def _get_config_and_input( body = InvokeRequestShallowValidator.validate(body) # Merge the config from the path with the config from the body. - user_provided_config = _unpack_request_config( + config = _unpack_request_config( config_hash, body.config, config_keys=config_keys, @@ -672,9 +633,6 @@ async def _get_config_and_input( request=request, per_req_config_modifier=per_req_config_modifier, ) - config = _update_config_with_defaults( - path, user_provided_config, request, endpoint=endpoint - ) # Unpack the input dynamically using the input schema of the runnable. # This takes into account changes in the input type when # using configuration. @@ -696,11 +654,10 @@ async def invoke( """Invoke the runnable with the given input and config.""" # We do not use the InvokeRequest model here since configurable runnables # have dynamic schema -- so the validation below is a bit more involved. - config, input_ = await _get_config_and_input( - request, config_hash, endpoint="invoke" - ) + config, input_ = await _get_config_and_input(request, config_hash) event_aggregator = AsyncEventAggregatorCallback() + _add_tracing_info_to_metadata(config, request) config["callbacks"] = [event_aggregator] output = await runnable.ainvoke(input_, config=config) @@ -795,14 +752,11 @@ async def batch( # Update the configuration with callbacks aggregators = [AsyncEventAggregatorCallback() for _ in range(len(inputs))] - final_configs = [] for config_, aggregator in zip(configs_, aggregators): + _add_tracing_info_to_metadata(config_, request) config_["callbacks"] = [aggregator] - final_configs.append( - _update_config_with_defaults(path, config_, request, endpoint="batch") - ) - output = await runnable.abatch(inputs, config=final_configs) + output = await runnable.abatch(inputs, config=configs_) if include_callback_events: callback_events = [ @@ -840,9 +794,7 @@ async def stream( err_event = {} validation_exception: Optional[BaseException] = None try: - config, input_ = await _get_config_and_input( - request, config_hash, endpoint="stream" - ) + config, input_ = await _get_config_and_input(request, config_hash) except BaseException as e: validation_exception = e if isinstance(e, RequestValidationError): @@ -1040,14 +992,13 @@ async def _stream_log() -> AsyncIterator[dict]: async def input_schema(request: Request, config_hash: str = "") -> Any: """Return the input schema of the runnable.""" with _with_validation_error_translation(): - user_provided_config = _unpack_request_config( + config = _unpack_request_config( config_hash, config_keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, ) - config = _update_config_with_defaults(path, user_provided_config, request) return runnable.get_input_schema(config).schema() @@ -1064,14 +1015,13 @@ async def input_schema(request: Request, config_hash: str = "") -> Any: async def output_schema(request: Request, config_hash: str = "") -> Any: """Return the output schema of the runnable.""" with _with_validation_error_translation(): - user_provided_config = _unpack_request_config( + config = _unpack_request_config( config_hash, config_keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, ) - config = _update_config_with_defaults(path, user_provided_config, request) return runnable.get_output_schema(config).schema() @app.get( @@ -1085,14 +1035,13 @@ async def output_schema(request: Request, config_hash: str = "") -> Any: async def config_schema(request: Request, config_hash: str = "") -> Any: """Return the config schema of the runnable.""" with _with_validation_error_translation(): - user_provided_config = _unpack_request_config( + config = _unpack_request_config( config_hash, config_keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, ) - config = _update_config_with_defaults(path, user_provided_config, request) return runnable.with_config(config).config_schema(include=config_keys).schema() @app.get( @@ -1105,7 +1054,7 @@ async def playground( ) -> Any: """Return the playground of the runnable.""" with _with_validation_error_translation(): - user_provided_config = _unpack_request_config( + config = _unpack_request_config( config_hash, config_keys=config_keys, model=ConfigPayload, @@ -1113,8 +1062,6 @@ async def playground( per_req_config_modifier=per_req_config_modifier, ) - config = _update_config_with_defaults(path, user_provided_config, request) - if isinstance(app, FastAPI): # type: ignore base_url = f"{namespace}/playground" else: diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 4c62b854..05c60635 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -882,7 +882,7 @@ async def add_one(x: int) -> int: server_runnable2, input_type=int, path="/add_one_config", - config_keys=["tags", "metadata"], + config_keys=["tags", "run_name", "metadata"], ) async with get_async_remote_runnable( @@ -909,11 +909,8 @@ async def add_one(x: int) -> int: # will still be added config_seen = server_runnable_spy.call_args[0][1] assert "metadata" in config_seen - assert "a" not in config_seen["metadata"] assert "__useragent" in config_seen["metadata"] assert "__langserve_version" in config_seen["metadata"] - assert "__langserve_endpoint" in config_seen["metadata"] - assert config_seen["metadata"]["__langserve_endpoint"] == "invoke" server_runnable2_spy = mocker.spy(server_runnable2, "ainvoke") async with get_async_remote_runnable(app, path="/add_one_config") as runnable2: @@ -926,8 +923,6 @@ async def add_one(x: int) -> int: assert config_seen["metadata"]["a"] == 5 assert "__useragent" in config_seen["metadata"] assert "__langserve_version" in config_seen["metadata"] - assert "__langserve_endpoint" in config_seen["metadata"] - assert config_seen["metadata"]["__langserve_endpoint"] == "invoke" async def test_input_validation_with_lc_types(event_loop: AbstractEventLoop) -> None: