Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "add default configs for all served runnables" #248

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 30 additions & 83 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""
import contextlib
import json
import os
import re
import weakref
from inspect import isclass
Expand Down Expand Up @@ -133,56 +132,6 @@ def _unpack_request_config(
)


def _update_config_with_defaults(
path: str,
incomingConfig: RunnableConfig,
request: Request,
*,
endpoint: Optional[str] = None,
) -> RunnableConfig:
"""Set up some baseline configuration for the underlying runnable."""

# Currently all defaults are non-overridable
overridable_default_config = RunnableConfig()

metadata = {
"__useragent": request.headers.get("user-agent"),
"__langserve_version": __version__,
}

if endpoint:
metadata["__langserve_endpoint"] = endpoint

is_hosted = os.environ.get("HOSTED_LANGSERVE_ENABLED", "false").lower() == "true"
if is_hosted:
hosted_metadata = {
"__langserve_hosted_git_commit_sha": os.environ.get(
"HOSTED_LANGSERVE_GIT_COMMIT", ""
),
"__langserve_hosted_repo_subdirectory_path": os.environ.get(
"HOSTED_LANGSERVE_GIT_REPO_PATH", ""
),
"__langserve_hosted_repo_url": os.environ.get(
"HOSTED_LANGSERVE_GIT_REPO", ""
),
}
metadata.update(hosted_metadata)

non_overridable_default_config = RunnableConfig(
run_name=path,
metadata=metadata,
)

# merge_configs is last-writer-wins, so we specifically pass in the
# overridable configs first, then the user provided configs, then
# finally the non-overridable configs
return merge_configs(
overridable_default_config,
incomingConfig,
non_overridable_default_config,
)


