-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial commit of Amazon Q Business runnable for langchain #301
base: main
Are you sure you want to change the base?
Changes from 3 commits
096b24c
c9785c1
40c8306
38665d0
e52d3ee
b7378ea
79c32b5
49e7b4c
99ea5d9
2f0834c
62a926f
802dfb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
from typing import Any, Dict, List, Optional | ||
from langchain_core.callbacks.manager import ( | ||
AsyncCallbackManagerForLLMRun, | ||
CallbackManagerForLLMRun | ||
) | ||
import logging | ||
from langchain_core.language_models import LLM | ||
from pydantic import ConfigDict, model_validator | ||
import json | ||
import asyncio | ||
import boto3 | ||
from typing_extensions import Self | ||
|
||
class AmazonQ(LLM): | ||
"""Amazon Q LLM wrapper. | ||
|
||
To authenticate, the AWS client uses the following methods to | ||
automatically load credentials: | ||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | ||
|
||
Make sure the credentials / roles used have the required policies to | ||
access the Amazon Q service. | ||
""" | ||
|
||
region_name: Optional[str] = None | ||
"""AWS region name. If not provided, will be extracted from environment.""" | ||
|
||
streaming: bool = False | ||
"""Whether to stream the results or not.""" | ||
|
||
client: Any = None | ||
"""Amazon Q client.""" | ||
|
||
application_id: str = None | ||
"""Amazon Q client.""" | ||
|
||
_last_response: Dict = None # Add this to store the full response | ||
"""Store the full response from Amazon Q.""" | ||
|
||
parent_message_id: Optional[str] = None | ||
"""AWS region name. If not provided, will be extracted from environment.""" | ||
|
||
conversation_id: Optional[str] = None | ||
"""AWS region name. If not provided, will be extracted from environment.""" | ||
|
||
chat_mode: str = "RETRIEVAL_MODE" | ||
"""AWS region name. If not provided, will be extracted from environment.""" | ||
|
||
credentials_profile_name: Optional[str] = None | ||
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which | ||
has either access keys or role information specified. | ||
If not specified, the default credential profile or, if on an EC2 instance, | ||
credentials from IMDS will be used. | ||
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | ||
""" | ||
model_config = ConfigDict( | ||
extra="forbid", | ||
) | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
"""Return type of llm.""" | ||
return "amazon_q" | ||
|
||
def __init__(self, **kwargs: Any) -> None: | ||
"""Initialize the Amazon Q client.""" | ||
super().__init__(**kwargs) | ||
|
||
def _call( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> str: | ||
"""Call out to Amazon Q service. | ||
|
||
Args: | ||
prompt: The prompt to pass into the model. | ||
stop: Optional list of stop words to use when generating. | ||
|
||
Returns: | ||
The string generated by the model. | ||
|
||
Example: | ||
.. code-block:: python | ||
|
||
response = llm.invoke("Tell me a joke.") | ||
""" | ||
try: | ||
print("Prompt Length (Amazon Q ChatSync API takes a maximum of 7000 chars)") | ||
print(len(prompt)) | ||
|
||
# Prepare the request | ||
request = { | ||
'applicationId': self.application_id, | ||
'userMessage': prompt, | ||
'chatMode':self.chat_mode, | ||
} | ||
if self.conversation_id: | ||
request.update({ | ||
'conversationId': self.conversation_id, | ||
'parentMessageId': self.parent_message_id, | ||
}) | ||
|
||
# Call Amazon Q | ||
response = self.client.chat_sync(**request) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before attempting to call chat_sync API, can we run some validation code to check if the QBusiness boto Refer here for an example: https://github.com/langchain-ai/langchain-aws/blob/main/libs/aws/langchain_aws/llms/sagemaker_endpoint.py#L271 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I implemented this as well. Let me know if that satisfies the requirements. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good, thanks! |
||
self._last_response = response | ||
|
||
# Extract the response text | ||
if 'systemMessage' in response: | ||
return response["systemMessage"] | ||
else: | ||
raise ValueError("Unexpected response format from Amazon Q") | ||
|
||
except Exception as e: | ||
if "Prompt Length" in str(e): | ||
logging.info(f"Prompt Length: {len(prompt)}") | ||
print(f"""Prompt: | ||
{prompt}""") | ||
raise ValueError(f"Error raised by Amazon Q service: {e}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this line, as it's redundant with L123. Or change the ValueError message here to specifically reference the prompt length constraint |
||
|
||
raise ValueError(f"Error raised by Amazon Q service: {e}") | ||
|
||
def get_last_response(self) -> Dict: | ||
"""Method to access the full response from the last call""" | ||
return self._last_response | ||
|
||
async def _acall( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> str: | ||
"""Async call to Amazon Q service.""" | ||
|
||
def _execute_call(): | ||
return self._call(prompt, stop=stop, **kwargs) | ||
|
||
# Run the synchronous call in a thread pool | ||
return await asyncio.get_running_loop().run_in_executor( | ||
None, _execute_call | ||
) | ||
|
||
@property | ||
def _identifying_params(self) -> Dict[str, Any]: | ||
"""Get the identifying parameters.""" | ||
return { | ||
"region_name": self.region_name, | ||
} | ||
@model_validator(mode="after") | ||
def validate_environment(self) -> Self: | ||
"""Dont do anything if client provided externally""" | ||
if self.client is not None: | ||
return self | ||
|
||
"""Validate that AWS credentials to and python package exists in environment.""" | ||
try: | ||
import boto3 | ||
|
||
try: | ||
if self.credentials_profile_name is not None: | ||
session = boto3.Session(profile_name=self.credentials_profile_name) | ||
else: | ||
# use default credentials | ||
session = boto3.Session() | ||
|
||
self.client = session.client( | ||
"qbusiness", region_name=self.region_name | ||
) | ||
|
||
except Exception as e: | ||
raise ValueError( | ||
"Could not load credentials to authenticate with AWS client. " | ||
"Please check that credentials in the specified " | ||
"profile name are valid." | ||
) from e | ||
|
||
except ImportError: | ||
raise ImportError( | ||
"Could not import boto3 python package. " | ||
"Please install it with `pip install boto3`." | ||
) | ||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These look like debug print statements. Should they be removed or replaced with logger statements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been debating that. Currently Amazon Q ChatSync API only supports a max of 7000 chars. Its helpful to debug in the case that you are passing along additional context with the prompt.
I think a good place for that log would be in the error handling block, which I added. Let me know what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - let's remove L83/84 in favor of the new changes.