Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial commit of Amazon Q Business runnable for langchain #301

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
146 changes: 146 additions & 0 deletions libs/aws/langchain_aws/llms/q_business.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional
tomron-aws marked this conversation as resolved.
Show resolved Hide resolved
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun
)
from langchain_core.outputs import GenerationChunk
tomron-aws marked this conversation as resolved.
Show resolved Hide resolved
from langchain_core.language_models import LLM
from pydantic import ConfigDict
import json
import asyncio
import boto3

class AmazonQ(LLM):
"""Amazon Q LLM wrapper.

To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

Make sure the credentials / roles used have the required policies to
access the Amazon Q service.
"""

region_name: Optional[str] = None
"""AWS region name. If not provided, will be extracted from environment."""

streaming: bool = False
"""Whether to stream the results or not."""

client: Any = None
"""Amazon Q client."""

application_id: str = None
"""Amazon Q client."""

_last_response: Dict = None # Add this to store the full response
"""Store the full response from Amazon Q."""

parent_message_id: Optional[str] = None
"""AWS region name. If not provided, will be extracted from environment."""

conversation_id: Optional[str] = None
"""AWS region name. If not provided, will be extracted from environment."""

chat_mode: str = "RETRIEVAL_MODE"
"""AWS region name. If not provided, will be extracted from environment."""

model_config = ConfigDict(
extra="forbid",
)

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "amazon_q"

def __init__(self, **kwargs: Any) -> None:
"""Initialize the Amazon Q client."""
super().__init__(**kwargs)

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Amazon Q service.

Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.

Returns:
The string generated by the model.

Example:
.. code-block:: python

response = llm.invoke("Tell me a joke.")
"""
try:
print("Prompt Length (Amazon Q ChatSync API takes a maximum of 7000 chars)")
print(len(prompt))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look like debug print statements. Should they be removed or replaced with logger statements?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been debating that. Currently Amazon Q ChatSync API only supports a max of 7000 chars. Its helpful to debug in the case that you are passing along additional context with the prompt.

I think a good place for that log would be in the error handling block, which I added. Let me know what you think

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - let's remove L83/84 in favor of the new changes.


# Prepare the request
request = {
'applicationId': "130f4ea4-855f-4ddf-b2a5-1e40923692d4",
tomron-aws marked this conversation as resolved.
Show resolved Hide resolved
'userMessage': prompt,
'chatMode':self.chat_mode,
}
if not self.conversation_id:
request = {
'applicationId': self.application_id,
'userMessage': prompt,
'chatMode':self.chat_mode,
}
else:
request = {
'applicationId': self.application_id,
'userMessage': prompt,
'chatMode':self.chat_mode,
'conversationId':self.conversation_id,
'parentMessageId':self.parent_message_id,
}
tomron-aws marked this conversation as resolved.
Show resolved Hide resolved

# Call Amazon Q
response = self.client.chat_sync(**request)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before attempting to call chat_sync API, can we run some validation code to check if the QBusiness boto client has already been defined externally? If it doesn't exist, we should attempt to create it.

Refer here for an example: https://github.com/langchain-ai/langchain-aws/blob/main/libs/aws/langchain_aws/llms/sagemaker_endpoint.py#L271

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented this as well. Let me know if that satisfies the requirements.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

self._last_response = response

# Extract the response text
if 'systemMessage' in response:
return response["systemMessage"]
else:
raise ValueError("Unexpected response format from Amazon Q")

except Exception as e:
raise ValueError(f"Error raised by Amazon Q service: {e}")

def get_last_response(self) -> Dict:
"""Method to access the full response from the last call"""
return self._last_response

async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Async call to Amazon Q service."""

def _execute_call():
return self._call(prompt, stop=stop, **kwargs)

# Run the synchronous call in a thread pool
return await asyncio.get_running_loop().run_in_executor(
None, _execute_call
)

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"region_name": self.region_name,
}