From 598f62a0848e29244dce6b03019588a4c1197725 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 5 Dec 2024 17:28:00 -0500 Subject: [PATCH] resolve some flak8 complaints about new code Signed-off-by: Jared Van Bortel --- gpt4all-bindings/python/gpt4all/_pyllmodel.py | 2 +- gpt4all-bindings/python/gpt4all/gpt4all.py | 54 +++++++++++-------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 71f74508f3c5b..9cc480c311933 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -9,7 +9,7 @@ import threading from enum import Enum from queue import Queue -from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Literal, NoReturn, TypeVar, overload if sys.version_info >= (3, 9): import importlib.resources as importlib_resources diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 390eb410c7d21..b933052a7c456 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -37,9 +37,9 @@ ConfigType: TypeAlias = "dict[str, Any]" -# Environment setup adapted from HF transformers @_operator_call def _jinja_env() -> ImmutableSandboxedEnvironment: + # Environment setup adapted from HF transformers def raise_exception(message: str) -> NoReturn: raise jinja2.exceptions.TemplateError(message) @@ -56,15 +56,17 @@ def strftime_now(fmt: str) -> str: return env -class MessageType(TypedDict): +class Message(TypedDict): + """A message in a chat with a GPT4All model.""" + role: str content: str -class ChatSession(NamedTuple): +class _ChatSession(NamedTuple): template: jinja2.Template template_source: str - history: list[MessageType] + history: list[Message] class Embed4All: @@ -195,7 +197,8 @@ class GPT4All: """ RE_LEGACY_SYSPROMPT = re.compile( - r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<>", + r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|" + r"<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<>", re.MULTILINE, ) @@ -244,7 +247,7 @@ def __init__( """ self.model_type = model_type - self._chat_session: ChatSession | None = None + self._chat_session: _ChatSession | None = None device_init = None if sys.platform == "darwin": @@ -303,11 +306,12 @@ def device(self) -> str | None: return self.model.device @property - def current_chat_session(self) -> list[MessageType] | None: + def current_chat_session(self) -> list[Message] | None: + """The message history of the current chat session.""" return None if self._chat_session is None else self._chat_session.history @current_chat_session.setter - def current_chat_session(self, history: list[MessageType]) -> None: + def current_chat_session(self, history: list[Message]) -> None: if self._chat_session is None: raise ValueError("current_chat_session may only be set when there is an active chat session") self._chat_session.history[:] = history @@ -585,13 +589,13 @@ def _callback_wrapper(token_id: int, response: str) -> bool: last_msg_rendered = prompt if self._chat_session is not None: session = self._chat_session - def render(messages: list[MessageType]) -> str: + def render(messages: list[Message]) -> str: return session.template.render( messages=messages, add_generation_prompt=True, **self.model.special_tokens_map, ) - session.history.append(MessageType(role="user", content=prompt)) + session.history.append(Message(role="user", content=prompt)) prompt = render(session.history) if len(session.history) > 1: last_msg_rendered = render(session.history[-1:]) @@ -606,20 +610,14 @@ def render(messages: list[MessageType]) -> str: def stream() -> Iterator[str]: yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs) if self._chat_session is not None: - self._chat_session.history.append(MessageType(role="assistant", content=full_response)) + self._chat_session.history.append(Message(role="assistant", content=full_response)) return stream() self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs) if self._chat_session is not None: - self._chat_session.history.append(MessageType(role="assistant", content=full_response)) + self._chat_session.history.append(Message(role="assistant", content=full_response)) return full_response - @classmethod - def is_legacy_chat_template(cls, tmpl: str) -> bool: - """A fairly reliable heuristic for detecting templates that don't look like Jinja templates.""" - return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl) - or not re.search(r"\bcontent\b", tmpl)) - @contextmanager def chat_session( self, @@ -632,10 +630,14 @@ def chat_session( Context manager to hold an inference optimized chat session with a GPT4All model. Args: - system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None. + system_message: An initial instruction for the model, None to use the model default, or False to disable. + Defaults to None. chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None. - """ + warn_legacy: Whether to warn about legacy system prompts or prompt templates. Defaults to True. + Raises: + ValueError: If no valid chat template was found. + """ if system_message is None: system_message = self.config.get("systemMessage", False) elif system_message is not False and warn_legacy and (m := self.RE_LEGACY_SYSPROMPT.search(system_message)): @@ -662,7 +664,7 @@ def chat_session( msg += " If this is a built-in model, consider setting allow_download to True." raise ValueError(msg) from None raise - elif warn_legacy and self.is_legacy_chat_template(chat_template): + elif warn_legacy and self._is_legacy_chat_template(chat_template): print( "Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt " "templates are no longer supported.\nTo disable this warning, pass warn_legacy=False.", @@ -671,8 +673,8 @@ def chat_session( history = [] if system_message is not False: - history.append(MessageType(role="system", content=system_message)) - self._chat_session = ChatSession( + history.append(Message(role="system", content=system_message)) + self._chat_session = _ChatSession( template=_jinja_env.from_string(chat_template), template_source=chat_template, history=history, @@ -692,6 +694,12 @@ def list_gpus() -> list[str]: """ return LLModel.list_gpus() + @classmethod + def _is_legacy_chat_template(cls, tmpl: str) -> bool: + # check if tmpl does not look like a Jinja template + return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl) + or not re.search(r"\bcontent\b", tmpl)) + def append_extension_if_missing(model_name): if not model_name.endswith((".bin", ".gguf")):