From d6638b50640fed5a7a7a6b1dfda8567acb24b6f3 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 5 Dec 2024 14:35:48 -0500 Subject: [PATCH 1/4] python: load templates from model files, and add legacy template warning Signed-off-by: Jared Van Bortel --- .../include/gpt4all-backend/llmodel_c.h | 2 + gpt4all-backend/src/llmodel_c.cpp | 16 ++++- gpt4all-bindings/python/gpt4all/_pyllmodel.py | 28 ++++++-- gpt4all-bindings/python/gpt4all/gpt4all.py | 65 ++++++++++++++++--- 4 files changed, 95 insertions(+), 16 deletions(-) diff --git a/gpt4all-backend/include/gpt4all-backend/llmodel_c.h b/gpt4all-backend/include/gpt4all-backend/llmodel_c.h index 271475bae480..3f7b0851ba1c 100644 --- a/gpt4all-backend/include/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/include/gpt4all-backend/llmodel_c.h @@ -312,6 +312,8 @@ int32_t llmodel_count_prompt_tokens(llmodel_model model, const char *prompt, con void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_token_callback callback); +const char *llmodel_model_chat_template(const char *model_path, const char **error); + #ifdef __cplusplus } #endif diff --git a/gpt4all-backend/src/llmodel_c.cpp b/gpt4all-backend/src/llmodel_c.cpp index a8c5554da0d3..cb44da716952 100644 --- a/gpt4all-backend/src/llmodel_c.cpp +++ b/gpt4all-backend/src/llmodel_c.cpp @@ -34,11 +34,11 @@ llmodel_model llmodel_model_create(const char *model_path) return fres; } -static void llmodel_set_error(const char **errptr, const char *message) +static void llmodel_set_error(const char **errptr, std::string message) { thread_local static std::string last_error_message; if (errptr) { - last_error_message = message; + last_error_message = std::move(message); *errptr = last_error_message.c_str(); } } @@ -318,3 +318,15 @@ void llmodel_model_foreach_special_token(llmodel_model model, llmodel_special_to for (auto &[name, token] : wrapper->llModel->specialTokens()) callback(name.c_str(), token.c_str()); } + +const char *llmodel_model_chat_template(const char *model_path, const char **error) +{ + static std::string s_chatTemplate; + auto res = LLModel::Implementation::chatTemplate(model_path); + if (res) { + s_chatTemplate = *res; + return s_chatTemplate.c_str(); + } + llmodel_set_error(error, std::move(res.error())); + return nullptr; +} diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 616ce80a3533..c59f1dc5c694 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -227,6 +227,9 @@ class LLModelGPUDevice(ctypes.Structure): llmodel.llmodel_model_foreach_special_token.argtypes = [ctypes.c_void_p, SpecialTokenCallback] llmodel.llmodel_model_foreach_special_token.restype = None +llmodel.llmodel_model_chat_template.argtypes = [ctypes.c_char_p, ctypes.POINTER(ctypes.c_char_p)] +llmodel.llmodel_model_chat_template.restype = ctypes.c_char_p + ResponseCallbackType = Callable[[int, str], bool] RawResponseCallbackType = Callable[[int, bytes], bool] EmbCancelCallbackType: TypeAlias = 'Callable[[list[int], str], bool]' @@ -290,10 +293,7 @@ def __init__(self, model_path: str, n_ctx: int, ngl: int, backend: str): raise RuntimeError(f"Unable to instantiate model: {errmsg}") self.model: ctypes.c_void_p | None = model - self.special_tokens_map: dict[str, str] = {} - llmodel.llmodel_model_foreach_special_token( - self.model, lambda n, t: self.special_tokens_map.__setitem__(n.decode(), t.decode()), - ) + self._special_tokens_map: dict[str, str] | None = None def __del__(self, llmodel=llmodel): if hasattr(self, 'model'): @@ -320,6 +320,26 @@ def device(self) -> str | None: dev = llmodel.llmodel_model_gpu_device_name(self.model) return None if dev is None else dev.decode() + @property + def builtin_chat_template(self) -> str: + err = ctypes.c_char_p() + tmpl = llmodel.llmodel_model_chat_template(self.model_path, ctypes.byref(err)) + if tmpl is not None: + return tmpl.decode() + s = err.value + raise ValueError('Failed to get chat template', 'null' if s is None else s.decode()) + + @property + def special_tokens_map(self) -> dict[str, str]: + if self.model is None: + self._raise_closed() + if self._special_tokens_map is None: + tokens: dict[str, str] = {} + cb = SpecialTokenCallback(lambda n, t: tokens.__setitem__(n.decode(), t.decode())) + llmodel.llmodel_model_foreach_special_token(self.model, cb) + self._special_tokens_map = tokens + return self._special_tokens_map + def count_prompt_tokens(self, prompt: str) -> int: if self.model is None: self._raise_closed() diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 84b236c996dc..390eb410c7d2 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -62,8 +62,9 @@ class MessageType(TypedDict): class ChatSession(NamedTuple): - template: jinja2.Template - history: list[MessageType] + template: jinja2.Template + template_source: str + history: list[MessageType] class Embed4All: @@ -193,6 +194,16 @@ class GPT4All: Python class that handles instantiation, downloading, generation and chat with GPT4All models. """ + RE_LEGACY_SYSPROMPT = re.compile( + r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<>", + re.MULTILINE, + ) + + RE_JINJA_LIKE = re.compile( + r"\{%.*%\}.*\{\{.*\}\}.*\{%.*%\}", + re.DOTALL, + ) + def __init__( self, model_name: str, @@ -260,6 +271,7 @@ def __init__( # Retrieve model and download if allowed self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose) + self._was_allow_download = allow_download self.model = LLModel(self.config["path"], n_ctx, ngl, backend) if device_init is not None: self.model.init_gpu(device_init) @@ -300,6 +312,10 @@ def current_chat_session(self, history: list[MessageType]) -> None: raise ValueError("current_chat_session may only be set when there is an active chat session") self._chat_session.history[:] = history + @property + def current_chat_template(self) -> str | None: + return None if self._chat_session is None else self._chat_session.template_source + @staticmethod def list_models() -> list[ConfigType]: """ @@ -598,11 +614,19 @@ def stream() -> Iterator[str]: self._chat_session.history.append(MessageType(role="assistant", content=full_response)) return full_response + @classmethod + def is_legacy_chat_template(cls, tmpl: str) -> bool: + """A fairly reliable heuristic for detecting templates that don't look like Jinja templates.""" + return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl) + or not re.search(r"\bcontent\b", tmpl)) + @contextmanager def chat_session( self, + *, system_message: str | Literal[False] | None = None, chat_template: str | None = None, + warn_legacy: bool = True, ): """ Context manager to hold an inference optimized chat session with a GPT4All model. @@ -614,22 +638,43 @@ def chat_session( if system_message is None: system_message = self.config.get("systemMessage", False) + elif system_message is not False and warn_legacy and (m := self.RE_LEGACY_SYSPROMPT.search(system_message)): + print( + "Warning: chat_session() was passed a system message that is not plain text. System messages " + f"containing {m.group()!r} or with any special prefix/suffix are no longer supported.\nTo disable this " + "warning, pass warn_legacy=False.", + file=sys.stderr, + ) if chat_template is None: - if "name" not in self.config: - raise ValueError("For sideloaded models or with allow_download=False, you must specify a chat template.") - if "chatTemplate" not in self.config: - raise NotImplementedError("This model appears to have a built-in chat template, but loading it is not " - "currently implemented. Please pass a template to chat_session() directly.") - if (tmpl := self.config["chatTemplate"]) is None: - raise ValueError(f"The model {self.config['name']!r} does not support chat.") - chat_template = tmpl + if "chatTemplate" in self.config: + if (tmpl := self.config["chatTemplate"]) is None: + raise ValueError(f"The model {self.config['name']!r} does not support chat.") + chat_template = tmpl + else: + try: + chat_template = self.model.builtin_chat_template + except ValueError as e: + if len(e.args) >= 2 and isinstance(err := e.args[1], str): + msg = (f"Failed to load default chat template from model: {err}\n" + "Please pass a template to chat_session() directly.") + if not self._was_allow_download: + msg += " If this is a built-in model, consider setting allow_download to True." + raise ValueError(msg) from None + raise + elif warn_legacy and self.is_legacy_chat_template(chat_template): + print( + "Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt " + "templates are no longer supported.\nTo disable this warning, pass warn_legacy=False.", + file=sys.stderr, + ) history = [] if system_message is not False: history.append(MessageType(role="system", content=system_message)) self._chat_session = ChatSession( template=_jinja_env.from_string(chat_template), + template_source=chat_template, history=history, ) try: From 5b8682a6554999efd8771e59a27c354913d74d89 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 5 Dec 2024 14:36:26 -0500 Subject: [PATCH 2/4] python: bump major version for breaking change to chat templates Signed-off-by: Jared Van Bortel --- gpt4all-bindings/python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt4all-bindings/python/setup.py b/gpt4all-bindings/python/setup.py index b316adc0ec4b..e76fe0ac3a8d 100644 --- a/gpt4all-bindings/python/setup.py +++ b/gpt4all-bindings/python/setup.py @@ -68,7 +68,7 @@ def get_long_description(): setup( name=package_name, - version="2.8.3.dev0", + version="3.0.0.dev0", description="Python bindings for GPT4All", long_description=get_long_description(), long_description_content_type="text/markdown", From b3cc8602336aa938e8cc82fe746f5bae7767ea59 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 5 Dec 2024 17:25:57 -0500 Subject: [PATCH 3/4] remove misplaced assignment to argtypes with wrong values Signed-off-by: Jared Van Bortel --- gpt4all-bindings/python/gpt4all/_pyllmodel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index c59f1dc5c694..71f74508f3c5 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -351,8 +351,6 @@ def count_prompt_tokens(self, prompt: str) -> int: raise RuntimeError(f'Unable to count prompt tokens: {errmsg}') return n_tok - llmodel.llmodel_count_prompt_tokens.argtypes = [ctypes.c_void_p, ctypes.c_char_p] - @staticmethod def list_gpus(mem_required: int = 0) -> list[str]: """ From 260ad4b1630c3f3af788b3b0430554d80c0e87ca Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 5 Dec 2024 17:28:00 -0500 Subject: [PATCH 4/4] resolve some flake8 complaints about new code Signed-off-by: Jared Van Bortel --- gpt4all-bindings/python/gpt4all/_pyllmodel.py | 2 +- gpt4all-bindings/python/gpt4all/gpt4all.py | 54 +++++++++++-------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 71f74508f3c5..9cc480c31193 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -9,7 +9,7 @@ import threading from enum import Enum from queue import Queue -from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Literal, NoReturn, TypeVar, overload if sys.version_info >= (3, 9): import importlib.resources as importlib_resources diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 390eb410c7d2..b933052a7c45 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -37,9 +37,9 @@ ConfigType: TypeAlias = "dict[str, Any]" -# Environment setup adapted from HF transformers @_operator_call def _jinja_env() -> ImmutableSandboxedEnvironment: + # Environment setup adapted from HF transformers def raise_exception(message: str) -> NoReturn: raise jinja2.exceptions.TemplateError(message) @@ -56,15 +56,17 @@ def strftime_now(fmt: str) -> str: return env -class MessageType(TypedDict): +class Message(TypedDict): + """A message in a chat with a GPT4All model.""" + role: str content: str -class ChatSession(NamedTuple): +class _ChatSession(NamedTuple): template: jinja2.Template template_source: str - history: list[MessageType] + history: list[Message] class Embed4All: @@ -195,7 +197,8 @@ class GPT4All: """ RE_LEGACY_SYSPROMPT = re.compile( - r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<>", + r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|" + r"<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<>", re.MULTILINE, ) @@ -244,7 +247,7 @@ def __init__( """ self.model_type = model_type - self._chat_session: ChatSession | None = None + self._chat_session: _ChatSession | None = None device_init = None if sys.platform == "darwin": @@ -303,11 +306,12 @@ def device(self) -> str | None: return self.model.device @property - def current_chat_session(self) -> list[MessageType] | None: + def current_chat_session(self) -> list[Message] | None: + """The message history of the current chat session.""" return None if self._chat_session is None else self._chat_session.history @current_chat_session.setter - def current_chat_session(self, history: list[MessageType]) -> None: + def current_chat_session(self, history: list[Message]) -> None: if self._chat_session is None: raise ValueError("current_chat_session may only be set when there is an active chat session") self._chat_session.history[:] = history @@ -585,13 +589,13 @@ def _callback_wrapper(token_id: int, response: str) -> bool: last_msg_rendered = prompt if self._chat_session is not None: session = self._chat_session - def render(messages: list[MessageType]) -> str: + def render(messages: list[Message]) -> str: return session.template.render( messages=messages, add_generation_prompt=True, **self.model.special_tokens_map, ) - session.history.append(MessageType(role="user", content=prompt)) + session.history.append(Message(role="user", content=prompt)) prompt = render(session.history) if len(session.history) > 1: last_msg_rendered = render(session.history[-1:]) @@ -606,20 +610,14 @@ def render(messages: list[MessageType]) -> str: def stream() -> Iterator[str]: yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs) if self._chat_session is not None: - self._chat_session.history.append(MessageType(role="assistant", content=full_response)) + self._chat_session.history.append(Message(role="assistant", content=full_response)) return stream() self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs) if self._chat_session is not None: - self._chat_session.history.append(MessageType(role="assistant", content=full_response)) + self._chat_session.history.append(Message(role="assistant", content=full_response)) return full_response - @classmethod - def is_legacy_chat_template(cls, tmpl: str) -> bool: - """A fairly reliable heuristic for detecting templates that don't look like Jinja templates.""" - return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl) - or not re.search(r"\bcontent\b", tmpl)) - @contextmanager def chat_session( self, @@ -632,10 +630,14 @@ def chat_session( Context manager to hold an inference optimized chat session with a GPT4All model. Args: - system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None. + system_message: An initial instruction for the model, None to use the model default, or False to disable. + Defaults to None. chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None. - """ + warn_legacy: Whether to warn about legacy system prompts or prompt templates. Defaults to True. + Raises: + ValueError: If no valid chat template was found. + """ if system_message is None: system_message = self.config.get("systemMessage", False) elif system_message is not False and warn_legacy and (m := self.RE_LEGACY_SYSPROMPT.search(system_message)): @@ -662,7 +664,7 @@ def chat_session( msg += " If this is a built-in model, consider setting allow_download to True." raise ValueError(msg) from None raise - elif warn_legacy and self.is_legacy_chat_template(chat_template): + elif warn_legacy and self._is_legacy_chat_template(chat_template): print( "Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt " "templates are no longer supported.\nTo disable this warning, pass warn_legacy=False.", @@ -671,8 +673,8 @@ def chat_session( history = [] if system_message is not False: - history.append(MessageType(role="system", content=system_message)) - self._chat_session = ChatSession( + history.append(Message(role="system", content=system_message)) + self._chat_session = _ChatSession( template=_jinja_env.from_string(chat_template), template_source=chat_template, history=history, @@ -692,6 +694,12 @@ def list_gpus() -> list[str]: """ return LLModel.list_gpus() + @classmethod + def _is_legacy_chat_template(cls, tmpl: str) -> bool: + # check if tmpl does not look like a Jinja template + return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl) + or not re.search(r"\bcontent\b", tmpl)) + def append_extension_if_missing(model_name): if not model_name.endswith((".bin", ".gguf")):