Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial commit of Amazon Q Business runnable for langchain #301

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions libs/aws/langchain_aws/runnables/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from langchain_aws.runnables.q_business import AmazonQ

__all__ = ["AmazonQ"]
167 changes: 167 additions & 0 deletions libs/aws/langchain_aws/runnables/q_business.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from typing import Any, Dict, List, Optional
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun
)
import logging
from langchain_core.language_models import LLM
from pydantic import ConfigDict, model_validator
import json
import asyncio
import boto3
from typing_extensions import Self

class AmazonQ(LLM):
"""Amazon Q LLM wrapper.

To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

Make sure the credentials / roles used have the required policies to
access the Amazon Q service.
"""

region_name: Optional[str] = None
"""AWS region name. If not provided, will be extracted from environment."""

credentials: Optional[Any] = None
"""Amazon Q credentials used to instantiate the client if the client is not provided."""

client: Any = None
"""Amazon Q client."""

application_id: str = None
"""Amazon Q client."""

_last_response: Dict = None # Add this to store the full response
"""Store the full response from Amazon Q."""

parent_message_id: Optional[str] = None
"""AWS region name. If not provided, will be extracted from environment."""

conversation_id: Optional[str] = None
"""AWS region name. If not provided, will be extracted from environment."""

chat_mode: str = "RETRIEVAL_MODE"
"""AWS region name. If not provided, will be extracted from environment."""

model_config = ConfigDict(
extra="forbid",
)

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "amazon_q"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this method for non-LLM class

def __init__(self, **kwargs: Any) -> None:
"""Initialize the Amazon Q client."""
super().__init__(**kwargs)

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Call out to Amazon Q service.

Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.

Returns:
The string generated by the model.

Example:
.. code-block:: python

response = llm.invoke("Tell me a joke.")
"""
try:
# Prepare the request
request = {
'applicationId': self.application_id,
'userMessage': prompt,
'chatMode':self.chat_mode,
}
if self.conversation_id:
request.update({
'conversationId': self.conversation_id,
'parentMessageId': self.parent_message_id,
})

# Call Amazon Q
response = self.client.chat_sync(**request)
self._last_response = response

# Extract the response text
if 'systemMessage' in response:
return response["systemMessage"]
else:
raise ValueError("Unexpected response format from Amazon Q")

except Exception as e:
if "Prompt Length" in str(e):
logging.info(f"Prompt Length: {len(prompt)}")
print(f"""Prompt:
{prompt}""")
raise ValueError(f"Error raised by Amazon Q service: {e}")

def get_last_response(self) -> Dict:
"""Method to access the full response from the last call"""
return self._last_response

async def _acall(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Async call to Amazon Q service."""

def _execute_call():
return self._call(prompt, stop=stop, **kwargs)

# Run the synchronous call in a thread pool
return await asyncio.get_running_loop().run_in_executor(
None, _execute_call
)

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Dont do anything if client provided externally"""
if self.client is not None:
return self
michaelnchin marked this conversation as resolved.
Show resolved Hide resolved
#If the client is not provided, and the user_id is not provided in the class constructor, throw an error saying one or the other needs to be provided
if self.credentials is None:
raise ValueError(
"Either the credentials or the client needs to be provided."
)

"""Validate that AWS credentials to and python package exists in environment."""
try:
import boto3

try:
if self.region_name is not None:
self.client = boto3.client('qbusiness', self.region_name, **self.credentials)
michaelnchin marked this conversation as resolved.
Show resolved Hide resolved
else:
# use default region
self.client = boto3.client('qbusiness', **self.credentials)
michaelnchin marked this conversation as resolved.
Show resolved Hide resolved

except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e

except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
return self