Skip to content

Commit

Permalink
Update LLM interface (#5)
Browse files Browse the repository at this point in the history
Update LLM interface organization
  • Loading branch information
eyurtsev authored Mar 9, 2023
1 parent ba360b7 commit 94fd6e1
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 21 deletions.
10 changes: 10 additions & 0 deletions kor/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Define public interface for llm wrapping package."""
from .openai import OpenAIChatCompletion, OpenAICompletion
from .typedefs import ChatCompletionModel, CompletionModel

__all__ = [
"OpenAIChatCompletion",
"OpenAICompletion",
"CompletionModel",
"ChatCompletionModel",
]
43 changes: 22 additions & 21 deletions kor/llms.py → kor/llms/openai.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,32 @@
"""Provide standard interface on tops of LLMs."""
import abc
import dataclasses
import json
import logging
import os

import openai

logger = logging.getLogger(__name__)
from .typedefs import CompletionModel, ChatCompletionModel

try:
import openai
except ImportError:
openai = None

@dataclasses.dataclass(kw_only=True)
class CompletionModel(abc.ABC):
"""Abstract completion model interface."""

def __call__(self, prompt: str) -> str:
"""Call the model."""
raise NotImplementedError()
logger = logging.getLogger(__name__)


@dataclasses.dataclass(kw_only=True)
class ChatCompletionModel(abc.ABC):
"""Abstract chat completion model interface."""
def _set_openai_api_key_if_needed() -> None:
"""Set the openai api key if needed."""
if not openai:
raise ImportError("Missing `openai` dependency.")

def __call__(self, messages: list[dict[str, str]]) -> str:
"""Call the model."""
raise NotImplementedError()
if not openai.api_key:
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"Please include OPENAI_API_KEY in the environment or set the openai.api_key."
)
openai.api_key = api_key


@dataclasses.dataclass(kw_only=True)
Expand All @@ -41,8 +42,8 @@ class OpenAICompletion(CompletionModel):
top_p: float = 1.0

def __post_init__(self) -> None:
"""Initialize the LLM model."""
openai.api_key = os.environ["OPENAI_API_KEY"]
"""Set credentials if needed."""
_set_openai_api_key_if_needed()

def __call__(self, prompt: str) -> str:
"""Invoke the LLM with the given prompt."""
Expand Down Expand Up @@ -76,8 +77,8 @@ class OpenAIChatCompletion(ChatCompletionModel):
top_p: float = 1.0

def __post_init__(self) -> None:
"""Initialize the LLM model."""
openai.api_key = os.environ["OPENAI_API_KEY"]
"""Set credentials if needed."""
_set_openai_api_key_if_needed()

def __call__(self, messages: list[dict[str, str]]) -> str:
"""Invoke the LLM with the given prompt."""
Expand Down
21 changes: 21 additions & 0 deletions kor/llms/typedefs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Provide standard interface on tops of LLMs."""
import abc
import dataclasses


@dataclasses.dataclass(kw_only=True)
class CompletionModel(abc.ABC):
"""Abstract completion model interface."""

def __call__(self, prompt: str) -> str:
"""Call the model."""
raise NotImplementedError()


@dataclasses.dataclass(kw_only=True)
class ChatCompletionModel(abc.ABC):
"""Abstract chat completion model interface."""

def __call__(self, messages: list[dict[str, str]]) -> str:
"""Call the model."""
raise NotImplementedError()

0 comments on commit 94fd6e1

Please sign in to comment.