Skip to content

Commit

Permalink
Feat/zh036 (#898)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaojin3616 authored Oct 28, 2024
2 parents bef1878 + e6bf853 commit 3e6aa67
Show file tree
Hide file tree
Showing 60 changed files with 2,255 additions and 167 deletions.
4 changes: 2 additions & 2 deletions docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ services:

backend:
container_name: bisheng-backend
image: dataelement/bisheng-backend:v0.3.6.dev1
image: dataelement/bisheng-backend:v0.3.7.dev1
ports:
- "7860:7860"
environment:
Expand Down Expand Up @@ -92,7 +92,7 @@ services:

frontend:
container_name: bisheng-frontend
image: dataelement/bisheng-frontend:v0.3.6.dev1
image: dataelement/bisheng-frontend:v0.3.7.dev1
ports:
- "3001:3001"
environment:
Expand Down
Empty file removed src/backend/None
Empty file.
2 changes: 1 addition & 1 deletion src/backend/bisheng/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

try:
# 通过ci去自动修改
__version__ = '0.3.6.dev1'
__version__ = '0.3.7.dev1'
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ''
Expand Down
5 changes: 5 additions & 0 deletions src/backend/bisheng/api/errcode/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ class KnowledgeChunkError(BaseErrorCode):
class KnowledgeSimilarError(BaseErrorCode):
Code: int = 10920
Msg: str = '未配置QA知识库相似问模型'


class KnowledgeQAError(BaseErrorCode):
Code: int = 10930
Msg: str = '该问题已被标注过'
3 changes: 2 additions & 1 deletion src/backend/bisheng/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
finetune_router, flows_router, group_router, knowledge_router,
qa_router, report_router, server_router, skillcenter_router,
user_router, validate_router, variable_router, audit_router, evaluation_router,
tag_router, llm_router)
tag_router, llm_router,mark_router)
from bisheng.api.v2 import chat_router_rpc, knowledge_router_rpc, rpc_router_rpc, flow_router, assistant_router_rpc
from fastapi import APIRouter

Expand All @@ -27,6 +27,7 @@
router.include_router(evaluation_router)
router.include_router(tag_router)
router.include_router(llm_router)
router.include_router(mark_router)

router_rpc = APIRouter(prefix='/api/v2', )
router_rpc.include_router(knowledge_router_rpc)
Expand Down
142 changes: 142 additions & 0 deletions src/backend/bisheng/api/services/chat_imp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,82 @@
import asyncio
import json
# 设置 websockets 的日志级别为 NONE
import logging
from collections import defaultdict
from datetime import datetime, timedelta

from bisheng.api.v1.schemas import resp_500
from bisheng.database.base import session_getter
from bisheng.database.models.message import ChatMessage
from pydantic import BaseModel
from websockets import connect

# 维护一个连接池
connection_pool = defaultdict(asyncio.Queue)
logging.getLogger('websockets').setLevel(logging.ERROR)

expire = 600 # reids 60s 过期


class TimedQueue:

def __init__(self):
self.queue = asyncio.Queue()
self.last_active = datetime.now()

async def put_nowait(self, item):
self.last_active = datetime.now()
await self.queue.put(item)

async def get_nowait(self):
self.last_active = datetime.now()
return await self.queue.get()

def empty(self):
return self.queue.empty()

def qsize(self):
return self.queue.qsize()


async def clean_inactive_queues(queue: defaultdict, timeout_threshold: timedelta):
while True:
current_time = datetime.now()
for key, timed_queue in list(queue.items()):
# 如果队列超过设定的阈值时间没有活跃,则清理队列
if current_time - timed_queue.last_active > timeout_threshold:
while not timed_queue.empty():
timed_queue.get_nowait() # 从队列中移除任务
del queue[key] # 删除队列
await asyncio.sleep(timeout_threshold.total_seconds())


# 维护一个连接池
connection_pool = defaultdict(TimedQueue)
clean_inactive_queues(connection_pool, timedelta(minutes=5))


async def get_connection(uri, identifier):
"""
获取WebSocket连接。如果连接池中有可用的连接,则直接返回;
否则,创建新的连接并添加到连接池。
"""
if connection_pool[identifier].empty():
# 建立新的WebSocket连接
websocket = await connect(uri)

await connection_pool[identifier].put_nowait(websocket)

