Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added WatsonxRerank integration, update logic of passing params #33

Merged
merged 17 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libs/ibm/langchain_ibm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from langchain_ibm.chat_models import ChatWatsonx
from langchain_ibm.embeddings import WatsonxEmbeddings
from langchain_ibm.llms import WatsonxLLM
from langchain_ibm.rerank import WatsonxRerank

__all__ = ["WatsonxLLM", "WatsonxEmbeddings", "ChatWatsonx"]
__all__ = ["WatsonxLLM", "WatsonxEmbeddings", "ChatWatsonx", "WatsonxRerank"]
25 changes: 4 additions & 21 deletions libs/ibm/langchain_ibm/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models import ModelInference # type: ignore
from ibm_watsonx_ai.foundation_models.schema import ( # type: ignore
BaseSchema,
TextChatParameters,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
Expand Down Expand Up @@ -74,7 +73,7 @@
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_ibm.utils import check_for_attribute
from langchain_ibm.utils import check_for_attribute, extract_params

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -668,32 +667,16 @@ def _stream(
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]], **kwargs: Any
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = (
{
**(
self.params.to_dict()
if isinstance(self.params, BaseSchema)
else self.params
)
}
if self.params
else {}
)
params = params | {
**(
kwargs.get("params", {}).to_dict()
if isinstance(kwargs.get("params", {}), BaseSchema)
else kwargs.get("params", {})
)
}
params = extract_params(kwargs, self.params)

if stop is not None:
if params and "stop_sequences" in params:
raise ValueError(
"`stop_sequences` found in both the input and default params."
)
params = (params or {}) | {"stop_sequences": stop}
message_dicts = [_convert_message_to_dict(m, self.model_id) for m in messages]
return message_dicts, params
return message_dicts, params or {}

def _create_chat_result(
self, response: dict, generation_info: Optional[Dict] = None
Expand Down
19 changes: 10 additions & 9 deletions libs/ibm/langchain_ibm/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import logging
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models.embeddings import Embeddings # type: ignore
Expand All @@ -10,7 +8,7 @@
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_ibm.utils import check_for_attribute
from langchain_ibm.utils import check_for_attribute, extract_params

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,7 +61,7 @@ class WatsonxEmbeddings(BaseModel, LangChainEmbeddings):
version: Optional[SecretStr] = None
"""Version of the CPD instance."""

params: Optional[dict] = None
params: Optional[Dict] = None
"""Model parameters to use during request generation."""

verify: Union[str, bool, None] = None
Expand Down Expand Up @@ -151,10 +149,13 @@ def validate_environment(self) -> Self:

return self

def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: List[str], **kwargs: Any) -> List[List[float]]:
"""Embed search docs."""
return self.watsonx_embed.embed_documents(texts=texts)
params = extract_params(kwargs, self.params)
return self.watsonx_embed.embed_documents(
texts=texts, **(kwargs | {"params": params})
)

def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str, **kwargs: Any) -> List[float]:
"""Embed query text."""
return self.embed_documents([text])[0]
return self.embed_documents([text], **kwargs)[0]
11 changes: 4 additions & 7 deletions libs/ibm/langchain_ibm/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_ibm.utils import check_for_attribute
from langchain_ibm.utils import check_for_attribute, extract_params

