From 06fea7c778ccf36833c8408ce207ffac60a75f1f Mon Sep 17 00:00:00 2001 From: grvvrmtech Date: Thu, 6 Jun 2024 23:44:40 +0530 Subject: [PATCH 1/2] adding bedrock token usage callback handler --- .../callbacks/bedrock_callback.py | 119 ++++++++++++++++++ libs/aws/langchain_aws/callbacks/manager.py | 45 +++++++ libs/aws/langchain_aws/llms/bedrock.py | 8 +- 3 files changed, 168 insertions(+), 4 deletions(-) create mode 100644 libs/aws/langchain_aws/callbacks/bedrock_callback.py create mode 100644 libs/aws/langchain_aws/callbacks/manager.py diff --git a/libs/aws/langchain_aws/callbacks/bedrock_callback.py b/libs/aws/langchain_aws/callbacks/bedrock_callback.py new file mode 100644 index 00000000..daee8dfc --- /dev/null +++ b/libs/aws/langchain_aws/callbacks/bedrock_callback.py @@ -0,0 +1,119 @@ +import threading +from typing import Any, Dict, List, Union + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + +MODEL_COST_PER_1K_INPUT_TOKENS = { + "anthropic.claude-instant-v1": 0.00265, + "anthropic.claude-v2": 0.008, + "anthropic.claude-v2:1": 0.008, + "anthropic.claude-3-sonnet-20240229-v1:0": 0.003, + "anthropic.claude-3-haiku-20240307-v1:0": 0.00025, + "meta.llama3-70b-instruct-v1:0": 0.00265, + "meta.llama3-8b-instruct-v1:0" : 0.00040, + "meta.llama2-13b-chat-v1" : 0.00075, + "meta.llama2-70b-chat-v1" : 0.00195 +} + +MODEL_COST_PER_1K_OUTPUT_TOKENS = { + "anthropic.claude-instant-v1": 0.0035, + "anthropic.claude-v2": 0.024, + "anthropic.claude-v2:1": 0.024, + "anthropic.claude-3-sonnet-20240229-v1:0": 0.015, + "anthropic.claude-3-haiku-20240307-v1:0": 0.00125, + "meta.llama3-70b-instruct-v1:0": 0.0035, + "meta.llama3-8b-instruct-v1:0" : 0.0006, + "meta.llama2-13b-chat-v1" : 0.00100, + "meta.llama2-70b-chat-v1" : 0.00256 +} + + +def _get_token_cost( + prompt_tokens: int, completion_tokens: int, model_id: Union[str, None] +) -> float: + """Get the cost of tokens for the Claude model.""" + if model_id not in MODEL_COST_PER_1K_INPUT_TOKENS: + raise ValueError( + f"Unknown model: {model_id}. Please provide a valid Bedrock model name." + "Known models are: " + ", ".join(MODEL_COST_PER_1K_INPUT_TOKENS.keys()) + ) + return (prompt_tokens / 1000) * MODEL_COST_PER_1K_INPUT_TOKENS[model_id] + ( + completion_tokens / 1000 + ) * MODEL_COST_PER_1K_OUTPUT_TOKENS[model_id] + + +class BedrockTokenUsageCallbackHandler(BaseCallbackHandler): + """Callback Handler that tracks bedrock anthropic info.""" + + total_tokens: int = 0 + prompt_tokens: int = 0 + completion_tokens: int = 0 + successful_requests: int = 0 + total_cost: float = 0.0 + + def __init__(self) -> None: + super().__init__() + self._lock = threading.Lock() + + def __repr__(self) -> str: + return ( + f"Tokens Used: {self.total_tokens}\n" + f"\tPrompt Tokens: {self.prompt_tokens}\n" + f"\tCompletion Tokens: {self.completion_tokens}\n" + f"Successful Requests: {self.successful_requests}\n" + f"Total Cost (USD): ${self.total_cost}" + ) + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return True + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Print out the token.""" + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Collect token usage.""" + if response.llm_output is None: + return None + + if "usage" not in response.llm_output: + with self._lock: + self.successful_requests += 1 + return None + + # compute tokens and cost for this request + token_usage = response.llm_output["usage"] + completion_tokens = token_usage.get("completion_tokens", 0) + prompt_tokens = token_usage.get("prompt_tokens", 0) + total_tokens = token_usage.get("total_tokens", 0) + model_id = response.llm_output.get("model_id", None) + total_cost = _get_token_cost( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + model_id=model_id, + ) + + # update shared state behind lock + with self._lock: + self.total_cost += total_cost + self.total_tokens += total_tokens + self.prompt_tokens += prompt_tokens + self.completion_tokens += completion_tokens + self.successful_requests += 1 + + def __copy__(self) -> "BedrockTokenUsageCallbackHandler": + """Return a copy of the callback handler.""" + return self + + def __deepcopy__(self, memo: Any) -> "BedrockTokenUsageCallbackHandler": + """Return a deep copy of the callback handler.""" + return self diff --git a/libs/aws/langchain_aws/callbacks/manager.py b/libs/aws/langchain_aws/callbacks/manager.py new file mode 100644 index 00000000..962f6d82 --- /dev/null +++ b/libs/aws/langchain_aws/callbacks/manager.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import logging +from contextlib import contextmanager +from contextvars import ContextVar +from typing import ( + Generator, + Optional, +) + +from langchain_core.tracers.context import register_configure_hook + +from langchain_aws.callbacks.bedrock_callback import ( + BedrockTokenUsageCallbackHandler, +) + +logger = logging.getLogger(__name__) + + +bedrock_callback_var: (ContextVar)[ + Optional[BedrockTokenUsageCallbackHandler] +] = ContextVar("bedrock_anthropic_callback", default=None) + +register_configure_hook(bedrock_callback_var, True) + + +@contextmanager +def get_bedrock_callback() -> ( + Generator[BedrockTokenUsageCallbackHandler, None, None] +): + """Get the Bedrock callback handler in a context manager. + which conveniently exposes token and cost information. + + Returns: + BedrockTokenUsageCallbackHandler: + The Bedrock callback handler. + + Example: + >>> with get_bedrock_callback() as cb: + ... # Use the Bedrock callback handler + """ + cb = BedrockTokenUsageCallbackHandler() + bedrock_callback_var.set(cb) + yield cb + bedrock_callback_var.set(None) \ No newline at end of file diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index c20e3365..181ea37e 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -136,7 +136,7 @@ def _nest_usage_info_token_counts(usage_info: dict) -> dict: def _combine_generation_info_for_llm_result( - chunks_generation_info: List[Dict[str, Any]], provider_stop_code: str + chunks_generation_info: List[Dict[str, Any]], provider_stop_code: str, model_id: str ) -> Dict[str, Any]: """ Returns usage and stop reason information with the intent to pack into an LLMResult @@ -171,7 +171,7 @@ def _combine_generation_info_for_llm_result( total_usage_info["prompt_tokens"] + total_usage_info["completion_tokens"] ) - return {"usage": total_usage_info, "stop_reason": stop_reason} + return {"usage": total_usage_info, "stop_reason": stop_reason, "model_id" : model_id} class LLMInputOutputAdapter: @@ -939,7 +939,7 @@ def _call( if chunk.generation_info is not None ] llm_output = _combine_generation_info_for_llm_result( - chunks_generation_info, provider_stop_code=provider_stop_reason_code + chunks_generation_info, provider_stop_code=provider_stop_reason_code, model_id=self.model_id ) all_generations = [ Generation(text=chunk.text, generation_info=chunk.generation_info) @@ -1031,7 +1031,7 @@ async def _acall( if chunk.generation_info is not None ] llm_output = _combine_generation_info_for_llm_result( - chunks_generation_info, provider_stop_code=provider_stop_reason_code + chunks_generation_info, provider_stop_code=provider_stop_reason_code, model_id=self.model_id ) generations = [ Generation(text=chunk.text, generation_info=chunk.generation_info) From ed6d68d21fb50fdebfdc7b67e62b20e7109b74ba Mon Sep 17 00:00:00 2001 From: grvvrmtech Date: Fri, 7 Jun 2024 12:09:10 +0530 Subject: [PATCH 2/2] adding bug fix in non stream LLM call --- libs/aws/langchain_aws/llms/bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 181ea37e..0376c3c6 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -631,7 +631,7 @@ def _prepare_input_and_invoke( if stop is not None: text = enforce_stop_tokens(text, stop) - llm_output = {"usage": usage_info, "stop_reason": stop_reason} + llm_output = {"usage": usage_info, "stop_reason": stop_reason, "model_id": self.model_id} # Verify and raise a callback error if any intervention occurs or a signal is # sent from a Bedrock service,