-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathserver.py
206 lines (159 loc) · 6.68 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#!/usr/bin/env python3
# github.com/deadbits/vector-embedding-api
# server.py
import os
import sys
import time
import argparse
import hashlib
import logging
import configparser
import openai
from typing import Dict, List, Union, Optional
from collections import OrderedDict
from flask import Flask, request, jsonify, abort
from sentence_transformers import SentenceTransformer
app = Flask(__name__)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Config:
def __init__(self, config_file: str):
self.config_file = config_file
self.config = configparser.ConfigParser()
if not os.path.exists(self.config_file):
logging.error(f'Config file not found: {self.config_file}')
sys.exit(1)
logging.info(f'Loading config file: {self.config_file}')
self.config.read(config_file)
def get_val(self, section: str, key: str) -> Optional[str]:
answer = None
try:
answer = self.config.get(section, key)
except Exception as err:
logging.error(f'Config file missing section: {section} - {err}')
return answer
def get_bool(self, section: str, key: str, default: bool = False) -> bool:
try:
return self.config.getboolean(section, key)
except Exception as err:
logging.error(f'Failed to parse boolean - returning default "False": {section} - {err}')
return default
class EmbeddingCache:
def __init__(self, max_size: int = 500):
logger.info(f'Created in-memory cache; max size={max_size}')
self.cache = OrderedDict()
self.max_size = max_size
def get_cache_key(self, text: str, model_type: str) -> str:
return hashlib.sha256((text + model_type).encode()).hexdigest()
def get(self, text: str, model_type: str):
return self.cache.get(self.get_cache_key(text, model_type))
def set(self, text: str, model_type: str, embedding):
key = self.get_cache_key(text, model_type)
self.cache[key] = embedding
if len(self.cache) > self.max_size:
self.cache.popitem(last=False)
class EmbeddingGenerator:
def __init__(self, sbert_model: Optional[str] = None, openai_key: Optional[str] = None):
self.sbert_model = sbert_model
self.openai_key = openai_key
if self.sbert_model is not None:
try:
self.model = SentenceTransformer(self.sbert_model)
logger.info(f'enabled model: {self.sbert_model}')
except Exception as err:
logger.error(f'Failed to load SentenceTransformer model "{self.sbert_model}": {err}')
sys.exit(1)
if openai_key is not None:
openai.api_key = self.openai_key
try:
openai.Model.list()
logger.info('enabled model: text-embedding-ada-002')
except Exception as err:
logger.error(f'Failed to connect to OpenAI API; disabling OpenAI model: {err}')
def generate(self, text_batch: List[str], model_type: str) -> Dict[str, Union[str, float, list]]:
start_time = time.time()
result = {
'status': 'success',
'message': '',
'model': '',
'elapsed': 0,
'embeddings': []
}
if model_type == 'openai':
try:
response = openai.Embedding.create(input=text_batch, model='text-embedding-ada-002')
result['embeddings'] = [data['embedding'] for data in response['data']]
result['model'] = 'text-embedding-ada-002'
except Exception as err:
logger.error(f'Failed to get OpenAI embeddings: {err}')
result['status'] = 'error'
result['message'] = str(err)
else:
try:
embedding = self.model.encode(text_batch, batch_size=len(text_batch), device='cuda').tolist()
result['embeddings'] = embedding
result['model'] = self.sbert_model
except Exception as err:
logger.error(f'Failed to get sentence-transformers embeddings: {err}')
result['status'] = 'error'
result['message'] = str(err)
result['elapsed'] = (time.time() - start_time) * 1000
return result
@app.route('/health', methods=['GET'])
def health_check():
sbert_on = embedding_generator.sbert_model if embedding_generator.sbert_model else 'disabled'
openai_on = True if embedding_generator.openai_key else 'disabled'
health_status = {
"models": {
"openai": openai_on,
'sentence-transformers': sbert_on
},
"cache": {
"enabled": embedding_cache is not None,
"size": len(embedding_cache.cache) if embedding_cache else None,
"max_size": embedding_cache.max_size if embedding_cache else None
}
}
return jsonify(health_status)
@app.route('/submit', methods=['POST'])
def submit_text():
data = request.json
text_data = data.get('text')
model_type = data.get('model', 'local').lower()
if text_data is None:
abort(400, 'Missing text data to embed')
if not all(isinstance(text, str) for text in text_data):
abort(400, 'all data must be text strings')
results = []
result = embedding_generator.generate(text_data, model_type)
if embedding_cache and result['status'] == 'success':
for text, embedding in zip(text_data, result['embeddings']):
embedding_cache.set(text, model_type, embedding)
logger.info('added to cache')
results.append(result)
return jsonify(results)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-c', '--config',
help='config file',
type=str,
required=True
)
args = parser.parse_args()
conf = Config(args.config)
openai_key = conf.get_val('main', 'openai_api_key')
sbert_model = conf.get_val('main', 'sent_transformers_model')
use_cache = conf.get_bool('main', 'use_cache', default=False)
if use_cache:
max_cache_size = int(conf.get_val('main', 'cache_max'))
if openai_key is None:
logger.warn('No OpenAI API key set in configuration file: server.conf')
if sbert_model is None:
logger.warn('No transformer model set in configuration file: server.conf')
if openai_key is None and sbert_model is None:
logger.error('No sbert model set *and* no openAI key set; exiting')
sys.exit(1)
embedding_cache = EmbeddingCache(max_cache_size) if use_cache else None
embedding_generator = EmbeddingGenerator(sbert_model, openai_key)
app.run(debug=True)