-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4aa1b58
commit 8fe1a14
Showing
6 changed files
with
158 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 8 additions & 0 deletions
8
apps/setting/models_provider/impl/xf_model_provider/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: __init__.py.py | ||
@date:2024/04/19 15:55 | ||
@desc: | ||
""" |
1 change: 1 addition & 0 deletions
1
apps/setting/models_provider/impl/xf_model_provider/icon/xf_icon_svg
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
<svg t="1713509569091" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="4361" xmlns:xlink="http://www.w3.org/1999/xlink" width="100%" height="100%" ><path d="M500.1216 971.40736c-55.58272-4.28032-102.11328-16.7936-147.74272-39.7312-115.0976-57.83552-192.88064-168.5504-208.81408-297.12384-2.8672-23.18336-2.90816-69.18144-0.06144-93.3888a387.95264 387.95264 0 0 1 82.65728-196.46464c6.22592-7.7824 47.616-49.7664 94.49472-95.92832 112.8448-111.104 111.53408-109.85472 113.29536-108.09344 1.024 1.024 0.75776 5.59104-0.8192 14.45888-3.13344 17.57184-3.13344 55.99232 0.04096 77.0048 9.66656 64.49152 37.66272 124.3136 87.16288 186.32704 12.6976 15.91296 59.22816 63.24224 76.57472 77.88544 20.13184 16.9984 33.1776 37.53984 39.34208 61.8496 4.07552 16.0768 4.07552 42.10688 0 58.20416-10.24 40.57088-40.8576 72.58112-81.6128 85.38112-9.35936 2.92864-13.84448 3.39968-32.68608 3.39968s-23.3472-0.47104-32.768-3.39968c-29.02016-9.07264-56.40192-30.06464-32.52224-24.94464 12.94336 2.7648 29.65504-3.2768 37.49888-13.57824 10.81344-14.1312 12.57472-29.53216 5.09952-44.48256-3.76832-7.53664-6.8608-10.91584-19.12832-20.82816-33.1776-26.86976-65.7408-59.5968-88.8832-89.25184-11.81696-15.17568-28.8768-40.59136-33.95584-50.5856-1.92512-3.7888-4.15744-6.90176-4.95616-6.90176-2.00704 0-17.92 24.43264-24.73984 37.96992-7.84384 15.52384-15.33952 37.888-19.12832 57.0368-4.64896 23.3472-4.64896 59.14624-0.04096 82.16576 17.77664 88.63744 82.78016 154.74688 171.02848 173.8752 12.3904 2.70336 19.39456 3.23584 42.496 3.23584 23.08096 0 30.1056-0.53248 42.47552-3.23584 43.04896-9.3184 78.4384-28.672 109.6704-59.904 32.72704-32.72704 52.26496-69.44768 61.29664-115.3024 4.54656-23.06048 4.56704-56.36096 0.04096-79.29856-3.87072-19.5584-8.76544-35.16416-15.85152-50.3808-6.69696-14.35648-6.0416-15.21664 9.1136-12.0832 40.89856 8.45824 85.6064 31.41632 114.40128 58.75712 34.6112 32.84992 49.27488 65.45408 49.27488 109.71136 0 24.00256-3.4816 41.6768-13.35296 68.17792-20.54144 54.9888-50.54464 100.61824-93.7984 142.66368-51.26144 49.80736-116.8384 85.03296-183.95136 98.79552-30.45376 6.2464-76.53376 9.89184-101.1712 7.9872z m391.53664-433.93024c-17.05984-32.93184-41.75872-56.48384-76.8-73.19552-18.80064-8.97024-35.67616-14.52032-68.75136-22.58944-44.46208-10.8544-66.2528-18.2272-93.16352-31.62112-26.2144-13.04576-46.16192-27.3408-66.3552-47.5136-26.70592-26.74688-42.63936-52.4288-54.35392-87.67488-10.36288-31.27296-10.0352-27.2384-10.62912-128.96256-0.45056-78.37696-0.2048-92.0576 1.51552-92.0576 1.1264 0 45.8752 42.98752 99.40992 95.51872 190.95552 187.37152 194.58048 191.11936 216.7808 224.78848 20.13184 30.53568 39.26016 72.0896 48.76288 105.92256 4.7104 16.73216 11.0592 48.18944 11.81696 58.40896 0.7168 9.8304-2.84672 9.35936-8.23296-1.024z" fill="#3DC8F9" p-id="4362"></path><path d="M523.12064 53.8624c-1.7408 0-1.96608 13.68064-1.51552 92.0576 0.57344 101.74464 0.24576 97.6896 10.6496 128.96256 11.6736 35.2256 27.62752 60.928 54.33344 87.6544 20.19328 20.19328 40.1408 34.48832 66.3552 47.5136 26.91072 13.4144 48.70144 20.80768 93.14304 31.6416 33.09568 8.0896 49.9712 13.6192 68.75136 22.58944 35.04128 16.71168 59.74016 40.2432 76.8 73.19552 5.40672 10.38336 8.97024 10.8544 8.25344 1.024-0.75776-10.21952-7.12704-41.6768-11.81696-58.40896-9.50272-33.83296-28.63104-75.3664-48.76288-105.92256-22.20032-33.66912-25.82528-37.41696-216.7808-224.78848-53.5552-52.5312-98.304-95.51872-99.40992-95.51872z" fill="#EA0100" p-id="4363"></path><path d="M391.3728 762.30656s86.2208 88.41216 241.9712 100.06528c155.7504 11.63264 193.536-45.44512 193.536-45.44512s76.3904-100.864 65.08544-177.54112c-11.32544-76.67712-71.02464-131.21536-174.96064-154.89024 0 0 31.90784 80.2816 20.5824 128.14336-11.32544 47.86176-20.0704 138.93632-159.5392 186.7776 0 0-102.68672 30.208-186.6752-37.10976z" fill="#1652D8" p-id="4364"></path></svg> |
43 changes: 43 additions & 0 deletions
43
apps/setting/models_provider/impl/xf_model_provider/model/xf_chat_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: __init__.py.py | ||
@date:2024/04/19 15:55 | ||
@desc: | ||
""" | ||
|
||
from typing import List, Optional, Any, Iterator | ||
|
||
from langchain_community.chat_models import ChatSparkLLM | ||
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk | ||
from langchain_core.callbacks import CallbackManagerForLLMRun | ||
from langchain_core.messages import BaseMessage, AIMessageChunk | ||
from langchain_core.outputs import ChatGenerationChunk | ||
|
||
|
||
class XFChatSparkLLM(ChatSparkLLM): | ||
def _stream( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> Iterator[ChatGenerationChunk]: | ||
default_chunk_class = AIMessageChunk | ||
|
||
self.client.arun( | ||
[_convert_message_to_dict(m) for m in messages], | ||
self.spark_user_id, | ||
self.model_kwargs, | ||
True, | ||
) | ||
for content in self.client.subscribe(timeout=self.request_timeout): | ||
if "data" not in content: | ||
continue | ||
delta = content["data"] | ||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) | ||
cg_chunk = ChatGenerationChunk(message=chunk) | ||
if run_manager: | ||
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk) | ||
yield cg_chunk |
103 changes: 103 additions & 0 deletions
103
apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: maxkb | ||
@Author:虎 | ||
@file: xf_model_provider.py | ||
@date:2024/04/19 14:47 | ||
@desc: | ||
""" | ||
import os | ||
from typing import Dict | ||
|
||
from langchain.schema import HumanMessage | ||
from langchain_community.chat_models import ChatSparkLLM | ||
|
||
from common import forms | ||
from common.exception.app_exception import AppApiException | ||
from common.forms import BaseForm | ||
from common.util.file_util import get_file_content | ||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ | ||
ModelInfo, IModelProvider, ValidCode | ||
from setting.models_provider.impl.xf_model_provider.model.xf_chat_model import XFChatSparkLLM | ||
from smartdoc.conf import PROJECT_DIR | ||
import ssl | ||
|
||
ssl._create_default_https_context = ssl.create_default_context() | ||
|
||
|
||
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential): | ||
|
||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): | ||
model_type_list = XunFeiModelProvider().get_model_type_list() | ||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') | ||
|
||
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: | ||
if key not in model_credential: | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') | ||
else: | ||
return False | ||
try: | ||
model = XunFeiModelProvider().get_model(model_type, model_name, model_credential) | ||
model.invoke([HumanMessage(content='你好')]) | ||
except Exception as e: | ||
if isinstance(e, AppApiException): | ||
raise e | ||
if raise_exception: | ||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') | ||
else: | ||
return False | ||
return True | ||
|
||
def encryption_dict(self, model: Dict[str, object]): | ||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} | ||
|
||
spark_api_url = forms.TextInputField('API 域名', required=True) | ||
spark_app_id = forms.TextInputField('APP ID', required=True) | ||
spark_api_key = forms.PasswordInputField("API Key", required=True) | ||
spark_api_secret = forms.PasswordInputField('API Secret', required=True) | ||
|
||
|
||
qwen_model_credential = XunFeiLLMModelCredential() | ||
|
||
model_dict = { | ||
'generalv3.5': ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential), | ||
'generalv3': ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential), | ||
'generalv2': ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential) | ||
} | ||
|
||
|
||
class XunFeiModelProvider(IModelProvider): | ||
|
||
def get_dialogue_number(self): | ||
return 3 | ||
|
||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> XFChatSparkLLM: | ||
zhipuai_chat = XFChatSparkLLM( | ||
spark_app_id=model_credential.get('spark_app_id'), | ||
spark_api_key=model_credential.get('spark_api_key'), | ||
spark_api_secret=model_credential.get('spark_api_secret'), | ||
spark_api_url=model_credential.get('spark_api_url'), | ||
spark_llm_domain=model_name | ||
) | ||
return zhipuai_chat | ||
|
||
def get_model_credential(self, model_type, model_name): | ||
if model_name in model_dict: | ||
return model_dict.get(model_name).model_credential | ||
return qwen_model_credential | ||
|
||
def get_model_provide_info(self): | ||
return ModelProvideInfo(provider='model_xf_provider', name='讯飞星火', icon=get_file_content( | ||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xf_model_provider', 'icon', | ||
'xf_icon_svg'))) | ||
|
||
def get_model_list(self, model_type: str): | ||
if model_type is None: | ||
raise AppApiException(500, '模型类型不能为空') | ||
return [model_dict.get(key).to_dict() for key in | ||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))] | ||
|
||
def get_model_type_list(self): | ||
return [{'key': "大语言模型", 'value': "LLM"}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters