Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python: follow-up to Jinja PR #3225

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gpt4all-backend/include/gpt4all-backend/llmodel_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, con

void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback);

const char *llmodel_model_chat_template(const char *model_path, const char **error);

#ifdef __cplusplus
}
#endif
Expand Down
16 changes: 14 additions & 2 deletions gpt4all-backend/src/llmodel_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ llmodel_model llmodel_model_create(const char *model_path)
return fres;
}

static void llmodel_set_error(const char **errptr, const char *message)
static void llmodel_set_error(const char **errptr, std::string message)
{
thread_local static std::string last_error_message;
if (errptr) {
last_error_message = message;
last_error_message = std::move(message);
*errptr = last_error_message.c_str();
}
}
Expand Down Expand Up @@ -318,3 +318,15 @@ void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_to
for (auto &[name, token] : wrapper->llModel->specialTokens())
callback(name.c_str(), token.c_str());
}

const char *llmodel_model_chat_template(const char *model_path, const char **error)
{
static std::string s_chatTemplate;
auto res = LLModel::Implementation::chatTemplate(model_path);
if (res) {
s_chatTemplate = *res;
return s_chatTemplate.c_str();
}
llmodel_set_error(error, std::move(res.error()));
return nullptr;
}
32 changes: 25 additions & 7 deletions gpt4all-bindings/python/gpt4all/_pyllmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import threading
from enum import Enum
from queue import Queue
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Literal, NoReturn, TypeVar, overload

if sys.version_info >= (3, 9):
import importlib.resources as importlib_resources
Expand Down Expand Up @@ -227,6 +227,9 @@ class LLModelGPUDevice(ctypes.Structure):
llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback]
llmodel.llmodel_model_foreach_special_token.restype = None

llmodel.llmodel_model_chat_template.argtypes = [ctypes.c_char_p, ctypes.POINTER(ctypes.c_char_p)]
llmodel.llmodel_model_chat_template.restype = ctypes.c_char_p

ResponseCallbackType = Callable[[int, str], bool]
RawResponseCallbackType = Callable[[int, bytes], bool]
EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]'
Expand Down Expand Up @@ -290,10 +293,7 @@ def __init__(self, model_path: str, n_ctx: int, ngl: int, backend: str):

raise RuntimeError(f"Unable to instantiate model: {errmsg}")
self.model: ctypes.c_void_p | None = model
self.special_tokens_map: dict[str, str] = {}
llmodel.llmodel_model_foreach_special_token(
self.model, lambda n, t: self.special_tokens_map.__setitem__(n.decode(), t.decode()),
)
self._special_tokens_map: dict[str, str] | None = None

def __del__(self, llmodel=llmodel):
if hasattr(self, 'model'):
Expand All @@ -320,6 +320,26 @@ def device(self) -> str | None:
dev = llmodel.llmodel_model_gpu_device_name(self.model)
return None if dev is None else dev.decode()

@property
def builtin_chat_template(self) -> str:
err = ctypes.c_char_p()
tmpl = llmodel.llmodel_model_chat_template(self.model_path, ctypes.byref(err))
if tmpl is not None:
return tmpl.decode()
s = err.value
raise ValueError('Failed to get chat template', 'null' if s is None else s.decode())

@property
def special_tokens_map(self) -> dict[str, str]:
if self.model is None:
self._raise_closed()
if self._special_tokens_map is None:
tokens: dict[str, str] = {}
cb = SpecialTokenCallback(lambda n, t: tokens.__setitem__(n.decode(), t.decode()))
llmodel.llmodel_model_foreach_special_token(self.model, cb)
self._special_tokens_map = tokens
return self._special_tokens_map

def count_prompt_tokens(self, prompt: str) -> int:
if self.model is None:
self._raise_closed()
Expand All @@ -331,8 +351,6 @@ def count_prompt_tokens(self, prompt: str) -> int:
raise RuntimeError(f'Unable to count prompt tokens: {errmsg}')
return n_tok

llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.c_char_p]

@staticmethod
def list_gpus(mem_required: int = 0) -> list[str]:
"""
Expand Down
101 changes: 77 additions & 24 deletions gpt4all-bindings/python/gpt4all/gpt4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@

ConfigType: TypeAlias = "dict[str, Any]"

# Environment setup adapted from HF transformers
@_operator_call
def _jinja_env() -> ImmutableSandboxedEnvironment:
# Environment setup adapted from HF transformers
def raise_exception(message: str) -> NoReturn:
raise jinja2.exceptions.TemplateError(message)

Expand All @@ -56,14 +56,17 @@ def strftime_now(fmt: str) -> str:
return env


class MessageType(TypedDict):
class Message(TypedDict):
"""A message in a chat with a GPT4All model."""

role: str
content: str


class ChatSession(NamedTuple):
template: jinja2.Template
history: list[MessageType]
class _ChatSession(NamedTuple):
template: jinja2.Template
template_source: str
history: list[Message]


class Embed4All:
Expand Down Expand Up @@ -193,6 +196,17 @@ class GPT4All:
Python class that handles instantiation, downloading, generation and chat with GPT4All models.
"""

RE_LEGACY_SYSPROMPT = re.compile(
r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|"
r"<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>",
re.MULTILINE,
)

RE_JINJA_LIKE = re.compile(
r"\{%.*%\}.*\{\{.*\}\}.*\{%.*%\}",
re.DOTALL,
)