# 从连接池中获取连接
websocket = await connection_pool[identifier].get_nowait()
return websocket


async def release_connection(identifier, websocket):
"""
释放WebSocket连接,将其放回连接池。
"""
await connection_pool[identifier].put_nowait(websocket)


def comment_answer(message_id: int, comment: str):
Expand All @@ -9,3 +86,68 @@ def comment_answer(message_id: int, comment: str):
message.remark = comment[:4096]
session.add(message)
session.commit()


class ContentStreamResp(BaseModel):
role: str
content: str


class ChoiceStreamResp(BaseModel):
index: int = 0
delta: ContentStreamResp = 0
session_id: str

def __str__(self) -> str:
jsonData = '{"index": "%s", "delta": %s, "session_id": "%s"}' % (
self.index, json.dumps(self.delta.dict(), ensure_ascii=False), self.session_id)
return '{"choices":[%s]}\n\n' % (jsonData)


async def event_stream(
webosocket: connect,
message: str,
session_id: str,
model: str,
streaming: bool,
):

payload = {'inputs': message, 'flow_id': model, 'chat_id': session_id}
try:
await webosocket.send(json.dumps(payload, ensure_ascii=False))
except Exception as e:
yield json.dumps(resp_500(message=str(e)).__dict__)
return
sync = ''
while True:
try:
msg = await webosocket.recv()
except Exception as e:
yield json.dumps(resp_500(message=str(e)).__dict__)
break
if msg is None:
continue
# 判断msg 的类型
res = json.loads(msg)
if streaming:
if res.get('type') != 'end' and res.get('message'):
delta = ContentStreamResp(role='assistant', content=res.get('message'))
yield str(ChoiceStreamResp(index=0, session_id=session_id, delta=delta))
else:
# 通过此处控制下面的close是否发送消息
if res.get('type') == 'end':
sync = res.get('message')

if res.get('type') == 'close':
if not streaming and sync:
delta = ContentStreamResp(role='assistant', content=sync)
msg = ChoiceStreamResp(index=0,
session_id=session_id,
delta=delta,
finish_reason='stop')
yield '{"choices":[%s]}' % (json.dumps(msg.dict()))
# 释放连接
elif streaming:
yield 'data: [DONE]'
await release_connection(session_id, webosocket)
break
2 changes: 1 addition & 1 deletion src/backend/bisheng/api/services/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def parse_log_data(cls, log_data: str) -> List[Dict[str, str]]:
sub_data = {'step': None, 'loss': None}
elem = elem.strip()
elem_data = json.loads(elem)
if elem_data['loss'] is None:
if elem_data.get('loss', None) is None:
continue
sub_data['step'] = elem_data['current_steps']
sub_data['loss'] = elem_data['loss']
Expand Down
2 changes: 2 additions & 0 deletions src/backend/bisheng/api/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from bisheng.api.v1.evaluation import router as evaluation_router
from bisheng.api.v1.tag import router as tag_router
from bisheng.api.v1.llm import router as llm_router
from bisheng.api.v1.mark_task import router as mark_router

__all__ = [
'chat_router',
Expand All @@ -38,4 +39,5 @@
'audit_router',
'tag_router',
'llm_router',
'mark_router',
]
8 changes: 8 additions & 0 deletions src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,10 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
# 从tool cache中获取input信息
input_info = self.tool_cache.get(kwargs.get('run_id').hex)
if input_info:
if not self.chat_id:
# 说明是调试界面,不用持久化数据
self.tool_cache.pop(kwargs.get('run_id').hex)
return
output_info.update(input_info['input'])
intermediate_steps = f'{input_info["steps"]}\n\n{intermediate_steps}'
ChatMessageDao.insert_one(
Expand Down Expand Up @@ -541,6 +545,10 @@ async def on_tool_error(self, error: Union[Exception, KeyboardInterrupt],
await self.websocket.send_json(resp.dict())

# 保存工具调用记录
if not self.chat_id:
# 说明是调试界面,不用持久化数据
self.tool_cache.pop(kwargs.get('run_id').hex)
return
tool_name, tool_category = self.parse_tool_category(kwargs.get('name'))
self.tool_cache.pop(kwargs.get('run_id').hex)
ChatMessageDao.insert_one(
Expand Down
Loading

0 comments on commit 3e6aa67

Please sign in to comment.