Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add string input support for revised Neptune chains #329

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import Any, Optional
from typing import Any, Optional, Union

from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
Expand Down Expand Up @@ -90,7 +90,7 @@ def create_neptune_opencypher_qa_chain(
return_direct: bool = False,
extra_instructions: Optional[str] = None,
allow_dangerous_requests: bool = False,
) -> Runnable[dict[str, Any], dict]:
) -> Runnable:
"""Chain for question-answering against a Neptune graph
by generating openCypher statements.

Expand Down Expand Up @@ -133,6 +133,11 @@ def create_neptune_opencypher_qa_chain(
_cypher_prompt = cypher_prompt or get_prompt(llm)
cypher_generation_chain = _cypher_prompt | llm

def normalize_input(raw_input: Union[str, dict]) -> dict:
if isinstance(raw_input, str):
return {"query": raw_input}
return raw_input

def execute_graph_query(cypher_query: str) -> dict:
return graph.query(cypher_query)

Expand Down Expand Up @@ -164,7 +169,8 @@ def format_response(inputs: dict) -> dict:
return final_response

chain_result = (
RunnablePassthrough.assign(cypher_generation_inputs=get_cypher_inputs)
normalize_input
| RunnablePassthrough.assign(cypher_generation_inputs=get_cypher_inputs)
| {
"query": lambda x: x["query"],
"cypher": (lambda x: x["cypher_generation_inputs"])
Expand Down
12 changes: 9 additions & 3 deletions libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import Any, Optional
from typing import Any, Optional, Union

from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
Expand Down Expand Up @@ -64,7 +64,7 @@ def create_neptune_sparql_qa_chain(
extra_instructions: Optional[str] = None,
allow_dangerous_requests: bool = False,
examples: Optional[str] = None,
) -> Runnable[dict[str, Any], dict]:
) -> Runnable[Any, dict]:
"""Chain for question-answering against a Neptune graph
by generating SPARQL statements.

Expand Down Expand Up @@ -106,6 +106,11 @@ def create_neptune_sparql_qa_chain(
_sparql_prompt = sparql_prompt or get_prompt(examples)
sparql_generation_chain = _sparql_prompt | llm

def normalize_input(raw_input: Union[str, dict]) -> dict:
if isinstance(raw_input, str):
return {"query": raw_input}
return raw_input

def execute_graph_query(sparql_query: str) -> dict:
return graph.query(sparql_query)

Expand Down Expand Up @@ -137,7 +142,8 @@ def format_response(inputs: dict) -> dict:
return final_response

chain_result = (
RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs)
normalize_input
| RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs)
| {
"query": lambda x: x["query"],
"sparql": (lambda x: x["sparql_generation_inputs"])
Expand Down
Loading