Skip to content

Commit

Permalink
feat: Add _agenerate implementation (#12)
Browse files Browse the repository at this point in the history
* Add _agenerate implementation

---------------

Co-authored-by: Mateusz Szewczyk <[email protected]>
  • Loading branch information
Wojciech-Rebisz and MateuszOssGit authored Sep 17, 2024
1 parent fa4b8d3 commit 240f370
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 26 deletions.
50 changes: 46 additions & 4 deletions libs/ibm/langchain_ibm/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
from ibm_watsonx_ai.foundation_models import Model, ModelInference # type: ignore
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames # type: ignore
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.utils.utils import secret_from_env
Expand Down Expand Up @@ -349,9 +352,9 @@ def _stream_response_to_generation_chunk(
return GenerationChunk(
text=stream_response["results"][0]["generated_text"],
generation_info=dict(
finish_reason=None
if finish_reason == "not_finished"
else finish_reason,
finish_reason=(
None if finish_reason == "not_finished" else finish_reason
),
llm_output={
"model_id": self.model_id,
"deployment_id": self.deployment_id,
Expand Down Expand Up @@ -383,6 +386,20 @@ def _call(
)
return result.generations[0][0].text

async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Async version of the _call method."""

result = await self._agenerate(
prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs
)
return result.generations[0][0].text

def _generate(
self,
prompts: List[str],
Expand Down Expand Up @@ -431,6 +448,31 @@ def _generate(
)
return self._create_llm_result(response)

async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> LLMResult:
"""Async run the LLM on the given prompt and input."""
params, kwargs = self._get_chat_params(stop=stop, **kwargs)
params = self._validate_chat_params(params)
if stream:
return await super()._agenerate(
prompts=prompts, stop=stop, run_manager=run_manager, **kwargs
)
else:
responses = [
await self.watsonx_model.agenerate(
prompt=prompt, params=params, **kwargs
)
for prompt in prompts
]

return self._create_llm_result(responses)

def _stream(
self,
prompt: str,
Expand Down
39 changes: 19 additions & 20 deletions libs/ibm/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions libs/ibm/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-ibm"
version = "0.2.0"
version = "0.2.1"
description = "An integration package connecting IBM watsonx.ai and LangChain"
authors = ["IBM"]
readme = "README.md"
Expand All @@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
langchain-core = ">=0.3.0,<0.4"
ibm-watsonx-ai = "^1.0.8"
ibm-watsonx-ai = "^1.1.9"

[tool.poetry.group.test]
optional = true
Expand Down
20 changes: 20 additions & 0 deletions libs/ibm/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,16 @@ async def test_watsonx_ainvoke() -> None:
assert isinstance(response, str)


async def test_watsonx_acall() -> None:
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
)
response = await watsonxllm._acall("what is the color of the grass?")
assert "green" in response.lower()


async def test_watsonx_agenerate() -> None:
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
Expand All @@ -435,6 +445,16 @@ async def test_watsonx_agenerate() -> None:
assert response.llm_output["token_usage"]["generated_token_count"] != 0 # type: ignore


async def test_watsonx_agenerate_with_stream() -> None:
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
)
response = await watsonxllm.agenerate(["What color sunflower is?"], stream=True)
assert "yellow" in response.generations[0][0].text.lower()


def test_get_num_tokens() -> None:
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
Expand Down

0 comments on commit 240f370

Please sign in to comment.