Skip to content

Commit

Permalink
Added lightweight engine support (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
MateuszOssGit authored Jul 23, 2024
1 parent 0c81f32 commit 74a231e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 10 deletions.
16 changes: 15 additions & 1 deletion libs/ibm/langchain_ibm/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union

from ibm_watsonx_ai import Credentials # type: ignore
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models import Model, ModelInference # type: ignore
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames # type: ignore
from langchain_core.callbacks import CallbackManagerForLLMRun
Expand Down Expand Up @@ -97,6 +97,8 @@ class WatsonxLLM(BaseLLM):

watsonx_model: ModelInference = Field(default=None, exclude=True) #: :meta private:

watsonx_client: APIClient = Field(default=None) #: :meta private:

class Config:
"""Configuration for this pydantic object."""

Expand Down Expand Up @@ -145,6 +147,18 @@ def validate_environment(cls, values: Dict) -> Dict:
getattr(values["watsonx_model"], "_client"), "default_space_id"
)
values["params"] = getattr(values["watsonx_model"], "params")

elif isinstance(values.get("watsonx_client"), APIClient):
watsonx_model = ModelInference(
model_id=values["model_id"],
params=values["params"],
api_client=values["watsonx_client"],
project_id=values["project_id"],
space_id=values["space_id"],
verify=values["verify"],
)
values["watsonx_model"] = watsonx_model

else:
values["url"] = convert_to_secret_str(
get_from_dict_or_env(values, "url", "WATSONX_URL")
Expand Down
16 changes: 8 additions & 8 deletions libs/ibm/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 19 additions & 1 deletion libs/ibm/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os

from ibm_watsonx_ai import Credentials # type: ignore
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models import Model, ModelInference # type: ignore
from ibm_watsonx_ai.foundation_models.utils.enums import ( # type: ignore
DecodingMethods,
Expand Down Expand Up @@ -432,3 +432,21 @@ def test_get_num_tokens() -> None:
)
num_tokens = watsonxllm.get_num_tokens("What color sunflower is?")
assert num_tokens > 0


def test_init_with_client() -> None:
watsonx_client = APIClient(
credentials={
"url": "https://us-south.ml.cloud.ibm.com",
"apikey": WX_APIKEY,
}
)
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
watsonx_client=watsonx_client,
project_id=WX_PROJECT_ID,
)
response = watsonxllm.invoke("What color sunflower is?")
print(f"\nResponse: {response}")
assert isinstance(response, str)
assert len(response) > 0

0 comments on commit 74a231e

Please sign in to comment.