Skip to content

Commit

Permalink
Chat completion: messages, stop, temperature, top_p, model, max_compl…
Browse files Browse the repository at this point in the history
…etion_tokens (#13)

* chore: Fix small oversights

* Add max_completion_tokens

* Remove min_tokens

* Add stop example

* Be strict with model

* Set top_p default to 1.0

* Limit temperature to 2
  • Loading branch information
vidas authored Dec 13, 2024
1 parent 83c6e00 commit 8e8b450
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 49 deletions.
9 changes: 6 additions & 3 deletions examples/openai-client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from openai import OpenAI

model = "llama"
# model = "smolm2-135m"
# model = "llama"
model = "smolm2-135m"
# model = "olmo-7b"
uri = "http://localhost:8000/v1/"

Expand All @@ -28,7 +28,10 @@
stream = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=500,
max_completion_tokens=200,
stop=["4.", "sushi"],
top_p=0.3,
# temperature=2.0,
stream=True
)

Expand Down
46 changes: 16 additions & 30 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fastapi import Depends, FastAPI, APIRouter, Request, HTTPException, status, Body
from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sse_starlette.sse import EventSourceResponse
from starlette_context.plugins import RequestIdPlugin # type: ignore
from starlette_context.middleware import RawContextMiddleware
Expand Down Expand Up @@ -119,9 +119,8 @@ def create_app(
server_settings = ServerSettings.model_validate(config_file_settings)
model_settings = config_file_settings.models

# TODO: remove settings argument altogether.
if server_settings is None and model_settings is None:
if settings is None:
settings = Settings()
server_settings = ServerSettings.model_validate(settings)
model_settings = [ModelSettings.model_validate(settings)]

Expand All @@ -133,7 +132,7 @@ def create_app(
middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))]
app = FastAPI(
middleware=middleware,
title="🦙 llama.cpp Python API",
title="NekkoAPI",
version=llama_cpp.__version__,
root_path=server_settings.root_path,
)
Expand Down Expand Up @@ -191,8 +190,8 @@ def _logit_bias_tokens_to_input_ids(
) -> Dict[str, float]:
to_bias: Dict[str, float] = {}
for token, score in logit_bias.items():
token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False, special=True):
token_encoded = token.encode("utf-8")
for input_id in llama.tokenize(token_encoded, add_bos=False, special=True):
to_bias[str(input_id)] = score
return to_bias

Expand All @@ -203,7 +202,7 @@ def _logit_bias_tokens_to_input_ids(

async def authenticate(
settings: Settings = Depends(get_server_settings),
authorization: Optional[str] = Depends(bearer_scheme),
authorization: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme),
):
# Skip API key check if it's not set in settings
if settings.api_key is None:
Expand Down Expand Up @@ -291,7 +290,6 @@ async def create_completion(
"best_of",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)

Expand All @@ -305,15 +303,6 @@ async def create_completion(
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

try:
iterator_or_completion: Union[
llama_cpp.CreateCompletionResponse,
Expand All @@ -334,7 +323,7 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
yield from iterator_or_completion
exit_stack.close()

send_chan, recv_chan = anyio.create_memory_object_stream(10)
send_chan, recv_chan = anyio.create_memory_object_stream(10) # type: ignore
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
Expand All @@ -345,7 +334,7 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
on_complete=exit_stack.close,
),
sep="\n",
ping_message_factory=_ping_message_factory,
ping_message_factory=_ping_message_factory, # type: ignore
)
else:
exit_stack.close()
Expand Down Expand Up @@ -490,8 +479,10 @@ async def create_chat_completion(
"n",
"logit_bias_type",
"user",
"min_tokens",
"max_completion_tokens",
}
# TODO: use whitelisting and only include permitted fields.
# TODO: only leave OpenAI API compatible fields.
kwargs = body.model_dump(exclude=exclude)
llama = llama_proxy(body.model)
if body.logit_bias is not None:
Expand All @@ -504,14 +495,9 @@ async def create_chat_completion(
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
# Override max_tokens with max_completion_tokens.
if body.max_completion_tokens is not None:
kwargs["max_tokens"] = body.max_completion_tokens

try:
iterator_or_completion: Union[
Expand All @@ -532,7 +518,7 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
yield from iterator_or_completion
exit_stack.close()

send_chan, recv_chan = anyio.create_memory_object_stream(10)
send_chan, recv_chan = anyio.create_memory_object_stream(10) # type: ignore
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
Expand All @@ -543,7 +529,7 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
on_complete=exit_stack.close,
),
sep="\n",
ping_message_factory=_ping_message_factory,
ping_message_factory=_ping_message_factory, # type: ignore
)
else:
exit_stack.close()
Expand Down
7 changes: 5 additions & 2 deletions llama_cpp/server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,24 @@ def __init__(self, models: List[ModelSettings]) -> None:
self._current_model: Optional[llama_cpp.Llama] = None
self._current_model_alias: Optional[str] = None

