From 240f3706acdd485e4ef6ea6c32a190aed3214792 Mon Sep 17 00:00:00 2001 From: Wojciech-Rebisz <147821486+Wojciech-Rebisz@users.noreply.github.com> Date: Tue, 17 Sep 2024 12:33:52 +0200 Subject: [PATCH] feat: Add _agenerate implementation (#12) * Add _agenerate implementation --------------- Co-authored-by: Mateusz Szewczyk --- libs/ibm/langchain_ibm/llms.py | 50 +++++++++++++++++-- libs/ibm/poetry.lock | 39 +++++++-------- libs/ibm/pyproject.toml | 4 +- libs/ibm/tests/integration_tests/test_llms.py | 20 ++++++++ 4 files changed, 87 insertions(+), 26 deletions(-) diff --git a/libs/ibm/langchain_ibm/llms.py b/libs/ibm/langchain_ibm/llms.py index 69d00c7..0cc5f5e 100644 --- a/libs/ibm/langchain_ibm/llms.py +++ b/libs/ibm/langchain_ibm/llms.py @@ -6,7 +6,10 @@ from ibm_watsonx_ai import APIClient, Credentials # type: ignore from ibm_watsonx_ai.foundation_models import Model, ModelInference # type: ignore from ibm_watsonx_ai.metanames import GenTextParamsMetaNames # type: ignore -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.utils.utils import secret_from_env @@ -349,9 +352,9 @@ def _stream_response_to_generation_chunk( return GenerationChunk( text=stream_response["results"][0]["generated_text"], generation_info=dict( - finish_reason=None - if finish_reason == "not_finished" - else finish_reason, + finish_reason=( + None if finish_reason == "not_finished" else finish_reason + ), llm_output={ "model_id": self.model_id, "deployment_id": self.deployment_id, @@ -383,6 +386,20 @@ def _call( ) return result.generations[0][0].text + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Async version of the _call method.""" + + result = await self._agenerate( + prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs + ) + return result.generations[0][0].text + def _generate( self, prompts: List[str], @@ -431,6 +448,31 @@ def _generate( ) return self._create_llm_result(response) + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> LLMResult: + """Async run the LLM on the given prompt and input.""" + params, kwargs = self._get_chat_params(stop=stop, **kwargs) + params = self._validate_chat_params(params) + if stream: + return await super()._agenerate( + prompts=prompts, stop=stop, run_manager=run_manager, **kwargs + ) + else: + responses = [ + await self.watsonx_model.agenerate( + prompt=prompt, params=params, **kwargs + ) + for prompt in prompts + ] + + return self._create_llm_result(responses) + def _stream( self, prompt: str, diff --git a/libs/ibm/poetry.lock b/libs/ibm/poetry.lock index b1264b4..ed95f7f 100644 --- a/libs/ibm/poetry.lock +++ b/libs/ibm/poetry.lock @@ -345,47 +345,47 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "ibm-cos-sdk" -version = "2.13.6" +version = "2.13.5" description = "IBM SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "ibm-cos-sdk-2.13.6.tar.gz", hash = "sha256:171cf2ae4ab662a4b8ab58dcf4ac994b0577d6c92d78490295fd7704a83978f6"}, + {file = "ibm-cos-sdk-2.13.5.tar.gz", hash = "sha256:1aff7f9863ac9072a3db2f0053bec99478b26f3fb5fa797ce96a15bbb13cd40e"}, ] [package.dependencies] -ibm-cos-sdk-core = "2.13.6" -ibm-cos-sdk-s3transfer = "2.13.6" +ibm-cos-sdk-core = "2.13.5" +ibm-cos-sdk-s3transfer = "2.13.5" jmespath = ">=0.10.0,<=1.0.1" [[package]] name = "ibm-cos-sdk-core" -version = "2.13.6" +version = "2.13.5" description = "Low-level, data-driven core of IBM SDK for Python" optional = false python-versions = ">=3.6" files = [ - {file = "ibm-cos-sdk-core-2.13.6.tar.gz", hash = "sha256:dd41fb789eeb65546501afabcd50e78846ab4513b6ad4042e410b6a14ff88413"}, + {file = "ibm-cos-sdk-core-2.13.5.tar.gz", hash = "sha256:d3a99d8b06b3f8c00b1a9501f85538d592463e63ddf8cec32672ab5a0b107b83"}, ] [package.dependencies] jmespath = ">=0.10.0,<=1.0.1" python-dateutil = ">=2.9.0,<3.0.0" -requests = ">=2.32.0,<2.32.3" -urllib3 = ">=1.26.18,<3" +requests = ">=2.32.3,<3.0" +urllib3 = {version = ">=1.26.18,<2.2", markers = "python_version >= \"3.10\""} [[package]] name = "ibm-cos-sdk-s3transfer" -version = "2.13.6" +version = "2.13.5" description = "IBM S3 Transfer Manager" optional = false python-versions = ">=3.8" files = [ - {file = "ibm-cos-sdk-s3transfer-2.13.6.tar.gz", hash = "sha256:e0acce6f380c47d11e07c6765b684b4ababbf5c66cc0503bc246469a1e2b9790"}, + {file = "ibm-cos-sdk-s3transfer-2.13.5.tar.gz", hash = "sha256:9649b1f2201c6de96ff5a6b5a3686de3a809e6ef3b8b12c7c4f2f7ce72da7749"}, ] [package.dependencies] -ibm-cos-sdk-core = "2.13.6" +ibm-cos-sdk-core = "2.13.5" [[package]] name = "ibm-watsonx-ai" @@ -526,7 +526,7 @@ typing-extensions = ">=4.7" type = "git" url = "https://github.com/langchain-ai/langchain.git" reference = "HEAD" -resolved_reference = "8a2f2fc30b960dcf4baab7cd12a99f5ef9c0df16" +resolved_reference = "d8952b8e8c1ee824f2bce27988d401bfcfd96779" subdirectory = "libs/core" [[package]] @@ -1129,13 +1129,13 @@ files = [ [[package]] name = "requests" -version = "2.32.2" +version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" files = [ - {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, - {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, ] [package.dependencies] @@ -1289,18 +1289,17 @@ files = [ [[package]] name = "urllib3" -version = "2.2.3" +version = "2.1.0" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, - {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, + {file = "urllib3-2.1.0-py3-none-any.whl", hash = "sha256:55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3"}, + {file = "urllib3-2.1.0.tar.gz", hash = "sha256:df7aa8afb0148fa78488e7899b2c59b5f4ffcfa82e6c54ccb9dd37c1d7b52d54"}, ] [package.extras] brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -1368,4 +1367,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "befb90970136d5cc65dfcc9a6d58a666f9cd711cdde50c6fa2f0be73ff2f8f00" +content-hash = "84b71aacace41099670cb831f589df0a2a9b27aef7dc6a88792873f5abe6265e" diff --git a/libs/ibm/pyproject.toml b/libs/ibm/pyproject.toml index 039ffbe..dd9bdc3 100644 --- a/libs/ibm/pyproject.toml +++ b/libs/ibm/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-ibm" -version = "0.2.0" +version = "0.2.1" description = "An integration package connecting IBM watsonx.ai and LangChain" authors = ["IBM"] readme = "README.md" @@ -13,7 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.10,<4.0" langchain-core = ">=0.3.0,<0.4" -ibm-watsonx-ai = "^1.0.8" +ibm-watsonx-ai = "^1.1.9" [tool.poetry.group.test] optional = true diff --git a/libs/ibm/tests/integration_tests/test_llms.py b/libs/ibm/tests/integration_tests/test_llms.py index 353f5b4..5a8cfb7 100644 --- a/libs/ibm/tests/integration_tests/test_llms.py +++ b/libs/ibm/tests/integration_tests/test_llms.py @@ -422,6 +422,16 @@ async def test_watsonx_ainvoke() -> None: assert isinstance(response, str) +async def test_watsonx_acall() -> None: + watsonxllm = WatsonxLLM( + model_id=MODEL_ID, + url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + response = await watsonxllm._acall("what is the color of the grass?") + assert "green" in response.lower() + + async def test_watsonx_agenerate() -> None: watsonxllm = WatsonxLLM( model_id=MODEL_ID, @@ -435,6 +445,16 @@ async def test_watsonx_agenerate() -> None: assert response.llm_output["token_usage"]["generated_token_count"] != 0 # type: ignore +async def test_watsonx_agenerate_with_stream() -> None: + watsonxllm = WatsonxLLM( + model_id=MODEL_ID, + url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + response = await watsonxllm.agenerate(["What color sunflower is?"], stream=True) + assert "yellow" in response.generations[0][0].text.lower() + + def test_get_num_tokens() -> None: watsonxllm = WatsonxLLM( model_id=MODEL_ID,