diff --git a/langserve/server.py b/langserve/server.py index bee2f64b..9ab15062 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -123,6 +123,45 @@ def _unpack_request_config( ) +def _update_config_with_defaults( + path: str, incomingConfig: RunnableConfig, endpoint: Optional[str] = None +) -> RunnableConfig: + """Set up some baseline configuration for the underlying runnable.""" + + # Currently all defaults are non-overridable + overridable_default_config = RunnableConfig() + + metadata = {} + + if endpoint: + metadata["LangServe Endpoint"] = endpoint + + is_hosted = os.environ.get("HOSTED_LANGSERVE_ENABLED", "false").lower() == "true" + if is_hosted: + hosted_metadata = { + "Commit SHA": os.environ.get("HOSTED_LANGSERVE_GIT_COMMIT", ""), + "Git Repo Subdirectory": os.environ.get( + "HOSTED_LANGSERVE_GIT_REPO_PATH", "" + ), + "Git Repo URL": os.environ.get("HOSTED_LANGSERVE_GIT_REPO", ""), + } + metadata.update(hosted_metadata) + + non_overridable_default_config = RunnableConfig( + run_name=path, + metadata=metadata, + ) + + # merge_configs is last-writer-wins, so we specifically pass in the + # overridable configs first, then the user provided configs, then + # finally the non-overridable configs + return merge_configs( + overridable_default_config, + incomingConfig, + non_overridable_default_config, + ) + + def _unpack_input(validated_model: BaseModel) -> Any: """Unpack the decoded input from the validated model.""" if hasattr(validated_model, "__root__"): @@ -475,11 +514,10 @@ def add_routes( f"Got an invalid path: {path}. " f"If specifying path please start it with a `/`" ) - + if "run_name" in config_keys: raise ValueError( - f"Cannot configure run_name. " - f"Please remove it from config_keys." + "Cannot configure run_name. Please remove it from config_keys." ) namespace = path or "" @@ -562,40 +600,8 @@ def _route_name_with_config(name: str) -> str: InvokeResponse = create_invoke_response_model(model_namespace, output_type_) BatchResponse = create_batch_response_model(model_namespace, output_type_) - def _update_config_with_defaults(incomingConfig: RunnableConfig) -> RunnableConfig: - """Set up some baseline configuration for the underlying runnable.""" - - # Currently all defaults are non-overridable - overridable_default_config = RunnableConfig() - - metadata = {} - - is_hosted = os.environ.get("HOSTED_LANGSERVE_ENABLED", "false").lower() == "true" - if is_hosted: - hosted_metadata = { - "Commit SHA": os.environ.get("HOSTED_LANGSERVE_GIT_COMMIT", ""), - "Git Repo Subdirectory": os.environ.get("HOSTED_LANGSERVE_GIT_REPO_PATH", ""), - "Git Repo URL": os.environ.get("HOSTED_LANGSERVE_GIT_REPO", ""), - - } - metadata.update(hosted_metadata) - - non_overridable_default_config = RunnableConfig( - run_name=path, - metadata=metadata, - ) - - # merge_configs is last-writer-wins, so we specifically pass in the - # overridable configs first, then the user provided configs, then - # finally the non-overridable configs - return merge_configs( - overridable_default_config, - incomingConfig, - non_overridable_default_config, - ) - async def _get_config_and_input( - request: Request, config_hash: str + request: Request, config_hash: str, endpoint: Optional[str] = None ) -> Tuple[RunnableConfig, Any]: """Extract the config and input from the request, validating the request.""" try: @@ -614,7 +620,7 @@ async def _get_config_and_input( request=request, per_req_config_modifier=per_req_config_modifier, ) - config = _update_config_with_defaults(user_provided_config) + config = _update_config_with_defaults(path, user_provided_config) # Unpack the input dynamically using the input schema of the runnable. # This takes into account changes in the input type when # using configuration. @@ -636,9 +642,11 @@ async def invoke( """Invoke the runnable with the given input and config.""" # We do not use the InvokeRequest model here since configurable runnables # have dynamic schema -- so the validation below is a bit more involved. - user_provided_config, input_ = await _get_config_and_input(request, config_hash) + user_provided_config, input_ = await _get_config_and_input( + request, config_hash, "invoke" + ) - config = _update_config_with_defaults(user_provided_config) + config = _update_config_with_defaults(path, user_provided_config) event_aggregator = AsyncEventAggregatorCallback() _add_tracing_info_to_metadata(config, request) config["callbacks"] = [event_aggregator] @@ -700,7 +708,8 @@ async def batch( model=ConfigPayload, request=request, per_req_config_modifier=per_req_config_modifier, - ) for config in config + ) + for config in config ] elif isinstance(config, dict): configs = _unpack_request_config( @@ -725,7 +734,6 @@ async def batch( {k: v for k, v in config_.items() if k in config_keys} for config_ in get_config_list(configs, len(inputs_)) ] - print(configs_) inputs = [ _unpack_input(runnable.with_config(config_).input_schema.validate(input_)) @@ -739,7 +747,7 @@ async def batch( for config_, aggregator in zip(configs_, aggregators): _add_tracing_info_to_metadata(config_, request) config_["callbacks"] = [aggregator] - final_configs.append(_update_config_with_defaults(config_)) + final_configs.append(_update_config_with_defaults(path, config_)) output = await runnable.abatch(inputs, config=final_configs) @@ -779,7 +787,7 @@ async def stream( err_event = {} validation_exception: Optional[BaseException] = None try: - config, input_ = await _get_config_and_input(request, config_hash) + config, input_ = await _get_config_and_input(request, config_hash, "stream") except BaseException as e: validation_exception = e if isinstance(e, RequestValidationError): @@ -984,7 +992,7 @@ async def input_schema(request: Request, config_hash: str = "") -> Any: request=request, per_req_config_modifier=per_req_config_modifier, ) - config = _update_config_with_defaults(user_provided_config) + config = _update_config_with_defaults(path, user_provided_config) return runnable.get_input_schema(config).schema() @@ -1008,7 +1016,7 @@ async def output_schema(request: Request, config_hash: str = "") -> Any: request=request, per_req_config_modifier=per_req_config_modifier, ) - config = _update_config_with_defaults(user_provided_config) + config = _update_config_with_defaults(path, user_provided_config) return runnable.get_output_schema(config).schema() @app.get( @@ -1029,7 +1037,7 @@ async def config_schema(request: Request, config_hash: str = "") -> Any: request=request, per_req_config_modifier=per_req_config_modifier, ) - config = _update_config_with_defaults(user_provided_config) + config = _update_config_with_defaults(path, user_provided_config) return runnable.with_config(config).config_schema(include=config_keys).schema() @app.get( @@ -1049,7 +1057,7 @@ async def playground( request=request, per_req_config_modifier=per_req_config_modifier, ) - config = _update_config_with_defaults(user_provided_config) + config = _update_config_with_defaults(path, user_provided_config) return await serve_playground( runnable.with_config(config),