# TODO: there should be no such thing as default model.
self._default_model_settings: ModelSettings = models[0]
self._default_model_alias: str = self._default_model_settings.model_alias # type: ignore

# TODO: when _default_model is removed, what do we set as
# current model and do we load it?
# Load default model
self._current_model = self.load_llama_from_model_settings(
self._default_model_settings
)
self._current_model_alias = self._default_model_alias

def __call__(self, model: Optional[str] = None) -> llama_cpp.Llama:
def __call__(self, model: str = None) -> llama_cpp.Llama:
if model is None:
model = self._default_model_alias

if model not in self._model_settings_dict:
model = self._default_model_alias
raise ValueError(f"Model {model} not found.")

if model == self._current_model_alias:
if self._current_model is not None:
Expand Down
1 change: 1 addition & 0 deletions llama_cpp/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class ModelSettings(BaseSettings):
offload_kqv: bool = Field(
default=True, description="Whether to offload kqv to the GPU."
)
# TODO: default this to True?
flash_attn: bool = Field(
default=False, description="Whether to use flash attention."
)
Expand Down
27 changes: 14 additions & 13 deletions llama_cpp/server/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,20 @@
default=16, ge=1, description="The maximum number of tokens to generate."
)

min_tokens_field = Field(
default=0,
ge=0,
description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (e.g. max_tokens, stop).",
)

temperature_field = Field(
default=0.8,
default=1.0,
ge=0.0,
le=2.0,
description="Adjust the randomness of the generated text.\n\n"
+ "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.",
+ "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 1, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run. We recommend to alter this or top_p, but not both.",
)

top_p_field = Field(
default=0.95,
default=1.0,
ge=0.0,
le=1.0,
description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n"
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. We recomment altering this or temperature, but not both.",
)

min_p_field = Field(
Expand Down Expand Up @@ -117,7 +113,6 @@ class CreateCompletionRequest(BaseModel):
max_tokens: Optional[int] = Field(
default=16, ge=0, description="The maximum number of tokens to generate."
)
min_tokens: int = min_tokens_field
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
Expand Down Expand Up @@ -212,8 +207,13 @@ class CreateChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = Field(
default=None,
description="The maximum number of tokens to generate. Defaults to inf",
deprecated="Deprecated in favor of max_completion_tokens",
)
max_completion_tokens: Optional[int] = Field(
gt=0,
default=None,
description="An upper bound for the number of tokens that can be generated for a completion. Defaults to inf",
)
min_tokens: int = min_tokens_field
logprobs: Optional[bool] = Field(
default=False,
description="Whether to output the logprobs or not. Default is True",
Expand All @@ -236,8 +236,9 @@ class CreateChatCompletionRequest(BaseModel):
default=None,
)

model: str = model_field

# ignored or currently unsupported
model: Optional[str] = model_field
n: Optional[int] = 1
user: Optional[str] = Field(None)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "nekko_api"
dynamic = ["version"]
description = "OpenAI API compatible llama.cpp server"
readme = "README.md"
license = { text = "MIT" }
license = { text = "Apache License 2.0" }
authors = [
]
dependencies = [
Expand Down Expand Up @@ -48,6 +48,7 @@ test = [
"huggingface-hub>=0.23.0"
]
dev = [
"mypy>=1.13.0",
"black>=23.3.0",
"twine>=4.0.2",
"mkdocs>=1.4.3",
Expand Down

0 comments on commit 8e8b450

Please sign in to comment.