Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configuration options for changing embedding model name and NIM endpoint url at runtime #440

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
40 changes: 36 additions & 4 deletions client/src/nv_ingest_client/primitives/tasks/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
from typing import Dict
from typing import Optional

from pydantic import BaseModel, root_validator

Expand All @@ -17,6 +18,9 @@


class EmbedTaskSchema(BaseModel):
model_name: Optional[str] = None
endpoint_url: Optional[str] = None
api_key: Optional[str] = None
filter_errors: bool = False

@root_validator(pre=True)
Expand All @@ -42,7 +46,15 @@ class EmbedTask(Task):
Object for document embedding task
"""

def __init__(self, text: bool = None, tables: bool = None, filter_errors: bool = False) -> None:
def __init__(
self,
model_name: str = None,
endpoint_url: str = None,
api_key: str = None,
text: bool = None,
tables: bool = None,
filter_errors: bool = False,
) -> None:
"""
Setup Embed Task Config
"""
Expand All @@ -58,6 +70,9 @@ def __init__(self, text: bool = None, tables: bool = None, filter_errors: bool =
"'tables' parameter is deprecated and will be ignored. Future versions will remove this argument."
)

self._model_name = model_name
self._endpoint_url = endpoint_url
self._api_key = api_key
self._filter_errors = filter_errors

def __str__(self) -> str:
Expand All @@ -66,16 +81,33 @@ def __str__(self) -> str:
"""
info = ""
info += "Embed Task:\n"

if self._model_name:
info += f" model_name: {self._model_name}\n"
if self._endpoint_url:
info += f" endpoint_url: {self._endpoint_url}\n"
if self._api_key:
info += " api_key: [redacted]\n"
info += f" filter_errors: {self._filter_errors}\n"

return info

def to_dict(self) -> Dict:
"""
Convert to a dict for submission to redis
"""

task_properties = {
"filter_errors": False,
}
task_properties = {}

if self._model_name:
task_properties["model_name"] = self._model_name

if self._endpoint_url:
task_properties["endpoint_url"] = self._endpoint_url

if self._api_key:
task_properties["api_key"] = self._api_key

task_properties["filter_errors"] = self._filter_errors

