Skip to content

Commit

Permalink
move method
Browse files Browse the repository at this point in the history
  • Loading branch information
jakerachleff committed Nov 13, 2023
1 parent dd25fdb commit 6827aa4
Showing 1 changed file with 55 additions and 47 deletions.
102 changes: 55 additions & 47 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,45 @@ def _unpack_request_config(
)


def _update_config_with_defaults(
path: str, incomingConfig: RunnableConfig, endpoint: Optional[str] = None
) -> RunnableConfig:
"""Set up some baseline configuration for the underlying runnable."""

# Currently all defaults are non-overridable
overridable_default_config = RunnableConfig()

metadata = {}

if endpoint:
metadata["LangServe Endpoint"] = endpoint

is_hosted = os.environ.get("HOSTED_LANGSERVE_ENABLED", "false").lower() == "true"
if is_hosted:
hosted_metadata = {
"Commit SHA": os.environ.get("HOSTED_LANGSERVE_GIT_COMMIT", ""),
"Git Repo Subdirectory": os.environ.get(
"HOSTED_LANGSERVE_GIT_REPO_PATH", ""
),
"Git Repo URL": os.environ.get("HOSTED_LANGSERVE_GIT_REPO", ""),
}
metadata.update(hosted_metadata)

non_overridable_default_config = RunnableConfig(
run_name=path,
metadata=metadata,
)

# merge_configs is last-writer-wins, so we specifically pass in the
# overridable configs first, then the user provided configs, then
# finally the non-overridable configs
return merge_configs(
overridable_default_config,
incomingConfig,
non_overridable_default_config,
)


def _unpack_input(validated_model: BaseModel) -> Any:
"""Unpack the decoded input from the validated model."""
if hasattr(validated_model, "__root__"):
Expand Down Expand Up @@ -475,11 +514,10 @@ def add_routes(
f"Got an invalid path: {path}. "
f"If specifying path please start it with a `/`"
)

if "run_name" in config_keys:
raise ValueError(
f"Cannot configure run_name. "
f"Please remove it from config_keys."
"Cannot configure run_name. Please remove it from config_keys."
)

namespace = path or ""
Expand Down Expand Up @@ -562,40 +600,8 @@ def _route_name_with_config(name: str) -> str:
InvokeResponse = create_invoke_response_model(model_namespace, output_type_)
BatchResponse = create_batch_response_model(model_namespace, output_type_)

def _update_config_with_defaults(incomingConfig: RunnableConfig) -> RunnableConfig:
"""Set up some baseline configuration for the underlying runnable."""

# Currently all defaults are non-overridable
overridable_default_config = RunnableConfig()

metadata = {}

is_hosted = os.environ.get("HOSTED_LANGSERVE_ENABLED", "false").lower() == "true"
if is_hosted:
hosted_metadata = {
"Commit SHA": os.environ.get("HOSTED_LANGSERVE_GIT_COMMIT", ""),
"Git Repo Subdirectory": os.environ.get("HOSTED_LANGSERVE_GIT_REPO_PATH", ""),
"Git Repo URL": os.environ.get("HOSTED_LANGSERVE_GIT_REPO", ""),

}
metadata.update(hosted_metadata)

non_overridable_default_config = RunnableConfig(
run_name=path,
metadata=metadata,
)

# merge_configs is last-writer-wins, so we specifically pass in the
# overridable configs first, then the user provided configs, then
# finally the non-overridable configs
return merge_configs(
overridable_default_config,
incomingConfig,
non_overridable_default_config,
)

async def _get_config_and_input(
request: Request, config_hash: str
request: Request, config_hash: str, endpoint: Optional[str] = None
) -> Tuple[RunnableConfig, Any]:
"""Extract the config and input from the request, validating the request."""
try:
Expand All @@ -614,7 +620,7 @@ async def _get_config_and_input(
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(user_provided_config)
config = _update_config_with_defaults(path, user_provided_config)
# Unpack the input dynamically using the input schema of the runnable.
# This takes into account changes in the input type when
# using configuration.
Expand All @@ -636,9 +642,11 @@ async def invoke(
"""Invoke the runnable with the given input and config."""
# We do not use the InvokeRequest model here since configurable runnables
# have dynamic schema -- so the validation below is a bit more involved.
user_provided_config, input_ = await _get_config_and_input(request, config_hash)
user_provided_config, input_ = await _get_config_and_input(
request, config_hash, "invoke"
)

config = _update_config_with_defaults(user_provided_config)
config = _update_config_with_defaults(path, user_provided_config)
event_aggregator = AsyncEventAggregatorCallback()
_add_tracing_info_to_metadata(config, request)
config["callbacks"] = [event_aggregator]
Expand Down Expand Up @@ -700,7 +708,8 @@ async def batch(
model=ConfigPayload,
request=request,
per_req_config_modifier=per_req_config_modifier,
) for config in config
)
for config in config
]
elif isinstance(config, dict):
configs = _unpack_request_config(
Expand All @@ -725,7 +734,6 @@ async def batch(
{k: v for k, v in config_.items() if k in config_keys}
for config_ in get_config_list(configs, len(inputs_))
]
print(configs_)

inputs = [
_unpack_input(runnable.with_config(config_).input_schema.validate(input_))
Expand All @@ -739,7 +747,7 @@ async def batch(
for config_, aggregator in zip(configs_, aggregators):
_add_tracing_info_to_metadata(config_, request)
config_["callbacks"] = [aggregator]
final_configs.append(_update_config_with_defaults(config_))
final_configs.append(_update_config_with_defaults(path, config_))

output = await runnable.abatch(inputs, config=final_configs)

Expand Down Expand Up @@ -779,7 +787,7 @@ async def stream(
err_event = {}
validation_exception: Optional[BaseException] = None
try:
config, input_ = await _get_config_and_input(request, config_hash)
config, input_ = await _get_config_and_input(request, config_hash, "stream")
except BaseException as e:
validation_exception = e
if isinstance(e, RequestValidationError):
Expand Down Expand Up @@ -984,7 +992,7 @@ async def input_schema(request: Request, config_hash: str = "") -> Any:
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(user_provided_config)
config = _update_config_with_defaults(path, user_provided_config)

return runnable.get_input_schema(config).schema()

Expand All @@ -1008,7 +1016,7 @@ async def output_schema(request: Request, config_hash: str = "") -> Any:
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(user_provided_config)
config = _update_config_with_defaults(path, user_provided_config)
return runnable.get_output_schema(config).schema()

@app.get(
Expand All @@ -1029,7 +1037,7 @@ async def config_schema(request: Request, config_hash: str = "") -> Any:
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(user_provided_config)
config = _update_config_with_defaults(path, user_provided_config)
return runnable.with_config(config).config_schema(include=config_keys).schema()

@app.get(
Expand All @@ -1049,7 +1057,7 @@ async def playground(
request=request,
per_req_config_modifier=per_req_config_modifier,
)
config = _update_config_with_defaults(user_provided_config)
config = _update_config_with_defaults(path, user_provided_config)

return await serve_playground(
runnable.with_config(config),
Expand Down

0 comments on commit 6827aa4

Please sign in to comment.