Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rfc zep search mem #9135

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions libs/langchain/langchain/memory/zep_memory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.chat_message_histories import ZepChatMessageHistory
from langchain.memory.utils import get_prompt_input_key


class ZepMemory(ConversationBufferMemory):
class _ZepMemory(BaseChatMemory):
"""Persist your chain history to the Zep Memory Server.

The number of messages returned by Zep and when the Zep server summarizes chat
Expand Down Expand Up @@ -51,6 +53,7 @@ class ZepMemory(ConversationBufferMemory):
"""

chat_memory: ZepChatMessageHistory
memory_key: str = "history" #: :meta private:

def __init__(
self,
Expand Down Expand Up @@ -102,6 +105,14 @@ def __init__(
memory_key=memory_key,
)

@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.

:meta private:
"""
return [self.memory_key]

def save_context(
self,
inputs: Dict[str, Any],
Expand All @@ -122,3 +133,25 @@ def save_context(
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str, metadata=metadata)
self.chat_memory.add_ai_message(output_str, metadata=metadata)


class ZepSearchMemory(_ZepMemory):
top_k: int = 4

def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""Get the input key for the prompt."""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key

def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
results = self.chat_memory.search(query, limit=self.top_k)
result = "\n".join([r.message.pop("content") for r in results])
return {self.memory_key: result}


class ZepBufferMemory(_ZepMemory, ConversationBufferMemory):
""""""
2 changes: 1 addition & 1 deletion libs/langchain/langchain/retrievers/zep.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def create_client(cls, values: dict) -> dict:
try:
from zep_python import ZepClient
except ImportError:
raise ValueError(
raise ImportError(
"Could not import zep-python package. "
"Please install it with `pip install zep-python`."
)
Expand Down