Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Bedrock token count and IDs for Anthropic models #341

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import Any, Optional, Union
from typing import Optional, Union

from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
Expand Down
30 changes: 24 additions & 6 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import re
from collections import defaultdict
from operator import itemgetter
Expand Down Expand Up @@ -50,10 +51,13 @@
_combine_generation_info_for_llm_result,
)
from langchain_aws.utils import (
check_anthropic_tokens_dependencies,
get_num_tokens_anthropic,
get_token_ids_anthropic,
)

logger = logging.getLogger(__name__)


def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
Expand Down Expand Up @@ -619,15 +623,29 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:

def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)
else:
return super().get_num_tokens(text)
bad_deps = check_anthropic_tokens_dependencies()
if not bad_deps:
return get_num_tokens_anthropic(text)
else:
logger.debug(
"Falling back to default token counting due to incompatible/missing Anthropic dependencies:"
)
for x in bad_deps:
logger.debug(x)
Copy link
Collaborator

@3coins 3coins Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this cause duplicate logger messages, as super().get_num_tokens(text) ends up calling the get_token_ids method. Also, adding a logger message would just fall under the radar for most users, unless they have the debug logging switched on, should we rather use a warning?

import warnings

if self._model_is_anthropic and not self.custom_get_token_ids:
    warnings.warn(
        "Falling back to default token counting due to incompatible anthropic count_tokens API. "
        "For anthropic versions > 0.38.0, it is recommended to provide a custom_get_token_ids "
        "method to the chat model class that implements the appropriate tokenizer for Anthropic. "
        "Alternately, you can implement your own token counter method using the ChatAnthropic "
        "or AnthropicLLM classes."
    )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch - removed the warning from get_num_tokens, and updated the message as suggested.

return super().get_num_tokens(text)

def get_token_ids(self, text: str) -> List[int]:
if self._model_is_anthropic:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)
bad_deps = check_anthropic_tokens_dependencies()
if not bad_deps:
return get_token_ids_anthropic(text)
else:
logger.debug(
"Falling back to default token ids retrieval due to incompatible/missing Anthropic dependencies:"
)
for x in bad_deps:
logger.debug(x)
return super().get_token_ids(text)

def set_system_prompt_with_tools(self, xml_tools_system_prompt: str) -> None:
"""Workaround to bind. Sets the system prompt with tools"""
Expand Down
30 changes: 24 additions & 6 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@

from langchain_aws.function_calling import _tools_in_params
from langchain_aws.utils import (
check_anthropic_tokens_dependencies,
enforce_stop_tokens,
get_num_tokens_anthropic,
get_token_ids_anthropic,
)

logger = logging.getLogger(__name__)

AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace"
GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAction"
HUMAN_PROMPT = "\n\nHuman:"
Expand Down Expand Up @@ -1298,12 +1301,27 @@ async def _acall(

def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)
else:
return super().get_num_tokens(text)
bad_deps = check_anthropic_tokens_dependencies()
if not bad_deps:
return get_num_tokens_anthropic(text)
else:
logger.debug(
"Falling back to default token counting due to incompatible/missing Anthropic dependencies:"
)
for x in bad_deps:
logger.debug(x)

return super().get_num_tokens(text)

def get_token_ids(self, text: str) -> List[int]:
if self._model_is_anthropic:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)
bad_deps = check_anthropic_tokens_dependencies()
if not bad_deps:
return get_token_ids_anthropic(text)
else:
logger.debug(
"Falling back to default token ids retrieval due to incompatible/missing Anthropic dependencies:"
)
for x in bad_deps:
logger.debug(x)
return super().get_token_ids(text)
42 changes: 36 additions & 6 deletions libs/aws/langchain_aws/utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,51 @@
import re
import sys
from typing import Any, List

from packaging import version


def enforce_stop_tokens(text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]


def _get_anthropic_client() -> Any:
def check_anthropic_tokens_dependencies() -> List[str]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would anthropic_tokens_supported be a better name here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated this 👍

"""Check if we have all requirements for Anthropic count_tokens() and get_tokenizer()."""
bad_deps = []

python_version = sys.version_info
if python_version > (3, 12):
bad_deps.append(f"Python 3.12 or earlier required, found {'.'.join(map(str, python_version[:3]))})")

bad_anthropic = None
try:
import anthropic
anthropic_version = version.parse(anthropic.__version__)
if anthropic_version > version.parse("0.38.0"):
bad_anthropic = anthropic_version
except ImportError:
bad_anthropic = "none installed"

bad_httpx = None
try:
import httpx
httpx_version = version.parse(httpx.__version__)
if httpx_version > version.parse("0.27.2"):
bad_httpx = httpx_version
except ImportError:
raise ImportError(
"Could not import anthropic python package. "
"This is needed in order to accurately tokenize the text "
"for anthropic models. Please install it with `pip install anthropic`."
)
bad_httpx = "none installed"

if bad_anthropic:
bad_deps.append(f"anthropic<=0.38.0 required, found {bad_anthropic}.")
if bad_httpx:
bad_deps.append(f"httpx<=0.27.2 required, found {bad_httpx}.")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about these other checks here, would only checking the anthropic version not suffice?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad_deps.append(f"httpx<=0.27.2 required, found {bad_httpx}.")

The langsmith chain dependency of langchain-aws installs the latest httpx version.

Collecting httpx<1,>=0.23.0 (from langsmith<0.4,>=0.1.125->langchain-core<0.4.0,>=0.3.27->langchain-aws==0.2.11)
  Using cached httpx-0.28.1-py3-none-any.whl.metadata (7.1 kB)

However, this is incompatible with anthropic==0.38.0, which breaks on anything higher than httpx==0.27.2. Unfortunately, Anthropic 0.38.0 itself does not cap httpx to the version it requires, and subsequently may break after it picks up the version installed by LangSmith SDK:

Collecting httpx<1,>=0.23.0 (from anthropic==0.38.0)
  Using cached httpx-0.28.1-py3-none-any.whl.metadata (7.1 kB)

So, users must install httpx<=0.27.2 to get the Anthropic tokens methods working.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad_deps.append(f"Python 3.12 or earlier required, found {'.'.join(map(str, python_version[:3]))})")

anthropic==0.38.0 used to be incompatible with Python 3.13 due to issues with a chain dependency (pyo3-ffi): anthropics/anthropic-sdk-python#718

But pyo3-ffi was recently updated, and the 3.13 installation appears to be fixed now - will remove this check as it is no longer necessary.

return bad_deps


def _get_anthropic_client() -> Any:
import anthropic
return anthropic.Anthropic()


Expand Down