Skip to content

Commit

Permalink
Upgrade ChromaDB to >=0.6.0 and fix broken tests (#530)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <[email protected]>
  • Loading branch information
smokestacklightnin and pmeier authored Jan 10, 2025
1 parent 891050e commit 37923c5
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Repository = "https://github.com/Quansight/ragna"
[project.optional-dependencies]
# to update the array below, run scripts/update_optional_dependencies.py
all = [
"chromadb<=0.5.11,>=0.4.13",
"chromadb>=0.6.0",
"httpx_sse",
"ijson",
"lancedb>=0.2",
Expand Down
28 changes: 20 additions & 8 deletions ragna/source_storages/_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,20 @@ class Chroma(VectorDatabaseSourceStorage):
!!! info "Required packages"
- `chromadb>=0.4.13`
- `chromadb>=0.6.0`
!!! warning
The `NE` and `NOT_IN` metadata filter operators behave differently in Chroma
than the other builtin source storages. With most other source storages,
given a key-value pair `(key, value)`, the operators `NE` and `NOT_IN` return
only the sources with a metadata key `key` and a value not equal to or
not in, respectively, `value`. To contrast, the `NE` and `NOT_IN` metadata filter
operators in `ChromaDB` return everything described in the preceding sentence,
together with all sources that do not have the metadata key `key`.
For more information, see the notes for `v0.5.12` in the
[`ChromaDB` migration guide](https://docs.trychroma.com/production/administration/migration).
"""

# Note that this class has no extra requirements, since the chromadb package is
Expand All @@ -39,7 +52,7 @@ def __init__(self) -> None:
)

def list_corpuses(self) -> list[str]:
return [collection.name for collection in self._client.list_collections()]
return [str(c) for c in self._client.list_collections()]

def _get_collection(
self, corpus_name: str, *, create: bool = False
Expand All @@ -49,15 +62,14 @@ def _get_collection(
corpus_name, embedding_function=self._embedding_function
)

collections = list(self._client.list_collections())
if not collections:
corpuses = self.list_corpuses()
if not corpuses:
raise_no_corpuses_available(self)

try:
return next(
collection
for collection in collections
if collection.name == corpus_name
return self._client.get_collection(
name=next(name for name in corpuses if name == corpus_name),
embedding_function=self._embedding_function,
)
except StopIteration:
raise_non_existing_corpus(self, corpus_name)
Expand Down
2 changes: 1 addition & 1 deletion ragna/source_storages/_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class LanceDB(VectorDatabaseSourceStorage):
!!! info "Required packages"
- `chromadb>=0.4.13`
- `chromadb>=0.6.0`
- `lancedb>=0.2`
- `pyarrow`
"""
Expand Down
2 changes: 1 addition & 1 deletion ragna/source_storages/_vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def requirements(cls) -> list[Requirement]:
# to manage and mostly not even used by the vector DB. Chroma provides a
# wrapper around a compiled embedding function that has only minimal
# requirements. We use this as base for all of our Vector DBs.
PackageRequirement("chromadb<=0.5.11,>=0.4.13"),
PackageRequirement("chromadb>=0.6.0"),
PackageRequirement("tiktoken"),
]

Expand Down
47 changes: 43 additions & 4 deletions tests/source_storages/test_source_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ragna.core import (
LocalDocument,
MetadataFilter,
MetadataOperator,
PlainTextDocumentHandler,
RagnaException,
)
Expand Down Expand Up @@ -69,7 +70,9 @@
MetadataFilter.and_(
[
MetadataFilter.eq("key", "other_value"),
MetadataFilter.ne("other_key", "other_value"),
MetadataFilter.in_(
"other_key", ["some_value", "other_value"]
),
]
),
]
Expand Down Expand Up @@ -104,7 +107,13 @@
@pytest.mark.parametrize(
"source_storage_cls", set(SOURCE_STORAGES) - {RagnaDemoSourceStorage}
)
def test_smoke(tmp_local_root, source_storage_cls, metadata_filter, expected_idcs):
def test_smoke(
tmp_local_root,
source_storage_cls,
metadata_filter,
expected_idcs,
chroma_override=False,
):
document_root = tmp_local_root / "documents"
document_root.mkdir()
documents = []
Expand Down Expand Up @@ -135,13 +144,43 @@ def test_smoke(tmp_local_root, source_storage_cls, metadata_filter, expected_idc
num_tokens=num_tokens,
)

actual_idcs = sorted(map(int, (source.document_name for source in sources)))
assert actual_idcs == expected_idcs
if (
not (
source_storage_cls is Chroma
and isinstance(metadata_filter, MetadataFilter)
and metadata_filter.operator
in {
MetadataOperator.NE,
MetadataOperator.NOT_IN,
}
)
or chroma_override
):
actual_idcs = sorted(map(int, (source.document_name for source in sources)))
assert actual_idcs == expected_idcs

# Should be able to call .store() multiple times
source_storage.store(corpus_name, documents)


@pytest.mark.parametrize(
("metadata_filter", "expected_idcs"),
[
pytest.param(MetadataFilter.ne("key", "value"), [2, 3, 4, 5, 6], id="ne"),
pytest.param(
MetadataFilter.not_in("key", ["foo", "bar"]), [0, 1, 2, 3, 4], id="not_in"
),
],
)
def test_chroma_ne_nin_non_existing_keys(
tmp_local_root, metadata_filter, expected_idcs
):
# See https://github.com/Quansight/ragna/issues/523 for details
test_smoke(
tmp_local_root, Chroma, metadata_filter, expected_idcs, chroma_override=True
)


@pytest.mark.parametrize("source_storage_cls", [Chroma, LanceDB])
def test_corpus_names(tmp_local_root, source_storage_cls):
document_root = tmp_local_root / "documents"
Expand Down

0 comments on commit 37923c5

Please sign in to comment.