diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 9b0d7c8c..4c27e656 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -1,5 +1,6 @@ import logging import re +import warnings from collections import defaultdict from operator import itemgetter from typing import ( @@ -51,7 +52,7 @@ _combine_generation_info_for_llm_result, ) from langchain_aws.utils import ( - check_anthropic_tokens_dependencies, + anthropic_tokens_supported, get_num_tokens_anthropic, get_token_ids_anthropic, ) @@ -622,29 +623,25 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: return final_output def get_num_tokens(self, text: str) -> int: - if self._model_is_anthropic: - bad_deps = check_anthropic_tokens_dependencies() + if self._model_is_anthropic and not self.custom_get_token_ids: + bad_deps = anthropic_tokens_supported() if not bad_deps: return get_num_tokens_anthropic(text) - else: - logger.debug( - "Falling back to default token counting due to incompatible/missing Anthropic dependencies:" - ) - for x in bad_deps: - logger.debug(x) return super().get_num_tokens(text) def get_token_ids(self, text: str) -> List[int]: - if self._model_is_anthropic: - bad_deps = check_anthropic_tokens_dependencies() + if self._model_is_anthropic and not self.custom_get_token_ids: + bad_deps = anthropic_tokens_supported() if not bad_deps: return get_token_ids_anthropic(text) else: - logger.debug( - "Falling back to default token ids retrieval due to incompatible/missing Anthropic dependencies:" + warnings.warn( + f"Falling back to default token method due to conflicts with the Anthropic API: {bad_deps}" + f"\n\nFor Anthropic SDK versions > 0.38.0, it is recommended to provide the chat model class with a" + f" custom_get_token_ids method that implements a more accurate tokenizer for Anthropic. " + f"For get_num_tokens, as another alternative, you can implement your own token counter method " + f"using the ChatAnthropic or AnthropicLLM classes." ) - for x in bad_deps: - logger.debug(x) return super().get_token_ids(text) def set_system_prompt_with_tools(self, xml_tools_system_prompt: str) -> None: diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 05bf936b..83e619ab 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -33,7 +33,7 @@ from langchain_aws.function_calling import _tools_in_params from langchain_aws.utils import ( - check_anthropic_tokens_dependencies, + anthropic_tokens_supported, enforce_stop_tokens, get_num_tokens_anthropic, get_token_ids_anthropic, @@ -1300,28 +1300,23 @@ async def _acall( return "".join([chunk.text for chunk in chunks]) def get_num_tokens(self, text: str) -> int: - if self._model_is_anthropic: - bad_deps = check_anthropic_tokens_dependencies() + if self._model_is_anthropic and not self.custom_get_token_ids: + bad_deps = anthropic_tokens_supported() if not bad_deps: return get_num_tokens_anthropic(text) - else: - logger.debug( - "Falling back to default token counting due to incompatible/missing Anthropic dependencies:" - ) - for x in bad_deps: - logger.debug(x) - return super().get_num_tokens(text) def get_token_ids(self, text: str) -> List[int]: - if self._model_is_anthropic: - bad_deps = check_anthropic_tokens_dependencies() + if self._model_is_anthropic and not self.custom_get_token_ids: + bad_deps = anthropic_tokens_supported() if not bad_deps: return get_token_ids_anthropic(text) else: - logger.debug( - "Falling back to default token ids retrieval due to incompatible/missing Anthropic dependencies:" + warnings.warn( + f"Falling back to default token method due to incompatibilities with the Anthropic API: {bad_deps}" + f"For anthropic versions > 0.38.0, it is recommended to provide a custom_get_token_ids " + f"method to the chat model class that implements the appropriate tokenizer for Anthropic. " + f"Alternately, you can implement your own token counter method using the ChatAnthropic " + f"or AnthropicLLM classes." ) - for x in bad_deps: - logger.debug(x) return super().get_token_ids(text) diff --git a/libs/aws/langchain_aws/utils.py b/libs/aws/langchain_aws/utils.py index af2ff78e..2d9b50dd 100644 --- a/libs/aws/langchain_aws/utils.py +++ b/libs/aws/langchain_aws/utils.py @@ -1,5 +1,4 @@ import re -import sys from typing import Any, List from packaging import version @@ -10,14 +9,9 @@ def enforce_stop_tokens(text: str, stop: List[str]) -> str: return re.split("|".join(stop), text, maxsplit=1)[0] -def check_anthropic_tokens_dependencies() -> List[str]: +def anthropic_tokens_supported() -> List[str]: """Check if we have all requirements for Anthropic count_tokens() and get_tokenizer().""" bad_deps = [] - - python_version = sys.version_info - if python_version > (3, 12): - bad_deps.append(f"Python 3.12 or earlier required, found {'.'.join(map(str, python_version[:3]))})") - bad_anthropic = None try: import anthropic