From b3ae46a2051347c05aa4f7af65d306ccd838af8a Mon Sep 17 00:00:00 2001 From: Michael Chin Date: Thu, 16 Jan 2025 07:29:39 -0800 Subject: [PATCH] Add string input support for revised Neptune chains (#329) Updated `create_neptune_opencypher_qa_chain` and `create_neptune_sparql_qa_chain` to accept base string type queries on invoke, in addition to the current dict format. This restores consistency with the input format of the older `langchain-community` Neptune chains. --- .../langchain_aws/chains/graph_qa/neptune_cypher.py | 12 +++++++++--- .../langchain_aws/chains/graph_qa/neptune_sparql.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py b/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py index e191e174..8c6a35a3 100644 --- a/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py +++ b/libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import Any, Optional +from typing import Any, Optional, Union from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts.base import BasePromptTemplate @@ -90,7 +90,7 @@ def create_neptune_opencypher_qa_chain( return_direct: bool = False, extra_instructions: Optional[str] = None, allow_dangerous_requests: bool = False, -) -> Runnable[dict[str, Any], dict]: +) -> Runnable: """Chain for question-answering against a Neptune graph by generating openCypher statements. @@ -133,6 +133,11 @@ def create_neptune_opencypher_qa_chain( _cypher_prompt = cypher_prompt or get_prompt(llm) cypher_generation_chain = _cypher_prompt | llm + def normalize_input(raw_input: Union[str, dict]) -> dict: + if isinstance(raw_input, str): + return {"query": raw_input} + return raw_input + def execute_graph_query(cypher_query: str) -> dict: return graph.query(cypher_query) @@ -164,7 +169,8 @@ def format_response(inputs: dict) -> dict: return final_response chain_result = ( - RunnablePassthrough.assign(cypher_generation_inputs=get_cypher_inputs) + normalize_input + | RunnablePassthrough.assign(cypher_generation_inputs=get_cypher_inputs) | { "query": lambda x: x["query"], "cypher": (lambda x: x["cypher_generation_inputs"]) diff --git a/libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py b/libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py index 69b43d66..65bc0782 100644 --- a/libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py +++ b/libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any, Optional, Union from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts.base import BasePromptTemplate @@ -64,7 +64,7 @@ def create_neptune_sparql_qa_chain( extra_instructions: Optional[str] = None, allow_dangerous_requests: bool = False, examples: Optional[str] = None, -) -> Runnable[dict[str, Any], dict]: +) -> Runnable[Any, dict]: """Chain for question-answering against a Neptune graph by generating SPARQL statements. @@ -106,6 +106,11 @@ def create_neptune_sparql_qa_chain( _sparql_prompt = sparql_prompt or get_prompt(examples) sparql_generation_chain = _sparql_prompt | llm + def normalize_input(raw_input: Union[str, dict]) -> dict: + if isinstance(raw_input, str): + return {"query": raw_input} + return raw_input + def execute_graph_query(sparql_query: str) -> dict: return graph.query(sparql_query) @@ -137,7 +142,8 @@ def format_response(inputs: dict) -> dict: return final_response chain_result = ( - RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs) + normalize_input + | RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs) | { "query": lambda x: x["query"], "sparql": (lambda x: x["sparql_generation_inputs"])