-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmllm_client.py
252 lines (223 loc) · 8.7 KB
/
mllm_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
#!/usr/bin/env python3
import argparse
import os
import subprocess
import tempfile
from anthropic import Anthropic
from openai import OpenAI
import sys, shutil
import boto3
from botocore.config import Config
import json, logging, time
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Generator, TypedDict, Literal
import hashlib
JSONSchema = Any
model_mapping: Dict[str, str] = {
'claude-sonnet': 'claude-3-5-sonnet-latest',
'claude-haiku': 'claude-3-5-haiku-latest',
'gemini': 'gemini-exp-1206',
}
@dataclass
class FunctionDefinition:
name: str
description: str
parameters: Dict[str, JSONSchema]
def to_openai_dict(self) -> Dict[str, Any]:
return {
"name": self.name,
"description": self.description,
"parameters": self.parameters
}
@dataclass
class FunctionCall:
name: str
arguments: Dict[str, Any]
class Message(TypedDict):
content: str
role: Literal['user'] | Literal['assistant']
def stream_llm_output_uncached(
messages: List[Message],
model: str = 'claude-3-5-sonnet-latest',
tokens: Optional[int] = None,
functions: Optional[List[FunctionDefinition]] = None,
function_call: Optional[str] = None,
prediction: Optional[str] = None
) -> Generator[str | FunctionCall, None, None]:
"""
Added parameters:
- functions: List of function definitions
- function_call: Optional forcing of specific function
"""
assert len(messages) > 0
model = model_mapping.get(model, model)
if model.startswith('claude'):
provider = 'anthropic'
elif model.startswith('gemini'):
provider = 'google'
elif model.startswith('bedrock'):
provider = 'bedrock'
else:
provider = 'openai'
logging.info('LLM call (model = %s, characters = %d)', model, sum( len(m['content']) for m in messages ))
start_time = time.time()
if tokens is None:
if model.startswith('o1'):
tokens = 1024 * 30
else:
tokens = 1024 * 8
if provider == 'google':
from google.generativeai import GenerativeModel
import google.generativeai as genai
gemini_api_key = open(os.path.expanduser('~/kbox/gemini-api-key')).read().strip()
genai.configure(api_key=gemini_api_key)
model_inst = GenerativeModel(model)
prompt = "\n".join(f"{m['role']}: {m['content']}" for m in messages)
response = model_inst.generate_content(prompt, stream=True)
for chunk in response:
if chunk.text:
yield chunk.text
elif provider == 'openai':
if 'OPENAI_API_KEY' not in os.environ:
os.environ['OPENAI_API_KEY'] = open(os.path.expanduser('~/kbox/openai')).read().strip().split('=')[1]
client = OpenAI()
kwargs: dict[str, Any] = {}
effort = 'medium'
if model.startswith('o1:'):
base, suffix = model.split(':', 1)
suffix = suffix.strip()
if suffix in ('low', 'medium', 'high'):
effort = suffix
model = base
if model == 'o1':
kwargs['reasoning_effort'] = effort
if functions:
kwargs["functions"] = [f.to_openai_dict() for f in functions]
if function_call:
kwargs["function_call"] = {"name": function_call}
if prediction is not None:
kwargs['prediction'] = {
"type": "content",
"content": prediction
}
else:
kwargs['max_completion_tokens'] = tokens
stream = (functions is None or len(functions) == 0) and not model.startswith('o1')
response = client.chat.completions.create( # type: ignore
model=model,
messages=messages, # type: ignore
stream=stream,
**kwargs)
if stream:
for chunk in response:
delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content is not None:
yield delta.content
else:
choice = response.choices[0].message
if choice.function_call:
function_call_data = choice.function_call
func_call = FunctionCall(
name=function_call_data.name,
arguments=json.loads(function_call_data.arguments)
)
yield func_call
else:
yield choice.content
elif provider == 'anthropic':
if 'ANTHROPIC_API_KEY' not in os.environ:
os.environ['ANTHROPIC_API_KEY'] = open(os.path.expanduser('~/kbox/claude')).read().strip()
anthropic_client = Anthropic()
kwargs = {
"model": model,
"max_tokens": tokens,
"messages": messages,
}
if functions:
kwargs["tools"] = [f.to_openai_dict() for f in functions]
stream1 = anthropic_client.messages.stream(**kwargs)
with stream1 as stream2:
for text in stream2.text_stream:
yield text
# Handle function calling response
if False and stream2.tool_calls:
for tool_call in stream2.tool_calls:
yield FunctionCall(name=tool_call.function.name,
arguments=tool_call.function.arguments)
else:
assert False, provider
logging.info('LLM call finished (model = %s, characters = %d, time = %.1f)',
model, sum( len(m['content']) for m in messages ),
time.time() - start_time)
def stream_llm_output(
messages: List[Message],
model: str = 'claude-3-5-sonnet-latest',
tokens: Optional[int] = None,
functions: Optional[List[FunctionDefinition]] = None,
function_call: Optional[str] = None,
prediction: Optional[str]=None,
) -> Generator[str | FunctionCall, None, None]:
"""
Wrapper around stream_llm_output_uncached that adds on-disk caching if MLLM_DISK_CACHE_LOCATION is set.
"""
cache_location = os.environ.get('MLLM_DISK_CACHE_LOCATION')
if cache_location:
# Generate a hash based on the inputs
cache_key_data = {
'messages': messages,
'model': model,
'tokens': tokens,
'functions': [f.to_openai_dict() for f in functions] if functions else None,
'function_call': function_call
}
if prediction is not None:
cache_key_data['prediction'] = prediction
cache_key_json = json.dumps(cache_key_data, sort_keys=True)
cache_hash = hashlib.sha256(cache_key_json.encode('utf-8')).hexdigest()
cache_file = os.path.join(cache_location, f"{cache_hash}.json")
if os.path.exists(cache_file):
# Read from cache
with open(cache_file, 'r') as f:
cached_output = json.load(f)
# Yield cached outputs
for item in cached_output:
if isinstance(item, dict) and 'FunctionCall' in item:
function_call_data = item['FunctionCall']
yield FunctionCall(**function_call_data)
else:
yield item
return
else:
# File doesn't exist in cache, need to generate and cache
outputs = []
for output in stream_llm_output_uncached(messages, model, tokens, functions, function_call, prediction):
outputs.append(output)
try:
yield output
except GeneratorExit:
break
# Save outputs to cache
# Ensure outputs are serializable
serializable_outputs = []
for item in outputs:
if isinstance(item, FunctionCall):
serializable_outputs.append({'FunctionCall': item.__dict__})
else:
serializable_outputs.append(item)
os.makedirs(cache_location, exist_ok=True)
with open(cache_file, 'w') as f:
json.dump(serializable_outputs, f)
else:
# No caching, just call the uncached function
yield from stream_llm_output_uncached(messages, model, tokens, functions, function_call, prediction)
def llm_output(*args, **kwargs):
result = []
for part in stream_llm_output(*args, **kwargs):
assert not isinstance(part, FunctionCall)
result.append(part)
return ''.join(result)
def llm_one_function_call(*args, **kwargs) -> FunctionCall:
for part in stream_llm_output(*args, **kwargs):
if isinstance(part, FunctionCall):
return part
raise Exception('no function called')