From 804b3cd4935ecc4974a96ad26ddfb56420df73fb Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Thu, 17 Oct 2024 17:40:30 -0400 Subject: [PATCH] fix: reset docs on each interaction (#387) --- backend/retrieval_graph/graph.py | 2 +- backend/utils.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/backend/retrieval_graph/graph.py b/backend/retrieval_graph/graph.py index 7489d7e0..02b1d30c 100644 --- a/backend/retrieval_graph/graph.py +++ b/backend/retrieval_graph/graph.py @@ -145,7 +145,7 @@ class Plan(TypedDict): {"role": "system", "content": configuration.research_plan_system_prompt} ] + state.messages response = cast(Plan, await model.ainvoke(messages)) - return {"steps": response["steps"]} + return {"steps": response["steps"], "documents": "delete"} async def conduct_research(state: AgentState) -> dict[str, Any]: diff --git a/backend/utils.py b/backend/utils.py index 124a38d4..d4dbc5bf 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -6,7 +6,7 @@ """ import uuid -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union from langchain.chat_models import init_chat_model from langchain_core.documents import Document @@ -89,6 +89,7 @@ def reduce_docs( list[dict[str, Any]], list[str], str, + Literal["delete"], ], ) -> list[Document]: """Reduce and process documents based on the input type. @@ -101,6 +102,9 @@ def reduce_docs( new (Union[Sequence[Document], Sequence[dict[str, Any]], Sequence[str], str, Literal["delete"]]): The new input to process. Can be a sequence of Documents, dictionaries, strings, or a single string. """ + if new == "delete": + return [] + existing_list = list(existing) if existing else [] if isinstance(new, str): return existing_list + [