Skip to content

Commit

Permalink
feat: support siliconCloud rerank
Browse files Browse the repository at this point in the history
  • Loading branch information
wxg0103 committed Feb 6, 2025
1 parent 8957b77 commit 1234096
Show file tree
Hide file tree
Showing 17 changed files with 100 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tti import QwenTextToImageModel
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tts import AliyunBaiLianTextToSpeech
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _, gettext
from django.utils.translation import gettext as _, gettext

aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential()
aliyun_bai_lian_tts_model_credential = AliyunBaiLianTTSModelCredential()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _


def _create_model_info(model_name, description, model_type, credential_class, model_class):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from setting.models_provider.impl.deepseek_model_provider.credential.llm import DeepSeekLLMModelCredential
from setting.models_provider.impl.deepseek_model_provider.model.llm import DeepSeekChatModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
deepseek_llm_model_credential = DeepSeekLLMModelCredential()

deepseek_chat = ModelInfo('deepseek-chat', _('Good at common conversational tasks, supports 32K contexts'), ModelTypeConst.LLM,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _


gemini_llm_model_credential = GeminiLLMModelCredential()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _

embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
LocalEmbeddingCredential(), LocalEmbedding)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from setting.models_provider.impl.ollama_model_provider.model.image import OllamaImage
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _

""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
from setting.models_provider.impl.qwen_model_provider.model.tti import QwenTextToImageModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _

qwen_model_credential = OpenAILLMModelCredential()
qwenvl_model_credential = QwenVLModelCredential()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
from setting.models_provider.impl.siliconCloud_model_provider.model.reranker import SiliconCloudReranker


Expand All @@ -26,7 +25,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
if not model_type == 'RERANKER':
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['dashscope_api_key']:
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
Expand All @@ -47,6 +46,6 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'dashscope_api_key': super().encryption(model.get('dashscope_api_key', ''))}

dashscope_api_key = forms.PasswordInputField('API Key', required=True)
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = forms.TextInputField('API URL', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,92 @@
"""
@project: MaxKB
@Author:虎
@file: reranker.py.py
@date:2024/9/2 16:42
@desc:
@file: siliconcloud_reranker.py
@date:2024/9/10 9:45
@desc: SiliconCloud 文档重排封装
"""
from typing import Dict

from langchain_community.document_compressors import DashScopeRerank
from typing import Sequence, Optional, Any, Dict
import requests

from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document

from setting.models_provider.base_model_provider import MaxKBBaseModel
from django.utils.translation import gettext as _


class SiliconCloudReranker(MaxKBBaseModel, BaseDocumentCompressor):
api_base: Optional[str]
"""SiliconCloud API URL"""
model: Optional[str]
"""SiliconCloud 重排模型 ID"""
api_key: Optional[str]
"""API Key"""

top_n: Optional[int] = 3 # 取前 N 个最相关的结果

class SiliconCloudReranker(MaxKBBaseModel, DashScopeRerank):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return SiliconCloudReranker(model=model_name, dashscope_api_key=model_credential.get('dashscope_api_key'),
top_n=model_kwargs.get('top_n', 3))
return SiliconCloudReranker(
api_base=model_credential.get('api_base'),
model=model_name,
api_key=model_credential.get('api_key'),
top_n=model_kwargs.get('top_n', 3)
)

def __init__(
self, api_base: Optional[str] = None, model: Optional[str] = None, top_n=3,
api_key: Optional[str] = None
):
super().__init__()

if not api_base:
raise ValueError(_('Please provide server URL'))

if not model:
raise ValueError(_('Please provide the model'))

if not api_key:
raise ValueError(_('Please provide the API Key'))

self.api_base = api_base
self.model = model
self.api_key = api_key
self.top_n = top_n

def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if not documents:
return []

# 预处理文本
texts = [doc.page_content for doc in documents]

# 发送请求到 SiliconCloud API
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"query": query,
"documents": texts,
"top_n": self.top_n
}

response = requests.post(f"{self.api_base}/rerank", json=payload, headers=headers)

if response.status_code != 200:
raise RuntimeError(f"SiliconCloud API 请求失败: {response.text}")

res = response.json()

# 解析返回结果
return [
Document(
page_content=item.get('document', {}).get('text', ''),
metadata={'relevance_score': item.get('relevance_score')}
)
for item in res.get('results', [])
]
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from setting.models_provider.impl.siliconCloud_model_provider.model.stt import SiliconCloudSpeechToText
from setting.models_provider.impl.siliconCloud_model_provider.model.tti import SiliconCloudTextToImage
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _

openai_llm_model_credential = SiliconCloudLLMModelCredential()
openai_stt_model_credential = SiliconCloudSTTModelCredential()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
from setting.models_provider.impl.tencent_model_provider.model.tti import TencentTextToImageModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _

def _create_model_info(model_name, description, model_type, credential_class, model_class):
return ModelInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from setting.models_provider.impl.vllm_model_provider.model.image import VllmImage
from setting.models_provider.impl.vllm_model_provider.model.llm import VllmChatModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _

v_llm_model_credential = VLLMModelCredential()
image_model_credential = VllmImageModelCredential()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech

from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _

volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from setting.models_provider.impl.wenxin_model_provider.model.embedding import QianfanEmbeddings
from setting.models_provider.impl.wenxin_model_provider.model.llm import QianfanChatModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _

win_xin_llm_model_credential = WenxinLLMModelCredential()
qianfan_embedding_credential = QianfanEmbeddingCredential()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _

ssl._create_default_https_context = ssl.create_default_context()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from setting.models_provider.impl.xinference_model_provider.model.tti import XinferenceTextToImage
from setting.models_provider.impl.xinference_model_provider.model.tts import XInferenceTextToSpeech
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _

from django.utils.translation import gettext as _

xinference_llm_model_credential = XinferenceLLMModelCredential()
xinference_stt_model_credential = XInferenceSTTModelCredential()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel
from setting.models_provider.impl.zhipu_model_provider.model.tti import ZhiPuTextToImage
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _

qwen_model_credential = ZhiPuLLMModelCredential()
zhipu_image_model_credential = ZhiPuImageModelCredential()
Expand Down

0 comments on commit 1234096

Please sign in to comment.