logger = logging.getLogger(__name__)
textgen_valid_params = [
Expand Down Expand Up @@ -305,12 +305,9 @@ def _override_chat_params(
def _get_chat_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
params = (
{**self.params, **kwargs.pop("params", {})}
if self.params
else kwargs.pop("params", {})
)
params, kwargs = self._override_chat_params(params, **kwargs)
params = extract_params(kwargs, self.params)

params, kwargs = self._override_chat_params(params or {}, **kwargs)
if stop is not None:
if params and "stop_sequences" in params:
raise ValueError(
Expand Down
232 changes: 232 additions & 0 deletions libs/ibm/langchain_ibm/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Union

from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models import Rerank # type: ignore
from ibm_watsonx_ai.foundation_models.schema import ( # type: ignore
RerankParameters,
)
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.utils.utils import secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_ibm.utils import check_for_attribute, extract_params


class WatsonxRerank(BaseDocumentCompressor):
"""Document compressor that uses `watsonx Rerank API`."""

model_id: str
"""Type of model to use."""

project_id: Optional[str] = None
"""ID of the Watson Studio project."""

space_id: Optional[str] = None
"""ID of the Watson Studio space."""

url: SecretStr = Field(
alias="url", default_factory=secret_from_env("WATSONX_URL", default=None)
)
"""URL to the Watson Machine Learning or CPD instance."""

apikey: Optional[SecretStr] = Field(
alias="apikey", default_factory=secret_from_env("WATSONX_APIKEY", default=None)
)
"""API key to the Watson Machine Learning or CPD instance."""

token: Optional[SecretStr] = Field(
alias="token", default_factory=secret_from_env("WATSONX_TOKEN", default=None)
)
"""Token to the CPD instance."""

password: Optional[SecretStr] = Field(
alias="password",
default_factory=secret_from_env("WATSONX_PASSWORD", default=None),
)
"""Password to the CPD instance."""

username: Optional[SecretStr] = Field(
alias="username",
default_factory=secret_from_env("WATSONX_USERNAME", default=None),
)
"""Username to the CPD instance."""

instance_id: Optional[SecretStr] = Field(
alias="instance_id",
default_factory=secret_from_env("WATSONX_INSTANCE_ID", default=None),
)
"""Instance_id of the CPD instance."""

version: Optional[SecretStr] = None
"""Version of the CPD instance."""

params: Optional[Union[dict, RerankParameters]] = None
"""Model parameters to use during request generation."""

verify: Union[str, bool, None] = None
"""You can pass one of following as verify:
* the path to a CA_BUNDLE file
* the path of directory with certificates of trusted CAs
* True - default path to truststore will be taken
* False - no verification will be made"""

validate_model: bool = True
"""Model ID validation."""

streaming: bool = False
""" Whether to stream the results or not. """

watsonx_rerank: Rerank = Field(default=None, exclude=True) #: :meta private:

watsonx_client: Optional[APIClient] = Field(default=None, exclude=True)

model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
protected_namespaces=(),
)

@property
def lc_secrets(self) -> Dict[str, str]:
"""A map of constructor argument names to secret ids.

For example:
{
"url": "WATSONX_URL",
"apikey": "WATSONX_APIKEY",
"token": "WATSONX_TOKEN",
"password": "WATSONX_PASSWORD",
"username": "WATSONX_USERNAME",
"instance_id": "WATSONX_INSTANCE_ID",
}
"""
return {
"url": "WATSONX_URL",
"apikey": "WATSONX_APIKEY",
"token": "WATSONX_TOKEN",
"password": "WATSONX_PASSWORD",
"username": "WATSONX_USERNAME",
"instance_id": "WATSONX_INSTANCE_ID",
}

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that credentials and python package exists in environment."""
if isinstance(self.watsonx_client, APIClient):
watsonx_rerank = Rerank(
model_id=self.model_id,
params=self.params,
api_client=self.watsonx_client,
project_id=self.project_id,
space_id=self.space_id,
verify=self.verify,
)
self.watsonx_rerank = watsonx_rerank

else:
check_for_attribute(self.url, "url", "WATSONX_URL")

if "cloud.ibm.com" in self.url.get_secret_value():
check_for_attribute(self.apikey, "apikey", "WATSONX_APIKEY")
else:
if not self.token and not self.password and not self.apikey:
raise ValueError(
"Did not find 'token', 'password' or 'apikey',"
" please add an environment variable"
" `WATSONX_TOKEN`, 'WATSONX_PASSWORD' or 'WATSONX_APIKEY' "
"which contains it,"
" or pass 'token', 'password' or 'apikey'"
" as a named parameter."
)
elif self.token:
check_for_attribute(self.token, "token", "WATSONX_TOKEN")
elif self.password:
check_for_attribute(self.password, "password", "WATSONX_PASSWORD")
check_for_attribute(self.username, "username", "WATSONX_USERNAME")
elif self.apikey:
check_for_attribute(self.apikey, "apikey", "WATSONX_APIKEY")
check_for_attribute(self.username, "username", "WATSONX_USERNAME")

if not self.instance_id:
check_for_attribute(
self.instance_id, "instance_id", "WATSONX_INSTANCE_ID"
)

credentials = Credentials(
url=self.url.get_secret_value() if self.url else None,
api_key=self.apikey.get_secret_value() if self.apikey else None,
token=self.token.get_secret_value() if self.token else None,
password=self.password.get_secret_value() if self.password else None,
username=self.username.get_secret_value() if self.username else None,
instance_id=self.instance_id.get_secret_value()
if self.instance_id
else None,
version=self.version.get_secret_value() if self.version else None,
verify=self.verify,
)

watsonx_rerank = Rerank(
model_id=self.model_id,
credentials=credentials,
params=self.params,
project_id=self.project_id,
space_id=self.space_id,
verify=self.verify,
)
self.watsonx_rerank = watsonx_rerank

return self

def rerank(
self,
documents: Sequence[Union[str, Document, dict]],
query: str,
**kwargs: Any,
) -> List[Dict[str, Any]]:
if len(documents) == 0: # to avoid empty api call
return []
docs = [
doc.page_content if isinstance(doc, Document) else doc for doc in documents
]
params = extract_params(kwargs, self.params)

results = self.watsonx_rerank.generate(
query=query, inputs=docs, **(kwargs | {"params": params})
)
result_dicts = []
for res in results["results"]:
result_dicts.append(
{"index": res.get("index"), "relevance_score": res.get("score")}
)
return result_dicts

def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
**kwargs: Any,
) -> Sequence[Document]:
"""
Compress documents using watsonx's rerank API.

Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.

Returns:
A sequence of compressed documents.
"""
compressed = []
for res in self.rerank(documents, query, **kwargs):
doc = documents[res["index"]]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res["relevance_score"]
compressed.append(doc_copy)
return compressed
20 changes: 20 additions & 0 deletions libs/ibm/langchain_ibm/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any, Dict, Optional, Union

from ibm_watsonx_ai.foundation_models.schema import BaseSchema # type: ignore
from pydantic import SecretStr


Expand All @@ -8,3 +11,20 @@ def check_for_attribute(value: SecretStr | None, key: str, env_key: str) -> None
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)


def extract_params(
kwargs: Dict[str, Any],
default_params: Optional[Union[BaseSchema, Dict[str, Any]]] = None,
) -> Dict[str, Any]:
if kwargs.get("params") is not None:
params = kwargs.pop("params")
elif default_params is not None:
params = default_params
else:
params = None

if isinstance(params, BaseSchema):
params = params.to_dict()

return params or {}
Loading
Loading