Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory chat #1726

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 88 additions & 18 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import json
import logging
import textwrap
import uuid
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -59,6 +62,7 @@
)
from camel.prompts import TextPrompt
from camel.responses import ChatAgentResponse
from camel.storages.key_value_storages.json import JsonStorage
from camel.toolkits import FunctionTool
from camel.types import (
ChatCompletion,
Expand Down Expand Up @@ -138,6 +142,8 @@ class ChatAgent(BaseAgent):
the next model in ModelManager. (default: :str:`round_robin`)
single_iteration (bool): Whether to let the agent perform only one
model calling at each step. (default: :obj:`False`)
agent_id (str, optional): The ID of the agent. If not provided, a
random UUID will be generated. (default: :obj:`None`)
"""

def __init__(
Expand All @@ -157,6 +163,7 @@ def __init__(
response_terminators: Optional[List[ResponseTerminator]] = None,
scheduling_strategy: str = "round_robin",
single_iteration: bool = False,
agent_id: Optional[str] = None,
) -> None:
# Set up model backend
self.model_backend = ModelManager(
Expand All @@ -171,16 +178,25 @@ def __init__(
scheduling_strategy=scheduling_strategy,
)
self.model_type = self.model_backend.model_type
# Assign unique ID
self.agent_id = agent_id if agent_id else str(uuid.uuid4())

# Set up memory
context_creator = ScoreBasedContextCreator(
self.model_backend.token_counter,
token_limit or self.model_backend.token_limit,
)

self.memory: AgentMemory = memory or ChatHistoryMemory(
context_creator, window_size=message_window_size
context_creator,
window_size=message_window_size,
agent_id=self.agent_id,
)

# So we don't have to pass agent_id when we define memory
if memory is not None:
memory.agent_id = self.agent_id

# Set up system message and initialize messages
self._original_system_message = (
BaseMessage.make_assistant_message(
Expand Down Expand Up @@ -321,9 +337,77 @@ def update_memory(
role (OpenAIBackendRole): The backend role type.
"""
self.memory.write_record(
MemoryRecord(message=message, role_at_backend=role)
MemoryRecord(
message=message,
role_at_backend=role,
timestamp=datetime.now().timestamp(),
agent_id=self.agent_id,
)
)

def load_memory(self, memory: AgentMemory) -> None:
r"""Load the provided memory into the agent.

Args:
memory (AgentMemory): The memory to load into the agent.

Returns:
None
"""

for context_record in memory.retrieve():
self.memory.write_record(context_record.memory_record)
print(f"Memory loaded from {memory}")

def load_memory_from_path(self, path: str) -> None:
r"""
Loads memory records from a JSON file filtered by this agent's ID.

Args:
path (str): The file path to a JSON memory file that uses
JsonStorage.

Raises:
ValueError: If no matching records for the agent_id are found
(optional check; commented out below).
"""
json_store = JsonStorage(Path(path))
all_records = json_store.load()

if not all_records:
raise ValueError(
f"No records found for agent_id={self.agent_id} in {path}"
)

for record_dict in all_records:
record = MemoryRecord.from_dict(record_dict)
self.memory.write_records([record])
print(f"Memory loaded from {path}")

def save_memory(self, path: str) -> None:
r"""
Retrieves the current conversation data from memory and writes it
into a JSON file using JsonStorage.

Args:
path (str): Target file path to store JSON data.
"""
json_store = JsonStorage(Path(path))
context_records = self.memory.retrieve()
to_save = [cr.memory_record.to_dict() for cr in context_records]
json_store.save(to_save)
print(f"Memory saved to {path}")

def clear_memory(self) -> None:
r"""Clear the agent's memory and reset to initial state.

Returns:
None
"""
self.memory.clear()
if self.system_message is not None:
self.update_memory(self.system_message, OpenAIBackendRole.SYSTEM)

def _generate_system_message_for_output_language(
self,
) -> Optional[BaseMessage]:
Expand Down Expand Up @@ -694,18 +778,11 @@ def _get_model_response(
f"index: {self.model_backend.current_model_index}",
exc_info=exc,
)
error_info = str(exc)

if not response and self.model_backend.num_models > 1:
if not response:
raise ModelProcessingError(
"Unable to process messages: none of the provided models "
"run succesfully."
)
elif not response:
raise ModelProcessingError(
f"Unable to process messages: the only provided model "
f"did not run succesfully. Error: {error_info}"
)

logger.info(
f"Model {self.model_backend.model_type}, "
Expand Down Expand Up @@ -739,18 +816,11 @@ async def _aget_model_response(
f"index: {self.model_backend.current_model_index}",
exc_info=exc,
)
error_info = str(exc)

if not response and self.model_backend.num_models > 1:
if not response:
raise ModelProcessingError(
"Unable to process messages: none of the provided models "
"run succesfully."
)
elif not response:
raise ModelProcessingError(
f"Unable to process messages: the only provided model "
f"did not run succesfully. Error: {error_info}"
)

logger.info(
f"Model {self.model_backend.model_type}, "
Expand Down
52 changes: 51 additions & 1 deletion camel/memories/agent_memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ class ChatHistoryMemory(AgentMemory):
window_size (int, optional): The number of recent chat messages to
retrieve. If not provided, the entire chat history will be
retrieved. (default: :obj:`None`)
agent_id (str, optional): The ID of the agent associated with the chat
history.
"""

def __init__(
self,
context_creator: BaseContextCreator,
storage: Optional[BaseKeyValueStorage] = None,
window_size: Optional[int] = None,
agent_id: Optional[str] = None,
) -> None:
if window_size is not None and not isinstance(window_size, int):
raise TypeError("`window_size` must be an integer or None.")
Expand All @@ -48,6 +51,15 @@ def __init__(
self._context_creator = context_creator
self._window_size = window_size
self._chat_history_block = ChatHistoryBlock(storage=storage)
self._agent_id = agent_id

@property
def agent_id(self) -> Optional[str]:
return self._agent_id

@agent_id.setter
def agent_id(self, val: Optional[str]) -> None:
self._agent_id = val

def retrieve(self) -> List[ContextRecord]:
records = self._chat_history_block.retrieve(self._window_size)
Expand All @@ -63,6 +75,12 @@ def retrieve(self) -> List[ContextRecord]:
return records

def write_records(self, records: List[MemoryRecord]) -> None:
for record in records:
# assign the agent_id to the record
if record.agent_id == "":
# if the agent memory has an agent_id, use it
if self.agent_id is not None:
record.agent_id = self.agent_id
self._chat_history_block.write_records(records)

def get_context_creator(self) -> BaseContextCreator:
Expand All @@ -84,20 +102,32 @@ class VectorDBMemory(AgentMemory):
(default: :obj:`None`)
retrieve_limit (int, optional): The maximum number of messages
to be added into the context. (default: :obj:`3`)
agent_id (str, optional): The ID of the agent associated with
the messages stored in the vector database.
"""

def __init__(
self,
context_creator: BaseContextCreator,
storage: Optional[BaseVectorStorage] = None,
retrieve_limit: int = 3,
agent_id: Optional[str] = None,
) -> None:
self._context_creator = context_creator
self._retrieve_limit = retrieve_limit
self._vectordb_block = VectorDBBlock(storage=storage)
self._agent_id = agent_id

self._current_topic: str = ""

@property
def agent_id(self) -> Optional[str]:
return self._agent_id

@agent_id.setter
def agent_id(self, val: Optional[str]) -> None:
self._agent_id = val

def retrieve(self) -> List[ContextRecord]:
return self._vectordb_block.retrieve(
self._current_topic,
Expand All @@ -109,6 +139,13 @@ def write_records(self, records: List[MemoryRecord]) -> None:
for record in records:
if record.role_at_backend == OpenAIBackendRole.USER:
self._current_topic = record.message.content

# assign the agent_id to the record
if record.agent_id == "":
# if the agent memory has an agent_id, use it
if self.agent_id is not None:
record.agent_id = self.agent_id

self._vectordb_block.write_records(records)

def get_context_creator(self) -> BaseContextCreator:
Expand All @@ -133,6 +170,8 @@ class LongtermAgentMemory(AgentMemory):
(default: :obj:`None`)
retrieve_limit (int, optional): The maximum number of messages
to be added into the context. (default: :obj:`3`)
agent_id (str, optional): The ID of the agent associated with the chat
history and the messages stored in the vector database.
"""

def __init__(
Expand All @@ -141,12 +180,22 @@ def __init__(
chat_history_block: Optional[ChatHistoryBlock] = None,
vector_db_block: Optional[VectorDBBlock] = None,
retrieve_limit: int = 3,
agent_id: Optional[str] = None,
) -> None:
self.chat_history_block = chat_history_block or ChatHistoryBlock()
self.vector_db_block = vector_db_block or VectorDBBlock()
self.retrieve_limit = retrieve_limit
self._context_creator = context_creator
self._current_topic: str = ""
self._agent_id = agent_id

@property
def agent_id(self) -> Optional[str]:
return self._agent_id

@agent_id.setter
def agent_id(self, val: Optional[str]) -> None:
self._agent_id = val

def get_context_creator(self) -> BaseContextCreator:
r"""Returns the context creator used by the memory.
Expand All @@ -166,7 +215,8 @@ def retrieve(self) -> List[ContextRecord]:
"""
chat_history = self.chat_history_block.retrieve()
vector_db_retrieve = self.vector_db_block.retrieve(
self._current_topic, self.retrieve_limit
self._current_topic,
self.retrieve_limit,
)
return chat_history[:1] + vector_db_retrieve + chat_history[1:]

Expand Down
24 changes: 23 additions & 1 deletion camel/memories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

from abc import ABC, abstractmethod
from typing import List, Tuple
from typing import List, Optional, Tuple

from camel.memories.records import ContextRecord, MemoryRecord
from camel.messages import OpenAIMessage
Expand Down Expand Up @@ -112,6 +112,16 @@ class AgentMemory(MemoryBlock, ABC):
the memory records stored within the AgentMemory.
"""

@property
@abstractmethod
def agent_id(self) -> Optional[str]:
pass

@agent_id.setter
@abstractmethod
def agent_id(self, val: Optional[str]) -> None:
pass

@abstractmethod
def retrieve(self) -> List[ContextRecord]:
r"""Get a record list from the memory for creating model context.
Expand All @@ -138,3 +148,15 @@ def get_context(self) -> Tuple[List[OpenAIMessage], int]:
context in OpenAIMessage format and the total token count.
"""
return self.get_context_creator().create_context(self.retrieve())

def __repr__(self) -> str:
r"""Returns a string representation of the AgentMemory.

Returns:
str: A string in the format 'ClassName(agent_id=<id>)'
if agent_id exists, otherwise just 'ClassName()'.
"""
agent_id = getattr(self, '_agent_id', None)
if agent_id:
return f"{self.__class__.__name__}(agent_id='{agent_id}')"
return f"{self.__class__.__name__}()"
5 changes: 5 additions & 0 deletions camel/memories/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class MemoryRecord(BaseModel):
key-value pairs that provide more information. If not given, it
will be an empty `Dict`.
timestamp (float, optional): The timestamp when the record was created.
agent_id (str): The identifier of the agent associated with this
memory.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
Expand All @@ -50,6 +52,7 @@ class MemoryRecord(BaseModel):
timestamp: float = Field(
default_factory=lambda: datetime.now(timezone.utc).timestamp()
)
agent_id: str = Field(default="")

_MESSAGE_TYPES: ClassVar[dict] = {
"BaseMessage": BaseMessage,
Expand All @@ -73,6 +76,7 @@ def from_dict(cls, record_dict: Dict[str, Any]) -> "MemoryRecord":
role_at_backend=record_dict["role_at_backend"],
extra_info=record_dict["extra_info"],
timestamp=record_dict["timestamp"],
agent_id=record_dict["agent_id"],
)

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -88,6 +92,7 @@ def to_dict(self) -> Dict[str, Any]:
"role_at_backend": self.role_at_backend,
"extra_info": self.extra_info,
"timestamp": self.timestamp,
"agent_id": self.agent_id,
}

def to_openai_message(self) -> OpenAIMessage:
Expand Down
Loading