From dd9b6815128e7d6211169b392bbb1e8593485bcc Mon Sep 17 00:00:00 2001 From: KaiJye Date: Thu, 30 May 2024 15:37:23 +0800 Subject: [PATCH] Fixed cohere --- libs/aws/langchain_aws/llms/bedrock.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index c20e3365..a77bc20d 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -211,8 +211,10 @@ def prepare_input( input_body["prompt"] = _human_assistant_format(prompt) if "max_tokens_to_sample" not in input_body: input_body["max_tokens_to_sample"] = 1024 - elif provider in ("ai21", "cohere", "meta", "mistral"): + elif provider in ("ai21", "meta", "mistral"): input_body["prompt"] = prompt + elif provider == "cohere": + input_body["message"] = prompt elif provider == "amazon": input_body = dict() input_body["inputText"] = prompt @@ -238,7 +240,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict: if provider == "ai21": text = response_body.get("completions")[0].get("data").get("text") elif provider == "cohere": - text = response_body.get("generations")[0].get("text") + text = response_body.get("text") elif provider == "meta": text = response_body.get("generation") elif provider == "mistral":