diff --git a/README.md b/README.md
index 4c5021ad75b..b5b3136e373 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@ MaxKB 是一款基于 LLM 大语言模型的知识库问答系统。MaxKB = Max
- **开箱即用**:支持直接上传文档、自动爬取在线文档,支持文本自动拆分、向量化,智能问答交互体验好;
- **无缝嵌入**:支持零编码快速嵌入到第三方业务系统;
-- **多模型支持**:支持对接主流的大模型,包括本地私有大模型(如 Llama 2)、Azure OpenAI 和百度千帆大模型等。
+- **多模型支持**:支持对接主流的大模型,包括本地私有大模型(如 Llama 2、Llama 3)、通义千问、OpenAI、Azure OpenAI、Kimi 和百度千帆大模型等。
## 快速开始
@@ -53,7 +53,7 @@ docker run -d --name=maxkb -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data 1pa
- 后端:[Python / Django](https://www.djangoproject.com/)
- LangChain:[LangChain](https://www.langchain.com/)
- 向量数据库:[PostgreSQL / pgvector](https://www.postgresql.org/)
-- 大模型:Azure OpenAI、百度千帆大模型、[Ollama](https://github.com/ollama/ollama)
+- 大模型:Azure OpenAI、OpenAI、百度千帆大模型、[Ollama](https://github.com/ollama/ollama)、通义千问、Kimi
## Star History
diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py
index 3816795e59f..0a46cbfa41f 100644
--- a/apps/setting/models_provider/constants/model_provider_constants.py
+++ b/apps/setting/models_provider/constants/model_provider_constants.py
@@ -14,6 +14,8 @@
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider
+from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
+from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
class ModelProvideConstants(Enum):
@@ -23,3 +25,5 @@ class ModelProvideConstants(Enum):
model_openai_provider = OpenAIModelProvider()
model_kimi_provider = KimiModelProvider()
model_qwen_provider = QwenModelProvider()
+ model_zhipu_provider = ZhiPuModelProvider()
+ model_xf_provider = XunFeiModelProvider()
diff --git a/apps/setting/models_provider/impl/xf_model_provider/__init__.py b/apps/setting/models_provider/impl/xf_model_provider/__init__.py
new file mode 100644
index 00000000000..c743b4e183a
--- /dev/null
+++ b/apps/setting/models_provider/impl/xf_model_provider/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/04/19 15:55
+ @desc:
+"""
\ No newline at end of file
diff --git a/apps/setting/models_provider/impl/xf_model_provider/icon/xf_icon_svg b/apps/setting/models_provider/impl/xf_model_provider/icon/xf_icon_svg
new file mode 100644
index 00000000000..b74e351e2ee
--- /dev/null
+++ b/apps/setting/models_provider/impl/xf_model_provider/icon/xf_icon_svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py b/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py
new file mode 100644
index 00000000000..a09d48092c9
--- /dev/null
+++ b/apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py
@@ -0,0 +1,43 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/04/19 15:55
+ @desc:
+"""
+
+from typing import List, Optional, Any, Iterator
+
+from langchain_community.chat_models import ChatSparkLLM
+from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.messages import BaseMessage, AIMessageChunk
+from langchain_core.outputs import ChatGenerationChunk
+
+
+class XFChatSparkLLM(ChatSparkLLM):
+ def _stream(
+ self,
+ messages: List[BaseMessage],
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> Iterator[ChatGenerationChunk]:
+ default_chunk_class = AIMessageChunk
+
+ self.client.arun(
+ [_convert_message_to_dict(m) for m in messages],
+ self.spark_user_id,
+ self.model_kwargs,
+ True,
+ )
+ for content in self.client.subscribe(timeout=self.request_timeout):
+ if "data" not in content:
+ continue
+ delta = content["data"]
+ chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
+ cg_chunk = ChatGenerationChunk(message=chunk)
+ if run_manager:
+ run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
+ yield cg_chunk
diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py
new file mode 100644
index 00000000000..28059c5c69b
--- /dev/null
+++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py
@@ -0,0 +1,103 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: xf_model_provider.py
+ @date:2024/04/19 14:47
+ @desc:
+"""
+import os
+from typing import Dict
+
+from langchain.schema import HumanMessage
+from langchain_community.chat_models import ChatSparkLLM
+
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from common.util.file_util import get_file_content
+from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
+ ModelInfo, IModelProvider, ValidCode
+from setting.models_provider.impl.xf_model_provider.model.xf_chat_model import XFChatSparkLLM
+from smartdoc.conf import PROJECT_DIR
+import ssl
+
+ssl._create_default_https_context = ssl.create_default_context()
+
+
+class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
+
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
+ model_type_list = XunFeiModelProvider().get_model_type_list()
+ if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+
+ for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
+ if key not in model_credential:
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
+ else:
+ return False
+ try:
+ model = XunFeiModelProvider().get_model(model_type, model_name, model_credential)
+ model.invoke([HumanMessage(content='你好')])
+ except Exception as e:
+ if isinstance(e, AppApiException):
+ raise e
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+ else:
+ return False
+ return True
+
+ def encryption_dict(self, model: Dict[str, object]):
+ return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
+
+ spark_api_url = forms.TextInputField('API 域名', required=True)
+ spark_app_id = forms.TextInputField('APP ID', required=True)
+ spark_api_key = forms.PasswordInputField("API Key", required=True)
+ spark_api_secret = forms.PasswordInputField('API Secret', required=True)
+
+
+qwen_model_credential = XunFeiLLMModelCredential()
+
+model_dict = {
+ 'generalv3.5': ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential),
+ 'generalv3': ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential),
+ 'generalv2': ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential)
+}
+
+
+class XunFeiModelProvider(IModelProvider):
+
+ def get_dialogue_number(self):
+ return 3
+
+ def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> XFChatSparkLLM:
+ zhipuai_chat = XFChatSparkLLM(
+ spark_app_id=model_credential.get('spark_app_id'),
+ spark_api_key=model_credential.get('spark_api_key'),
+ spark_api_secret=model_credential.get('spark_api_secret'),
+ spark_api_url=model_credential.get('spark_api_url'),
+ spark_llm_domain=model_name
+ )
+ return zhipuai_chat
+
+ def get_model_credential(self, model_type, model_name):
+ if model_name in model_dict:
+ return model_dict.get(model_name).model_credential
+ return qwen_model_credential
+
+ def get_model_provide_info(self):
+ return ModelProvideInfo(provider='model_xf_provider', name='讯飞星火', icon=get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xf_model_provider', 'icon',
+ 'xf_icon_svg')))
+
+ def get_model_list(self, model_type: str):
+ if model_type is None:
+ raise AppApiException(500, '模型类型不能为空')
+ return [model_dict.get(key).to_dict() for key in
+ list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
+
+ def get_model_type_list(self):
+ return [{'key': "大语言模型", 'value': "LLM"}]
diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/__init__.py b/apps/setting/models_provider/impl/zhipu_model_provider/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/icon/zhipuai_icon_svg b/apps/setting/models_provider/impl/zhipu_model_provider/icon/zhipuai_icon_svg
new file mode 100644
index 00000000000..f39fedcbbc3
--- /dev/null
+++ b/apps/setting/models_provider/impl/zhipu_model_provider/icon/zhipuai_icon_svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py
new file mode 100644
index 00000000000..b84bb3d15cc
--- /dev/null
+++ b/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py
@@ -0,0 +1,93 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: zhipu_model_provider.py
+ @date:2024/04/19 13:5
+ @desc:
+"""
+import os
+from typing import Dict
+
+from langchain.schema import HumanMessage
+from langchain_community.chat_models import ChatZhipuAI
+
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from common.util.file_util import get_file_content
+from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
+ ModelInfo, IModelProvider, ValidCode
+from smartdoc.conf import PROJECT_DIR
+
+
+class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
+
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
+ model_type_list = ZhiPuModelProvider().get_model_type_list()
+ if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+ for key in ['api_key']:
+ if key not in model_credential:
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
+ else:
+ return False
+ try:
+ model = ZhiPuModelProvider().get_model(model_type, model_name, model_credential)
+ model.invoke([HumanMessage(content='你好')])
+ except Exception as e:
+ if isinstance(e, AppApiException):
+ raise e
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+ else:
+ return False
+ return True
+
+ def encryption_dict(self, model: Dict[str, object]):
+ return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
+
+ api_key = forms.PasswordInputField('API Key', required=True)
+
+
+qwen_model_credential = ZhiPuLLMModelCredential()
+
+model_dict = {
+ 'glm-4': ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential),
+ 'glm-4v': ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential),
+ 'glm-3-turbo': ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential)
+}
+
+
+class ZhiPuModelProvider(IModelProvider):
+
+ def get_dialogue_number(self):
+ return 3
+
+ def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> ChatZhipuAI:
+ zhipuai_chat = ChatZhipuAI(
+ temperature=0.5,
+ api_key=model_credential.get('api_key'),
+ model=model_name
+ )
+ return zhipuai_chat
+
+ def get_model_credential(self, model_type, model_name):
+ if model_name in model_dict:
+ return model_dict.get(model_name).model_credential
+ return qwen_model_credential
+
+ def get_model_provide_info(self):
+ return ModelProvideInfo(provider='model_zhipu_provider', name='智谱AI', icon=get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'zhipu_model_provider', 'icon',
+ 'zhipuai_icon_svg')))
+
+ def get_model_list(self, model_type: str):
+ if model_type is None:
+ raise AppApiException(500, '模型类型不能为空')
+ return [model_dict.get(key).to_dict() for key in
+ list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
+
+ def get_model_type_list(self):
+ return [{'key': "大语言模型", 'value': "LLM"}]
diff --git a/apps/users/serializers/user_serializers.py b/apps/users/serializers/user_serializers.py
index 6672a4cb545..352bf6f96eb 100644
--- a/apps/users/serializers/user_serializers.py
+++ b/apps/users/serializers/user_serializers.py
@@ -418,7 +418,8 @@ def get_user_profile(user: User):
permission_list = get_user_dynamics_permission(str(user.id))
permission_list += [p.value for p in get_permission_list_by_role(RoleConstants[user.role])]
return {'id': user.id, 'username': user.username, 'email': user.email, 'role': user.role,
- 'permissions': [str(p) for p in permission_list]}
+ 'permissions': [str(p) for p in permission_list],
+ 'is_edit_password': user.password == 'd880e722c47a34d8e9fce789fc62389d' if user.role == 'ADMIN' else False}
@staticmethod
def get_response_body_api():
diff --git a/pyproject.toml b/pyproject.toml
index 35d4a42f074..f70b250bbce 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -33,6 +33,10 @@ pymupdf = "1.24.1"
python-docx = "^1.1.0"
xlwt = "^1.3.0"
dashscope = "^1.17.0"
+zhipuai = "^2.0.1"
+httpx = "^0.27.0"
+httpx-sse = "^0.4.0"
+websocket-client = "^1.7.0"
[build-system]
requires = ["poetry-core"]
diff --git a/ui/src/api/type/user.ts b/ui/src/api/type/user.ts
index 04cbd41409a..6724252c959 100644
--- a/ui/src/api/type/user.ts
+++ b/ui/src/api/type/user.ts
@@ -19,6 +19,10 @@ interface User {
* 用户权限
*/
permissions: Array
+ /**
+ * 是否需要修改密码
+ */
+ is_edit_password?: boolean
}
interface LoginRequest {
diff --git a/ui/src/layout/components/top-bar/avatar/index.vue b/ui/src/layout/components/top-bar/avatar/index.vue
index 19bdcdbb50b..d6a889cf304 100644
--- a/ui/src/layout/components/top-bar/avatar/index.vue
+++ b/ui/src/layout/components/top-bar/avatar/index.vue
@@ -30,16 +30,19 @@
+