Skip to content

Commit

Permalink
fix all endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
jakerachleff committed Nov 10, 2023
1 parent 971e454 commit dd25fdb
Showing 1 changed file with 43 additions and 13 deletions.
56 changes: 43 additions & 13 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def add_routes(
Favor using runnable.with_types(input_type=..., output_type=...) instead.
This parameter may get deprecated!
config_keys: list of config keys that will be accepted, by default
no config keys are accepted.
no config keys are accepted. Cannot configure run_name,
which is set by default to the path of the API
include_callback_events: Whether to include callback events in the response.
If true, the client will be able to show trace information
including events that occurred on the server side.
Expand Down Expand Up @@ -474,6 +475,12 @@ def add_routes(
f"Got an invalid path: {path}. "
f"If specifying path please start it with a `/`"
)

if "run_name" in config_keys:
raise ValueError(
f"Cannot configure run_name. "
f"Please remove it from config_keys."
)

namespace = path or ""

Expand Down Expand Up @@ -555,9 +562,14 @@ def _route_name_with_config(name: str) -> str:
InvokeResponse = create_invoke_response_model(model_namespace, output_type_)
BatchResponse = create_batch_response_model(model_namespace, output_type_)

def _get_default_base_config() -> RunnableConfig:
def _update_config_with_defaults(incomingConfig: RunnableConfig) -> RunnableConfig:
"""Set up some baseline configuration for the underlying runnable."""

# Currently all defaults are non-overridable
overridable_default_config = RunnableConfig()

metadata = {}

is_hosted = os.environ.get("HOSTED_LANGSERVE_ENABLED", "false").lower() == "true"
if is_hosted:
hosted_metadata = {
Expand All @@ -567,11 +579,21 @@ def _get_default_base_config() -> RunnableConfig:

}
metadata.update(hosted_metadata)
return RunnableConfig(

non_overridable_default_config = RunnableConfig(
run_name=path,
metadata=metadata,
)

# merge_configs is last-writer-wins, so we specifically pass in the
# overridable configs first, then the user provided configs, then
# finally the non-overridable configs
return merge_configs(
overridable_default_config,
incomingConfig,
non_overridable_default_config,
)

async def _get_config_and_input(
request: Request, config_hash: str
) -> Tuple[RunnableConfig, Any]:
Expand All @@ -584,14 +606,15 @@ async def _get_config_and_input(
body = InvokeRequestShallowValidator.validate(body)

# Merge the config from the path with the config from the body.
config = _unpack_request_config(
user_provided_config = _unpack_request_config(
config_hash,
body.config,
keys=config_keys,
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(user_provided_config)
# Unpack the input dynamically using the input schema of the runnable.
# This takes into account changes in the input type when
# using configuration.
Expand All @@ -613,9 +636,9 @@ async def invoke(
"""Invoke the runnable with the given input and config."""
# We do not use the InvokeRequest model here since configurable runnables
# have dynamic schema -- so the validation below is a bit more involved.
config, input_ = await _get_config_and_input(request, config_hash)
user_provided_config, input_ = await _get_config_and_input(request, config_hash)

config = merge_configs(_get_default_base_config(), config)
config = _update_config_with_defaults(user_provided_config)
event_aggregator = AsyncEventAggregatorCallback()
_add_tracing_info_to_metadata(config, request)
config["callbacks"] = [event_aggregator]
Expand Down Expand Up @@ -677,8 +700,7 @@ async def batch(
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
)
for config in config
) for config in config
]
elif isinstance(config, dict):
configs = _unpack_request_config(
Expand All @@ -703,6 +725,7 @@ async def batch(
{k: v for k, v in config_.items() if k in config_keys}
for config_ in get_config_list(configs, len(inputs_))
]
print(configs_)

inputs = [
_unpack_input(runnable.with_config(config_).input_schema.validate(input_))
Expand All @@ -712,11 +735,13 @@ async def batch(
# Update the configuration with callbacks
aggregators = [AsyncEventAggregatorCallback() for _ in range(len(inputs))]

final_configs = []
for config_, aggregator in zip(configs_, aggregators):
_add_tracing_info_to_metadata(config_, request)
config_["callbacks"] = [aggregator]
final_configs.append(_update_config_with_defaults(config_))

output = await runnable.abatch(inputs, config=configs_)
output = await runnable.abatch(inputs, config=final_configs)

if include_callback_events:
callback_events = [
Expand Down Expand Up @@ -952,13 +977,14 @@ async def _stream_log() -> AsyncIterator[dict]:
async def input_schema(request: Request, config_hash: str = "") -> Any:
"""Return the input schema of the runnable."""
with _with_validation_error_translation():
config = _unpack_request_config(
user_provided_config = _unpack_request_config(
config_hash,
keys=config_keys,
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(user_provided_config)

return runnable.get_input_schema(config).schema()

Expand All @@ -975,13 +1001,14 @@ async def input_schema(request: Request, config_hash: str = "") -> Any:
async def output_schema(request: Request, config_hash: str = "") -> Any:
"""Return the output schema of the runnable."""
with _with_validation_error_translation():
config = _unpack_request_config(
user_provided_config = _unpack_request_config(
config_hash,
keys=config_keys,
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(user_provided_config)
return runnable.get_output_schema(config).schema()

@app.get(
Expand All @@ -995,13 +1022,14 @@ async def output_schema(request: Request, config_hash: str = "") -> Any:
async def config_schema(request: Request, config_hash: str = "") -> Any:
"""Return the config schema of the runnable."""
with _with_validation_error_translation():
config = _unpack_request_config(
user_provided_config = _unpack_request_config(
config_hash,
keys=config_keys,
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(user_provided_config)
return runnable.with_config(config).config_schema(include=config_keys).schema()

@app.get(
Expand All @@ -1014,13 +1042,15 @@ async def playground(
) -> Any:
"""Return the playground of the runnable."""
with _with_validation_error_translation():
config = _unpack_request_config(
user_provided_config = _unpack_request_config(
config_hash,
keys=config_keys,
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(user_provided_config)

return await serve_playground(
runnable.with_config(config),
runnable.with_config(config).input_schema,
Expand Down

0 comments on commit dd25fdb

Please sign in to comment.