diff --git a/kor/llms/__init__.py b/kor/llms/__init__.py new file mode 100644 index 0000000..a86e311 --- /dev/null +++ b/kor/llms/__init__.py @@ -0,0 +1,10 @@ +"""Define public interface for llm wrapping package.""" +from .openai import OpenAIChatCompletion, OpenAICompletion +from .typedefs import ChatCompletionModel, CompletionModel + +__all__ = [ + "OpenAIChatCompletion", + "OpenAICompletion", + "CompletionModel", + "ChatCompletionModel", +] diff --git a/kor/llms.py b/kor/llms/openai.py similarity index 73% rename from kor/llms.py rename to kor/llms/openai.py index 65fdc07..0678a4c 100644 --- a/kor/llms.py +++ b/kor/llms/openai.py @@ -1,31 +1,32 @@ -"""Provide standard interface on tops of LLMs.""" -import abc import dataclasses import json import logging import os -import openai - -logger = logging.getLogger(__name__) +from .typedefs import CompletionModel, ChatCompletionModel +try: + import openai +except ImportError: + openai = None -@dataclasses.dataclass(kw_only=True) -class CompletionModel(abc.ABC): - """Abstract completion model interface.""" - def __call__(self, prompt: str) -> str: - """Call the model.""" - raise NotImplementedError() +logger = logging.getLogger(__name__) -@dataclasses.dataclass(kw_only=True) -class ChatCompletionModel(abc.ABC): - """Abstract chat completion model interface.""" +def _set_openai_api_key_if_needed() -> None: + """Set the openai api key if needed.""" + if not openai: + raise ImportError("Missing `openai` dependency.") - def __call__(self, messages: list[dict[str, str]]) -> str: - """Call the model.""" - raise NotImplementedError() + if not openai.api_key: + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + if "OPENAI_API_KEY" not in os.environ: + raise ValueError( + "Please include OPENAI_API_KEY in the environment or set the openai.api_key." + ) + openai.api_key = api_key @dataclasses.dataclass(kw_only=True) @@ -41,8 +42,8 @@ class OpenAICompletion(CompletionModel): top_p: float = 1.0 def __post_init__(self) -> None: - """Initialize the LLM model.""" - openai.api_key = os.environ["OPENAI_API_KEY"] + """Set credentials if needed.""" + _set_openai_api_key_if_needed() def __call__(self, prompt: str) -> str: """Invoke the LLM with the given prompt.""" @@ -76,8 +77,8 @@ class OpenAIChatCompletion(ChatCompletionModel): top_p: float = 1.0 def __post_init__(self) -> None: - """Initialize the LLM model.""" - openai.api_key = os.environ["OPENAI_API_KEY"] + """Set credentials if needed.""" + _set_openai_api_key_if_needed() def __call__(self, messages: list[dict[str, str]]) -> str: """Invoke the LLM with the given prompt.""" diff --git a/kor/llms/typedefs.py b/kor/llms/typedefs.py new file mode 100644 index 0000000..264901f --- /dev/null +++ b/kor/llms/typedefs.py @@ -0,0 +1,21 @@ +"""Provide standard interface on tops of LLMs.""" +import abc +import dataclasses + + +@dataclasses.dataclass(kw_only=True) +class CompletionModel(abc.ABC): + """Abstract completion model interface.""" + + def __call__(self, prompt: str) -> str: + """Call the model.""" + raise NotImplementedError() + + +@dataclasses.dataclass(kw_only=True) +class ChatCompletionModel(abc.ABC): + """Abstract chat completion model interface.""" + + def __call__(self, messages: list[dict[str, str]]) -> str: + """Call the model.""" + raise NotImplementedError()