def _unpack_input(validated_model: BaseModel) -> Any:
"""Unpack the decoded input from the validated model."""
if hasattr(validated_model, "__root__"):
Expand Down Expand Up @@ -283,6 +232,24 @@ def _add_namespace_to_model(namespace: str, model: Type[BaseModel]) -> Type[Base
return model_with_unique_name


def _add_tracing_info_to_metadata(config: Dict[str, Any], request: Request) -> None:
"""Add information useful for tracing and debugging purposes.

Args:
config: The config to expand with tracing information.
request: The request to use for expanding the metadata.
"""

metadata = config["metadata"] if "metadata" in config else {}

info = {
"__useragent": request.headers.get("user-agent"),
"__langserve_version": __version__,
}
metadata.update(info)
config["metadata"] = metadata


def _scrub_exceptions_in_event(event: CallbackEventDict) -> CallbackEventDict:
"""Scrub exceptions and change to a serializable format."""
type_ = event["type"]
Expand Down Expand Up @@ -500,8 +467,7 @@ def add_routes(
This parameter may get deprecated!
config_keys: list of config keys that will be accepted, by default
will accept `configurable` key in the config. Will only be used
if the runnable is configurable. Cannot configure run_name,
which is set by default to the path of the API.
if the runnable is configurable.
include_callback_events: Whether to include callback events in the response.
If true, the client will be able to show trace information
including events that occurred on the server side.
Expand Down Expand Up @@ -537,11 +503,6 @@ def add_routes(
f"If specifying path please start it with a `/`"
)

if "run_name" in config_keys:
raise ValueError(
"Cannot configure run_name. Please remove it from config_keys."
)

namespace = path or ""

model_namespace = _replace_non_alphanumeric_with_underscores(path.strip("/"))
Expand Down Expand Up @@ -653,7 +614,7 @@ def _route_name_with_config(name: str) -> str:
BatchResponse = create_batch_response_model(model_namespace, output_type_)

async def _get_config_and_input(
request: Request, config_hash: str, *, endpoint: Optional[str] = None
request: Request, config_hash: str
) -> Tuple[RunnableConfig, Any]:
"""Extract the config and input from the request, validating the request."""
try:
Expand All @@ -664,17 +625,14 @@ async def _get_config_and_input(
body = InvokeRequestShallowValidator.validate(body)

# Merge the config from the path with the config from the body.
user_provided_config = _unpack_request_config(
config = _unpack_request_config(
config_hash,
body.config,
config_keys=config_keys,
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(
path, user_provided_config, request, endpoint=endpoint
)
# Unpack the input dynamically using the input schema of the runnable.
# This takes into account changes in the input type when
# using configuration.
Expand All @@ -696,11 +654,10 @@ async def invoke(
"""Invoke the runnable with the given input and config."""
# We do not use the InvokeRequest model here since configurable runnables
# have dynamic schema -- so the validation below is a bit more involved.
config, input_ = await _get_config_and_input(
request, config_hash, endpoint="invoke"
)
config, input_ = await _get_config_and_input(request, config_hash)

event_aggregator = AsyncEventAggregatorCallback()
_add_tracing_info_to_metadata(config, request)
config["callbacks"] = [event_aggregator]
output = await runnable.ainvoke(input_, config=config)

Expand Down Expand Up @@ -795,14 +752,11 @@ async def batch(
# Update the configuration with callbacks
aggregators = [AsyncEventAggregatorCallback() for _ in range(len(inputs))]

final_configs = []
for config_, aggregator in zip(configs_, aggregators):
_add_tracing_info_to_metadata(config_, request)
config_["callbacks"] = [aggregator]
final_configs.append(
_update_config_with_defaults(path, config_, request, endpoint="batch")
)

output = await runnable.abatch(inputs, config=final_configs)
output = await runnable.abatch(inputs, config=configs_)

if include_callback_events:
callback_events = [
Expand Down Expand Up @@ -840,9 +794,7 @@ async def stream(
err_event = {}
validation_exception: Optional[BaseException] = None
try:
config, input_ = await _get_config_and_input(
request, config_hash, endpoint="stream"
)
config, input_ = await _get_config_and_input(request, config_hash)
except BaseException as e:
validation_exception = e
if isinstance(e, RequestValidationError):
Expand Down Expand Up @@ -1040,14 +992,13 @@ async def _stream_log() -> AsyncIterator[dict]:
async def input_schema(request: Request, config_hash: str = "") -> Any:
"""Return the input schema of the runnable."""
with _with_validation_error_translation():
user_provided_config = _unpack_request_config(
config = _unpack_request_config(
config_hash,
config_keys=config_keys,
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(path, user_provided_config, request)

return runnable.get_input_schema(config).schema()

Expand All @@ -1064,14 +1015,13 @@ async def input_schema(request: Request, config_hash: str = "") -> Any:
async def output_schema(request: Request, config_hash: str = "") -> Any:
"""Return the output schema of the runnable."""
with _with_validation_error_translation():
user_provided_config = _unpack_request_config(
config = _unpack_request_config(
config_hash,
config_keys=config_keys,
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(path, user_provided_config, request)
return runnable.get_output_schema(config).schema()

@app.get(
Expand All @@ -1085,14 +1035,13 @@ async def output_schema(request: Request, config_hash: str = "") -> Any:
async def config_schema(request: Request, config_hash: str = "") -> Any:
"""Return the config schema of the runnable."""
with _with_validation_error_translation():
user_provided_config = _unpack_request_config(
config = _unpack_request_config(
config_hash,
config_keys=config_keys,
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(path, user_provided_config, request)
return runnable.with_config(config).config_schema(include=config_keys).schema()

@app.get(
Expand All @@ -1105,16 +1054,14 @@ async def playground(
) -> Any:
"""Return the playground of the runnable."""
with _with_validation_error_translation():
user_provided_config = _unpack_request_config(
config = _unpack_request_config(
config_hash,
config_keys=config_keys,
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
)

config = _update_config_with_defaults(path, user_provided_config, request)

if isinstance(app, FastAPI): # type: ignore
base_url = f"{namespace}/playground"
else:
Expand Down
7 changes: 1 addition & 6 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ async def add_one(x: int) -> int:
server_runnable2,
input_type=int,
path="/add_one_config",
config_keys=["tags", "metadata"],
config_keys=["tags", "run_name", "metadata"],
)

async with get_async_remote_runnable(
Expand All @@ -909,11 +909,8 @@ async def add_one(x: int) -> int:
# will still be added
config_seen = server_runnable_spy.call_args[0][1]
assert "metadata" in config_seen
assert "a" not in config_seen["metadata"]
assert "__useragent" in config_seen["metadata"]
assert "__langserve_version" in config_seen["metadata"]
assert "__langserve_endpoint" in config_seen["metadata"]
assert config_seen["metadata"]["__langserve_endpoint"] == "invoke"

server_runnable2_spy = mocker.spy(server_runnable2, "ainvoke")
async with get_async_remote_runnable(app, path="/add_one_config") as runnable2:
Expand All @@ -926,8 +923,6 @@ async def add_one(x: int) -> int:
assert config_seen["metadata"]["a"] == 5
assert "__useragent" in config_seen["metadata"]
assert "__langserve_version" in config_seen["metadata"]
assert "__langserve_endpoint" in config_seen["metadata"]
assert config_seen["metadata"]["__langserve_endpoint"] == "invoke"


async def test_input_validation_with_lc_types(event_loop: AbstractEventLoop) -> None:
Expand Down