-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial commit of Amazon Q Business runnable for langchain #301
base: main
Are you sure you want to change the base?
Changes from all commits
096b24c
c9785c1
40c8306
38665d0
e52d3ee
b7378ea
79c32b5
49e7b4c
99ea5d9
2f0834c
62a926f
802dfb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from langchain_aws.runnables.q_business import AmazonQ | ||
|
||
__all__ = ["AmazonQ"] |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,162 @@ | ||||||
import logging | ||||||
from typing import Any, Dict, Optional | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
from langchain_core._api.beta_decorator import beta | ||||||
from langchain_core.runnables import Runnable | ||||||
from langchain_core.runnables.config import RunnableConfig | ||||||
from pydantic import ConfigDict | ||||||
from typing_extensions import Self | ||||||
|
||||||
|
||||||
@beta(message="This API is in beta and can change in future.") | ||||||
class AmazonQ(Runnable[str, str]): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
"""Amazon Q Runnable wrapper. | ||||||
|
||||||
To authenticate, the AWS client uses the following methods to | ||||||
automatically load credentials: | ||||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | ||||||
|
||||||
Make sure the credentials / roles used have the required policies to | ||||||
access the Amazon Q service. | ||||||
""" | ||||||
|
||||||
region_name: Optional[str] = None | ||||||
"""AWS region name. If not provided, will be extracted from environment.""" | ||||||
|
||||||
credentials: Optional[Any] = None | ||||||
"""Amazon Q credentials used to instantiate the client if the client is not provided.""" | ||||||
|
||||||
client: Optional[Any] = None | ||||||
"""Amazon Q client.""" | ||||||
|
||||||
application_id: str = None | ||||||
|
||||||
_last_response: Dict = None # Add this to store the full response | ||||||
"""Store the full response from Amazon Q.""" | ||||||
|
||||||
parent_message_id: Optional[str] = None | ||||||
|
||||||
conversation_id: Optional[str] = None | ||||||
|
||||||
chat_mode: str = "RETRIEVAL_MODE" | ||||||
|
||||||
model_config = ConfigDict( | ||||||
extra="forbid", | ||||||
) | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
region_name: Optional[str] = None, | ||||||
credentials: Optional[Any] = None, | ||||||
client: Optional[Any] = None, | ||||||
application_id: str = None, | ||||||
parent_message_id: Optional[str] = None, | ||||||
conversation_id: Optional[str] = None, | ||||||
chat_mode: str = "RETRIEVAL_MODE", | ||||||
): | ||||||
self.region_name = region_name | ||||||
self.credentials = credentials | ||||||
self.client = client or self.validate_environment() | ||||||
self.application_id = application_id | ||||||
self.parent_message_id = parent_message_id | ||||||
self.conversation_id = conversation_id | ||||||
self.chat_mode = chat_mode | ||||||
|
||||||
def invoke( | ||||||
self, | ||||||
input: Any, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
config: Optional[RunnableConfig] = None, | ||||||
**kwargs: Any | ||||||
) -> str: | ||||||
"""Call out to Amazon Q service. | ||||||
|
||||||
Args: | ||||||
input: The prompt to pass into the model. | ||||||
|
||||||
Returns: | ||||||
The string generated by the model. | ||||||
|
||||||
Example: | ||||||
.. code-block:: python | ||||||
|
||||||
model = AmazonQ( | ||||||
credentials=your_credentials, | ||||||
application_id=your_app_id | ||||||
) | ||||||
response = model.invoke("Tell me a joke") | ||||||
""" | ||||||
try: | ||||||
# Prepare the request | ||||||
request = { | ||||||
'applicationId': self.application_id, | ||||||
'userMessage': self.convert_langchain_messages_to_q_input(input), # Langchain's input comes in the form of an array of "messages". We must convert to a single string for Amazon Q's use | ||||||
'chatMode': self.chat_mode, | ||||||
} | ||||||
if self.conversation_id: | ||||||
request.update({ | ||||||
'conversationId': self.conversation_id, | ||||||
'parentMessageId': self.parent_message_id, | ||||||
}) | ||||||
|
||||||
# Call Amazon Q | ||||||
response = self.client.chat_sync(**request) | ||||||
self._last_response = response | ||||||
|
||||||
# Extract the response text | ||||||
if 'systemMessage' in response: | ||||||
return response["systemMessage"] | ||||||
else: | ||||||
raise ValueError("Unexpected response format from Amazon Q") | ||||||
|
||||||
except Exception as e: | ||||||
if "Prompt Length" in str(e): | ||||||
logging.info(f"Prompt Length: {len(input)}") | ||||||
print(f"""Prompt: | ||||||
{input}""") | ||||||
raise ValueError(f"Error raised by Amazon Q service: {e}") | ||||||
|
||||||
def get_last_response(self) -> Dict: | ||||||
"""Method to access the full response from the last call""" | ||||||
return self._last_response | ||||||
|
||||||
def validate_environment(self) -> Self: | ||||||
"""Don't do anything if client provided externally""" | ||||||
#If the client is not provided, and the user_id is not provided in the class constructor, throw an error saying one or the other needs to be provided | ||||||
if self.credentials is None: | ||||||
raise ValueError( | ||||||
"Either the credentials or the client needs to be provided." | ||||||
) | ||||||
|
||||||
"""Validate that AWS credentials to and python package exists in environment.""" | ||||||
try: | ||||||
import boto3 | ||||||
|
||||||
try: | ||||||
if self.region_name is not None: | ||||||
client = boto3.client('qbusiness', self.region_name, **self.credentials) | ||||||
else: | ||||||
# use default region | ||||||
client = boto3.client('qbusiness', **self.credentials) | ||||||
|
||||||
except Exception as e: | ||||||
raise ValueError( | ||||||
"Could not load credentials to authenticate with AWS client. " | ||||||
"Please check that credentials in the specified " | ||||||
"profile name are valid." | ||||||
) from e | ||||||
|
||||||
except ImportError: | ||||||
raise ImportError( | ||||||
"Could not import boto3 python package. " | ||||||
"Please install it with `pip install boto3`." | ||||||
) | ||||||
return client | ||||||
def convert_langchain_messages_to_q_input(self, input: Any) -> str: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a conditional in this function to check for and immediately return string type inputs. |
||||||
# Messages must be of type human', 'user', 'ai', 'assistant', or 'system | ||||||
# Instead of logically formulating a message. We will allow langchain users to have their messages | ||||||
# Added line by line the way they ordered them in the chain. We will prefix the content with the type, | ||||||
# Hopefully this will inform Amazon Q how each message in the chain should be interpreted | ||||||
messagesToStringArray = [] | ||||||
for message in input.to_messages(): # Returns List[BaseMessage] | ||||||
messagesToStringArray.append(message.type + ": " + message.content) | ||||||
return "\n".join(messagesToStringArray) | ||||||
Comment on lines
+159
to
+162
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PromptValue has a But test it out and ensure that the value returned is still in the desired format. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.