Skip to content

Commit

Permalink
Close #28 - Implement model selection
Browse files Browse the repository at this point in the history
Implement model selection for current sources. The models can be set
with 'model' keys for each source.
  • Loading branch information
w0rp committed Sep 19, 2023
1 parent f8b9a4d commit 07713e2
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
.coverage
.directory
.tox
.tox-docker
/env
__pycache__
tags
14 changes: 8 additions & 6 deletions autoload/neural/config.vim
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ let s:defaults = {
\ 'source': {
\ 'openai': {
\ 'api_key': '',
\ 'temperature': 0.2,
\ 'top_p': 1,
\ 'frequency_penalty': 0.1,
\ 'max_tokens': 1024,
\ 'model': 'text-davinci-003',
\ 'presence_penalty': 0.1,
\ 'frequency_penalty': 0.1,
\ 'temperature': 0.2,
\ 'top_p': 1,
\ },
\ 'chatgpt': {
\ 'api_key': '',
\ 'temperature': 0.2,
\ 'top_p': 1,
\ 'frequency_penalty': 0.1,
\ 'max_tokens': 2048,
\ 'model': 'gpt-3.5-turbo',
\ 'presence_penalty': 0.1,
\ 'frequency_penalty': 0.1,
\ 'temperature': 0.2,
\ 'top_p': 1,
\ },
\ },
\}
Expand Down
18 changes: 18 additions & 0 deletions doc/neural.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,15 @@ g:neural.source.openai.max_tokens *g:neural.source.openai.max_tokens*
This translates to roughly `¾` of a word (e.g. `100 tokens ~= 75 words`).


g:neural.source.openai.model *g:neural.source.openai.model*
*vim.g.neural.source.openai.model*
Type: |String|
Default: `'text-davinci-003'`

The model to use for OpenAI. Please consult OpenAI's documentation for more
information on models: https://platform.openai.com/docs/models/overview


g:neural.source.openai.presence_penalty
*g:neural.source.openai.presence_penalty*
*vim.g.neural.source.openai.presence_penalty*
Expand Down Expand Up @@ -289,6 +298,15 @@ g:neural.source.chatgpt.max_tokens *g:neural.source.chatgpt.max_tokens*
See Also: |g:neural.source.openai.max_tokens|


g:neural.source.chatgpt.model *g:neural.source.chatgpt.model*
*vim.g.neural.source.chatgpt.model*
Type: |String|
Default: `'gpt-3.5-turbo'`

The model to use for ChatGPT. Please consult OpenAI's documentation for more
information on models: https://platform.openai.com/docs/models/overview


