Skip to content

Commit

Permalink
update logic of passing params in method
Browse files Browse the repository at this point in the history
  • Loading branch information
MateuszOssGit committed Oct 17, 2024
1 parent ee26a2c commit a84ddb5
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 33 deletions.
30 changes: 11 additions & 19 deletions libs/ibm/langchain_ibm/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,32 +668,24 @@ def _stream(
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]], **kwargs: Any
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = (
{
**(
self.params.to_dict()
if isinstance(self.params, BaseSchema)
else self.params
)
}
if self.params
else {}
)
params = params | {
**(
kwargs.get("params", {}).to_dict()
if isinstance(kwargs.get("params", {}), BaseSchema)
else kwargs.get("params", {})
)
}
if kwargs.get("params") is not None:
params = kwargs.get("params")
elif self.params is not None:
params = self.params
else:
params = None

if isinstance(params, BaseSchema):
params = params.to_dict()

if stop is not None:
if params and "stop_sequences" in params:
raise ValueError(
"`stop_sequences` found in both the input and default params."
)
params = (params or {}) | {"stop_sequences": stop}
message_dicts = [_convert_message_to_dict(m, self.model_id) for m in messages]
return message_dicts, params
return message_dicts, params or {}

def _create_chat_result(
self, response: dict, generation_info: Optional[Dict] = None
Expand Down
27 changes: 19 additions & 8 deletions libs/ibm/langchain_ibm/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import logging
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models.embeddings import Embeddings # type: ignore
Expand Down Expand Up @@ -63,7 +61,7 @@ class WatsonxEmbeddings(BaseModel, LangChainEmbeddings):
version: Optional[SecretStr] = None
"""Version of the CPD instance."""

params: Optional[dict] = None
params: Optional[Dict] = None
"""Model parameters to use during request generation."""

verify: Union[str, bool, None] = None
Expand Down Expand Up @@ -151,10 +149,23 @@ def validate_environment(self) -> Self:

return self

def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: List[str], **kwargs: Any) -> List[List[float]]:
"""Embed search docs."""
return self.watsonx_embed.embed_documents(texts=texts)
params = self._get_embeddings_params(**kwargs)
return self.watsonx_embed.embed_documents(
texts=texts, **(kwargs | {"params": params})
)

def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str, **kwargs: Any) -> List[float]:
"""Embed query text."""
return self.embed_documents([text])[0]
return self.embed_documents([text], **kwargs)[0]

def _get_embeddings_params(self, **kwargs: Any) -> Dict[str, Any]:
if kwargs.get("params") is not None:
params = kwargs.get("params")
elif self.params is not None:
params = self.params
else:
params = None

return params or {}
80 changes: 74 additions & 6 deletions libs/ibm/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,55 @@ def test_01_generate_embed_documents() -> None:
assert all(isinstance(el, float) for el in generate_embedding[0])


def test_02_generate_embed_query() -> None:
def test_02_generate_embed_documents_with_param() -> None:
embed_params = {
EmbedTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: 3,
}
watsonx_embedding = WatsonxEmbeddings(
model_id=MODEL_ID,
url=URL, # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
params=embed_params,
)
generate_embedding = watsonx_embedding.embed_documents(texts=DOCUMENTS)
assert len(generate_embedding) == len(DOCUMENTS)
assert all(isinstance(el, float) for el in generate_embedding[0])


def test_03_generate_embed_documents_with_param_in_method() -> None:
embed_params = {
EmbedTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: 3,
}
watsonx_embedding = WatsonxEmbeddings(
model_id=MODEL_ID,
url=URL, # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
)
generate_embedding = watsonx_embedding.embed_documents(
texts=DOCUMENTS, params=embed_params
)
assert len(generate_embedding) == len(DOCUMENTS)
assert all(isinstance(el, float) for el in generate_embedding[0])


def test_04_generate_embed_documents_with_param_and_concurrency_limit() -> None:
embed_params = {
EmbedTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: 3,
}
watsonx_embedding = WatsonxEmbeddings(
model_id=MODEL_ID,
url=URL, # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
params=embed_params,
)
generate_embedding = watsonx_embedding.embed_documents(
texts=DOCUMENTS, concurrency_limit=9
)
assert len(generate_embedding) == len(DOCUMENTS)
assert all(isinstance(el, float) for el in generate_embedding[0])


def test_10_generate_embed_query() -> None:
watsonx_embedding = WatsonxEmbeddings(
model_id=MODEL_ID,
url=URL, # type: ignore[arg-type]
Expand All @@ -42,7 +90,24 @@ def test_02_generate_embed_query() -> None:
)


def test_03_generate_embed_documents_with_param() -> None:
def test_11_generate_embed_query_with_params() -> None:
embed_params = {
EmbedTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: 3,
}
watsonx_embedding = WatsonxEmbeddings(
model_id=MODEL_ID,
url=URL, # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
)
generate_embedding = watsonx_embedding.embed_query(
text=DOCUMENTS[0], params=embed_params
)
assert isinstance(generate_embedding, list) and isinstance(
generate_embedding[0], float
)


def test_12_generate_embed_query_with_params_and_concurrency_limit() -> None:
embed_params = {
EmbedTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: 3,
}
Expand All @@ -52,12 +117,15 @@ def test_03_generate_embed_documents_with_param() -> None:
project_id=WX_PROJECT_ID,
params=embed_params,
)
generate_embedding = watsonx_embedding.embed_documents(texts=DOCUMENTS)
assert len(generate_embedding) == len(DOCUMENTS)
assert all(isinstance(el, float) for el in generate_embedding[0])
generate_embedding = watsonx_embedding.embed_query(
text=DOCUMENTS[0], concurrency_limit=9
)
assert isinstance(generate_embedding, list) and isinstance(
generate_embedding[0], float
)


def test_10_generate_embed_query_with_client_initialization() -> None:
def test_20_generate_embed_query_with_client_initialization() -> None:
watsonx_client = APIClient(
credentials={
"url": URL,
Expand Down

0 comments on commit a84ddb5

Please sign in to comment.