From a17a8b734a50b7c8e2c62b5640e1c2664b44a2ea Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Sat, 18 Apr 2026 12:45:16 -0300 Subject: [PATCH 1/5] =?UTF-8?q?refactor(backends):=20typed=20QueryResult/G?= =?UTF-8?q?etResult,=20PalaceRef,=20BaseBackend=20registry=20(RFC=20001=20?= =?UTF-8?q?=C2=A710)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Advances RFC 001 §10 cleanup so backend-author PRs (#574 LanceDB, #665 Postgres, #700 Qdrant, #697 hosted, #643 PalaceStore, #381 Qdrant) have a stable target to align against. Scope (this PR): - Typed QueryResult / GetResult dataclasses replace Chroma's dict shape at the BaseCollection boundary (§1.3). A transitional _DictCompatMixin keeps existing callers working while the attribute-access migration proceeds. - BaseCollection is now kwargs-only across add/upsert/query/get/delete/update with ABC defaults for estimated_count/close/health and a non-atomic default update() (§1.1–1.2). - PalaceRef replaces raw path strings at the backend boundary (§2.2). - BaseBackend ABC with get_collection/close_palace/close/health/detect (§2.3). - mempalace.backends entry-point group + in-tree registry with resolve_backend_for_palace priority order matching §3.2–3.3. - ChromaCollection normalizes chroma returns into typed results; unknown where-clause operators raise UnsupportedFilterError (no silent drop, §1.4). - ChromaBackend absorbs the inode/mtime client-cache freshness check previously duplicated in mcp_server._get_client() (§10 + PR #757). - searcher.py migrated to typed-attribute access as the reference call site; remaining callers land in a follow-up. - pyproject: chroma registered via [project.entry-points."mempalace.backends"]. Out of scope (explicit follow-ups): - Full caller migration off the dict-compat shim across palace.py, mcp_server.py, miner.py, convo_miner.py, dedup.py, repair.py, exporter.py, palace_graph.py, cli.py, closet_llm.py. - Embedder injection + three-state EmbedderIdentityMismatchError check (§1.5). - maintenance_state() / run_maintenance() benchmark hooks (§7.3). - AbstractBackendContractSuite full coverage (§7.1–7.2). - mempalace migrate / mempalace verify CLI rewrites through BaseCollection (§8). Tests: 970 passed (up from 967 on develop); new coverage for typed results, empty-result outer-shape preservation, \$regex rejection, registry lookup, priority resolver, and PalaceRef-kwarg ChromaBackend.get_collection. Refs: #743 (RFC 001), #989 (RFC 002 tracking issue). --- mempalace/backends/__init__.py | 64 ++++- mempalace/backends/base.py | 359 ++++++++++++++++++++++++-- mempalace/backends/chroma.py | 445 ++++++++++++++++++++++++++++++--- mempalace/backends/registry.py | 189 ++++++++++++++ mempalace/searcher.py | 24 +- pyproject.toml | 3 + tests/test_backends.py | 153 ++++++++++-- 7 files changed, 1143 insertions(+), 94 deletions(-) create mode 100644 mempalace/backends/registry.py diff --git a/mempalace/backends/__init__.py b/mempalace/backends/__init__.py index cb5f14d..ab22ec6 100644 --- a/mempalace/backends/__init__.py +++ b/mempalace/backends/__init__.py @@ -1,6 +1,64 @@ -"""Storage backend implementations for MemPalace.""" +"""Storage backend implementations for MemPalace (RFC 001). -from .base import BaseCollection +Public surface: + +* :class:`BaseCollection` — per-collection read/write contract. +* :class:`BaseBackend` — per-palace factory contract. +* :class:`PalaceRef` — value object identifying a palace for a backend. +* :class:`QueryResult` / :class:`GetResult` — typed read returns. +* Error classes: :class:`PalaceNotFoundError`, :class:`BackendClosedError`, + :class:`UnsupportedFilterError`, :class:`DimensionMismatchError`, + :class:`EmbedderIdentityMismatchError`. +* Registry: :func:`get_backend`, :func:`register`, :func:`available_backends`, + :func:`resolve_backend_for_palace`. +* In-tree Chroma default: :class:`ChromaBackend`, :class:`ChromaCollection`. +""" + +from .base import ( + BackendClosedError, + BackendError, + BaseBackend, + BaseCollection, + DimensionMismatchError, + EmbedderIdentityMismatchError, + GetResult, + HealthStatus, + PalaceNotFoundError, + PalaceRef, + QueryResult, + UnsupportedFilterError, +) from .chroma import ChromaBackend, ChromaCollection +from .registry import ( + available_backends, + get_backend, + get_backend_class, + register, + reset_backends, + resolve_backend_for_palace, + unregister, +) -__all__ = ["BaseCollection", "ChromaBackend", "ChromaCollection"] +__all__ = [ + "BackendClosedError", + "BackendError", + "BaseBackend", + "BaseCollection", + "ChromaBackend", + "ChromaCollection", + "DimensionMismatchError", + "EmbedderIdentityMismatchError", + "GetResult", + "HealthStatus", + "PalaceNotFoundError", + "PalaceRef", + "QueryResult", + "UnsupportedFilterError", + "available_backends", + "get_backend", + "get_backend_class", + "register", + "reset_backends", + "resolve_backend_for_palace", + "unregister", +] diff --git a/mempalace/backends/base.py b/mempalace/backends/base.py index 877da53..819d326 100644 --- a/mempalace/backends/base.py +++ b/mempalace/backends/base.py @@ -1,49 +1,354 @@ -"""Abstract collection interface for MemPalace storage backends.""" +"""Storage backend contract for MemPalace (RFC 001). + +This module defines the surface every storage backend must implement: + +* ``BaseCollection`` — the per-collection read/write interface, kwargs-only. +* ``BaseBackend`` — the per-palace factory, addressed by ``PalaceRef``. +* ``QueryResult`` / ``GetResult`` — typed result dataclasses that replace the + Chroma dict shape as the canonical return type. +* Error classes + ``HealthStatus`` — uniform across backends. + +This is the v1 cleanup from RFC 001 §10: full typed results, ``PalaceRef``, +registry-ready ABC. Embedder injection, maintenance hooks, and the full +conformance suite land in follow-up PRs. +""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from dataclasses import dataclass +from typing import ClassVar, Optional + + +# --------------------------------------------------------------------------- +# Errors +# --------------------------------------------------------------------------- + + +class BackendError(Exception): + """Base class for every storage-backend error raised by core.""" + + +class PalaceNotFoundError(BackendError, FileNotFoundError): + """Raised when ``get_collection(create=False)`` is called on a missing palace. + + Subclass of ``FileNotFoundError`` so legacy callers that catch the latter + (pre-#413 seam) keep working unchanged. + """ + + +class BackendClosedError(BackendError): + """Raised when a backend method is called after ``close()``.""" + + +class UnsupportedFilterError(BackendError): + """Raised when a where-clause uses an operator the backend does not implement. + + Silent dropping of unknown operators is forbidden by spec (RFC 001 §1.4). + """ + + +class DimensionMismatchError(BackendError): + """Raised when the embedding dimension on write does not match the collection.""" + + +class EmbedderIdentityMismatchError(BackendError): + """Raised when the stored embedder model name differs from the current one.""" + + +# --------------------------------------------------------------------------- +# Value objects +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class PalaceRef: + """A handle to a palace, consumed by backends. + + ``id`` is always present and is the key backends use to cache handles. + ``local_path`` is populated for filesystem-rooted palaces. + ``namespace`` is used by server-mode backends for tenant / prefix routing. + """ + + id: str + local_path: Optional[str] = None + namespace: Optional[str] = None + + +@dataclass(frozen=True) +class HealthStatus: + ok: bool + detail: str = "" + + @classmethod + def healthy(cls, detail: str = "") -> "HealthStatus": + return cls(ok=True, detail=detail) + + @classmethod + def unhealthy(cls, detail: str) -> "HealthStatus": + return cls(ok=False, detail=detail) + + +_TYPED_RESULT_FIELDS = ("ids", "documents", "metadatas", "distances", "embeddings") + + +class _DictCompatMixin: + """Transitional dict-protocol access for typed results. + + RFC 001 §1.3 spec is attribute access (``result.ids``). The ``result["ids"]`` + and ``result.get("ids")`` forms are retained as a migration shim for callers + that predate the typed interface and are scheduled for removal in a follow- + up cleanup. New code MUST use attribute access. + """ + + def __getitem__(self, key: str): + if key in _TYPED_RESULT_FIELDS: + return getattr(self, key) + raise KeyError(key) + + def get(self, key: str, default=None): + if key in _TYPED_RESULT_FIELDS: + val = getattr(self, key, default) + return default if val is None else val + return default + + def __contains__(self, key: object) -> bool: + return key in _TYPED_RESULT_FIELDS and getattr(self, key, None) is not None + + +@dataclass(frozen=True) +class QueryResult(_DictCompatMixin): + """Typed return from ``BaseCollection.query``. + + Outer list dimension = number of query vectors / texts. + Inner list dimension = hits per query (may be zero). + + Fields not in ``include=`` at the call site are populated with empty lists + of the correct outer shape (never ``None``), except ``embeddings`` which + is ``None`` when not requested. + """ + + ids: list[list[str]] + documents: list[list[str]] + metadatas: list[list[dict]] + distances: list[list[float]] + embeddings: Optional[list[list[list[float]]]] = None + + @classmethod + def empty(cls, num_queries: int = 1) -> "QueryResult": + """Construct an all-empty result preserving outer dimension.""" + return cls( + ids=[[] for _ in range(num_queries)], + documents=[[] for _ in range(num_queries)], + metadatas=[[] for _ in range(num_queries)], + distances=[[] for _ in range(num_queries)], + embeddings=None, + ) + + +@dataclass(frozen=True) +class GetResult(_DictCompatMixin): + """Typed return from ``BaseCollection.get``.""" + + ids: list[str] + documents: list[str] + metadatas: list[dict] + embeddings: Optional[list[list[float]]] = None + + @classmethod + def empty(cls) -> "GetResult": + return cls(ids=[], documents=[], metadatas=[], embeddings=None) + + +# --------------------------------------------------------------------------- +# Collection contract +# --------------------------------------------------------------------------- class BaseCollection(ABC): - """Smallest collection contract the rest of MemPalace relies on.""" + """Per-collection read/write surface every backend must implement.""" @abstractmethod def add( self, *, - documents: List[str], - ids: List[str], - metadatas: Optional[List[Dict[str, Any]]] = None, - ) -> None: - raise NotImplementedError + documents: list[str], + ids: list[str], + metadatas: Optional[list[dict]] = None, + embeddings: Optional[list[list[float]]] = None, + ) -> None: ... @abstractmethod def upsert( self, *, - documents: List[str], - ids: List[str], - metadatas: Optional[List[Dict[str, Any]]] = None, + documents: list[str], + ids: list[str], + metadatas: Optional[list[dict]] = None, + embeddings: Optional[list[list[float]]] = None, + ) -> None: ... + + @abstractmethod + def query( + self, + *, + query_texts: Optional[list[str]] = None, + query_embeddings: Optional[list[list[float]]] = None, + n_results: int = 10, + where: Optional[dict] = None, + where_document: Optional[dict] = None, + include: Optional[list[str]] = None, + ) -> QueryResult: ... + + @abstractmethod + def get( + self, + *, + ids: Optional[list[str]] = None, + where: Optional[dict] = None, + where_document: Optional[dict] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + include: Optional[list[str]] = None, + ) -> GetResult: ... + + @abstractmethod + def delete( + self, + *, + ids: Optional[list[str]] = None, + where: Optional[dict] = None, + ) -> None: ... + + @abstractmethod + def count(self) -> int: ... + + # ------------------------------------------------------------------ + # Optional methods with ABC defaults (spec §1.2) + # ------------------------------------------------------------------ + + def estimated_count(self) -> int: + return self.count() + + def close(self) -> None: + return None + + def health(self) -> HealthStatus: + return HealthStatus.healthy() + + def update( + self, + *, + ids: list[str], + documents: Optional[list[str]] = None, + metadatas: Optional[list[dict]] = None, + embeddings: Optional[list[list[float]]] = None, ) -> None: - raise NotImplementedError + """Default non-atomic update: get + merge + upsert. + + Backends advertising ``supports_update`` MUST override with an atomic + single-round-trip implementation. + """ + if documents is None and metadatas is None and embeddings is None: + raise ValueError("update requires at least one of documents, metadatas, embeddings") + + existing = self.get(ids=ids, include=["documents", "metadatas"]) + by_id = { + rid: (existing.documents[i], existing.metadatas[i]) + for i, rid in enumerate(existing.ids) + } + merged_docs: list[str] = [] + merged_metas: list[dict] = [] + for i, rid in enumerate(ids): + prev_doc, prev_meta = by_id.get(rid, ("", {})) + merged_docs.append(documents[i] if documents is not None else prev_doc) + new_meta = dict(prev_meta or {}) + if metadatas is not None: + new_meta.update(metadatas[i] or {}) + merged_metas.append(new_meta) + self.upsert( + documents=merged_docs, + ids=list(ids), + metadatas=merged_metas, + embeddings=embeddings, + ) + + +# --------------------------------------------------------------------------- +# Backend contract +# --------------------------------------------------------------------------- + + +class BaseBackend(ABC): + """Long-lived factory serving many palaces (RFC 001 §2). + + Instances are lightweight on construction — no I/O, no network. All + connection work is deferred to ``get_collection``. Instances are thread- + safe for concurrent ``get_collection`` calls across different palaces. + """ + + name: ClassVar[str] + spec_version: ClassVar[str] = "1.0" + capabilities: ClassVar[frozenset[str]] = frozenset() @abstractmethod - def update(self, **kwargs: Any) -> None: - """Update existing records. Must raise if any ID is missing.""" - raise NotImplementedError + def get_collection( + self, + *, + palace: PalaceRef, + collection_name: str, + create: bool = False, + options: Optional[dict] = None, + ) -> BaseCollection: ... - @abstractmethod - def query(self, **kwargs: Any) -> Dict[str, Any]: - raise NotImplementedError + def close_palace(self, palace: PalaceRef) -> None: + """Evict cached handles for a single palace. Default: no-op.""" + return None - @abstractmethod - def get(self, **kwargs: Any) -> Dict[str, Any]: - raise NotImplementedError + def close(self) -> None: + """Shut down the entire backend. Default: no-op.""" + return None - @abstractmethod - def delete(self, **kwargs: Any) -> None: - raise NotImplementedError + def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus: + return HealthStatus.healthy() - @abstractmethod - def count(self) -> int: - raise NotImplementedError + # Optional detection hint used by selection priority (RFC 001 §3.3 (4)): + @classmethod + def detect(cls, path: str) -> bool: # pragma: no cover - default hook + return False + + +# --------------------------------------------------------------------------- +# Adapter utilities +# --------------------------------------------------------------------------- + + +# Keys the Chroma ``include=`` parameter accepts. +_VALID_INCLUDE_KEYS = frozenset({"documents", "metadatas", "distances", "embeddings"}) + + +@dataclass +class _IncludeSpec: + """Resolve an ``include=`` parameter with spec-mandated defaults.""" + + documents: bool = True + metadatas: bool = True + distances: bool = True # only meaningful for query + embeddings: bool = False + + @classmethod + def resolve( + cls, include: Optional[list[str]], *, default_distances: bool = True + ) -> "_IncludeSpec": + if include is None: + return cls( + documents=True, + metadatas=True, + distances=default_distances, + embeddings=False, + ) + keys = {k for k in include if k in _VALID_INCLUDE_KEYS} + return cls( + documents="documents" in keys, + metadatas="metadatas" in keys, + distances="distances" in keys, + embeddings="embeddings" in keys, + ) diff --git a/mempalace/backends/chroma.py b/mempalace/backends/chroma.py index 1a13675..f12a88b 100644 --- a/mempalace/backends/chroma.py +++ b/mempalace/backends/chroma.py @@ -1,17 +1,54 @@ -"""ChromaDB-backed MemPalace collection adapter.""" +"""ChromaDB-backed MemPalace storage backend (RFC 001 reference implementation).""" import logging import os import sqlite3 +from typing import Any, Optional import chromadb -from .base import BaseCollection +from .base import ( + BaseBackend, + BaseCollection, + GetResult, + HealthStatus, + PalaceNotFoundError, + PalaceRef, + QueryResult, + UnsupportedFilterError, + _IncludeSpec, +) logger = logging.getLogger(__name__) -def _fix_blob_seq_ids(palace_path: str): +_REQUIRED_OPERATORS = frozenset({"$eq", "$ne", "$in", "$nin", "$and", "$or", "$contains"}) +_OPTIONAL_OPERATORS = frozenset({"$gt", "$gte", "$lt", "$lte"}) +_SUPPORTED_OPERATORS = _REQUIRED_OPERATORS | _OPTIONAL_OPERATORS + + +def _validate_where(where: Optional[dict]) -> None: + """Scan a where-clause for unknown operators and raise ``UnsupportedFilterError``. + + Spec (RFC 001 §1.4): silent dropping of unknown operators is forbidden. + """ + if not where: + return + stack = [where] + while stack: + node = stack.pop() + if not isinstance(node, dict): + continue + for k, v in node.items(): + if k.startswith("$") and k not in _SUPPORTED_OPERATORS: + raise UnsupportedFilterError(f"operator {k!r} not supported by chroma backend") + if isinstance(v, dict): + stack.append(v) + elif isinstance(v, list): + stack.extend(x for x in v if isinstance(x, dict)) + + +def _fix_blob_seq_ids(palace_path: str) -> None: """Fix ChromaDB 0.6.x -> 1.5.x migration bug: BLOB seq_ids -> INTEGER. ChromaDB 0.6.x stored seq_id as big-endian 8-byte BLOBs. ChromaDB 1.5.x @@ -43,62 +80,293 @@ def _fix_blob_seq_ids(palace_path: str): logger.exception("Could not fix BLOB seq_ids in %s", db_path) +# --------------------------------------------------------------------------- +# Collection adapter +# --------------------------------------------------------------------------- + + +def _as_list(v: Any) -> list: + """Coerce possibly-None scalar-or-list into a list (defensive for chroma nulls).""" + if v is None: + return [] + if isinstance(v, list): + return v + return [v] + + class ChromaCollection(BaseCollection): - """Thin adapter over a ChromaDB collection.""" + """Thin adapter translating ChromaDB dict returns into typed results.""" def __init__(self, collection): self._collection = collection - def add(self, *, documents, ids, metadatas=None): - self._collection.add(documents=documents, ids=ids, metadatas=metadatas) + # ------------------------------------------------------------------ + # Writes + # ------------------------------------------------------------------ - def upsert(self, *, documents, ids, metadatas=None): - self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas) + def add(self, *, documents, ids, metadatas=None, embeddings=None): + kwargs: dict[str, Any] = {"documents": documents, "ids": ids} + if metadatas is not None: + kwargs["metadatas"] = metadatas + if embeddings is not None: + kwargs["embeddings"] = embeddings + self._collection.add(**kwargs) - def update(self, **kwargs): + def upsert(self, *, documents, ids, metadatas=None, embeddings=None): + kwargs: dict[str, Any] = {"documents": documents, "ids": ids} + if metadatas is not None: + kwargs["metadatas"] = metadatas + if embeddings is not None: + kwargs["embeddings"] = embeddings + self._collection.upsert(**kwargs) + + def update( + self, + *, + ids, + documents=None, + metadatas=None, + embeddings=None, + ): + if documents is None and metadatas is None and embeddings is None: + raise ValueError("update requires at least one of documents, metadatas, embeddings") + kwargs: dict[str, Any] = {"ids": ids} + if documents is not None: + kwargs["documents"] = documents + if metadatas is not None: + kwargs["metadatas"] = metadatas + if embeddings is not None: + kwargs["embeddings"] = embeddings self._collection.update(**kwargs) - def query(self, **kwargs): - return self._collection.query(**kwargs) + # ------------------------------------------------------------------ + # Reads + # ------------------------------------------------------------------ - def get(self, **kwargs): - return self._collection.get(**kwargs) + def query( + self, + *, + query_texts=None, + query_embeddings=None, + n_results=10, + where=None, + where_document=None, + include=None, + ) -> QueryResult: + _validate_where(where) + _validate_where(where_document) - def delete(self, **kwargs): + spec = _IncludeSpec.resolve(include, default_distances=True) + chroma_include: list[str] = [] + if spec.documents: + chroma_include.append("documents") + if spec.metadatas: + chroma_include.append("metadatas") + if spec.distances: + chroma_include.append("distances") + if spec.embeddings: + chroma_include.append("embeddings") + + kwargs: dict[str, Any] = { + "n_results": n_results, + "include": chroma_include, + } + if query_texts is not None: + kwargs["query_texts"] = query_texts + if query_embeddings is not None: + kwargs["query_embeddings"] = query_embeddings + if where is not None: + kwargs["where"] = where + if where_document is not None: + kwargs["where_document"] = where_document + + raw = self._collection.query(**kwargs) + + num_queries = ( + len(query_texts) + if query_texts is not None + else (len(query_embeddings) if query_embeddings is not None else 1) + ) + + ids = raw.get("ids") or [] + if not ids: + return QueryResult.empty(num_queries=num_queries) + + documents = raw.get("documents") or [[] for _ in ids] + metadatas = raw.get("metadatas") or [[] for _ in ids] + distances = raw.get("distances") or [[] for _ in ids] + embeddings_raw = raw.get("embeddings") if spec.embeddings else None + + def _none_list_to_empty(outer): + return [(inner or []) for inner in outer] + + return QueryResult( + ids=_none_list_to_empty(ids), + documents=_none_list_to_empty(documents), + metadatas=_none_list_to_empty(metadatas), + distances=_none_list_to_empty(distances), + embeddings=( + [list(inner) for inner in embeddings_raw] + if spec.embeddings and embeddings_raw is not None + else None + ), + ) + + def get( + self, + *, + ids=None, + where=None, + where_document=None, + limit=None, + offset=None, + include=None, + ) -> GetResult: + _validate_where(where) + _validate_where(where_document) + + spec = _IncludeSpec.resolve(include, default_distances=False) + chroma_include: list[str] = [] + if spec.documents: + chroma_include.append("documents") + if spec.metadatas: + chroma_include.append("metadatas") + if spec.embeddings: + chroma_include.append("embeddings") + + kwargs: dict[str, Any] = {"include": chroma_include} + if ids is not None: + kwargs["ids"] = ids + if where is not None: + kwargs["where"] = where + if where_document is not None: + kwargs["where_document"] = where_document + if limit is not None: + kwargs["limit"] = limit + if offset is not None: + kwargs["offset"] = offset + + raw = self._collection.get(**kwargs) + out_ids = list(raw.get("ids") or []) + out_docs = list(raw.get("documents") or []) if spec.documents else [] + out_metas = list(raw.get("metadatas") or []) if spec.metadatas else [] + out_embeds = raw.get("embeddings") if spec.embeddings else None + + # Pad doc/meta lists to match ids so downstream zipping is safe. + if spec.documents and len(out_docs) < len(out_ids): + out_docs = out_docs + [""] * (len(out_ids) - len(out_docs)) + if spec.metadatas and len(out_metas) < len(out_ids): + out_metas = out_metas + [{}] * (len(out_ids) - len(out_metas)) + + return GetResult( + ids=out_ids, + documents=out_docs, + metadatas=out_metas, + embeddings=[list(v) for v in out_embeds] if out_embeds is not None else None, + ) + + def delete(self, *, ids=None, where=None): + _validate_where(where) + kwargs: dict[str, Any] = {} + if ids is not None: + kwargs["ids"] = ids + if where is not None: + kwargs["where"] = where self._collection.delete(**kwargs) def count(self): return self._collection.count() -class ChromaBackend: - """Factory for MemPalace's default ChromaDB backend.""" +# --------------------------------------------------------------------------- +# Backend +# --------------------------------------------------------------------------- + + +class ChromaBackend(BaseBackend): + """MemPalace's default ChromaDB backend. + + Maintains two caches: + + * ``self._clients`` — ``palace_path -> PersistentClient`` for callers + using the ``PalaceRef`` / :meth:`get_collection` path. + * An inode+mtime freshness check absorbed from ``mcp_server._get_client`` + (merged via #757) ensuring a palace rebuild on disk is detected on the + next :meth:`get_collection` call. + """ + + name = "chroma" + capabilities = frozenset( + { + "supports_embeddings_in", + "supports_embeddings_passthrough", + "supports_embeddings_out", + "supports_metadata_filters", + "supports_contains_fast", + "local_mode", + } + ) def __init__(self): - # Per-instance client cache: palace_path -> chromadb.PersistentClient - self._clients: dict = {} + # palace_path -> PersistentClient + self._clients: dict[str, Any] = {} + # palace_path -> (inode, mtime) of chroma.sqlite3 at cache time. + self._freshness: dict[str, tuple[int, float]] = {} + self._closed = False # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ + @staticmethod + def _db_stat(palace_path: str) -> tuple[int, float]: + """Return ``(inode, mtime)`` of ``chroma.sqlite3`` or ``(0, 0.0)`` if absent.""" + db_path = os.path.join(palace_path, "chroma.sqlite3") + try: + st = os.stat(db_path) + return (st.st_ino, st.st_mtime) + except OSError: + return (0, 0.0) + def _client(self, palace_path: str): - """Return a cached PersistentClient for *palace_path*, creating one if needed.""" - if palace_path not in self._clients: + """Return a cached ``PersistentClient``, rebuilding on inode/mtime change. + + Handles the palace-rebuild case (repair/nuke/purge) by invalidating the + cache when ``chroma.sqlite3`` changes on disk. FAT/exFAT return inode 0, + so inode comparisons only fire when non-zero (matches #757 semantics). + """ + if self._closed: + from .base import BackendClosedError # late import avoids cycles at module load + + raise BackendClosedError("ChromaBackend has been closed") + + cached = self._clients.get(palace_path) + cached_inode, cached_mtime = self._freshness.get(palace_path, (0, 0.0)) + current_inode, current_mtime = self._db_stat(palace_path) + + inode_changed = current_inode != 0 and cached_inode != 0 and current_inode != cached_inode + mtime_changed = ( + current_mtime != 0.0 and cached_mtime != 0.0 and current_mtime > cached_mtime + ) + + if cached is None or inode_changed or mtime_changed: _fix_blob_seq_ids(palace_path) - self._clients[palace_path] = chromadb.PersistentClient(path=palace_path) - return self._clients[palace_path] + cached = chromadb.PersistentClient(path=palace_path) + self._clients[palace_path] = cached + self._freshness[palace_path] = (current_inode, current_mtime) + return cached # ------------------------------------------------------------------ - # Public static helpers (for callers that manage their own caching) + # Public static helpers (legacy; prefer :meth:`get_collection`) # ------------------------------------------------------------------ @staticmethod def make_client(palace_path: str): - """Create and return a fresh PersistentClient (fix BLOB seq_ids first). + """Create a fresh ``PersistentClient`` (fixes BLOB seq_ids first). - Intended for long-lived callers (e.g. mcp_server) that keep their own - inode/mtime-based client cache. + Deprecated-ish: exposed for legacy long-lived callers that manage their + own client cache. New code should obtain a collection through + :meth:`get_collection` which manages caching internally. """ _fix_blob_seq_ids(palace_path) return chromadb.PersistentClient(path=palace_path) @@ -109,12 +377,31 @@ class ChromaBackend: return chromadb.__version__ # ------------------------------------------------------------------ - # Collection lifecycle + # BaseBackend surface # ------------------------------------------------------------------ - def get_collection(self, palace_path: str, collection_name: str, create: bool = False): + def get_collection( + self, + *args, + **kwargs, + ) -> ChromaCollection: + """Obtain a collection for a palace. + + Supports two calling conventions during the RFC 001 transition: + + * New (preferred): ``get_collection(palace=PalaceRef, collection_name=..., + create=False, options=None)``. + * Legacy: ``get_collection(palace_path, collection_name, create=False)`` + — still used by callers not yet migrated. + """ + palace_ref, collection_name, create, options = _normalize_get_collection_args(args, kwargs) + + palace_path = palace_ref.local_path + if palace_path is None: + raise PalaceNotFoundError("ChromaBackend requires PalaceRef.local_path") + if not create and not os.path.isdir(palace_path): - raise FileNotFoundError(palace_path) + raise PalaceNotFoundError(palace_path) if create: os.makedirs(palace_path, exist_ok=True) @@ -124,29 +411,113 @@ class ChromaBackend: pass client = self._client(palace_path) + hnsw_space = "cosine" + if options and isinstance(options, dict): + hnsw_space = options.get("hnsw_space", hnsw_space) + if create: collection = client.get_or_create_collection( - collection_name, metadata={"hnsw:space": "cosine"} + collection_name, metadata={"hnsw:space": hnsw_space} ) else: collection = client.get_collection(collection_name) return ChromaCollection(collection) - def get_or_create_collection( - self, palace_path: str, collection_name: str - ) -> "ChromaCollection": - """Shorthand for get_collection(..., create=True).""" + def close_palace(self, palace) -> None: + """Drop cached handles for ``palace``. Accepts ``PalaceRef`` or legacy path str.""" + path = palace.local_path if isinstance(palace, PalaceRef) else palace + if path is None: + return + self._clients.pop(path, None) + self._freshness.pop(path, None) + + def close(self) -> None: + self._clients.clear() + self._freshness.clear() + self._closed = True + + def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus: + if self._closed: + return HealthStatus.unhealthy("backend closed") + return HealthStatus.healthy() + + @classmethod + def detect(cls, path: str) -> bool: + return os.path.isfile(os.path.join(path, "chroma.sqlite3")) + + # ------------------------------------------------------------------ + # Legacy (pre-RFC 001) surface — retained while callers migrate. + # ------------------------------------------------------------------ + + def get_or_create_collection(self, palace_path: str, collection_name: str) -> ChromaCollection: + """Legacy shim for ``get_collection(..., create=True)`` by path string.""" return self.get_collection(palace_path, collection_name, create=True) def delete_collection(self, palace_path: str, collection_name: str) -> None: - """Delete *collection_name* from the palace at *palace_path*.""" + """Delete ``collection_name`` from the palace at ``palace_path``.""" self._client(palace_path).delete_collection(collection_name) def create_collection( self, palace_path: str, collection_name: str, hnsw_space: str = "cosine" - ) -> "ChromaCollection": - """Create (not get-or-create) *collection_name* with cosine HNSW space.""" + ) -> ChromaCollection: + """Create (not get-or-create) ``collection_name`` with the given HNSW space.""" collection = self._client(palace_path).create_collection( collection_name, metadata={"hnsw:space": hnsw_space} ) return ChromaCollection(collection) + + +def _normalize_get_collection_args(args, kwargs): + """Unify legacy positional ``(palace_path, collection_name, create)`` calls + with the new kwargs-only ``(palace=PalaceRef, collection_name=..., create=...)``. + + Returns ``(PalaceRef, collection_name, create, options)``. + """ + # New-style: palace= kwarg with a PalaceRef (spec path). + if "palace" in kwargs: + palace_ref = kwargs.pop("palace") + if not isinstance(palace_ref, PalaceRef): + raise TypeError("palace= must be a PalaceRef instance") + collection_name = kwargs.pop("collection_name") + create = kwargs.pop("create", False) + options = kwargs.pop("options", None) + if kwargs: + raise TypeError(f"unexpected kwargs: {sorted(kwargs)}") + if args: + raise TypeError("positional args not allowed with palace= kwarg") + return palace_ref, collection_name, create, options + + # Legacy: first positional is a path string. + if args: + palace_path = args[0] + rest = list(args[1:]) + collection_name = kwargs.pop("collection_name", None) or (rest.pop(0) if rest else None) + if collection_name is None: + raise TypeError("collection_name is required") + create = kwargs.pop("create", False) + if rest: + create = rest.pop(0) + if kwargs: + raise TypeError(f"unexpected kwargs: {sorted(kwargs)}") + return ( + PalaceRef(id=palace_path, local_path=palace_path), + collection_name, + bool(create), + None, + ) + + # Legacy kwargs-only (palace_path=..., collection_name=..., create=...) + if "palace_path" in kwargs: + palace_path = kwargs.pop("palace_path") + collection_name = kwargs.pop("collection_name") + create = kwargs.pop("create", False) + if kwargs: + raise TypeError(f"unexpected kwargs: {sorted(kwargs)}") + return ( + PalaceRef(id=palace_path, local_path=palace_path), + collection_name, + bool(create), + None, + ) + + raise TypeError("get_collection requires palace= or a positional palace_path") diff --git a/mempalace/backends/registry.py b/mempalace/backends/registry.py new file mode 100644 index 0000000..7551bd3 --- /dev/null +++ b/mempalace/backends/registry.py @@ -0,0 +1,189 @@ +"""Backend registry + entry-point discovery (RFC 001 §3). + +Third-party backends ship as installable packages that declare a +``mempalace.backends`` entry point:: + + # pyproject.toml of mempalace-postgres + [project.entry-points."mempalace.backends"] + postgres = "mempalace_postgres:PostgresBackend" + +MemPalace discovers them at process start. In-tree tests and local development +can register manually via :func:`register`. Explicit registration wins on +name conflict (matches RFC 001 §3.2). +""" + +from __future__ import annotations + +import logging +from importlib import metadata +from threading import Lock +from typing import Optional, Type + +from .base import BaseBackend + +logger = logging.getLogger(__name__) + +_ENTRY_POINT_GROUP = "mempalace.backends" + +_registry: dict[str, Type[BaseBackend]] = {} +_instances: dict[str, BaseBackend] = {} +_explicit: set[str] = set() +_discovered = False +_lock = Lock() + + +def register(name: str, backend_cls: Type[BaseBackend]) -> None: + """Register ``backend_cls`` under ``name``. + + Explicit registration wins over entry-point discovery on conflict + (RFC 001 §3.2). + """ + with _lock: + _registry[name] = backend_cls + _explicit.add(name) + # Invalidate any cached instance so the new class is used on next get. + _instances.pop(name, None) + + +def unregister(name: str) -> None: + """Remove a backend registration (primarily for tests).""" + with _lock: + _registry.pop(name, None) + _explicit.discard(name) + _instances.pop(name, None) + + +def _discover_entry_points() -> None: + """Load entry-point-declared backends once per process.""" + global _discovered + if _discovered: + return + with _lock: + if _discovered: + return + try: + eps = metadata.entry_points() + # Py ≥ 3.10 returns an EntryPoints object; older versions returned a dict. + group = ( + eps.select(group=_ENTRY_POINT_GROUP) + if hasattr(eps, "select") + else eps.get(_ENTRY_POINT_GROUP, []) + ) + except Exception: + logger.exception("entry-point discovery for %s failed", _ENTRY_POINT_GROUP) + group = [] + for ep in group: + if ep.name in _explicit: + continue # explicit registration wins + try: + cls = ep.load() + except Exception: + logger.exception("failed to load backend entry point %r", ep.name) + continue + if not isinstance(cls, type) or not issubclass(cls, BaseBackend): + logger.warning( + "entry point %r did not resolve to a BaseBackend subclass (got %r)", + ep.name, + cls, + ) + continue + _registry.setdefault(ep.name, cls) + _discovered = True + + +def available_backends() -> list[str]: + """Return sorted list of all registered backend names.""" + _discover_entry_points() + return sorted(_registry.keys()) + + +def get_backend_class(name: str) -> Type[BaseBackend]: + """Return the registered backend class for ``name``.""" + _discover_entry_points() + try: + return _registry[name] + except KeyError as e: + raise KeyError(f"unknown backend {name!r}; available: {available_backends()}") from e + + +def get_backend(name: str) -> BaseBackend: + """Return a long-lived instance of the named backend. + + Instances are cached per-name; repeated calls return the same object. + Call :func:`reset_backends` in tests that need isolation. + """ + _discover_entry_points() + with _lock: + inst = _instances.get(name) + if inst is not None: + return inst + cls = _registry.get(name) + if cls is None: + raise KeyError(f"unknown backend {name!r}; available: {sorted(_registry.keys())}") + inst = cls() + _instances[name] = inst + return inst + + +def reset_backends() -> None: + """Close and drop all cached backend instances (primarily for tests).""" + with _lock: + for inst in _instances.values(): + try: + inst.close() + except Exception: + logger.exception("error closing backend during reset") + _instances.clear() + + +def resolve_backend_for_palace( + *, + explicit: Optional[str] = None, + config_value: Optional[str] = None, + env_value: Optional[str] = None, + palace_path: Optional[str] = None, + default: str = "chroma", +) -> str: + """Resolve the backend name for a palace per RFC 001 §3.3 priority order. + + 1. Explicit kwarg / CLI flag + 2. Per-palace config value + 3. ``MEMPALACE_BACKEND`` env var + 4. Auto-detect from on-disk artifacts (migration/upgrade path only) + 5. Default (``chroma``) + + Auto-detection is strictly a migration aid: it fires only when a local path + is presented, no earlier rule has chosen a backend, AND the path already + contains backend-identifiable artifacts. For new palaces, (5) wins. + """ + for candidate in (explicit, config_value, env_value): + if candidate: + return candidate + + _discover_entry_points() + if palace_path: + for name, cls in _registry.items(): + try: + if cls.detect(palace_path): + return name + except Exception: + logger.exception("detect() raised on backend %r", name) + continue + return default + + +# --------------------------------------------------------------------------- +# Built-in registration +# --------------------------------------------------------------------------- + + +def _register_builtins() -> None: + """Register chroma as the in-tree default.""" + from .chroma import ChromaBackend + + # Use setdefault semantics so a caller that pre-registered for tests wins. + if "chroma" not in _registry: + _registry["chroma"] = ChromaBackend + + +_register_builtins() diff --git a/mempalace/searcher.py b/mempalace/searcher.py index db809d9..081d3a7 100644 --- a/mempalace/searcher.py +++ b/mempalace/searcher.py @@ -30,15 +30,16 @@ class SearchError(Exception): _TOKEN_RE = re.compile(r"\w{2,}", re.UNICODE) -def _first_or_empty(results: dict, key: str) -> list: - """Return the first inner list of a ChromaDB query result, or []. +def _first_or_empty(results, key: str) -> list: + """Return the first inner list of a query result field, or []. - ChromaDB returns shapes like ``{"documents": [["a", "b"]], ...}`` for a - successful query, but ``{"documents": [], ...}`` (empty outer list) when - the collection is empty or the filter excludes everything. Indexing - ``[0]`` blindly raises IndexError in that case (issue #195). + Accepts both the typed :class:`QueryResult` (attribute access) and the + pre-typed chroma dict shape; this polymorphism is retained so test mocks + still work and callers mid-migration do not crash. Preserves the empty- + collection semantics from issue #195: when no queries returned hits, the + outer list may be empty and indexing ``[0]`` would raise. """ - outer = results.get(key) + outer = getattr(results, key, None) if not isinstance(results, dict) else results.get(key) if not outer: return [] return outer[0] or [] @@ -209,7 +210,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None} indexed_docs = [] - for doc, meta in zip(neighbors.get("documents") or [], neighbors.get("metadatas") or []): + for doc, meta in zip(neighbors.documents, neighbors.metadatas): ci = meta.get("chunk_index") if isinstance(ci, int): indexed_docs.append((ci, doc)) @@ -224,8 +225,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra total_drawers = None try: all_meta = drawers_col.get(where={"source_file": src}, include=["metadatas"]) - ids = all_meta.get("ids") or [] - total_drawers = len(ids) if ids else None + total_drawers = len(all_meta.ids) if all_meta.ids else None except Exception: pass @@ -451,8 +451,8 @@ def search_memories( ) except Exception: continue - docs = source_drawers.get("documents") or [] - metas_ = source_drawers.get("metadatas") or [] + docs = source_drawers.documents + metas_ = source_drawers.metadatas if len(docs) <= 1: continue diff --git a/pyproject.toml b/pyproject.toml index f3067f3..e03dbe3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,9 @@ Repository = "https://github.com/MemPalace/mempalace" [project.scripts] mempalace = "mempalace.cli:main" +[project.entry-points."mempalace.backends"] +chroma = "mempalace.backends.chroma:ChromaBackend" + [project.optional-dependencies] dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"] spellcheck = ["autocorrect>=2.0"] diff --git a/tests/test_backends.py b/tests/test_backends.py index a620bf9..6535691 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -3,12 +3,34 @@ import sqlite3 import chromadb import pytest +from mempalace.backends import ( + GetResult, + PalaceRef, + QueryResult, + UnsupportedFilterError, + available_backends, + get_backend, +) from mempalace.backends.chroma import ChromaBackend, ChromaCollection, _fix_blob_seq_ids class _FakeCollection: - def __init__(self): + """Stand-in for a chromadb.Collection returning raw chroma-shaped dicts.""" + + def __init__(self, query_response=None, get_response=None, count_value=7): self.calls = [] + self._query_response = query_response or { + "ids": [["a", "b"]], + "documents": [["da", "db"]], + "metadatas": [[{"wing": "w1"}, {"wing": "w2"}]], + "distances": [[0.1, 0.2]], + } + self._get_response = get_response or { + "ids": ["a"], + "documents": ["da"], + "metadatas": [{"wing": "w1"}], + } + self._count_value = count_value def add(self, **kwargs): self.calls.append(("add", kwargs)) @@ -16,41 +38,142 @@ class _FakeCollection: def upsert(self, **kwargs): self.calls.append(("upsert", kwargs)) + def update(self, **kwargs): + self.calls.append(("update", kwargs)) + def query(self, **kwargs): self.calls.append(("query", kwargs)) - return {"kind": "query"} + return self._query_response def get(self, **kwargs): self.calls.append(("get", kwargs)) - return {"kind": "get"} + return self._get_response def delete(self, **kwargs): self.calls.append(("delete", kwargs)) def count(self): self.calls.append(("count", {})) - return 7 + return self._count_value -def test_chroma_collection_delegates_methods(): +def test_chroma_collection_returns_typed_query_result(): + fake = _FakeCollection() + collection = ChromaCollection(fake) + + result = collection.query(query_texts=["q"]) + + assert isinstance(result, QueryResult) + assert result.ids == [["a", "b"]] + assert result.documents == [["da", "db"]] + assert result.metadatas == [[{"wing": "w1"}, {"wing": "w2"}]] + assert result.distances == [[0.1, 0.2]] + assert result.embeddings is None + + +def test_chroma_collection_returns_typed_get_result(): + fake = _FakeCollection() + collection = ChromaCollection(fake) + + result = collection.get(where={"wing": "w1"}) + + assert isinstance(result, GetResult) + assert result.ids == ["a"] + assert result.documents == ["da"] + assert result.metadatas == [{"wing": "w1"}] + + +def test_query_result_empty_preserves_outer_dimension(): + empty = QueryResult.empty(num_queries=2) + assert empty.ids == [[], []] + assert empty.documents == [[], []] + assert empty.distances == [[], []] + assert empty.embeddings is None + + +def test_typed_results_support_dict_compat_access(): + """Transitional compat shim per base.py — retained until callers migrate to attrs.""" + result = GetResult(ids=["a"], documents=["da"], metadatas=[{"w": 1}]) + assert result["ids"] == ["a"] + assert result.get("documents") == ["da"] + assert result.get("missing", "default") == "default" + assert "ids" in result + assert "missing" not in result + + +def test_chroma_collection_query_empty_result_preserves_outer_shape(): + fake = _FakeCollection( + query_response={"ids": [], "documents": [], "metadatas": [], "distances": []} + ) + collection = ChromaCollection(fake) + + result = collection.query(query_texts=["q1", "q2"]) + assert result.ids == [[], []] + assert result.documents == [[], []] + assert result.distances == [[], []] + + +def test_chroma_collection_rejects_unknown_where_operator(): + fake = _FakeCollection() + collection = ChromaCollection(fake) + + with pytest.raises(UnsupportedFilterError): + collection.query(query_texts=["q"], where={"$regex": "foo"}) + + +def test_chroma_collection_delegates_writes(): fake = _FakeCollection() collection = ChromaCollection(fake) collection.add(documents=["d"], ids=["1"], metadatas=[{"wing": "w"}]) collection.upsert(documents=["u"], ids=["2"], metadatas=[{"room": "r"}]) - assert collection.query(query_texts=["q"]) == {"kind": "query"} - assert collection.get(where={"wing": "w"}) == {"kind": "get"} collection.delete(ids=["1"]) assert collection.count() == 7 - assert fake.calls == [ - ("add", {"documents": ["d"], "ids": ["1"], "metadatas": [{"wing": "w"}]}), - ("upsert", {"documents": ["u"], "ids": ["2"], "metadatas": [{"room": "r"}]}), - ("query", {"query_texts": ["q"]}), - ("get", {"where": {"wing": "w"}}), - ("delete", {"ids": ["1"]}), - ("count", {}), - ] + kinds = [call[0] for call in fake.calls] + assert kinds == ["add", "upsert", "delete", "count"] + + +def test_registry_exposes_chroma_by_default(): + names = available_backends() + assert "chroma" in names + assert isinstance(get_backend("chroma"), ChromaBackend) + + +def test_registry_unknown_backend_raises(): + with pytest.raises(KeyError): + get_backend("no-such-backend-exists") + + +def test_resolve_backend_priority_order(tmp_path): + from mempalace.backends import resolve_backend_for_palace + + # explicit kwarg wins over everything + assert resolve_backend_for_palace(explicit="pg", config_value="lance") == "pg" + # config value wins over env / default + assert resolve_backend_for_palace(config_value="lance", env_value="qdrant") == "lance" + # env wins over default + assert resolve_backend_for_palace(env_value="qdrant", default="chroma") == "qdrant" + # falls back to default + assert resolve_backend_for_palace() == "chroma" + + +def test_chroma_detect_matches_palace_with_chroma_sqlite(tmp_path): + (tmp_path / "chroma.sqlite3").write_bytes(b"") + assert ChromaBackend.detect(str(tmp_path)) is True + assert ChromaBackend.detect(str(tmp_path.parent)) is False + + +def test_chroma_backend_accepts_palace_ref_kwarg(tmp_path): + palace_path = tmp_path / "palace" + backend = ChromaBackend() + collection = backend.get_collection( + palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)), + collection_name="mempalace_drawers", + create=True, + ) + assert palace_path.is_dir() + assert isinstance(collection, ChromaCollection) def test_chroma_backend_create_false_raises_without_creating_directory(tmp_path): From 42b940d26360f04c40ed684e77c6d6a74de837d2 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Sat, 18 Apr 2026 13:19:18 -0300 Subject: [PATCH 2/5] fix(backends): address Copilot review on #995 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Four defects surfaced by the automated review, fixed with targeted tests: 1. BaseCollection.update() default now validates that documents / metadatas / embeddings lengths match ids, raising ValueError instead of silently misaligning pairs or raising IndexError (base.py). 2. ChromaCollection.query() now rejects the two ambiguous input shapes up front — neither or both of query_texts / query_embeddings, and empty input lists — with clear ValueError messages rather than delegating to chromadb's less-obvious errors (chroma.py). 3. QueryResult.empty() accepts embeddings_requested=True to preserve the outer-query dimension with empty hit lists when the caller asked for embeddings, matching the spec rule that included fields carry the outer shape even when empty (base.py). ChromaCollection.query() threads this through on the empty-result path (chroma.py). 4. ChromaBackend cache-freshness check now matches the semantics from mcp_server._get_client (merged via #757) on three edge cases Copilot called out: (a) invalidate when chroma.sqlite3 disappears while a cached client is held, (b) treat a 0→nonzero stat transition as a change so a cache built when the DB did not yet exist is refreshed, (c) re-stat after PersistentClient constructs the DB lazily so freshness reflects the post-creation state (chroma.py). Tests: 978 passed (up from 970), 8 new tests covering the fixes. --- mempalace/backends/base.py | 22 +++++- mempalace/backends/chroma.py | 47 ++++++++++-- tests/test_backends.py | 134 +++++++++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 9 deletions(-) diff --git a/mempalace/backends/base.py b/mempalace/backends/base.py index 819d326..2ff9b87 100644 --- a/mempalace/backends/base.py +++ b/mempalace/backends/base.py @@ -133,14 +133,21 @@ class QueryResult(_DictCompatMixin): embeddings: Optional[list[list[list[float]]]] = None @classmethod - def empty(cls, num_queries: int = 1) -> "QueryResult": - """Construct an all-empty result preserving outer dimension.""" + def empty(cls, num_queries: int = 1, embeddings_requested: bool = False) -> "QueryResult": + """Construct an all-empty result preserving outer dimension. + + When ``embeddings_requested`` is True, ``embeddings`` preserves the outer + query dimension with empty hit lists (matching the spec's rule that fields + requested via ``include=`` carry the outer shape even when empty). When + False, ``embeddings`` stays ``None`` to signal the field was not requested. + """ + empty_outer = [[] for _ in range(num_queries)] return cls( ids=[[] for _ in range(num_queries)], documents=[[] for _ in range(num_queries)], metadatas=[[] for _ in range(num_queries)], distances=[[] for _ in range(num_queries)], - embeddings=None, + embeddings=empty_outer if embeddings_requested else None, ) @@ -250,6 +257,15 @@ class BaseCollection(ABC): if documents is None and metadatas is None and embeddings is None: raise ValueError("update requires at least one of documents, metadatas, embeddings") + n = len(ids) + for label, value in ( + ("documents", documents), + ("metadatas", metadatas), + ("embeddings", embeddings), + ): + if value is not None and len(value) != n: + raise ValueError(f"{label} length {len(value)} does not match ids length {n}") + existing = self.get(ids=ids, include=["documents", "metadatas"]) by_id = { rid: (existing.documents[i], existing.metadatas[i]) diff --git a/mempalace/backends/chroma.py b/mempalace/backends/chroma.py index f12a88b..835ce72 100644 --- a/mempalace/backends/chroma.py +++ b/mempalace/backends/chroma.py @@ -156,6 +156,12 @@ class ChromaCollection(BaseCollection): _validate_where(where) _validate_where(where_document) + if (query_texts is None) == (query_embeddings is None): + raise ValueError("query requires exactly one of query_texts or query_embeddings") + chosen = query_texts if query_texts is not None else query_embeddings + if not chosen: + raise ValueError("query input must be a non-empty list") + spec = _IncludeSpec.resolve(include, default_distances=True) chroma_include: list[str] = [] if spec.documents: @@ -190,7 +196,10 @@ class ChromaCollection(BaseCollection): ids = raw.get("ids") or [] if not ids: - return QueryResult.empty(num_queries=num_queries) + return QueryResult.empty( + num_queries=num_queries, + embeddings_requested=spec.embeddings, + ) documents = raw.get("documents") or [[] for _ in ids] metadatas = raw.get("metadatas") or [[] for _ in ids] @@ -332,8 +341,18 @@ class ChromaBackend(BaseBackend): """Return a cached ``PersistentClient``, rebuilding on inode/mtime change. Handles the palace-rebuild case (repair/nuke/purge) by invalidating the - cache when ``chroma.sqlite3`` changes on disk. FAT/exFAT return inode 0, - so inode comparisons only fire when non-zero (matches #757 semantics). + cache when ``chroma.sqlite3`` changes on disk. Mirrors the semantics of + ``mcp_server._get_client`` (merged via #757): + + * DB file missing while we hold a cached client → drop the cache so we + do not serve stale data after a rebuild that has not yet re-created + the DB. + * Transition 0 → nonzero stat (DB created after cache) counts as a + change, so the cached client is replaced with one that sees the DB. + * FAT/exFAT filesystems return inode 0; we never fire inode comparisons + when either side is 0 (safe fallback) but still honor mtime. + * Mtime change uses an epsilon (0.01 s) to tolerate FS timestamp + granularity without thrashing. """ if self._closed: from .base import BackendClosedError # late import avoids cycles at module load @@ -344,16 +363,32 @@ class ChromaBackend(BaseBackend): cached_inode, cached_mtime = self._freshness.get(palace_path, (0, 0.0)) current_inode, current_mtime = self._db_stat(palace_path) + db_path = os.path.join(palace_path, "chroma.sqlite3") + # DB was present when cache was built but is now missing → invalidate. + if cached is not None and not os.path.isfile(db_path): + self._clients.pop(palace_path, None) + self._freshness.pop(palace_path, None) + cached = None + cached_inode, cached_mtime = 0, 0.0 + inode_changed = current_inode != 0 and cached_inode != 0 and current_inode != cached_inode + # Transition from no-stat (0.0) to a real stat counts as a change so we + # pick up a DB that was created after the cache was built. + mtime_appeared = cached_mtime == 0.0 and current_mtime != 0.0 mtime_changed = ( - current_mtime != 0.0 and cached_mtime != 0.0 and current_mtime > cached_mtime + current_mtime != 0.0 + and cached_mtime != 0.0 + and abs(current_mtime - cached_mtime) > 0.01 ) - if cached is None or inode_changed or mtime_changed: + if cached is None or inode_changed or mtime_changed or mtime_appeared: _fix_blob_seq_ids(palace_path) cached = chromadb.PersistentClient(path=palace_path) self._clients[palace_path] = cached - self._freshness[palace_path] = (current_inode, current_mtime) + # Re-stat after the client constructor runs: chromadb creates + # chroma.sqlite3 lazily, so the stat captured before the call + # may still be (0, 0.0) on first open. + self._freshness[palace_path] = self._db_stat(palace_path) return cached # ------------------------------------------------------------------ diff --git a/tests/test_backends.py b/tests/test_backends.py index 6535691..f019927 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -164,6 +164,140 @@ def test_chroma_detect_matches_palace_with_chroma_sqlite(tmp_path): assert ChromaBackend.detect(str(tmp_path.parent)) is False +def test_query_rejects_missing_input(): + fake = _FakeCollection() + collection = ChromaCollection(fake) + with pytest.raises(ValueError): + collection.query() + + +def test_query_rejects_both_texts_and_embeddings(): + fake = _FakeCollection() + collection = ChromaCollection(fake) + with pytest.raises(ValueError): + collection.query(query_texts=["q"], query_embeddings=[[0.1, 0.2]]) + + +def test_query_rejects_empty_input_list(): + fake = _FakeCollection() + collection = ChromaCollection(fake) + with pytest.raises(ValueError): + collection.query(query_texts=[]) + + +def test_query_empty_preserves_embeddings_outer_shape_when_requested(): + fake = _FakeCollection( + query_response={"ids": [], "documents": [], "metadatas": [], "distances": []} + ) + collection = ChromaCollection(fake) + + requested = collection.query(query_texts=["q1", "q2"], include=["documents", "embeddings"]) + assert requested.embeddings == [[], []] + + not_requested = collection.query(query_texts=["q1", "q2"], include=["documents"]) + assert not_requested.embeddings is None + + +def test_base_collection_update_default_validates_list_lengths(tmp_path): + backend = ChromaBackend() + palace_path = tmp_path / "palace" + collection = backend.get_collection( + palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)), + collection_name="mempalace_drawers", + create=True, + ) + + # Mismatched documents length → clear ValueError, not silent merge. + with pytest.raises(ValueError, match="documents length"): + collection._collection.add( + documents=["a", "b"], + ids=["1", "2"], + metadatas=[{"k": 1}, {"k": 2}], + ) + from mempalace.backends.base import BaseCollection + + BaseCollection.update( + collection, + ids=["1", "2"], + documents=["only-one"], + ) + + +def test_chroma_cache_invalidates_when_db_file_missing(tmp_path): + """A palace rebuild that removes chroma.sqlite3 must drop the stale cache.""" + backend = ChromaBackend() + palace_path = tmp_path / "palace" + backend.get_collection( + palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)), + collection_name="mempalace_drawers", + create=True, + ) + assert str(palace_path) in backend._clients + prior_client = backend._clients[str(palace_path)] + prior_freshness = backend._freshness[str(palace_path)] + assert prior_freshness != (0, 0.0) # DB file exists after get_or_create_collection + + # Remove chroma.sqlite3 to simulate a rebuild mid-flight. The stale cache + # must not be silently reused — the in-memory HNSW index would be wrong. + (palace_path / "chroma.sqlite3").unlink() + + new_client = backend._client(str(palace_path)) + # New client object (cache was replaced, not reused) and freshness was reset + # to (0, 0.0) to reflect "no DB on disk yet" state. + assert new_client is not prior_client + assert backend._freshness[str(palace_path)] == (0, 0.0) + + +def test_chroma_cache_picks_up_db_created_after_first_open(tmp_path): + """The 0 → nonzero stat transition invalidates a cache built before the DB existed.""" + backend = ChromaBackend() + palace_path = tmp_path / "palace" + palace_path.mkdir() + + # Seed an entry in the caches as if a prior _client() call had opened the + # palace when chroma.sqlite3 did not exist yet. Freshness (0, 0.0) is the + # signal that the DB was absent at cache time. + sentinel = object() + backend._clients[str(palace_path)] = sentinel + backend._freshness[str(palace_path)] = (0, 0.0) + + # The DB file now appears (real chromadb would have created it by now). + # Use a real chromadb call so _fix_blob_seq_ids and PersistentClient succeed. + import chromadb as _chromadb + + _chromadb.PersistentClient(path=str(palace_path)).get_or_create_collection("seed") + assert (palace_path / "chroma.sqlite3").is_file() + + # Next _client() call must detect the 0 → nonzero transition and rebuild. + refreshed = backend._client(str(palace_path)) + assert refreshed is not sentinel + assert backend._freshness[str(palace_path)] != (0, 0.0) + + +def test_base_collection_update_default_rejects_mismatched_lengths(tmp_path): + """The ABC default update() raises ValueError rather than silently misaligning.""" + from mempalace.backends.base import BaseCollection + + backend = ChromaBackend() + palace_path = tmp_path / "palace" + collection = backend.get_collection( + palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)), + collection_name="mempalace_drawers", + create=True, + ) + collection.add( + documents=["a", "b"], + ids=["1", "2"], + metadatas=[{"k": 1}, {"k": 2}], + ) + + with pytest.raises(ValueError, match="documents length"): + BaseCollection.update(collection, ids=["1", "2"], documents=["only-one"]) + + with pytest.raises(ValueError, match="metadatas length"): + BaseCollection.update(collection, ids=["1", "2"], metadatas=[{"k": 9}]) + + def test_chroma_backend_accepts_palace_ref_kwarg(tmp_path): palace_path = tmp_path / "palace" backend = ChromaBackend() From 24bf97bb65c38bd4069d78a864f65ea0200a6851 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 18 Apr 2026 16:23:58 +0000 Subject: [PATCH 3/5] fix(tests): avoid ONNX network download in update-length validation tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit test_base_collection_update_default_validates_list_lengths and test_base_collection_update_default_rejects_mismatched_lengths were spinning up a real ChromaBackend and calling add(documents=...), which triggered ChromaDB's default ONNX embedding function and attempted a network download — failing in offline/sandboxed CI. BaseCollection.update() validates list lengths before any DB access, so no items need to be pre-loaded for the length-check to fire. Switch both tests to use _FakeCollection (same as the rest of the unit tests in this file) so they are pure in-memory and network-free. Also fixes a structural bug in test 1: collection._collection.add() was accidentally placed inside the pytest.raises(ValueError) block, masking the real assertion. Agent-Logs-Url: https://github.com/MemPalace/mempalace/sessions/55fc663e-b256-4b8b-88ce-4271560def8d Co-authored-by: igorls <4753812+igorls@users.noreply.github.com> --- tests/test_backends.py | 34 ++++++---------------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/tests/test_backends.py b/tests/test_backends.py index f019927..29f2b9b 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -198,24 +198,13 @@ def test_query_empty_preserves_embeddings_outer_shape_when_requested(): assert not_requested.embeddings is None -def test_base_collection_update_default_validates_list_lengths(tmp_path): - backend = ChromaBackend() - palace_path = tmp_path / "palace" - collection = backend.get_collection( - palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)), - collection_name="mempalace_drawers", - create=True, - ) +def test_base_collection_update_default_validates_list_lengths(): + from mempalace.backends.base import BaseCollection + + collection = ChromaCollection(_FakeCollection()) # Mismatched documents length → clear ValueError, not silent merge. with pytest.raises(ValueError, match="documents length"): - collection._collection.add( - documents=["a", "b"], - ids=["1", "2"], - metadatas=[{"k": 1}, {"k": 2}], - ) - from mempalace.backends.base import BaseCollection - BaseCollection.update( collection, ids=["1", "2"], @@ -274,22 +263,11 @@ def test_chroma_cache_picks_up_db_created_after_first_open(tmp_path): assert backend._freshness[str(palace_path)] != (0, 0.0) -def test_base_collection_update_default_rejects_mismatched_lengths(tmp_path): +def test_base_collection_update_default_rejects_mismatched_lengths(): """The ABC default update() raises ValueError rather than silently misaligning.""" from mempalace.backends.base import BaseCollection - backend = ChromaBackend() - palace_path = tmp_path / "palace" - collection = backend.get_collection( - palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)), - collection_name="mempalace_drawers", - create=True, - ) - collection.add( - documents=["a", "b"], - ids=["1", "2"], - metadatas=[{"k": 1}, {"k": 2}], - ) + collection = ChromaCollection(_FakeCollection()) with pytest.raises(ValueError, match="documents length"): BaseCollection.update(collection, ids=["1", "2"], documents=["only-one"]) From 61dd6e7d9c4d93d6b35fb3b2027d37ce532883d0 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Sat, 18 Apr 2026 13:52:40 -0300 Subject: [PATCH 4/5] test(backends): fix Windows file-lock in cache-invalidation test PermissionError [WinError 32] on Windows when Path.unlink() runs while chromadb.PersistentClient still holds a handle on chroma.sqlite3. Rewrite test_chroma_cache_invalidates_when_db_file_missing to prime backend._clients/_freshness with a sentinel object instead of opening a real PersistentClient, so the unlink runs against an unheld file. The assertion is also corrected: after invalidation, ChromaBackend's _client rebuilds a fresh PersistentClient which re-creates chroma.sqlite3 and re-stats it, so freshness ends up at the post-rebuild stat (not (0, 0.0) as the assertion previously expected). The meaningful invariant is "freshness advanced past the pre-unlink value AND the sentinel was replaced", which the test now checks. Ref: Windows CI failure on 995. --- tests/test_backends.py | 45 ++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/tests/test_backends.py b/tests/test_backends.py index 29f2b9b..a16df91 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -213,28 +213,39 @@ def test_base_collection_update_default_validates_list_lengths(): def test_chroma_cache_invalidates_when_db_file_missing(tmp_path): - """A palace rebuild that removes chroma.sqlite3 must drop the stale cache.""" + """A palace rebuild that removes chroma.sqlite3 must drop the stale cache. + + Primes backend._clients/_freshness directly with a sentinel rather than + opening a real ``PersistentClient``: on Windows the sqlite file handle + would still be live and ``Path.unlink`` would raise ``PermissionError``, + making the test unable to exercise the branch we care about. The decision + logic under test is pure (no chromadb calls before the branch), so a + sentinel is sufficient. + """ backend = ChromaBackend() palace_path = tmp_path / "palace" - backend.get_collection( - palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)), - collection_name="mempalace_drawers", - create=True, - ) - assert str(palace_path) in backend._clients - prior_client = backend._clients[str(palace_path)] - prior_freshness = backend._freshness[str(palace_path)] - assert prior_freshness != (0, 0.0) # DB file exists after get_or_create_collection + palace_path.mkdir() + db_file = palace_path / "chroma.sqlite3" + db_file.write_bytes(b"") # any file is enough for _db_stat to see it + st = db_file.stat() - # Remove chroma.sqlite3 to simulate a rebuild mid-flight. The stale cache - # must not be silently reused — the in-memory HNSW index would be wrong. - (palace_path / "chroma.sqlite3").unlink() + sentinel = object() + backend._clients[str(palace_path)] = sentinel + backend._freshness[str(palace_path)] = (st.st_ino, st.st_mtime) + # Simulate a rebuild mid-flight: chroma.sqlite3 goes away. Safe to unlink + # because nothing in this test is holding an OS handle on the file. + db_file.unlink() + + prior_freshness = (st.st_ino, st.st_mtime) new_client = backend._client(str(palace_path)) - # New client object (cache was replaced, not reused) and freshness was reset - # to (0, 0.0) to reflect "no DB on disk yet" state. - assert new_client is not prior_client - assert backend._freshness[str(palace_path)] == (0, 0.0) + # Cache was replaced (not the sentinel) and freshness reflects the post- + # rebuild stat (chromadb re-creates chroma.sqlite3 during PersistentClient + # construction; _client re-stats after the constructor so freshness is + # not frozen at the pre-rebuild value). The stale cached sentinel would + # have served wrong data if returned. + assert new_client is not sentinel + assert backend._freshness[str(palace_path)] != prior_freshness def test_chroma_cache_picks_up_db_created_after_first_open(tmp_path): From efaa39bea9365de526ce9ad4533f4b1e662e6a3c Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Sat, 18 Apr 2026 13:53:46 -0300 Subject: [PATCH 5/5] test(backends): dedup update-length-validation tests 24bf97b (network-download fix) and my earlier Copilot-review commit both added tests for the same ValueError. Keep the broader one that covers both 'documents length' and 'metadatas length' mismatches; drop the narrower duplicate. --- tests/test_backends.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/test_backends.py b/tests/test_backends.py index a16df91..2ca2d5a 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -198,20 +198,6 @@ def test_query_empty_preserves_embeddings_outer_shape_when_requested(): assert not_requested.embeddings is None -def test_base_collection_update_default_validates_list_lengths(): - from mempalace.backends.base import BaseCollection - - collection = ChromaCollection(_FakeCollection()) - - # Mismatched documents length → clear ValueError, not silent merge. - with pytest.raises(ValueError, match="documents length"): - BaseCollection.update( - collection, - ids=["1", "2"], - documents=["only-one"], - ) - - def test_chroma_cache_invalidates_when_db_file_missing(tmp_path): """A palace rebuild that removes chroma.sqlite3 must drop the stale cache.