From 4f58ed35496c39916b3a9e085fca4d60efe1415f Mon Sep 17 00:00:00 2001 From: isaac hershenson Date: Wed, 22 Jan 2025 17:05:11 -0800 Subject: [PATCH 1/3] draft --- libs/core/langchain_core/prompts/chat.py | 10 +++++- .../tests/unit_tests/prompts/test_chat.py | 35 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 8133d3b3687e9..9be53f915dfbd 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -1461,7 +1461,15 @@ def _convert_to_message( _message = _create_template_from_message_type( "human", message, template_format=template_format ) - elif isinstance(message, tuple): + elif isinstance(message, (tuple, dict)): + if isinstance(message, dict): + if sorted(message.keys()) != ["content", "role"]: + msg = ( + "Expected dict to have keys 'role' and 'content'." + f" Got: {message}" + ) + raise ValueError(msg) + message = (message["role"], message["content"]) if len(message) != 2: msg = f"Expected 2-tuple of (role, template), got {message}" raise ValueError(msg) diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index cad31d03ef929..460f53ace3af5 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -831,6 +831,41 @@ def test_chat_prompt_message_placeholder_tuple() -> None: assert optional_prompt.format_messages() == [] +def test_chat_prompt_message_placeholder_dict() -> None: + prompt = ChatPromptTemplate([{"role": "placeholder", "content": "{convo}"}]) + assert prompt.format_messages(convo=[("user", "foo")]) == [ + HumanMessage(content="foo") + ] + + assert prompt.format_messages() == [] + + # Is optional = True + optional_prompt = ChatPromptTemplate( + [{"role": "placeholder", "content": ["{convo}", False]}] + ) + assert optional_prompt.format_messages(convo=[("user", "foo")]) == [ + HumanMessage(content="foo") + ] + with pytest.raises(KeyError): + assert optional_prompt.format_messages() == [] + + +def test_chat_prompt_message_dict() -> None: + prompt = ChatPromptTemplate( + [{"role": "system", "content": "foo"}, {"role": "user", "content": "bar"}] + ) + assert prompt.format_messages() == [ + SystemMessage(content="foo"), + HumanMessage(content="bar"), + ] + + with pytest.raises(ValueError): + ChatPromptTemplate([{"role": "system", "content": False}]) + + with pytest.raises(ValueError): + ChatPromptTemplate([{"role": "foo", "content": "foo"}]) + + async def test_messages_prompt_accepts_list() -> None: prompt = ChatPromptTemplate([MessagesPlaceholder("history")]) value = prompt.invoke([("user", "Hi there")]) # type: ignore From ac27b293e4e68c09b5144dcf71f26c97d569ee39 Mon Sep 17 00:00:00 2001 From: Isaac Francisco <78627776+isahers1@users.noreply.github.com> Date: Wed, 22 Jan 2025 17:15:27 -0800 Subject: [PATCH 2/3] Update libs/core/langchain_core/prompts/chat.py Co-authored-by: Erick Friis --- libs/core/langchain_core/prompts/chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 9be53f915dfbd..8e0dc1a8967b1 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -1463,7 +1463,7 @@ def _convert_to_message( ) elif isinstance(message, (tuple, dict)): if isinstance(message, dict): - if sorted(message.keys()) != ["content", "role"]: + if set(message.keys()) != {"content", "role"}: msg = ( "Expected dict to have keys 'role' and 'content'." f" Got: {message}" From 7418d89416b513fa6b134cfc73373d6470ea832f Mon Sep 17 00:00:00 2001 From: Isaac Francisco <78627776+isahers1@users.noreply.github.com> Date: Wed, 22 Jan 2025 17:15:37 -0800 Subject: [PATCH 3/3] Update libs/core/langchain_core/prompts/chat.py Co-authored-by: Erick Friis --- libs/core/langchain_core/prompts/chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 8e0dc1a8967b1..84b80f565a160 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -1465,7 +1465,7 @@ def _convert_to_message( if isinstance(message, dict): if set(message.keys()) != {"content", "role"}: msg = ( - "Expected dict to have keys 'role' and 'content'." + "Expected dict to have exact keys 'role' and 'content'." f" Got: {message}" ) raise ValueError(msg)