return {"type": "embed", "task_properties": task_properties}
42 changes: 23 additions & 19 deletions src/nv_ingest/modules/transforms/embed_extractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
def _make_async_request(
prompts: List[str],
api_key: str,
embedding_nim_endpoint: str,
embedding_model: str,
endpoint_url: str,
model_name: str,
encoding_format: str,
input_type: str,
truncate: str,
Expand All @@ -51,12 +51,12 @@ def _make_async_request(
try:
client = OpenAI(
api_key=api_key,
base_url=embedding_nim_endpoint,
base_url=endpoint_url,
)

resp = client.embeddings.create(
input=prompts,
model=embedding_model,
model=model_name,
encoding_format=encoding_format,
extra_body={"input_type": input_type, "truncate": truncate},
)
Expand Down Expand Up @@ -85,8 +85,8 @@ def _make_async_request(
def _async_request_handler(
prompts: List[str],
api_key: str,
embedding_nim_endpoint: str,
embedding_model: str,
endpoint_url: str,
model_name: str,
encoding_format: str,
input_type: str,
truncate: str,
Expand All @@ -103,8 +103,8 @@ def _async_request_handler(
_make_async_request,
prompts=prompt_batch,
api_key=api_key,
embedding_nim_endpoint=embedding_nim_endpoint,
embedding_model=embedding_model,
endpoint_url=endpoint_url,
model_name=model_name,
encoding_format=encoding_format,
input_type=input_type,
truncate=truncate,
Expand All @@ -120,8 +120,8 @@ def _async_request_handler(
def _async_runner(
prompts: List[str],
api_key: str,
embedding_nim_endpoint: str,
embedding_model: str,
endpoint_url: str,
model_name: str,
encoding_format: str,
input_type: str,
truncate: str,
Expand All @@ -133,8 +133,8 @@ def _async_runner(
results = _async_request_handler(
prompts,
api_key,
embedding_nim_endpoint,
embedding_model,
endpoint_url,
model_name,
encoding_format,
input_type,
truncate,
Expand Down Expand Up @@ -236,8 +236,8 @@ def _generate_embeddings(
ctrl_msg: ControlMessage,
batch_size: int,
api_key: str,
embedding_nim_endpoint: str,
embedding_model: str,
endpoint_url: str,
model_name: str,
encoding_format: str,
input_type: str,
truncate: str,
Expand Down Expand Up @@ -292,8 +292,8 @@ def _generate_embeddings(
content_embeddings = _async_runner(
filtered_content_batches,
api_key,
embedding_nim_endpoint,
embedding_model,
endpoint_url,
model_name,
encoding_format,
input_type,
truncate,
Expand Down Expand Up @@ -355,14 +355,18 @@ def embed_extractions_fn(message: ControlMessage):
try:
task_props = message.remove_task("embed")
model_dump = task_props.model_dump()

model_name = model_dump.get("model_name") or validated_config.model_name
endpoint_url = model_dump.get("endpoint_url") or validated_config.endpoint_url
api_key = model_dump.get("api_key") or validated_config.api_key
filter_errors = model_dump.get("filter_errors", False)

return _generate_embeddings(
message,
validated_config.batch_size, # This parameter is now ignored in _generate_embeddings.
validated_config.api_key,
validated_config.embedding_nim_endpoint,
validated_config.embedding_model,
api_key,
endpoint_url,
model_name,
validated_config.encoding_format,
validated_config.input_type,
validated_config.truncate,
Expand Down
4 changes: 2 additions & 2 deletions src/nv_ingest/schemas/embed_extractions_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
class EmbedExtractionsSchema(BaseModel):
api_key: str = "api_key"
batch_size: int = 8192
embedding_model: str = "nvidia/nv-embedqa-e5-v5"
embedding_nim_endpoint: str = "http://embedding:8000/v1"
model_name: str = "nvidia/nv-embedqa-e5-v5"
endpoint_url: str = "http://embedding:8000/v1"
encoding_format: str = "float"
httpx_log_level: LogLevel = LogLevel.WARNING
input_type: str = "passage"
Expand Down
3 changes: 3 additions & 0 deletions src/nv_ingest/schemas/ingest_job_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ class IngestTaskDedupSchema(BaseModelNoExt):


class IngestTaskEmbedSchema(BaseModelNoExt):
model_name: Optional[str] = None
endpoint_url: Optional[str] = None
api_key: Optional[str] = None
filter_errors: bool = False


Expand Down
6 changes: 3 additions & 3 deletions src/nv_ingest/util/pipeline/stage_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,12 +422,12 @@ def add_embed_extractions_stage(pipe, morpheus_pipeline_config, ingest_config):
"",
)
embedding_nim_endpoint = os.getenv("EMBEDDING_NIM_ENDPOINT", "http://embedding:8000/v1")
embedding_model = os.getenv("EMBEDDING_NIM_MODEL_NAME", "nvidia/nv-embedqa-e5-v5")
embedding_model = os.getenv("EMBEDDING_NIM_MODEL_NAME", "nvidia/llama-3.2-nv-embedqa-1b-v2")

text_embed_extraction_config = {
"api_key": api_key,
"embedding_nim_endpoint": embedding_nim_endpoint,
"embedding_model": embedding_model,
"endpoint_url": embedding_nim_endpoint,
"model_name": embedding_model,
}

embed_extractions_stage = pipe.add_stage(
Expand Down
90 changes: 90 additions & 0 deletions tests/nv_ingest_client/primitives/tasks/test_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import pytest
from nv_ingest_client.primitives.tasks.embed import EmbedTask


# Initialization and Property Setting


def test_embed_task_initialization():
task = EmbedTask(
model_name="nvidia/llama-3.2-nv-embedqa-1b-v2",
endpoint_url="http://embedding:8000/v1",
api_key="API_KEY",
)
assert task._model_name == "nvidia/llama-3.2-nv-embedqa-1b-v2"
assert task._endpoint_url == "http://embedding:8000/v1"
assert task._api_key == "API_KEY"


# String Representation Tests


def test_embed_task_str_representation():
task = EmbedTask(
model_name="nvidia/nv-embedqa-e5-v5",
endpoint_url="http://localhost:8024/v1",
api_key="API_KEY",
filter_errors=True,
)
expected_str = (
"Embed Task:\n"
" model_name: nvidia/nv-embedqa-e5-v5\n"
" endpoint_url: http://localhost:8024/v1\n"
" api_key: [redacted]\n"
" filter_errors: True\n"
)
assert str(task) == expected_str


# Dictionary Representation Tests


@pytest.mark.parametrize(
"model_name, endpoint_url, api_key, filter_errors",
[
("meta-llama/Llama-3.2-1B", "http://embedding:8012/v1", "TEST", False),
("nvidia/nv-embedqa-mistral-7b-v2", "http://localhost:8000/v1", "12345", True),
("nvidia/nv-embedqa-e5-v5", "http://embedding:8000/v1", "key", True),
(None, None, None, False), # Test default parameters
],
)
def test_embed_task_to_dict(
model_name,
endpoint_url,
api_key,
filter_errors,
):
task = EmbedTask(
model_name=model_name,
endpoint_url=endpoint_url,
api_key=api_key,
filter_errors=filter_errors,
)

expected_dict = {"type": "embed", "task_properties": {}}

# Only add properties to expected_dict if they are not None
if model_name is not None:
expected_dict["task_properties"]["model_name"] = model_name
if endpoint_url is not None:
expected_dict["task_properties"]["endpoint_url"] = endpoint_url
if api_key is not None:
expected_dict["task_properties"]["api_key"] = api_key
expected_dict["task_properties"]["filter_errors"] = filter_errors
assert task.to_dict() == expected_dict, "The to_dict method did not return the expected dictionary representation"


# Default Parameter Handling


def test_embed_task_default_params():
task = EmbedTask()
assert "Embed Task:" in str(task)
assert "filter_errors: False" in str(task)

task_dict = task.to_dict()
assert task_dict == {"type": "embed", "task_properties": {"filter_errors": False}}
Loading