diff --git a/src/backend/bisheng/api/v1/callback.py b/src/backend/bisheng/api/v1/callback.py index 368881a25..e7953f448 100644 --- a/src/backend/bisheng/api/v1/callback.py +++ b/src/backend/bisheng/api/v1/callback.py @@ -1,6 +1,7 @@ import asyncio import copy import json +from queue import Queue from typing import Any, Dict, List, Union from bisheng.api.v1.schemas import ChatResponse @@ -19,7 +20,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): """Callback handler for streaming LLM responses.""" - def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None): + def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any): self.websocket = websocket self.flow_id = flow_id self.chat_id = chat_id @@ -385,6 +386,21 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any: class AsyncGptsDebugCallbackHandler(AsyncGptsLLMCallbackHandler): + def __init__(self, websocket: WebSocket, flow_id: str, chat_id: str, user_id: int = None, **kwargs: Any): + super().__init__(websocket, flow_id, chat_id, user_id, **kwargs) + self.stream_queue: Queue = kwargs.get('stream_queue') + + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + logger.debug(f'on_llm_new_token token={token} kwargs={kwargs}') + resp = ChatResponse(message=token, + type='stream', + flow_id=self.flow_id, + chat_id=self.chat_id) + + # 将流式输出内容放入到队列内,以方便中断流式输出后,可以将内容记录到数据库 + await self.websocket.send_json(resp.dict()) + self.stream_queue.put(token) + @staticmethod def parse_tool_category(tool_name) -> (str, str): """ diff --git a/src/backend/bisheng/chat/client.py b/src/backend/bisheng/chat/client.py index cfd8e7cb6..3aff9599c 100644 --- a/src/backend/bisheng/chat/client.py +++ b/src/backend/bisheng/chat/client.py @@ -1,12 +1,10 @@ import json -import os -import time -from typing import Dict +from typing import Dict, Callable from uuid import UUID, uuid4 +from queue import Queue from loguru import logger from langchain_core.messages import AIMessage, HumanMessage -from langchain.tools.render import format_tool_to_openai_tool from fastapi import WebSocket, status, Request from bisheng.api.services.assistant_agent import AssistantAgent @@ -20,6 +18,7 @@ from bisheng.database.models.message import ChatMessageDao from bisheng.settings import settings from bisheng.api.utils import get_request_ip +from bisheng.utils.threadpool import ThreadPoolManager, thread_pool class ChatClient: @@ -42,6 +41,10 @@ def __init__(self, request: Request, client_key: str, client_id: str, chat_id: s # 和模型对话时传入的 完整的历史对话轮数 self.latest_history_num = 5 self.gpts_conf = settings.get_from_db('gpts') + # 异步任务列表 + self.task_ids = [] + # 流式输出的队列,用来接受流式输出的内容 + self.stream_queue = Queue() async def send_message(self, message: str): await self.websocket.send_text(message) @@ -51,15 +54,33 @@ async def send_json(self, message: ChatMessage): async def handle_message(self, message: Dict[any, any]): trace_id = uuid4().hex + logger.info(f'client_id={self.client_key} trace_id={trace_id} message={message}') with logger.contextualize(trace_id=trace_id): - # 处理客户端发过来的信息 + # 处理客户端发过来的信息, 提交到线程池内执行 if self.work_type == WorkType.GPTS: - await self.handle_gpts_message(message) + thread_pool.submit(trace_id, + self.wrapper_task, + trace_id, + self.handle_gpts_message, + message, + trace_id=trace_id) + # await self.handle_gpts_message(message) - async def add_message(self, msg_type: str, message: str, category: str): + async def wrapper_task(self, task_id: str, fn: Callable, *args, **kwargs): + # 包装处理函数为异步任务 + self.task_ids.append(task_id) + try: + # 执行处理函数 + await fn(*args, **kwargs) + finally: + # 执行完成后将任务id从列表移除 + self.task_ids.remove(task_id) + + async def add_message(self, msg_type: str, message: str, category: str, remark: str = ''): self.chat_history.append({ 'category': category, - 'message': message + 'message': message, + 'remark': remark }) if not self.chat_id: # debug模式无需保存历史 @@ -75,6 +96,7 @@ async def add_message(self, msg_type: str, message: str, category: str): flow_id=self.client_id, chat_id=self.chat_id, user_id=self.user_id, + remark=remark, )) # 记录审计日志, 是新建会话 if len(self.chat_history) <= 1: @@ -142,7 +164,8 @@ async def init_chat_history(self): for one in res: self.chat_history.append({ 'message': one.message, - 'category': one.category + 'category': one.category, + 'remark': one.remark }) async def get_latest_history(self): @@ -152,13 +175,15 @@ async def get_latest_history(self): is_answer = True # 从聊天历史里获取 for i in range(len(self.chat_history) - 1, -1, -1): + one_item = self.chat_history[i] if find_i >= self.latest_history_num: break - if self.chat_history[i]['category'] == 'answer' and is_answer: - tmp.insert(0, AIMessage(content=self.chat_history[i]['message'])) + # 不包含中断的答案 + if one_item['category'] == 'answer' and one_item.get('remark') != 'break_answer' and is_answer: + tmp.insert(0, AIMessage(content=one_item['message'])) is_answer = False - elif self.chat_history[i]['category'] == 'question' and not is_answer: - tmp.insert(0, HumanMessage(content=json.loads(self.chat_history[i]['message'])['input'])) + elif one_item['category'] == 'question' and not is_answer: + tmp.insert(0, HumanMessage(content=json.loads(one_item['message'])['input'])) is_answer = True find_i += 1 @@ -171,16 +196,36 @@ async def init_gpts_callback(self): 'websocket': self.websocket, 'flow_id': self.client_id, 'chat_id': self.chat_id, - 'user_id': self.user_id + 'user_id': self.user_id, + 'stream_queue': self.stream_queue, })] self.gpts_async_callback = async_callbacks + async def stop_handle_message(self, message: Dict[any, any]): + # 中止流式输出, 因为最新的任务id是中止任务的id,不能取消自己 + logger.info(f'need stop agent, client_key: {self.client_key}, message: {message}') + + # 中止之前的处理函数 + thread_pool.cancel_task(self.task_ids[:-1]) + + # 将流式输出的内容写到数据库内 + answer = '' + while not self.stream_queue.empty(): + msg = self.stream_queue.get() + answer += msg + + # 有流式输出内容的话,记录流式输出内容到数据库 + if answer.strip(): + res = await self.add_message('bot', answer, 'answer', 'break_answer') + await self.send_response('answer', 'end', answer, message_id=res.id if res else None) + await self.send_response('processing', 'close', '') + async def handle_gpts_message(self, message: Dict[any, any]): if not message: return logger.debug(f'receive client message, client_key: {self.client_key} message: {message}') if message.get('action') == 'stop': - logger.info(f'need stop agent, client_key: {self.client_key}, message: {message}') + await self.stop_handle_message(message) return inputs = message.get('inputs', {}) diff --git a/src/backend/bisheng/database/models/message.py b/src/backend/bisheng/database/models/message.py index c4b21d15e..a072ee09d 100644 --- a/src/backend/bisheng/database/models/message.py +++ b/src/backend/bisheng/database/models/message.py @@ -26,7 +26,8 @@ class MessageBase(SQLModelSerializable): receiver: Optional[Dict] = Field(index=False, default=None, description='autogen 的发送方') intermediate_steps: Optional[str] = Field(sa_column=Column(Text), description='过程日志') files: Optional[str] = Field(sa_column=Column(String(length=4096)), description='上传的文件等') - remark: Optional[str] = Field(sa_column=Column(String(length=4096)), description='备注') + remark: Optional[str] = Field(sa_column=Column(String(length=4096)), + description='备注。break_answer: 中断的回复不作为history传给模型') create_time: Optional[datetime] = Field( sa_column=Column(DateTime, nullable=False, server_default=text('CURRENT_TIMESTAMP'))) update_time: Optional[datetime] = Field( @@ -88,11 +89,19 @@ def get_latest_message_by_chatid(cls, chat_id: str): @classmethod def get_latest_message_by_chat_ids(cls, chat_ids: list[str], category: str = None): - statement = select(ChatMessage).where(ChatMessage.chat_id.in_(chat_ids)) + """ + 获取每个会话最近的一次消息内容 + """ + statement = select(ChatMessage.chat_id, func.max(ChatMessage.id)).where(ChatMessage.chat_id.in_(chat_ids)) if category: statement = statement.where(ChatMessage.category == category) - statement = statement.order_by(ChatMessage.create_time.desc()).limit(1) + statement = statement.group_by(ChatMessage.chat_id) with session_getter() as session: + # 获取最新的id列表 + res = session.exec(statement).all() + ids = [one[1] for one in res] + # 获取消息的具体内容 + statement = select(ChatMessage).where(ChatMessage.id.in_(ids)) return session.exec(statement).all() @classmethod diff --git a/src/frontend/src/components/bs-comp/chatComponent/MessagePanne.tsx b/src/frontend/src/components/bs-comp/chatComponent/MessagePanne.tsx index c5b7667ac..45ff786fc 100644 --- a/src/frontend/src/components/bs-comp/chatComponent/MessagePanne.tsx +++ b/src/frontend/src/components/bs-comp/chatComponent/MessagePanne.tsx @@ -70,8 +70,9 @@ export default function MessagePanne({ useName, guideWord, loadMore }) { type = 'separator' } else if (msg.files?.length) { type = 'file' - } else if (['tool', 'flow', 'knowledge'].includes(msg.category)){ - // || msg.category === 'processing') { // 项目演示? + } else if (['tool', 'flow', 'knowledge'].includes(msg.category) + || (msg.category === 'processing' && msg.thought.indexOf(`status_code`) === -1) + ) { // 项目演示? type = 'runLog' } else if (msg.thought) { type = 'system' diff --git a/src/frontend/src/components/bs-comp/chatComponent/MessageSystem.tsx b/src/frontend/src/components/bs-comp/chatComponent/MessageSystem.tsx index 102a70e72..8320831a2 100644 --- a/src/frontend/src/components/bs-comp/chatComponent/MessageSystem.tsx +++ b/src/frontend/src/components/bs-comp/chatComponent/MessageSystem.tsx @@ -34,15 +34,10 @@ export default function MessageSystem({ data }) { const border = { system: 'border-slate-500', question: 'border-amber-500', processing: 'border-cyan-600', answer: 'border-lime-600', report: 'border-slate-500', guide: 'border-none' } - // 中英去掉最终的回答(report) - // if(data.category === 'report') return null - return