def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -233,7 +247,7 @@ def __init__(
"""

self.model_type = model_type
self._chat_session: ChatSession | None = None
self._chat_session: _ChatSession | None = None

device_init = None
if sys.platform == "darwin":
Expand All @@ -260,6 +274,7 @@ def __init__(

# Retrieve model and download if allowed
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
self._was_allow_download = allow_download
self.model = LLModel(self.config["path"], n_ctx, ngl, backend)
if device_init is not None:
self.model.init_gpu(device_init)
Expand Down Expand Up @@ -291,15 +306,20 @@ def device(self) -> str | None:
return self.model.device

@property
def current_chat_session(self) -> list[MessageType] | None:
def current_chat_session(self) -> list[Message] | None:
"""The message history of the current chat session."""
return None if self._chat_session is None else self._chat_session.history

@current_chat_session.setter
def current_chat_session(self, history: list[MessageType]) -> None:
def current_chat_session(self, history: list[Message]) -> None:
if self._chat_session is None:
raise ValueError("current_chat_session may only be set when there is an active chat session")
self._chat_session.history[:] = history

@property
def current_chat_template(self) -> str | None:
return None if self._chat_session is None else self._chat_session.template_source

@staticmethod
def list_models() -> list[ConfigType]:
"""
Expand Down Expand Up @@ -569,13 +589,13 @@ def _callback_wrapper(token_id: int, response: str) -> bool:
last_msg_rendered = prompt
if self._chat_session is not None:
session = self._chat_session
def render(messages: list[MessageType]) -> str:
def render(messages: list[Message]) -> str:
return session.template.render(
messages=messages,
add_generation_prompt=True,
**self.model.special_tokens_map,
)
session.history.append(MessageType(role="user", content=prompt))
session.history.append(Message(role="user", content=prompt))
prompt = render(session.history)
if len(session.history) > 1:
last_msg_rendered = render(session.history[-1:])
Expand All @@ -590,46 +610,73 @@ def render(messages: list[MessageType]) -> str:
def stream() -> Iterator[str]:
yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs)
if self._chat_session is not None:
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
self._chat_session.history.append(Message(role="assistant", content=full_response))
return stream()

self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs)
if self._chat_session is not None:
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
self._chat_session.history.append(Message(role="assistant", content=full_response))
return full_response

@contextmanager
def chat_session(
self,
*,
system_message: str | Literal[False] | None = None,
chat_template: str | None = None,
warn_legacy: bool = True,
):
"""
Context manager to hold an inference optimized chat session with a GPT4All model.

Args:
system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
system_message: An initial instruction for the model, None to use the model default, or False to disable.
Defaults to None.
chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
"""
warn_legacy: Whether to warn about legacy system prompts or prompt templates. Defaults to True.

Raises:
ValueError: If no valid chat template was found.
"""
if system_message is None:
system_message = self.config.get("systemMessage", False)
elif system_message is not False and warn_legacy and (m := self.RE_LEGACY_SYSPROMPT.search(system_message)):
print(
"Warning: chat_session() was passed a system message that is not plain text. System messages "
f"containing {m.group()!r} or with any special prefix/suffix are no longer supported.\nTo disable this "
"warning, pass warn_legacy=False.",
file=sys.stderr,
)

if chat_template is None:
if "name" not in self.config:
raise ValueError("For sideloaded models or with allow_download=False, you must specify a chat template.")
if "chatTemplate" not in self.config:
raise NotImplementedError("This model appears to have a built-in chat template, but loading it is not "
"currently implemented. Please pass a template to chat_session() directly.")
if (tmpl := self.config["chatTemplate"]) is None:
raise ValueError(f"The model {self.config['name']!r} does not support chat.")
chat_template = tmpl
if "chatTemplate" in self.config:
if (tmpl := self.config["chatTemplate"]) is None:
raise ValueError(f"The model {self.config['name']!r} does not support chat.")
chat_template = tmpl
else:
try:
chat_template = self.model.builtin_chat_template
except ValueError as e:
if len(e.args) >= 2 and isinstance(err := e.args[1], str):
msg = (f"Failed to load default chat template from model: {err}\n"
"Please pass a template to chat_session() directly.")
if not self._was_allow_download:
msg += " If this is a built-in model, consider setting allow_download to True."
raise ValueError(msg) from None
raise
elif warn_legacy and self._is_legacy_chat_template(chat_template):
print(
"Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt "
"templates are no longer supported.\nTo disable this warning, pass warn_legacy=False.",
file=sys.stderr,
)

history = []
if system_message is not False:
history.append(MessageType(role="system", content=system_message))
self._chat_session = ChatSession(
history.append(Message(role="system", content=system_message))
self._chat_session = _ChatSession(
template=_jinja_env.from_string(chat_template),
template_source=chat_template,
history=history,
)
try:
Expand All @@ -647,6 +694,12 @@ def list_gpus() -> list[str]:
"""
return LLModel.list_gpus()

@classmethod
def _is_legacy_chat_template(cls, tmpl: str) -> bool:
# check if tmpl does not look like a Jinja template
return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl)
or not re.search(r"\bcontent\b", tmpl))


def append_extension_if_missing(model_name):
if not model_name.endswith((".bin", ".gguf")):
Expand Down
2 changes: 1 addition & 1 deletion gpt4all-bindings/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_long_description():

setup(
name=package_name,
version="2.8.3.dev0",
version="3.0.0.dev0",
description="Python bindings for GPT4All",
long_description=get_long_description(),
long_description_content_type="text/markdown",
Expand Down
Loading