g:neural.source.chatgpt.presence_penalty
*g:neural.source.chatgpt.presence_penalty*
*vim.g.neural.source.chatgpt.presence_penalty*
Expand Down
10 changes: 9 additions & 1 deletion neural_sources/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ class Config:
def __init__(
self,
api_key: str,
model: str,
temperature: float,
top_p: float,
max_tokens: int,
presence_penalty: float,
frequency_penalty: float,
):
self.api_key = api_key
self.model = model
self.temperature = temperature
self.top_p = top_p
self.max_tokens = max_tokens
Expand All @@ -43,7 +45,7 @@ def get_chatgpt_completion(
"Authorization": f"Bearer {config.api_key}"
}
data = {
"model": "gpt-3.5-turbo",
"model": config.model,
"messages": (
[{"role": "user", "content": prompt}]
if isinstance(prompt, str) else
Expand Down Expand Up @@ -100,6 +102,11 @@ def load_config(raw_config: Dict[str, Any]) -> Config:
if not isinstance(api_key, str) or not api_key: # type: ignore
raise ValueError("chatgpt.api_key is not defined")

model = raw_config.get('model')

if not isinstance(model, str) or not model:
raise ValueError("chatgpt.model is not defined")

temperature = raw_config.get('temperature', 0.2)

if not isinstance(temperature, (int, float)):
Expand Down Expand Up @@ -127,6 +134,7 @@ def load_config(raw_config: Dict[str, Any]) -> Config:

return Config(
api_key=api_key,
model=model,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
Expand Down
10 changes: 9 additions & 1 deletion neural_sources/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ class Config:
def __init__(
self,
api_key: str,
model: str,
temperature: float,
top_p: float,
max_tokens: int,
presence_penalty: float,
frequency_penalty: float,
):
self.api_key = api_key
self.model = model
self.temperature = temperature
self.top_p = top_p
self.max_tokens = max_tokens
Expand All @@ -40,7 +42,7 @@ def get_openai_completion(config: Config, prompt: str) -> None:
"Authorization": f"Bearer {config.api_key}"
}
data = {
"model": "text-davinci-003",
"model": config.model,
"prompt": prompt,
"temperature": config.temperature,
"max_tokens": config.max_tokens,
Expand Down Expand Up @@ -88,6 +90,11 @@ def load_config(raw_config: Dict[str, Any]) -> Config:
if not isinstance(api_key, str) or not api_key: # type: ignore
raise ValueError("openai.api_key is not defined")

model = raw_config.get('model')

if not isinstance(model, str) or not model:
raise ValueError("openai.model is not defined")

temperature = raw_config.get('temperature', 0.2)

if not isinstance(temperature, (int, float)):
Expand Down Expand Up @@ -115,6 +122,7 @@ def load_config(raw_config: Dict[str, Any]) -> Config:

return Config(
api_key=api_key,
model=model,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
Expand Down
5 changes: 4 additions & 1 deletion test/python/test_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def get_valid_config() -> Dict[str, Any]:
return {
"api_key": ".",
"model": "foo",
"prompt": "say hello",
"temperature": 1,
"top_p": 1,
Expand All @@ -34,8 +35,10 @@ def test_load_config_errors():
for modification, expected_error in [
({}, "chatgpt.api_key is not defined"),
({"api_key": ""}, "chatgpt.api_key is not defined"),
({"api_key": "."}, "chatgpt.model is not defined"),
({"model": ""}, "chatgpt.model is not defined"),
(
{"api_key": ".", "temperature": "x"},
{"model": "x", "temperature": "x"},
"chatgpt.temperature is invalid"
),
(
Expand Down
5 changes: 4 additions & 1 deletion test/python/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def get_valid_config() -> Dict[str, Any]:
return {
"api_key": ".",
"model": "foo",
"prompt": "say hello",
"temperature": 1,
"top_p": 1,
Expand All @@ -34,8 +35,10 @@ def test_load_config_errors():
for modification, expected_error in [
({}, "openai.api_key is not defined"),
({"api_key": ""}, "openai.api_key is not defined"),
({"api_key": "."}, "openai.model is not defined"),
({"model": ""}, "openai.model is not defined"),
(
{"api_key": ".", "temperature": "x"},
{"model": "x", "temperature": "x"},
"openai.temperature is invalid"
),
(
Expand Down
10 changes: 9 additions & 1 deletion test/script/run-python
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@ echo '========================================'
echo 'tox warnings/errors follow:'
echo

tox_command='tox --workdir .tox-docker --skip-missing-interpreters=false'
# Use a different tox directory if running outside of docker to avoid issues
# with the docker mounted directory writing as root, and other differences.
if [ "$NO_DOCKER" -eq 1 ]; then
tox_dir=.tox
else
tox_dir=.tox-docker
fi

tox_command="tox --workdir $tox_dir --skip-missing-interpreters=false"

set -o pipefail

Expand Down
22 changes: 19 additions & 3 deletions test/vim/test_config.vader
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,30 @@ Execute(The default openai settings should be correct):
AssertEqual
\ {
\ 'api_key': '',
\ 'temperature': 0.2,
\ 'top_p': 1,
\ 'frequency_penalty': 0.1,
\ 'max_tokens': 1024,
\ 'model': 'text-davinci-003',
\ 'presence_penalty': 0.1,
\ 'frequency_penalty': 0.1,
\ 'temperature': 0.2,
\ 'top_p': 1,
\ },
\ get(g:neural.source, 'openai')

Execute(The default chatgpt settings should be correct):
call neural#config#Load()

AssertEqual
\ {
\ 'api_key': '',
\ 'frequency_penalty': 0.1,
\ 'max_tokens': 2048,
\ 'model': 'gpt-3.5-turbo',
\ 'presence_penalty': 0.1,
\ 'temperature': 0.2,
\ 'top_p': 1,
\ },
\ get(g:neural.source, 'chatgpt')

Execute(Settings should be merged correctly):
for s:i in range(2)
if s:i == 0
Expand Down

0 comments on commit 07713e2

Please sign in to comment.