fix(backends): address Copilot review on #995
Four defects surfaced by the automated review, fixed with targeted tests: 1. BaseCollection.update() default now validates that documents / metadatas / embeddings lengths match ids, raising ValueError instead of silently misaligning pairs or raising IndexError (base.py). 2. ChromaCollection.query() now rejects the two ambiguous input shapes up front — neither or both of query_texts / query_embeddings, and empty input lists — with clear ValueError messages rather than delegating to chromadb's less-obvious errors (chroma.py). 3. QueryResult.empty() accepts embeddings_requested=True to preserve the outer-query dimension with empty hit lists when the caller asked for embeddings, matching the spec rule that included fields carry the outer shape even when empty (base.py). ChromaCollection.query() threads this through on the empty-result path (chroma.py). 4. ChromaBackend cache-freshness check now matches the semantics from mcp_server._get_client (merged via #757) on three edge cases Copilot called out: (a) invalidate when chroma.sqlite3 disappears while a cached client is held, (b) treat a 0→nonzero stat transition as a change so a cache built when the DB did not yet exist is refreshed, (c) re-stat after PersistentClient constructs the DB lazily so freshness reflects the post-creation state (chroma.py). Tests: 978 passed (up from 970), 8 new tests covering the fixes.
This commit is contained in:
@@ -133,14 +133,21 @@ class QueryResult(_DictCompatMixin):
|
|||||||
embeddings: Optional[list[list[list[float]]]] = None
|
embeddings: Optional[list[list[list[float]]]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls, num_queries: int = 1) -> "QueryResult":
|
def empty(cls, num_queries: int = 1, embeddings_requested: bool = False) -> "QueryResult":
|
||||||
"""Construct an all-empty result preserving outer dimension."""
|
"""Construct an all-empty result preserving outer dimension.
|
||||||
|
|
||||||
|
When ``embeddings_requested`` is True, ``embeddings`` preserves the outer
|
||||||
|
query dimension with empty hit lists (matching the spec's rule that fields
|
||||||
|
requested via ``include=`` carry the outer shape even when empty). When
|
||||||
|
False, ``embeddings`` stays ``None`` to signal the field was not requested.
|
||||||
|
"""
|
||||||
|
empty_outer = [[] for _ in range(num_queries)]
|
||||||
return cls(
|
return cls(
|
||||||
ids=[[] for _ in range(num_queries)],
|
ids=[[] for _ in range(num_queries)],
|
||||||
documents=[[] for _ in range(num_queries)],
|
documents=[[] for _ in range(num_queries)],
|
||||||
metadatas=[[] for _ in range(num_queries)],
|
metadatas=[[] for _ in range(num_queries)],
|
||||||
distances=[[] for _ in range(num_queries)],
|
distances=[[] for _ in range(num_queries)],
|
||||||
embeddings=None,
|
embeddings=empty_outer if embeddings_requested else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -250,6 +257,15 @@ class BaseCollection(ABC):
|
|||||||
if documents is None and metadatas is None and embeddings is None:
|
if documents is None and metadatas is None and embeddings is None:
|
||||||
raise ValueError("update requires at least one of documents, metadatas, embeddings")
|
raise ValueError("update requires at least one of documents, metadatas, embeddings")
|
||||||
|
|
||||||
|
n = len(ids)
|
||||||
|
for label, value in (
|
||||||
|
("documents", documents),
|
||||||
|
("metadatas", metadatas),
|
||||||
|
("embeddings", embeddings),
|
||||||
|
):
|
||||||
|
if value is not None and len(value) != n:
|
||||||
|
raise ValueError(f"{label} length {len(value)} does not match ids length {n}")
|
||||||
|
|
||||||
existing = self.get(ids=ids, include=["documents", "metadatas"])
|
existing = self.get(ids=ids, include=["documents", "metadatas"])
|
||||||
by_id = {
|
by_id = {
|
||||||
rid: (existing.documents[i], existing.metadatas[i])
|
rid: (existing.documents[i], existing.metadatas[i])
|
||||||
|
|||||||
@@ -156,6 +156,12 @@ class ChromaCollection(BaseCollection):
|
|||||||
_validate_where(where)
|
_validate_where(where)
|
||||||
_validate_where(where_document)
|
_validate_where(where_document)
|
||||||
|
|
||||||
|
if (query_texts is None) == (query_embeddings is None):
|
||||||
|
raise ValueError("query requires exactly one of query_texts or query_embeddings")
|
||||||
|
chosen = query_texts if query_texts is not None else query_embeddings
|
||||||
|
if not chosen:
|
||||||
|
raise ValueError("query input must be a non-empty list")
|
||||||
|
|
||||||
spec = _IncludeSpec.resolve(include, default_distances=True)
|
spec = _IncludeSpec.resolve(include, default_distances=True)
|
||||||
chroma_include: list[str] = []
|
chroma_include: list[str] = []
|
||||||
if spec.documents:
|
if spec.documents:
|
||||||
@@ -190,7 +196,10 @@ class ChromaCollection(BaseCollection):
|
|||||||
|
|
||||||
ids = raw.get("ids") or []
|
ids = raw.get("ids") or []
|
||||||
if not ids:
|
if not ids:
|
||||||
return QueryResult.empty(num_queries=num_queries)
|
return QueryResult.empty(
|
||||||
|
num_queries=num_queries,
|
||||||
|
embeddings_requested=spec.embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
documents = raw.get("documents") or [[] for _ in ids]
|
documents = raw.get("documents") or [[] for _ in ids]
|
||||||
metadatas = raw.get("metadatas") or [[] for _ in ids]
|
metadatas = raw.get("metadatas") or [[] for _ in ids]
|
||||||
@@ -332,8 +341,18 @@ class ChromaBackend(BaseBackend):
|
|||||||
"""Return a cached ``PersistentClient``, rebuilding on inode/mtime change.
|
"""Return a cached ``PersistentClient``, rebuilding on inode/mtime change.
|
||||||
|
|
||||||
Handles the palace-rebuild case (repair/nuke/purge) by invalidating the
|
Handles the palace-rebuild case (repair/nuke/purge) by invalidating the
|
||||||
cache when ``chroma.sqlite3`` changes on disk. FAT/exFAT return inode 0,
|
cache when ``chroma.sqlite3`` changes on disk. Mirrors the semantics of
|
||||||
so inode comparisons only fire when non-zero (matches #757 semantics).
|
``mcp_server._get_client`` (merged via #757):
|
||||||
|
|
||||||
|
* DB file missing while we hold a cached client → drop the cache so we
|
||||||
|
do not serve stale data after a rebuild that has not yet re-created
|
||||||
|
the DB.
|
||||||
|
* Transition 0 → nonzero stat (DB created after cache) counts as a
|
||||||
|
change, so the cached client is replaced with one that sees the DB.
|
||||||
|
* FAT/exFAT filesystems return inode 0; we never fire inode comparisons
|
||||||
|
when either side is 0 (safe fallback) but still honor mtime.
|
||||||
|
* Mtime change uses an epsilon (0.01 s) to tolerate FS timestamp
|
||||||
|
granularity without thrashing.
|
||||||
"""
|
"""
|
||||||
if self._closed:
|
if self._closed:
|
||||||
from .base import BackendClosedError # late import avoids cycles at module load
|
from .base import BackendClosedError # late import avoids cycles at module load
|
||||||
@@ -344,16 +363,32 @@ class ChromaBackend(BaseBackend):
|
|||||||
cached_inode, cached_mtime = self._freshness.get(palace_path, (0, 0.0))
|
cached_inode, cached_mtime = self._freshness.get(palace_path, (0, 0.0))
|
||||||
current_inode, current_mtime = self._db_stat(palace_path)
|
current_inode, current_mtime = self._db_stat(palace_path)
|
||||||
|
|
||||||
|
db_path = os.path.join(palace_path, "chroma.sqlite3")
|
||||||
|
# DB was present when cache was built but is now missing → invalidate.
|
||||||
|
if cached is not None and not os.path.isfile(db_path):
|
||||||
|
self._clients.pop(palace_path, None)
|
||||||
|
self._freshness.pop(palace_path, None)
|
||||||
|
cached = None
|
||||||
|
cached_inode, cached_mtime = 0, 0.0
|
||||||
|
|
||||||
inode_changed = current_inode != 0 and cached_inode != 0 and current_inode != cached_inode
|
inode_changed = current_inode != 0 and cached_inode != 0 and current_inode != cached_inode
|
||||||
|
# Transition from no-stat (0.0) to a real stat counts as a change so we
|
||||||
|
# pick up a DB that was created after the cache was built.
|
||||||
|
mtime_appeared = cached_mtime == 0.0 and current_mtime != 0.0
|
||||||
mtime_changed = (
|
mtime_changed = (
|
||||||
current_mtime != 0.0 and cached_mtime != 0.0 and current_mtime > cached_mtime
|
current_mtime != 0.0
|
||||||
|
and cached_mtime != 0.0
|
||||||
|
and abs(current_mtime - cached_mtime) > 0.01
|
||||||
)
|
)
|
||||||
|
|
||||||
if cached is None or inode_changed or mtime_changed:
|
if cached is None or inode_changed or mtime_changed or mtime_appeared:
|
||||||
_fix_blob_seq_ids(palace_path)
|
_fix_blob_seq_ids(palace_path)
|
||||||
cached = chromadb.PersistentClient(path=palace_path)
|
cached = chromadb.PersistentClient(path=palace_path)
|
||||||
self._clients[palace_path] = cached
|
self._clients[palace_path] = cached
|
||||||
self._freshness[palace_path] = (current_inode, current_mtime)
|
# Re-stat after the client constructor runs: chromadb creates
|
||||||
|
# chroma.sqlite3 lazily, so the stat captured before the call
|
||||||
|
# may still be (0, 0.0) on first open.
|
||||||
|
self._freshness[palace_path] = self._db_stat(palace_path)
|
||||||
return cached
|
return cached
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
@@ -164,6 +164,140 @@ def test_chroma_detect_matches_palace_with_chroma_sqlite(tmp_path):
|
|||||||
assert ChromaBackend.detect(str(tmp_path.parent)) is False
|
assert ChromaBackend.detect(str(tmp_path.parent)) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_rejects_missing_input():
|
||||||
|
fake = _FakeCollection()
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
collection.query()
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_rejects_both_texts_and_embeddings():
|
||||||
|
fake = _FakeCollection()
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
collection.query(query_texts=["q"], query_embeddings=[[0.1, 0.2]])
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_rejects_empty_input_list():
|
||||||
|
fake = _FakeCollection()
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
collection.query(query_texts=[])
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_empty_preserves_embeddings_outer_shape_when_requested():
|
||||||
|
fake = _FakeCollection(
|
||||||
|
query_response={"ids": [], "documents": [], "metadatas": [], "distances": []}
|
||||||
|
)
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
|
||||||
|
requested = collection.query(query_texts=["q1", "q2"], include=["documents", "embeddings"])
|
||||||
|
assert requested.embeddings == [[], []]
|
||||||
|
|
||||||
|
not_requested = collection.query(query_texts=["q1", "q2"], include=["documents"])
|
||||||
|
assert not_requested.embeddings is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_collection_update_default_validates_list_lengths(tmp_path):
|
||||||
|
backend = ChromaBackend()
|
||||||
|
palace_path = tmp_path / "palace"
|
||||||
|
collection = backend.get_collection(
|
||||||
|
palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)),
|
||||||
|
collection_name="mempalace_drawers",
|
||||||
|
create=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mismatched documents length → clear ValueError, not silent merge.
|
||||||
|
with pytest.raises(ValueError, match="documents length"):
|
||||||
|
collection._collection.add(
|
||||||
|
documents=["a", "b"],
|
||||||
|
ids=["1", "2"],
|
||||||
|
metadatas=[{"k": 1}, {"k": 2}],
|
||||||
|
)
|
||||||
|
from mempalace.backends.base import BaseCollection
|
||||||
|
|
||||||
|
BaseCollection.update(
|
||||||
|
collection,
|
||||||
|
ids=["1", "2"],
|
||||||
|
documents=["only-one"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_cache_invalidates_when_db_file_missing(tmp_path):
|
||||||
|
"""A palace rebuild that removes chroma.sqlite3 must drop the stale cache."""
|
||||||
|
backend = ChromaBackend()
|
||||||
|
palace_path = tmp_path / "palace"
|
||||||
|
backend.get_collection(
|
||||||
|
palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)),
|
||||||
|
collection_name="mempalace_drawers",
|
||||||
|
create=True,
|
||||||
|
)
|
||||||
|
assert str(palace_path) in backend._clients
|
||||||
|
prior_client = backend._clients[str(palace_path)]
|
||||||
|
prior_freshness = backend._freshness[str(palace_path)]
|
||||||
|
assert prior_freshness != (0, 0.0) # DB file exists after get_or_create_collection
|
||||||
|
|
||||||
|
# Remove chroma.sqlite3 to simulate a rebuild mid-flight. The stale cache
|
||||||
|
# must not be silently reused — the in-memory HNSW index would be wrong.
|
||||||
|
(palace_path / "chroma.sqlite3").unlink()
|
||||||
|
|
||||||
|
new_client = backend._client(str(palace_path))
|
||||||
|
# New client object (cache was replaced, not reused) and freshness was reset
|
||||||
|
# to (0, 0.0) to reflect "no DB on disk yet" state.
|
||||||
|
assert new_client is not prior_client
|
||||||
|
assert backend._freshness[str(palace_path)] == (0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_cache_picks_up_db_created_after_first_open(tmp_path):
|
||||||
|
"""The 0 → nonzero stat transition invalidates a cache built before the DB existed."""
|
||||||
|
backend = ChromaBackend()
|
||||||
|
palace_path = tmp_path / "palace"
|
||||||
|
palace_path.mkdir()
|
||||||
|
|
||||||
|
# Seed an entry in the caches as if a prior _client() call had opened the
|
||||||
|
# palace when chroma.sqlite3 did not exist yet. Freshness (0, 0.0) is the
|
||||||
|
# signal that the DB was absent at cache time.
|
||||||
|
sentinel = object()
|
||||||
|
backend._clients[str(palace_path)] = sentinel
|
||||||
|
backend._freshness[str(palace_path)] = (0, 0.0)
|
||||||
|
|
||||||
|
# The DB file now appears (real chromadb would have created it by now).
|
||||||
|
# Use a real chromadb call so _fix_blob_seq_ids and PersistentClient succeed.
|
||||||
|
import chromadb as _chromadb
|
||||||
|
|
||||||
|
_chromadb.PersistentClient(path=str(palace_path)).get_or_create_collection("seed")
|
||||||
|
assert (palace_path / "chroma.sqlite3").is_file()
|
||||||
|
|
||||||
|
# Next _client() call must detect the 0 → nonzero transition and rebuild.
|
||||||
|
refreshed = backend._client(str(palace_path))
|
||||||
|
assert refreshed is not sentinel
|
||||||
|
assert backend._freshness[str(palace_path)] != (0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_collection_update_default_rejects_mismatched_lengths(tmp_path):
|
||||||
|
"""The ABC default update() raises ValueError rather than silently misaligning."""
|
||||||
|
from mempalace.backends.base import BaseCollection
|
||||||
|
|
||||||
|
backend = ChromaBackend()
|
||||||
|
palace_path = tmp_path / "palace"
|
||||||
|
collection = backend.get_collection(
|
||||||
|
palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)),
|
||||||
|
collection_name="mempalace_drawers",
|
||||||
|
create=True,
|
||||||
|
)
|
||||||
|
collection.add(
|
||||||
|
documents=["a", "b"],
|
||||||
|
ids=["1", "2"],
|
||||||
|
metadatas=[{"k": 1}, {"k": 2}],
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="documents length"):
|
||||||
|
BaseCollection.update(collection, ids=["1", "2"], documents=["only-one"])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="metadatas length"):
|
||||||
|
BaseCollection.update(collection, ids=["1", "2"], metadatas=[{"k": 9}])
|
||||||
|
|
||||||
|
|
||||||
def test_chroma_backend_accepts_palace_ref_kwarg(tmp_path):
|
def test_chroma_backend_accepts_palace_ref_kwarg(tmp_path):
|
||||||
palace_path = tmp_path / "palace"
|
palace_path = tmp_path / "palace"
|
||||||
backend = ChromaBackend()
|
backend = ChromaBackend()
|
||||||
|
|||||||
Reference in New Issue
Block a user