From dd25fdbd97aa1d05197d3751325224856636b6ae Mon Sep 17 00:00:00 2001 From: Jake Rachleff Date: Fri, 10 Nov 2023 15:46:48 -0800 Subject: [PATCH] fix all endpoints --- langserve/server.py | 56 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/langserve/server.py b/langserve/server.py index abac32ea..bee2f64b 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -439,7 +439,8 @@ def add_routes( Favor using runnable.with_types(input_type=..., output_type=...) instead. This parameter may get deprecated! config_keys: list of config keys that will be accepted, by default - no config keys are accepted. + no config keys are accepted. Cannot configure run_name, + which is set by default to the path of the API include_callback_events: Whether to include callback events in the response. If true, the client will be able to show trace information including events that occurred on the server side. @@ -474,6 +475,12 @@ def add_routes( f"Got an invalid path: {path}. " f"If specifying path please start it with a `/`" ) + + if "run_name" in config_keys: + raise ValueError( + f"Cannot configure run_name. " + f"Please remove it from config_keys." + ) namespace = path or "" @@ -555,9 +562,14 @@ def _route_name_with_config(name: str) -> str: InvokeResponse = create_invoke_response_model(model_namespace, output_type_) BatchResponse = create_batch_response_model(model_namespace, output_type_) - def _get_default_base_config() -> RunnableConfig: + def _update_config_with_defaults(incomingConfig: RunnableConfig) -> RunnableConfig: """Set up some baseline configuration for the underlying runnable.""" + + # Currently all defaults are non-overridable + overridable_default_config = RunnableConfig() + metadata = {} + is_hosted = os.environ.get("HOSTED_LANGSERVE_ENABLED", "false").lower() == "true" if is_hosted: hosted_metadata = { @@ -567,11 +579,21 @@ def _get_default_base_config() -> RunnableConfig: } metadata.update(hosted_metadata) - return RunnableConfig( + + non_overridable_default_config = RunnableConfig( run_name=path, metadata=metadata, ) + # merge_configs is last-writer-wins, so we specifically pass in the + # overridable configs first, then the user provided configs, then + # finally the non-overridable configs + return merge_configs( + overridable_default_config, + incomingConfig, + non_overridable_default_config, + ) + async def _get_config_and_input( request: Request, config_hash: str ) -> Tuple[RunnableConfig, Any]: @@ -584,7 +606,7 @@ async def _get_config_and_input( body = InvokeRequestShallowValidator.validate(body) # Merge the config from the path with the config from the body. - config = _unpack_request_config( + user_provided_config = _unpack_request_config( config_hash, body.config, keys=config_keys, @@ -592,6 +614,7 @@ async def _get_config_and_input( request=request, per_req_config_modifier=per_req_config_modifier, ) + config = _update_config_with_defaults(user_provided_config) # Unpack the input dynamically using the input schema of the runnable. # This takes into account changes in the input type when # using configuration. @@ -613,9 +636,9 @@ async def invoke( """Invoke the runnable with the given input and config.""" # We do not use the InvokeRequest model here since configurable runnables # have dynamic schema -- so the validation below is a bit more involved. - config, input_ = await _get_config_and_input(request, config_hash) + user_provided_config, input_ = await _get_config_and_input(request, config_hash) - config = merge_configs(_get_default_base_config(), config) + config = _update_config_with_defaults(user_provided_config) event_aggregator = AsyncEventAggregatorCallback() _add_tracing_info_to_metadata(config, request) config["callbacks"] = [event_aggregator] @@ -677,8 +700,7 @@ async def batch( model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, - ) - for config in config + ) for config in config ] elif isinstance(config, dict): configs = _unpack_request_config( @@ -703,6 +725,7 @@ async def batch( {k: v for k, v in config_.items() if k in config_keys} for config_ in get_config_list(configs, len(inputs_)) ] + print(configs_) inputs = [ _unpack_input(runnable.with_config(config_).input_schema.validate(input_)) @@ -712,11 +735,13 @@ async def batch( # Update the configuration with callbacks aggregators = [AsyncEventAggregatorCallback() for _ in range(len(inputs))] + final_configs = [] for config_, aggregator in zip(configs_, aggregators): _add_tracing_info_to_metadata(config_, request) config_["callbacks"] = [aggregator] + final_configs.append(_update_config_with_defaults(config_)) - output = await runnable.abatch(inputs, config=configs_) + output = await runnable.abatch(inputs, config=final_configs) if include_callback_events: callback_events = [ @@ -952,13 +977,14 @@ async def _stream_log() -> AsyncIterator[dict]: async def input_schema(request: Request, config_hash: str = "") -> Any: """Return the input schema of the runnable.""" with _with_validation_error_translation(): - config = _unpack_request_config( + user_provided_config = _unpack_request_config( config_hash, keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, ) + config = _update_config_with_defaults(user_provided_config) return runnable.get_input_schema(config).schema() @@ -975,13 +1001,14 @@ async def input_schema(request: Request, config_hash: str = "") -> Any: async def output_schema(request: Request, config_hash: str = "") -> Any: """Return the output schema of the runnable.""" with _with_validation_error_translation(): - config = _unpack_request_config( + user_provided_config = _unpack_request_config( config_hash, keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, ) + config = _update_config_with_defaults(user_provided_config) return runnable.get_output_schema(config).schema() @app.get( @@ -995,13 +1022,14 @@ async def output_schema(request: Request, config_hash: str = "") -> Any: async def config_schema(request: Request, config_hash: str = "") -> Any: """Return the config schema of the runnable.""" with _with_validation_error_translation(): - config = _unpack_request_config( + user_provided_config = _unpack_request_config( config_hash, keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, ) + config = _update_config_with_defaults(user_provided_config) return runnable.with_config(config).config_schema(include=config_keys).schema() @app.get( @@ -1014,13 +1042,15 @@ async def playground( ) -> Any: """Return the playground of the runnable.""" with _with_validation_error_translation(): - config = _unpack_request_config( + user_provided_config = _unpack_request_config( config_hash, keys=config_keys, model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, ) + config = _update_config_with_defaults(user_provided_config) + return await serve_playground( runnable.with_config(config), runnable.with_config(config).input_schema,