Skip to content

Commit

Permalink
Add string input support for revised Neptune chains (#329)
Browse files Browse the repository at this point in the history
Updated `create_neptune_opencypher_qa_chain` and
`create_neptune_sparql_qa_chain` to accept base string type queries on
invoke, in addition to the current dict format.

This restores consistency with the input format of the older
`langchain-community` Neptune chains.
  • Loading branch information
michaelnchin authored Jan 16, 2025
1 parent 38c28fa commit b3ae46a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
12 changes: 9 additions & 3 deletions libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import Any, Optional
from typing import Any, Optional, Union

from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
Expand Down Expand Up @@ -90,7 +90,7 @@ def create_neptune_opencypher_qa_chain(
return_direct: bool = False,
extra_instructions: Optional[str] = None,
allow_dangerous_requests: bool = False,
) -> Runnable[dict[str, Any], dict]:
) -> Runnable:
"""Chain for question-answering against a Neptune graph
by generating openCypher statements.
Expand Down Expand Up @@ -133,6 +133,11 @@ def create_neptune_opencypher_qa_chain(
_cypher_prompt = cypher_prompt or get_prompt(llm)
cypher_generation_chain = _cypher_prompt | llm

def normalize_input(raw_input: Union[str, dict]) -> dict:
if isinstance(raw_input, str):
return {"query": raw_input}
return raw_input

def execute_graph_query(cypher_query: str) -> dict:
return graph.query(cypher_query)

Expand Down Expand Up @@ -164,7 +169,8 @@ def format_response(inputs: dict) -> dict:
return final_response

chain_result = (
RunnablePassthrough.assign(cypher_generation_inputs=get_cypher_inputs)
normalize_input
| RunnablePassthrough.assign(cypher_generation_inputs=get_cypher_inputs)
| {
"query": lambda x: x["query"],
"cypher": (lambda x: x["cypher_generation_inputs"])
Expand Down
12 changes: 9 additions & 3 deletions libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import Any, Optional
from typing import Any, Optional, Union

from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
Expand Down Expand Up @@ -64,7 +64,7 @@ def create_neptune_sparql_qa_chain(
extra_instructions: Optional[str] = None,
allow_dangerous_requests: bool = False,
examples: Optional[str] = None,
) -> Runnable[dict[str, Any], dict]:
) -> Runnable[Any, dict]:
"""Chain for question-answering against a Neptune graph
by generating SPARQL statements.
Expand Down Expand Up @@ -106,6 +106,11 @@ def create_neptune_sparql_qa_chain(
_sparql_prompt = sparql_prompt or get_prompt(examples)
sparql_generation_chain = _sparql_prompt | llm

def normalize_input(raw_input: Union[str, dict]) -> dict:
if isinstance(raw_input, str):
return {"query": raw_input}
return raw_input

def execute_graph_query(sparql_query: str) -> dict:
return graph.query(sparql_query)

Expand Down Expand Up @@ -137,7 +142,8 @@ def format_response(inputs: dict) -> dict:
return final_response

chain_result = (
RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs)
normalize_input
| RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs)
| {
"query": lambda x: x["query"],
"sparql": (lambda x: x["sparql_generation_inputs"])
Expand Down

0 comments on commit b3ae46a

Please sign in to comment.