-
Notifications
You must be signed in to change notification settings - Fork 171
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This CL implements example agents and some essential components in v2…
… style. - Minimal agent - observation component and act components - Basic agent - question of recent memories - tests for agent factories PiperOrigin-RevId: 722673251 Change-Id: Ifda3dd4da3861bf0d77312b94997404acf535861
- Loading branch information
1 parent
221c866
commit 1b21ed8
Showing
11 changed files
with
1,580 additions
and
0 deletions.
There are no files selected for viewing
214 changes: 214 additions & 0 deletions
214
concordia/associative_memory/unstable/basic_associative_memory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# Copyright 2023 DeepMind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
"""An associative memory with basic retrieval methods.""" | ||
|
||
from collections.abc import Callable, Iterable, Sequence | ||
import threading | ||
|
||
from concordia.typing import entity_component | ||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
class AssociativeMemoryBank: | ||
"""Class that implements associative memory.""" | ||
|
||
def __init__( | ||
self, | ||
sentence_embedder: Callable[[str], np.ndarray], | ||
): | ||
"""Constructor. | ||
Args: | ||
sentence_embedder: text embedding model | ||
""" | ||
self._memory_bank_lock = threading.Lock() | ||
self._embedder = sentence_embedder | ||
|
||
self._memory_bank = pd.DataFrame(columns=['text', 'embedding']) | ||
self._stored_hashes = set() | ||
|
||
def get_state(self) -> entity_component.ComponentState: | ||
"""Converts the AssociativeMemory to a dictionary.""" | ||
|
||
with self._memory_bank_lock: | ||
output = { | ||
'stored_hashes': list(self._stored_hashes), | ||
'memory_bank': self._memory_bank.to_json(), | ||
} | ||
return output | ||
|
||
def set_state(self, state: entity_component.ComponentState) -> None: | ||
"""Sets the AssociativeMemory from a dictionary.""" | ||
|
||
with self._memory_bank_lock: | ||
self._stored_hashes = set(state['stored_hashes']) | ||
self._memory_bank = pd.read_json(state['memory_bank']) | ||
|
||
def add( | ||
self, | ||
text: str, | ||
) -> None: | ||
"""Adds nonduplicated entries (time, text, tags, importance) to the memory. | ||
Args: | ||
text: what goes into the memory | ||
""" | ||
|
||
# Remove all newline characters from memories. | ||
text = text.replace('\n', ' ') | ||
|
||
contents = { | ||
'text': text, | ||
} | ||
hashed_contents = hash(tuple(contents.values())) | ||
derived = {'embedding': self._embedder(text)} | ||
new_df = pd.Series(contents | derived).to_frame().T.infer_objects() | ||
|
||
with self._memory_bank_lock: | ||
if hashed_contents in self._stored_hashes: | ||
return | ||
self._memory_bank = pd.concat( | ||
[self._memory_bank, new_df], ignore_index=True | ||
) | ||
self._stored_hashes.add(hashed_contents) | ||
|
||
def extend( | ||
self, | ||
texts: Iterable[str], | ||
) -> None: | ||
"""Adds the texts to the memory. | ||
Args: | ||
texts: list of strings to add to the memory | ||
""" | ||
for text in texts: | ||
self.add(text) | ||
|
||
def get_data_frame(self) -> pd.DataFrame: | ||
with self._memory_bank_lock: | ||
return self._memory_bank.copy() | ||
|
||
def _get_top_k_cosine(self, x: np.ndarray, k: int): | ||
"""Returns the top k most cosine similar rows to an input vector x. | ||
Args: | ||
x: The input vector. | ||
k: The number of rows to return. | ||
Returns: | ||
Rows, sorted by cosine similarity in descending order. | ||
""" | ||
with self._memory_bank_lock: | ||
cosine_similarities = self._memory_bank['embedding'].apply( | ||
lambda y: np.dot(x, y) | ||
) | ||
|
||
# Sort the cosine similarities in descending order. | ||
cosine_similarities.sort_values(ascending=False, inplace=True) | ||
|
||
# Return the top k rows. | ||
return self._memory_bank.iloc[cosine_similarities.head(k).index] | ||
|
||
def _pd_to_text( | ||
self, | ||
data: pd.DataFrame, | ||
) -> Sequence[str]: | ||
"""Formats a dataframe into list of strings. | ||
Args: | ||
data: the dataframe to process | ||
Returns: | ||
A list of strings, one for each memory | ||
""" | ||
|
||
output = data['text'] | ||
|
||
return output.tolist() | ||
|
||
def retrieve_associative( | ||
self, | ||
query: str, | ||
k: int = 1, | ||
) -> Sequence[str]: | ||
"""Retrieve memories associatively. | ||
Args: | ||
query: a string to use for retrieval | ||
k: how many memories to retrieve | ||
Returns: | ||
List of strings corresponding to memories, sorted by cosine similarity | ||
""" | ||
if k <= 0: | ||
raise ValueError('Limit must be positive.') | ||
|
||
query_embedding = self._embedder(query) | ||
|
||
data = self._get_top_k_cosine(query_embedding, k) | ||
|
||
return self._pd_to_text(data) | ||
|
||
def scan(self, selector_fn: Callable[[str], bool]): | ||
"""Retrieve memories that match the selector function. | ||
Args: | ||
selector_fn: a function that takes a string and returns a boolean | ||
indicating whether the string matches the selector | ||
Returns: | ||
List of strings corresponding to memories, sorted by recency | ||
""" | ||
with self._memory_bank_lock: | ||
is_selected = self._memory_bank['text'].apply(selector_fn) | ||
data = self._memory_bank[is_selected] | ||
return self._pd_to_text(data) | ||
|
||
def retrieve_recent( | ||
self, | ||
k: int = 1, | ||
) -> Sequence[str]: | ||
"""Retrieve memories by recency. | ||
Args: | ||
k: number of entries to retrieve | ||
Returns: | ||
List of strings corresponding to memories, sorted by recency | ||
""" | ||
if k <= 0: | ||
raise ValueError('Limit must be positive.') | ||
|
||
with self._memory_bank_lock: | ||
return self._pd_to_text(self._memory_bank.iloc[-k:]) | ||
|
||
def __len__(self): | ||
"""Returns the number of entries in the memory bank. | ||
Since memories cannot be deleted, the length cannot decrease, and can be | ||
used to check if the contents of the memory bank have changed. | ||
""" | ||
with self._memory_bank_lock: | ||
return len(self._memory_bank) | ||
|
||
def get_all_memories_as_text( | ||
self, | ||
) -> Sequence[str]: | ||
"""Returns all memories in the memory bank as a sequence of strings.""" | ||
memories_data_frame = self.get_data_frame() | ||
texts = self._pd_to_text(memories_data_frame) | ||
return texts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright 2023 DeepMind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Library of components specifically for generative agents.""" | ||
|
||
from concordia.components.agent.unstable import all_similar_memories | ||
from concordia.components.agent.unstable import concat_act_component | ||
from concordia.components.agent.unstable import memory | ||
from concordia.components.agent.unstable import observation | ||
from concordia.components.agent.unstable import question_of_recent_memories |
145 changes: 145 additions & 0 deletions
145
concordia/components/agent/unstable/all_similar_memories.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Copyright 2023 DeepMind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Return all memories similar to a prompt and filter them for relevance.""" | ||
|
||
import types | ||
from typing import Mapping | ||
|
||
from concordia.components.agent import action_spec_ignored | ||
from concordia.components.agent.unstable import memory as memory_component | ||
from concordia.document import interactive_document | ||
from concordia.language_model import language_model | ||
from concordia.typing import entity as entity_lib | ||
from concordia.typing import entity_component | ||
from concordia.typing import logging | ||
|
||
|
||
class AllSimilarMemories(action_spec_ignored.ActionSpecIgnored): | ||
"""Get all memories similar to the state of the components and filter them.""" | ||
|
||
def __init__( | ||
self, | ||
model: language_model.LanguageModel, | ||
memory_component_name: str = ( | ||
memory_component.DEFAULT_MEMORY_COMPONENT_NAME | ||
), | ||
components: Mapping[ | ||
entity_component.ComponentName, str | ||
] = types.MappingProxyType({}), | ||
num_memories_to_retrieve: int = 25, | ||
pre_act_key: str = 'Relevant memories', | ||
logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel, | ||
): | ||
"""Initialize a component to report relevant memories (similar to a prompt). | ||
Args: | ||
model: The language model to use. | ||
memory_component_name: The name of the memory component from which to | ||
retrieve related memories. | ||
components: The components to condition the answer on. This is a mapping | ||
of the component name to a label to use in the prompt. | ||
num_memories_to_retrieve: The number of memories to retrieve. | ||
pre_act_key: Prefix to add to the output of the component when called | ||
in `pre_act`. | ||
logging_channel: The channel to log debug information to. | ||
""" | ||
super().__init__(pre_act_key) | ||
self._model = model | ||
self._memory_component_name = memory_component_name | ||
self._components = dict(components) | ||
self._num_memories_to_retrieve = num_memories_to_retrieve | ||
self._logging_channel = logging_channel | ||
|
||
def _make_pre_act_value(self) -> str: | ||
agent_name = self.get_entity().name | ||
prompt = interactive_document.InteractiveDocument(self._model) | ||
|
||
component_states = '\n'.join([ | ||
f"{agent_name}'s" | ||
f' {prefix}:\n{self.get_named_component_pre_act_value(key)}' | ||
for key, prefix in self._components.items() | ||
]) | ||
prompt.statement(f'Statements:\n{component_states}\n') | ||
prompt_summary = prompt.open_question( | ||
'Summarize the statements above.', max_tokens=750 | ||
) | ||
|
||
memory = self.get_entity().get_component( | ||
self._memory_component_name, type_=memory_component.Memory | ||
) | ||
|
||
query = f'{agent_name}, {prompt_summary}' | ||
mems = '\n'.join([ | ||
mem | ||
for mem in memory.retrieve_associative( | ||
query=query, limit=self._num_memories_to_retrieve | ||
) | ||
]) | ||
|
||
question = ( | ||
'Select the subset of the following set of statements that is most ' | ||
f'important for {agent_name} to consider right now. Whenever two ' | ||
'or more statements are not mutally consistent with each other ' | ||
'select whichever statement is more recent. Repeat all the ' | ||
'selected statements verbatim. Do not summarize. Include timestamps. ' | ||
'When in doubt, err on the side of including more, especially for ' | ||
'recent events. As long as they are not inconsistent, revent events ' | ||
'are usually important to consider.' | ||
) | ||
new_prompt = prompt.new() | ||
result = new_prompt.open_question( | ||
f'{question}\nStatements:\n{mems}', | ||
max_tokens=2000, | ||
terminators=('\n\n',), | ||
) | ||
|
||
self._logging_channel({ | ||
'Key': self.get_pre_act_key(), | ||
'Value': result, | ||
'Initial chain of thought': prompt.view().text().splitlines(), | ||
'Query': f'{query}', | ||
'Final chain of thought': new_prompt.view().text().splitlines(), | ||
}) | ||
|
||
return result | ||
|
||
|
||
class AllSimilarMemoriesWithoutPreAct( | ||
action_spec_ignored.ActionSpecIgnored | ||
): | ||
"""An AllSimilarMemories component that does not output its state to pre_act. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
self._component = AllSimilarMemories(*args, **kwargs) | ||
|
||
def set_entity(self, entity: entity_component.EntityWithComponents) -> None: | ||
self._component.set_entity(entity) | ||
|
||
def _make_pre_act_value(self) -> str: | ||
return '' | ||
|
||
def get_pre_act_value(self) -> str: | ||
return self._component.get_pre_act_value() | ||
|
||
def pre_act( | ||
self, | ||
unused_action_spec: entity_lib.ActionSpec, | ||
) -> str: | ||
del unused_action_spec | ||
return '' | ||
|
||
def update(self) -> None: | ||
self._component.update() |
Oops, something went wrong.