Skip to content

Commit

Permalink
remove lazy_import and enable IDE code completion and jump (#63)
Browse files Browse the repository at this point in the history
Signed-off-by: ChengZi <[email protected]>
  • Loading branch information
zc277584121 authored Jan 16, 2025
1 parent 0da992b commit fa4021e
Show file tree
Hide file tree
Showing 25 changed files with 352 additions and 406 deletions.
60 changes: 10 additions & 50 deletions src/pymilvus/model/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
from pymilvus.model.dense.openai import OpenAIEmbeddingFunction
from pymilvus.model.dense.sentence_transformer import SentenceTransformerEmbeddingFunction
from pymilvus.model.dense.voyageai import VoyageEmbeddingFunction
from pymilvus.model.dense.jinaai import JinaEmbeddingFunction
from pymilvus.model.dense.onnx import OnnxEmbeddingFunction
from pymilvus.model.dense.cohere import CohereEmbeddingFunction
from pymilvus.model.dense.mistralai import MistralAIEmbeddingFunction
from pymilvus.model.dense.nomic import NomicEmbeddingFunction
from pymilvus.model.dense.instructor import InstructorEmbeddingFunction

__all__ = [
"OpenAIEmbeddingFunction",
"SentenceTransformerEmbeddingFunction",
Expand All @@ -9,53 +19,3 @@
"NomicEmbeddingFunction",
"InstructorEmbeddingFunction"
]

from pymilvus.model.utils.lazy_import import LazyImport

jinaai = LazyImport("jinaai", globals(), "pymilvus.model.dense.jinaai")
openai = LazyImport("openai", globals(), "pymilvus.model.dense.openai")
sentence_transformer = LazyImport(
"sentence_transformer", globals(), "pymilvus.model.dense.sentence_transformer"
)
voyageai = LazyImport("voyageai", globals(), "pymilvus.model.dense.voyageai")
onnx = LazyImport("onnx", globals(), "pymilvus.model.dense.onnx")
cohere = LazyImport("cohere", globals(), "pymilvus.model.dense.cohere")
mistralai = LazyImport("mistralai", globals(), "pymilvus.model.dense.mistralai")
nomic = LazyImport("nomic", globals(), "pymilvus.model.dense.nomic")
instructor = LazyImport("instructor", globals(), "pymilvus.model.dense.instructor")


def JinaEmbeddingFunction(*args, **kwargs):
return jinaai.JinaEmbeddingFunction(*args, **kwargs)


def OpenAIEmbeddingFunction(*args, **kwargs):
return openai.OpenAIEmbeddingFunction(*args, **kwargs)


def SentenceTransformerEmbeddingFunction(*args, **kwargs):
return sentence_transformer.SentenceTransformerEmbeddingFunction(*args, **kwargs)


def VoyageEmbeddingFunction(*args, **kwargs):
return voyageai.VoyageEmbeddingFunction(*args, **kwargs)


def OnnxEmbeddingFunction(*args, **kwargs):
return onnx.OnnxEmbeddingFunction(*args, **kwargs)


def CohereEmbeddingFunction(*args, **kwargs):
return cohere.CohereEmbeddingFunction(*args, **kwargs)


def MistralAIEmbeddingFunction(*args, **kwargs):
return mistralai.MistralAIEmbeddingFunction(*args, **kwargs)


def NomicEmbeddingFunction(*args, **kwargs):
return nomic.NomicEmbeddingFunction(*args, **kwargs)


def InstructorEmbeddingFunction(*args, **kwargs):
return instructor.InstructorEmbeddingFunction(*args, **kwargs)
6 changes: 4 additions & 2 deletions src/pymilvus/model/dense/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.utils import import_cohere

import_cohere()
import cohere


class CohereEmbeddingFunction(BaseEmbeddingFunction):
def __init__(self,
Expand All @@ -22,6 +21,9 @@ def __init__(self,
self.embedding_types = embedding_types
self.truncate = truncate

import_cohere()
import cohere

if isinstance(embedding_types, list):
if len(embedding_types) > 1:
raise ValueError("Only one embedding type can be specified using current PyMilvus model library.")
Expand Down
7 changes: 2 additions & 5 deletions src/pymilvus/model/dense/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
import numpy as np

from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.utils import import_sentence_transformers, import_huggingface_hub

import_sentence_transformers()
import_huggingface_hub()

from .instructor_embedding.instructor_impl import Instructor

class InstructorEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
Expand All @@ -20,6 +15,8 @@ def __init__(
normalize_embeddings: bool = True,
**kwargs,
):
from .instructor_embedding.instructor_impl import Instructor

self.model_name = model_name
self.query_instruction = query_instruction
self.doc_instruction = doc_instruction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
from collections import OrderedDict
from typing import Union

from pymilvus.model.utils import import_sentence_transformers, import_huggingface_hub, import_torch

import_sentence_transformers()
import_huggingface_hub()
import_torch()

import numpy as np
import torch
from sentence_transformers import SentenceTransformer
Expand Down
6 changes: 4 additions & 2 deletions src/pymilvus/model/dense/mistralai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.utils import import_mistralai

import_mistralai()
from mistralai import Mistral


class MistralAIEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
Expand All @@ -16,6 +15,9 @@ def __init__(
model_name: str = "mistral-embed",
**kwargs,
):
import_mistralai()
from mistralai import Mistral

self._mistral_model_meta_info = defaultdict(dict)
self._mistral_model_meta_info[model_name]["dim"] = 1024 # fixed dimension

Expand Down
5 changes: 3 additions & 2 deletions src/pymilvus/model/dense/nomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.utils import import_nomic

import_nomic()
from nomic import embed

class NomicEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
Expand Down Expand Up @@ -54,6 +52,9 @@ def _encode_document(self, document: str) -> np.array:
return self._encode([document], task_type="search_document")[0]

def _call_nomic_api(self, texts: List[str], task_type: str):
import_nomic()
from nomic import embed

embeddings_batch_response = embed.text(
texts=texts,
**self._encode_config
Expand Down
6 changes: 5 additions & 1 deletion src/pymilvus/model/dense/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
import onnxruntime

from transformers import AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
import numpy as np
from typing import List

from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.utils import import_huggingface_hub


class OnnxEmbeddingFunction(BaseEmbeddingFunction):
def __init__(self, model_name: str = "GPTCache/paraphrase-albert-onnx", tokenizer_name: str = "GPTCache/paraphrase-albert-small-v2"):
import_huggingface_hub()
from huggingface_hub import hf_hub_download

self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.model_name = model_name
onnx_model_path = hf_hub_download(repo_id=model_name, filename="model.onnx")
Expand Down
5 changes: 3 additions & 2 deletions src/pymilvus/model/dense/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.utils import import_openai

import_openai()
from openai import OpenAI

class OpenAIEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
Expand All @@ -18,6 +16,9 @@ def __init__(
dimensions: Optional[int] = None,
**kwargs,
):
import_openai()
from openai import OpenAI

self._openai_model_meta_info = defaultdict(dict)
self._openai_model_meta_info["text-embedding-3-small"]["dim"] = 1536
self._openai_model_meta_info["text-embedding-3-large"]["dim"] = 3072
Expand Down
4 changes: 2 additions & 2 deletions src/pymilvus/model/dense/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.utils import import_sentence_transformers

import_sentence_transformers()
from sentence_transformers import SentenceTransformer

class SentenceTransformerEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
Expand All @@ -19,6 +17,8 @@ def __init__(
normalize_embeddings: bool = True,
**kwargs,
):
import_sentence_transformers()
from sentence_transformers import SentenceTransformer
self.model_name = model_name
self.query_instruction = query_instruction
self.doc_instruction = doc_instruction
Expand Down
6 changes: 3 additions & 3 deletions src/pymilvus/model/dense/voyageai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.utils import import_voyageai

import_voyageai()
import voyageai


class VoyageEmbeddingFunction(BaseEmbeddingFunction):
def __init__(self,
Expand All @@ -19,6 +16,9 @@ def __init__(self,
truncate: Optional[bool] = None,
dimension: Optional[int] = None,
**kwargs):
import_voyageai()
import voyageai

self.model_name = model_name
self.truncate = truncate
self._voyageai_model_meta_info = defaultdict(dict)
Expand Down
14 changes: 3 additions & 11 deletions src/pymilvus/model/hybrid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
__all__ = ["BGEM3EmbeddingFunction", "MGTEEmbeddingFunction"]

from pymilvus.model.utils.lazy_import import LazyImport

bge_m3 = LazyImport("bge_m3", globals(), "pymilvus.model.hybrid.bge_m3")
mgte = LazyImport("mgte", globals(), "pymilvus.model.hybrid.mgte")
from pymilvus.model.hybrid.bge_m3 import BGEM3EmbeddingFunction
from pymilvus.model.hybrid.mgte import MGTEEmbeddingFunction

def BGEM3EmbeddingFunction(*args, **kwargs):
return bge_m3.BGEM3EmbeddingFunction(*args, **kwargs)

def MGTEEmbeddingFunction(*args, **kwargs):
return mgte.MGTEEmbeddingFunction(*args, **kwargs)
__all__ = ["BGEM3EmbeddingFunction", "MGTEEmbeddingFunction"]
25 changes: 14 additions & 11 deletions src/pymilvus/model/hybrid/bge_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,6 @@
from pymilvus.model.utils import import_FlagEmbedding, import_datasets
from pymilvus.model.sparse.utils import stack_sparse_embeddings

import_datasets()
import_FlagEmbedding()

try:
from FlagEmbedding import BGEM3FlagModel
except AttributeError as e:
import sys
if "google.colab" in sys.modules and "ListView" in str(e):
print("\033[91mIt looks like you're running on Google Colab. Please restart the session to resolve this issue.\033[0m")
print("\033[91mFor further details, visit: https://github.com/milvus-io/milvus-model/issues/32.\033[0m")
raise

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
Expand All @@ -37,6 +26,20 @@ def __init__(
return_colbert_vecs: bool = False,
**kwargs,
):
import_datasets()
import_FlagEmbedding()

try:
from FlagEmbedding import BGEM3FlagModel
except AttributeError as e:
import sys
if "google.colab" in sys.modules and "ListView" in str(e):
print(
"\033[91mIt looks like you're running on Google Colab. Please restart the session to resolve this issue.\033[0m")
print(
"\033[91mFor further details, visit: https://github.com/milvus-io/milvus-model/issues/32.\033[0m")
raise

self.model_name = model_name
self.batch_size = batch_size
self.normalize_embeddings = normalize_embeddings
Expand Down
Loading

0 comments on commit fa4021e

Please sign in to comment.