Merge pull request #995 from MemPalace/refactor/rfc-001-cleanup

refactor(backends): RFC 001 §10 cleanup — typed results, PalaceRef, registry
This commit is contained in:
Igor Lins e Silva
2026-04-18 15:56:12 -03:00
committed by GitHub
7 changed files with 1303 additions and 94 deletions
+61 -3
View File
@@ -1,6 +1,64 @@
"""Storage backend implementations for MemPalace.""" """Storage backend implementations for MemPalace (RFC 001).
from .base import BaseCollection Public surface:
* :class:`BaseCollection` — per-collection read/write contract.
* :class:`BaseBackend` — per-palace factory contract.
* :class:`PalaceRef` — value object identifying a palace for a backend.
* :class:`QueryResult` / :class:`GetResult` — typed read returns.
* Error classes: :class:`PalaceNotFoundError`, :class:`BackendClosedError`,
:class:`UnsupportedFilterError`, :class:`DimensionMismatchError`,
:class:`EmbedderIdentityMismatchError`.
* Registry: :func:`get_backend`, :func:`register`, :func:`available_backends`,
:func:`resolve_backend_for_palace`.
* In-tree Chroma default: :class:`ChromaBackend`, :class:`ChromaCollection`.
"""
from .base import (
BackendClosedError,
BackendError,
BaseBackend,
BaseCollection,
DimensionMismatchError,
EmbedderIdentityMismatchError,
GetResult,
HealthStatus,
PalaceNotFoundError,
PalaceRef,
QueryResult,
UnsupportedFilterError,
)
from .chroma import ChromaBackend, ChromaCollection from .chroma import ChromaBackend, ChromaCollection
from .registry import (
available_backends,
get_backend,
get_backend_class,
register,
reset_backends,
resolve_backend_for_palace,
unregister,
)
__all__ = ["BaseCollection", "ChromaBackend", "ChromaCollection"] __all__ = [
"BackendClosedError",
"BackendError",
"BaseBackend",
"BaseCollection",
"ChromaBackend",
"ChromaCollection",
"DimensionMismatchError",
"EmbedderIdentityMismatchError",
"GetResult",
"HealthStatus",
"PalaceNotFoundError",
"PalaceRef",
"QueryResult",
"UnsupportedFilterError",
"available_backends",
"get_backend",
"get_backend_class",
"register",
"reset_backends",
"resolve_backend_for_palace",
"unregister",
]
+348 -27
View File
@@ -1,49 +1,370 @@
"""Abstract collection interface for MemPalace storage backends.""" """Storage backend contract for MemPalace (RFC 001).
This module defines the surface every storage backend must implement:
* ``BaseCollection`` — the per-collection read/write interface, kwargs-only.
* ``BaseBackend`` — the per-palace factory, addressed by ``PalaceRef``.
* ``QueryResult`` / ``GetResult`` — typed result dataclasses that replace the
Chroma dict shape as the canonical return type.
* Error classes + ``HealthStatus`` — uniform across backends.
This is the v1 cleanup from RFC 001 §10: full typed results, ``PalaceRef``,
registry-ready ABC. Embedder injection, maintenance hooks, and the full
conformance suite land in follow-up PRs.
"""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from dataclasses import dataclass
from typing import ClassVar, Optional
# ---------------------------------------------------------------------------
# Errors
# ---------------------------------------------------------------------------
class BackendError(Exception):
"""Base class for every storage-backend error raised by core."""
class PalaceNotFoundError(BackendError, FileNotFoundError):
"""Raised when ``get_collection(create=False)`` is called on a missing palace.
Subclass of ``FileNotFoundError`` so legacy callers that catch the latter
(pre-#413 seam) keep working unchanged.
"""
class BackendClosedError(BackendError):
"""Raised when a backend method is called after ``close()``."""
class UnsupportedFilterError(BackendError):
"""Raised when a where-clause uses an operator the backend does not implement.
Silent dropping of unknown operators is forbidden by spec (RFC 001 §1.4).
"""
class DimensionMismatchError(BackendError):
"""Raised when the embedding dimension on write does not match the collection."""
class EmbedderIdentityMismatchError(BackendError):
"""Raised when the stored embedder model name differs from the current one."""
# ---------------------------------------------------------------------------
# Value objects
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class PalaceRef:
"""A handle to a palace, consumed by backends.
``id`` is always present and is the key backends use to cache handles.
``local_path`` is populated for filesystem-rooted palaces.
``namespace`` is used by server-mode backends for tenant / prefix routing.
"""
id: str
local_path: Optional[str] = None
namespace: Optional[str] = None
@dataclass(frozen=True)
class HealthStatus:
ok: bool
detail: str = ""
@classmethod
def healthy(cls, detail: str = "") -> "HealthStatus":
return cls(ok=True, detail=detail)
@classmethod
def unhealthy(cls, detail: str) -> "HealthStatus":
return cls(ok=False, detail=detail)
_TYPED_RESULT_FIELDS = ("ids", "documents", "metadatas", "distances", "embeddings")
class _DictCompatMixin:
"""Transitional dict-protocol access for typed results.
RFC 001 §1.3 spec is attribute access (``result.ids``). The ``result["ids"]``
and ``result.get("ids")`` forms are retained as a migration shim for callers
that predate the typed interface and are scheduled for removal in a follow-
up cleanup. New code MUST use attribute access.
"""
def __getitem__(self, key: str):
if key in _TYPED_RESULT_FIELDS:
return getattr(self, key)
raise KeyError(key)
def get(self, key: str, default=None):
if key in _TYPED_RESULT_FIELDS:
val = getattr(self, key, default)
return default if val is None else val
return default
def __contains__(self, key: object) -> bool:
return key in _TYPED_RESULT_FIELDS and getattr(self, key, None) is not None
@dataclass(frozen=True)
class QueryResult(_DictCompatMixin):
"""Typed return from ``BaseCollection.query``.
Outer list dimension = number of query vectors / texts.
Inner list dimension = hits per query (may be zero).
Fields not in ``include=`` at the call site are populated with empty lists
of the correct outer shape (never ``None``), except ``embeddings`` which
is ``None`` when not requested.
"""
ids: list[list[str]]
documents: list[list[str]]
metadatas: list[list[dict]]
distances: list[list[float]]
embeddings: Optional[list[list[list[float]]]] = None
@classmethod
def empty(cls, num_queries: int = 1, embeddings_requested: bool = False) -> "QueryResult":
"""Construct an all-empty result preserving outer dimension.
When ``embeddings_requested`` is True, ``embeddings`` preserves the outer
query dimension with empty hit lists (matching the spec's rule that fields
requested via ``include=`` carry the outer shape even when empty). When
False, ``embeddings`` stays ``None`` to signal the field was not requested.
"""
empty_outer = [[] for _ in range(num_queries)]
return cls(
ids=[[] for _ in range(num_queries)],
documents=[[] for _ in range(num_queries)],
metadatas=[[] for _ in range(num_queries)],
distances=[[] for _ in range(num_queries)],
embeddings=empty_outer if embeddings_requested else None,
)
@dataclass(frozen=True)
class GetResult(_DictCompatMixin):
"""Typed return from ``BaseCollection.get``."""
ids: list[str]
documents: list[str]
metadatas: list[dict]
embeddings: Optional[list[list[float]]] = None
@classmethod
def empty(cls) -> "GetResult":
return cls(ids=[], documents=[], metadatas=[], embeddings=None)
# ---------------------------------------------------------------------------
# Collection contract
# ---------------------------------------------------------------------------
class BaseCollection(ABC): class BaseCollection(ABC):
"""Smallest collection contract the rest of MemPalace relies on.""" """Per-collection read/write surface every backend must implement."""
@abstractmethod @abstractmethod
def add( def add(
self, self,
*, *,
documents: List[str], documents: list[str],
ids: List[str], ids: list[str],
metadatas: Optional[List[Dict[str, Any]]] = None, metadatas: Optional[list[dict]] = None,
) -> None: embeddings: Optional[list[list[float]]] = None,
raise NotImplementedError ) -> None: ...
@abstractmethod @abstractmethod
def upsert( def upsert(
self, self,
*, *,
documents: List[str], documents: list[str],
ids: List[str], ids: list[str],
metadatas: Optional[List[Dict[str, Any]]] = None, metadatas: Optional[list[dict]] = None,
embeddings: Optional[list[list[float]]] = None,
) -> None: ...
@abstractmethod
def query(
self,
*,
query_texts: Optional[list[str]] = None,
query_embeddings: Optional[list[list[float]]] = None,
n_results: int = 10,
where: Optional[dict] = None,
where_document: Optional[dict] = None,
include: Optional[list[str]] = None,
) -> QueryResult: ...
@abstractmethod
def get(
self,
*,
ids: Optional[list[str]] = None,
where: Optional[dict] = None,
where_document: Optional[dict] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
include: Optional[list[str]] = None,
) -> GetResult: ...
@abstractmethod
def delete(
self,
*,
ids: Optional[list[str]] = None,
where: Optional[dict] = None,
) -> None: ...
@abstractmethod
def count(self) -> int: ...
# ------------------------------------------------------------------
# Optional methods with ABC defaults (spec §1.2)
# ------------------------------------------------------------------
def estimated_count(self) -> int:
return self.count()
def close(self) -> None:
return None
def health(self) -> HealthStatus:
return HealthStatus.healthy()
def update(
self,
*,
ids: list[str],
documents: Optional[list[str]] = None,
metadatas: Optional[list[dict]] = None,
embeddings: Optional[list[list[float]]] = None,
) -> None: ) -> None:
raise NotImplementedError """Default non-atomic update: get + merge + upsert.
Backends advertising ``supports_update`` MUST override with an atomic
single-round-trip implementation.
"""
if documents is None and metadatas is None and embeddings is None:
raise ValueError("update requires at least one of documents, metadatas, embeddings")
n = len(ids)
for label, value in (
("documents", documents),
("metadatas", metadatas),
("embeddings", embeddings),
):
if value is not None and len(value) != n:
raise ValueError(f"{label} length {len(value)} does not match ids length {n}")
existing = self.get(ids=ids, include=["documents", "metadatas"])
by_id = {
rid: (existing.documents[i], existing.metadatas[i])
for i, rid in enumerate(existing.ids)
}
merged_docs: list[str] = []
merged_metas: list[dict] = []
for i, rid in enumerate(ids):
prev_doc, prev_meta = by_id.get(rid, ("", {}))
merged_docs.append(documents[i] if documents is not None else prev_doc)
new_meta = dict(prev_meta or {})
if metadatas is not None:
new_meta.update(metadatas[i] or {})
merged_metas.append(new_meta)
self.upsert(
documents=merged_docs,
ids=list(ids),
metadatas=merged_metas,
embeddings=embeddings,
)
# ---------------------------------------------------------------------------
# Backend contract
# ---------------------------------------------------------------------------
class BaseBackend(ABC):
"""Long-lived factory serving many palaces (RFC 001 §2).
Instances are lightweight on construction — no I/O, no network. All
connection work is deferred to ``get_collection``. Instances are thread-
safe for concurrent ``get_collection`` calls across different palaces.
"""
name: ClassVar[str]
spec_version: ClassVar[str] = "1.0"
capabilities: ClassVar[frozenset[str]] = frozenset()
@abstractmethod @abstractmethod
def update(self, **kwargs: Any) -> None: def get_collection(
"""Update existing records. Must raise if any ID is missing.""" self,
raise NotImplementedError *,
palace: PalaceRef,
collection_name: str,
create: bool = False,
options: Optional[dict] = None,
) -> BaseCollection: ...
@abstractmethod def close_palace(self, palace: PalaceRef) -> None:
def query(self, **kwargs: Any) -> Dict[str, Any]: """Evict cached handles for a single palace. Default: no-op."""
raise NotImplementedError return None
@abstractmethod def close(self) -> None:
def get(self, **kwargs: Any) -> Dict[str, Any]: """Shut down the entire backend. Default: no-op."""
raise NotImplementedError return None
@abstractmethod def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus:
def delete(self, **kwargs: Any) -> None: return HealthStatus.healthy()
raise NotImplementedError
@abstractmethod # Optional detection hint used by selection priority (RFC 001 §3.3 (4)):
def count(self) -> int: @classmethod
raise NotImplementedError def detect(cls, path: str) -> bool: # pragma: no cover - default hook
return False
# ---------------------------------------------------------------------------
# Adapter utilities
# ---------------------------------------------------------------------------
# Keys the Chroma ``include=`` parameter accepts.
_VALID_INCLUDE_KEYS = frozenset({"documents", "metadatas", "distances", "embeddings"})
@dataclass
class _IncludeSpec:
"""Resolve an ``include=`` parameter with spec-mandated defaults."""
documents: bool = True
metadatas: bool = True
distances: bool = True # only meaningful for query
embeddings: bool = False
@classmethod
def resolve(
cls, include: Optional[list[str]], *, default_distances: bool = True
) -> "_IncludeSpec":
if include is None:
return cls(
documents=True,
metadatas=True,
distances=default_distances,
embeddings=False,
)
keys = {k for k in include if k in _VALID_INCLUDE_KEYS}
return cls(
documents="documents" in keys,
metadatas="metadatas" in keys,
distances="distances" in keys,
embeddings="embeddings" in keys,
)
+443 -37
View File
@@ -1,17 +1,54 @@
"""ChromaDB-backed MemPalace collection adapter.""" """ChromaDB-backed MemPalace storage backend (RFC 001 reference implementation)."""
import logging import logging
import os import os
import sqlite3 import sqlite3
from typing import Any, Optional
import chromadb import chromadb
from .base import BaseCollection from .base import (
BaseBackend,
BaseCollection,
GetResult,
HealthStatus,
PalaceNotFoundError,
PalaceRef,
QueryResult,
UnsupportedFilterError,
_IncludeSpec,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _fix_blob_seq_ids(palace_path: str): _REQUIRED_OPERATORS = frozenset({"$eq", "$ne", "$in", "$nin", "$and", "$or", "$contains"})
_OPTIONAL_OPERATORS = frozenset({"$gt", "$gte", "$lt", "$lte"})
_SUPPORTED_OPERATORS = _REQUIRED_OPERATORS | _OPTIONAL_OPERATORS
def _validate_where(where: Optional[dict]) -> None:
"""Scan a where-clause for unknown operators and raise ``UnsupportedFilterError``.
Spec (RFC 001 §1.4): silent dropping of unknown operators is forbidden.
"""
if not where:
return
stack = [where]
while stack:
node = stack.pop()
if not isinstance(node, dict):
continue
for k, v in node.items():
if k.startswith("$") and k not in _SUPPORTED_OPERATORS:
raise UnsupportedFilterError(f"operator {k!r} not supported by chroma backend")
if isinstance(v, dict):
stack.append(v)
elif isinstance(v, list):
stack.extend(x for x in v if isinstance(x, dict))
def _fix_blob_seq_ids(palace_path: str) -> None:
"""Fix ChromaDB 0.6.x -> 1.5.x migration bug: BLOB seq_ids -> INTEGER. """Fix ChromaDB 0.6.x -> 1.5.x migration bug: BLOB seq_ids -> INTEGER.
ChromaDB 0.6.x stored seq_id as big-endian 8-byte BLOBs. ChromaDB 1.5.x ChromaDB 0.6.x stored seq_id as big-endian 8-byte BLOBs. ChromaDB 1.5.x
@@ -43,62 +80,328 @@ def _fix_blob_seq_ids(palace_path: str):
logger.exception("Could not fix BLOB seq_ids in %s", db_path) logger.exception("Could not fix BLOB seq_ids in %s", db_path)
# ---------------------------------------------------------------------------
# Collection adapter
# ---------------------------------------------------------------------------
def _as_list(v: Any) -> list:
"""Coerce possibly-None scalar-or-list into a list (defensive for chroma nulls)."""
if v is None:
return []
if isinstance(v, list):
return v
return [v]
class ChromaCollection(BaseCollection): class ChromaCollection(BaseCollection):
"""Thin adapter over a ChromaDB collection.""" """Thin adapter translating ChromaDB dict returns into typed results."""
def __init__(self, collection): def __init__(self, collection):
self._collection = collection self._collection = collection
def add(self, *, documents, ids, metadatas=None): # ------------------------------------------------------------------
self._collection.add(documents=documents, ids=ids, metadatas=metadatas) # Writes
# ------------------------------------------------------------------
def upsert(self, *, documents, ids, metadatas=None): def add(self, *, documents, ids, metadatas=None, embeddings=None):
self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas) kwargs: dict[str, Any] = {"documents": documents, "ids": ids}
if metadatas is not None:
kwargs["metadatas"] = metadatas
if embeddings is not None:
kwargs["embeddings"] = embeddings
self._collection.add(**kwargs)
def update(self, **kwargs): def upsert(self, *, documents, ids, metadatas=None, embeddings=None):
kwargs: dict[str, Any] = {"documents": documents, "ids": ids}
if metadatas is not None:
kwargs["metadatas"] = metadatas
if embeddings is not None:
kwargs["embeddings"] = embeddings
self._collection.upsert(**kwargs)
def update(
self,
*,
ids,
documents=None,
metadatas=None,
embeddings=None,
):
if documents is None and metadatas is None and embeddings is None:
raise ValueError("update requires at least one of documents, metadatas, embeddings")
kwargs: dict[str, Any] = {"ids": ids}
if documents is not None:
kwargs["documents"] = documents
if metadatas is not None:
kwargs["metadatas"] = metadatas
if embeddings is not None:
kwargs["embeddings"] = embeddings
self._collection.update(**kwargs) self._collection.update(**kwargs)
def query(self, **kwargs): # ------------------------------------------------------------------
return self._collection.query(**kwargs) # Reads
# ------------------------------------------------------------------
def get(self, **kwargs): def query(
return self._collection.get(**kwargs) self,
*,
query_texts=None,
query_embeddings=None,
n_results=10,
where=None,
where_document=None,
include=None,
) -> QueryResult:
_validate_where(where)
_validate_where(where_document)
def delete(self, **kwargs): if (query_texts is None) == (query_embeddings is None):
raise ValueError("query requires exactly one of query_texts or query_embeddings")
chosen = query_texts if query_texts is not None else query_embeddings
if not chosen:
raise ValueError("query input must be a non-empty list")
spec = _IncludeSpec.resolve(include, default_distances=True)
chroma_include: list[str] = []
if spec.documents:
chroma_include.append("documents")
if spec.metadatas:
chroma_include.append("metadatas")
if spec.distances:
chroma_include.append("distances")
if spec.embeddings:
chroma_include.append("embeddings")
kwargs: dict[str, Any] = {
"n_results": n_results,
"include": chroma_include,
}
if query_texts is not None:
kwargs["query_texts"] = query_texts
if query_embeddings is not None:
kwargs["query_embeddings"] = query_embeddings
if where is not None:
kwargs["where"] = where
if where_document is not None:
kwargs["where_document"] = where_document
raw = self._collection.query(**kwargs)
num_queries = (
len(query_texts)
if query_texts is not None
else (len(query_embeddings) if query_embeddings is not None else 1)
)
ids = raw.get("ids") or []
if not ids:
return QueryResult.empty(
num_queries=num_queries,
embeddings_requested=spec.embeddings,
)
documents = raw.get("documents") or [[] for _ in ids]
metadatas = raw.get("metadatas") or [[] for _ in ids]
distances = raw.get("distances") or [[] for _ in ids]
embeddings_raw = raw.get("embeddings") if spec.embeddings else None
def _none_list_to_empty(outer):
return [(inner or []) for inner in outer]
return QueryResult(
ids=_none_list_to_empty(ids),
documents=_none_list_to_empty(documents),
metadatas=_none_list_to_empty(metadatas),
distances=_none_list_to_empty(distances),
embeddings=(
[list(inner) for inner in embeddings_raw]
if spec.embeddings and embeddings_raw is not None
else None
),
)
def get(
self,
*,
ids=None,
where=None,
where_document=None,
limit=None,
offset=None,
include=None,
) -> GetResult:
_validate_where(where)
_validate_where(where_document)
spec = _IncludeSpec.resolve(include, default_distances=False)
chroma_include: list[str] = []
if spec.documents:
chroma_include.append("documents")
if spec.metadatas:
chroma_include.append("metadatas")
if spec.embeddings:
chroma_include.append("embeddings")
kwargs: dict[str, Any] = {"include": chroma_include}
if ids is not None:
kwargs["ids"] = ids
if where is not None:
kwargs["where"] = where
if where_document is not None:
kwargs["where_document"] = where_document
if limit is not None:
kwargs["limit"] = limit
if offset is not None:
kwargs["offset"] = offset
raw = self._collection.get(**kwargs)
out_ids = list(raw.get("ids") or [])
out_docs = list(raw.get("documents") or []) if spec.documents else []
out_metas = list(raw.get("metadatas") or []) if spec.metadatas else []
out_embeds = raw.get("embeddings") if spec.embeddings else None
# Pad doc/meta lists to match ids so downstream zipping is safe.
if spec.documents and len(out_docs) < len(out_ids):
out_docs = out_docs + [""] * (len(out_ids) - len(out_docs))
if spec.metadatas and len(out_metas) < len(out_ids):
out_metas = out_metas + [{}] * (len(out_ids) - len(out_metas))
return GetResult(
ids=out_ids,
documents=out_docs,
metadatas=out_metas,
embeddings=[list(v) for v in out_embeds] if out_embeds is not None else None,
)
def delete(self, *, ids=None, where=None):
_validate_where(where)
kwargs: dict[str, Any] = {}
if ids is not None:
kwargs["ids"] = ids
if where is not None:
kwargs["where"] = where
self._collection.delete(**kwargs) self._collection.delete(**kwargs)
def count(self): def count(self):
return self._collection.count() return self._collection.count()
class ChromaBackend: # ---------------------------------------------------------------------------
"""Factory for MemPalace's default ChromaDB backend.""" # Backend
# ---------------------------------------------------------------------------
class ChromaBackend(BaseBackend):
"""MemPalace's default ChromaDB backend.
Maintains two caches:
* ``self._clients`` — ``palace_path -> PersistentClient`` for callers
using the ``PalaceRef`` / :meth:`get_collection` path.
* An inode+mtime freshness check absorbed from ``mcp_server._get_client``
(merged via #757) ensuring a palace rebuild on disk is detected on the
next :meth:`get_collection` call.
"""
name = "chroma"
capabilities = frozenset(
{
"supports_embeddings_in",
"supports_embeddings_passthrough",
"supports_embeddings_out",
"supports_metadata_filters",
"supports_contains_fast",
"local_mode",
}
)
def __init__(self): def __init__(self):
# Per-instance client cache: palace_path -> chromadb.PersistentClient # palace_path -> PersistentClient
self._clients: dict = {} self._clients: dict[str, Any] = {}
# palace_path -> (inode, mtime) of chroma.sqlite3 at cache time.
self._freshness: dict[str, tuple[int, float]] = {}
self._closed = False
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Internal helpers # Internal helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@staticmethod
def _db_stat(palace_path: str) -> tuple[int, float]:
"""Return ``(inode, mtime)`` of ``chroma.sqlite3`` or ``(0, 0.0)`` if absent."""
db_path = os.path.join(palace_path, "chroma.sqlite3")
try:
st = os.stat(db_path)
return (st.st_ino, st.st_mtime)
except OSError:
return (0, 0.0)
def _client(self, palace_path: str): def _client(self, palace_path: str):
"""Return a cached PersistentClient for *palace_path*, creating one if needed.""" """Return a cached ``PersistentClient``, rebuilding on inode/mtime change.
if palace_path not in self._clients:
Handles the palace-rebuild case (repair/nuke/purge) by invalidating the
cache when ``chroma.sqlite3`` changes on disk. Mirrors the semantics of
``mcp_server._get_client`` (merged via #757):
* DB file missing while we hold a cached client → drop the cache so we
do not serve stale data after a rebuild that has not yet re-created
the DB.
* Transition 0 → nonzero stat (DB created after cache) counts as a
change, so the cached client is replaced with one that sees the DB.
* FAT/exFAT filesystems return inode 0; we never fire inode comparisons
when either side is 0 (safe fallback) but still honor mtime.
* Mtime change uses an epsilon (0.01 s) to tolerate FS timestamp
granularity without thrashing.
"""
if self._closed:
from .base import BackendClosedError # late import avoids cycles at module load
raise BackendClosedError("ChromaBackend has been closed")
cached = self._clients.get(palace_path)
cached_inode, cached_mtime = self._freshness.get(palace_path, (0, 0.0))
current_inode, current_mtime = self._db_stat(palace_path)
db_path = os.path.join(palace_path, "chroma.sqlite3")
# DB was present when cache was built but is now missing → invalidate.
if cached is not None and not os.path.isfile(db_path):
self._clients.pop(palace_path, None)
self._freshness.pop(palace_path, None)
cached = None
cached_inode, cached_mtime = 0, 0.0
inode_changed = current_inode != 0 and cached_inode != 0 and current_inode != cached_inode
# Transition from no-stat (0.0) to a real stat counts as a change so we
# pick up a DB that was created after the cache was built.
mtime_appeared = cached_mtime == 0.0 and current_mtime != 0.0
mtime_changed = (
current_mtime != 0.0
and cached_mtime != 0.0
and abs(current_mtime - cached_mtime) > 0.01
)
if cached is None or inode_changed or mtime_changed or mtime_appeared:
_fix_blob_seq_ids(palace_path) _fix_blob_seq_ids(palace_path)
self._clients[palace_path] = chromadb.PersistentClient(path=palace_path) cached = chromadb.PersistentClient(path=palace_path)
return self._clients[palace_path] self._clients[palace_path] = cached
# Re-stat after the client constructor runs: chromadb creates
# chroma.sqlite3 lazily, so the stat captured before the call
# may still be (0, 0.0) on first open.
self._freshness[palace_path] = self._db_stat(palace_path)
return cached
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Public static helpers (for callers that manage their own caching) # Public static helpers (legacy; prefer :meth:`get_collection`)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@staticmethod @staticmethod
def make_client(palace_path: str): def make_client(palace_path: str):
"""Create and return a fresh PersistentClient (fix BLOB seq_ids first). """Create a fresh ``PersistentClient`` (fixes BLOB seq_ids first).
Intended for long-lived callers (e.g. mcp_server) that keep their own Deprecated-ish: exposed for legacy long-lived callers that manage their
inode/mtime-based client cache. own client cache. New code should obtain a collection through
:meth:`get_collection` which manages caching internally.
""" """
_fix_blob_seq_ids(palace_path) _fix_blob_seq_ids(palace_path)
return chromadb.PersistentClient(path=palace_path) return chromadb.PersistentClient(path=palace_path)
@@ -109,12 +412,31 @@ class ChromaBackend:
return chromadb.__version__ return chromadb.__version__
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Collection lifecycle # BaseBackend surface
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def get_collection(self, palace_path: str, collection_name: str, create: bool = False): def get_collection(
self,
*args,
**kwargs,
) -> ChromaCollection:
"""Obtain a collection for a palace.
Supports two calling conventions during the RFC 001 transition:
* New (preferred): ``get_collection(palace=PalaceRef, collection_name=...,
create=False, options=None)``.
* Legacy: ``get_collection(palace_path, collection_name, create=False)``
— still used by callers not yet migrated.
"""
palace_ref, collection_name, create, options = _normalize_get_collection_args(args, kwargs)
palace_path = palace_ref.local_path
if palace_path is None:
raise PalaceNotFoundError("ChromaBackend requires PalaceRef.local_path")
if not create and not os.path.isdir(palace_path): if not create and not os.path.isdir(palace_path):
raise FileNotFoundError(palace_path) raise PalaceNotFoundError(palace_path)
if create: if create:
os.makedirs(palace_path, exist_ok=True) os.makedirs(palace_path, exist_ok=True)
@@ -124,29 +446,113 @@ class ChromaBackend:
pass pass
client = self._client(palace_path) client = self._client(palace_path)
hnsw_space = "cosine"
if options and isinstance(options, dict):
hnsw_space = options.get("hnsw_space", hnsw_space)
if create: if create:
collection = client.get_or_create_collection( collection = client.get_or_create_collection(
collection_name, metadata={"hnsw:space": "cosine"} collection_name, metadata={"hnsw:space": hnsw_space}
) )
else: else:
collection = client.get_collection(collection_name) collection = client.get_collection(collection_name)
return ChromaCollection(collection) return ChromaCollection(collection)
def get_or_create_collection( def close_palace(self, palace) -> None:
self, palace_path: str, collection_name: str """Drop cached handles for ``palace``. Accepts ``PalaceRef`` or legacy path str."""
) -> "ChromaCollection": path = palace.local_path if isinstance(palace, PalaceRef) else palace
"""Shorthand for get_collection(..., create=True).""" if path is None:
return
self._clients.pop(path, None)
self._freshness.pop(path, None)
def close(self) -> None:
self._clients.clear()
self._freshness.clear()
self._closed = True
def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus:
if self._closed:
return HealthStatus.unhealthy("backend closed")
return HealthStatus.healthy()
@classmethod
def detect(cls, path: str) -> bool:
return os.path.isfile(os.path.join(path, "chroma.sqlite3"))
# ------------------------------------------------------------------
# Legacy (pre-RFC 001) surface — retained while callers migrate.
# ------------------------------------------------------------------
def get_or_create_collection(self, palace_path: str, collection_name: str) -> ChromaCollection:
"""Legacy shim for ``get_collection(..., create=True)`` by path string."""
return self.get_collection(palace_path, collection_name, create=True) return self.get_collection(palace_path, collection_name, create=True)
def delete_collection(self, palace_path: str, collection_name: str) -> None: def delete_collection(self, palace_path: str, collection_name: str) -> None:
"""Delete *collection_name* from the palace at *palace_path*.""" """Delete ``collection_name`` from the palace at ``palace_path``."""
self._client(palace_path).delete_collection(collection_name) self._client(palace_path).delete_collection(collection_name)
def create_collection( def create_collection(
self, palace_path: str, collection_name: str, hnsw_space: str = "cosine" self, palace_path: str, collection_name: str, hnsw_space: str = "cosine"
) -> "ChromaCollection": ) -> ChromaCollection:
"""Create (not get-or-create) *collection_name* with cosine HNSW space.""" """Create (not get-or-create) ``collection_name`` with the given HNSW space."""
collection = self._client(palace_path).create_collection( collection = self._client(palace_path).create_collection(
collection_name, metadata={"hnsw:space": hnsw_space} collection_name, metadata={"hnsw:space": hnsw_space}
) )
return ChromaCollection(collection) return ChromaCollection(collection)
def _normalize_get_collection_args(args, kwargs):
"""Unify legacy positional ``(palace_path, collection_name, create)`` calls
with the new kwargs-only ``(palace=PalaceRef, collection_name=..., create=...)``.
Returns ``(PalaceRef, collection_name, create, options)``.
"""
# New-style: palace= kwarg with a PalaceRef (spec path).
if "palace" in kwargs:
palace_ref = kwargs.pop("palace")
if not isinstance(palace_ref, PalaceRef):
raise TypeError("palace= must be a PalaceRef instance")
collection_name = kwargs.pop("collection_name")
create = kwargs.pop("create", False)
options = kwargs.pop("options", None)
if kwargs:
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
if args:
raise TypeError("positional args not allowed with palace= kwarg")
return palace_ref, collection_name, create, options
# Legacy: first positional is a path string.
if args:
palace_path = args[0]
rest = list(args[1:])
collection_name = kwargs.pop("collection_name", None) or (rest.pop(0) if rest else None)
if collection_name is None:
raise TypeError("collection_name is required")
create = kwargs.pop("create", False)
if rest:
create = rest.pop(0)
if kwargs:
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
return (
PalaceRef(id=palace_path, local_path=palace_path),
collection_name,
bool(create),
None,
)
# Legacy kwargs-only (palace_path=..., collection_name=..., create=...)
if "palace_path" in kwargs:
palace_path = kwargs.pop("palace_path")
collection_name = kwargs.pop("collection_name")
create = kwargs.pop("create", False)
if kwargs:
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
return (
PalaceRef(id=palace_path, local_path=palace_path),
collection_name,
bool(create),
None,
)
raise TypeError("get_collection requires palace= or a positional palace_path")
+189
View File
@@ -0,0 +1,189 @@
"""Backend registry + entry-point discovery (RFC 001 §3).
Third-party backends ship as installable packages that declare a
``mempalace.backends`` entry point::
# pyproject.toml of mempalace-postgres
[project.entry-points."mempalace.backends"]
postgres = "mempalace_postgres:PostgresBackend"
MemPalace discovers them at process start. In-tree tests and local development
can register manually via :func:`register`. Explicit registration wins on
name conflict (matches RFC 001 §3.2).
"""
from __future__ import annotations
import logging
from importlib import metadata
from threading import Lock
from typing import Optional, Type
from .base import BaseBackend
logger = logging.getLogger(__name__)
_ENTRY_POINT_GROUP = "mempalace.backends"
_registry: dict[str, Type[BaseBackend]] = {}
_instances: dict[str, BaseBackend] = {}
_explicit: set[str] = set()
_discovered = False
_lock = Lock()
def register(name: str, backend_cls: Type[BaseBackend]) -> None:
"""Register ``backend_cls`` under ``name``.
Explicit registration wins over entry-point discovery on conflict
(RFC 001 §3.2).
"""
with _lock:
_registry[name] = backend_cls
_explicit.add(name)
# Invalidate any cached instance so the new class is used on next get.
_instances.pop(name, None)
def unregister(name: str) -> None:
"""Remove a backend registration (primarily for tests)."""
with _lock:
_registry.pop(name, None)
_explicit.discard(name)
_instances.pop(name, None)
def _discover_entry_points() -> None:
"""Load entry-point-declared backends once per process."""
global _discovered
if _discovered:
return
with _lock:
if _discovered:
return
try:
eps = metadata.entry_points()
# Py ≥ 3.10 returns an EntryPoints object; older versions returned a dict.
group = (
eps.select(group=_ENTRY_POINT_GROUP)
if hasattr(eps, "select")
else eps.get(_ENTRY_POINT_GROUP, [])
)
except Exception:
logger.exception("entry-point discovery for %s failed", _ENTRY_POINT_GROUP)
group = []
for ep in group:
if ep.name in _explicit:
continue # explicit registration wins
try:
cls = ep.load()
except Exception:
logger.exception("failed to load backend entry point %r", ep.name)
continue
if not isinstance(cls, type) or not issubclass(cls, BaseBackend):
logger.warning(
"entry point %r did not resolve to a BaseBackend subclass (got %r)",
ep.name,
cls,
)
continue
_registry.setdefault(ep.name, cls)
_discovered = True
def available_backends() -> list[str]:
"""Return sorted list of all registered backend names."""
_discover_entry_points()
return sorted(_registry.keys())
def get_backend_class(name: str) -> Type[BaseBackend]:
"""Return the registered backend class for ``name``."""
_discover_entry_points()
try:
return _registry[name]
except KeyError as e:
raise KeyError(f"unknown backend {name!r}; available: {available_backends()}") from e
def get_backend(name: str) -> BaseBackend:
"""Return a long-lived instance of the named backend.
Instances are cached per-name; repeated calls return the same object.
Call :func:`reset_backends` in tests that need isolation.
"""
_discover_entry_points()
with _lock:
inst = _instances.get(name)
if inst is not None:
return inst
cls = _registry.get(name)
if cls is None:
raise KeyError(f"unknown backend {name!r}; available: {sorted(_registry.keys())}")
inst = cls()
_instances[name] = inst
return inst
def reset_backends() -> None:
"""Close and drop all cached backend instances (primarily for tests)."""
with _lock:
for inst in _instances.values():
try:
inst.close()
except Exception:
logger.exception("error closing backend during reset")
_instances.clear()
def resolve_backend_for_palace(
*,
explicit: Optional[str] = None,
config_value: Optional[str] = None,
env_value: Optional[str] = None,
palace_path: Optional[str] = None,
default: str = "chroma",
) -> str:
"""Resolve the backend name for a palace per RFC 001 §3.3 priority order.
1. Explicit kwarg / CLI flag
2. Per-palace config value
3. ``MEMPALACE_BACKEND`` env var
4. Auto-detect from on-disk artifacts (migration/upgrade path only)
5. Default (``chroma``)
Auto-detection is strictly a migration aid: it fires only when a local path
is presented, no earlier rule has chosen a backend, AND the path already
contains backend-identifiable artifacts. For new palaces, (5) wins.
"""
for candidate in (explicit, config_value, env_value):
if candidate:
return candidate
_discover_entry_points()
if palace_path:
for name, cls in _registry.items():
try:
if cls.detect(palace_path):
return name
except Exception:
logger.exception("detect() raised on backend %r", name)
continue
return default
# ---------------------------------------------------------------------------
# Built-in registration
# ---------------------------------------------------------------------------
def _register_builtins() -> None:
"""Register chroma as the in-tree default."""
from .chroma import ChromaBackend
# Use setdefault semantics so a caller that pre-registered for tests wins.
if "chroma" not in _registry:
_registry["chroma"] = ChromaBackend
_register_builtins()
+12 -12
View File
@@ -30,15 +30,16 @@ class SearchError(Exception):
_TOKEN_RE = re.compile(r"\w{2,}", re.UNICODE) _TOKEN_RE = re.compile(r"\w{2,}", re.UNICODE)
def _first_or_empty(results: dict, key: str) -> list: def _first_or_empty(results, key: str) -> list:
"""Return the first inner list of a ChromaDB query result, or []. """Return the first inner list of a query result field, or [].
ChromaDB returns shapes like ``{"documents": [["a", "b"]], ...}`` for a Accepts both the typed :class:`QueryResult` (attribute access) and the
successful query, but ``{"documents": [], ...}`` (empty outer list) when pre-typed chroma dict shape; this polymorphism is retained so test mocks
the collection is empty or the filter excludes everything. Indexing still work and callers mid-migration do not crash. Preserves the empty-
``[0]`` blindly raises IndexError in that case (issue #195). collection semantics from issue #195: when no queries returned hits, the
outer list may be empty and indexing ``[0]`` would raise.
""" """
outer = results.get(key) outer = getattr(results, key, None) if not isinstance(results, dict) else results.get(key)
if not outer: if not outer:
return [] return []
return outer[0] or [] return outer[0] or []
@@ -209,7 +210,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra
return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None} return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None}
indexed_docs = [] indexed_docs = []
for doc, meta in zip(neighbors.get("documents") or [], neighbors.get("metadatas") or []): for doc, meta in zip(neighbors.documents, neighbors.metadatas):
ci = meta.get("chunk_index") ci = meta.get("chunk_index")
if isinstance(ci, int): if isinstance(ci, int):
indexed_docs.append((ci, doc)) indexed_docs.append((ci, doc))
@@ -224,8 +225,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra
total_drawers = None total_drawers = None
try: try:
all_meta = drawers_col.get(where={"source_file": src}, include=["metadatas"]) all_meta = drawers_col.get(where={"source_file": src}, include=["metadatas"])
ids = all_meta.get("ids") or [] total_drawers = len(all_meta.ids) if all_meta.ids else None
total_drawers = len(ids) if ids else None
except Exception: except Exception:
pass pass
@@ -451,8 +451,8 @@ def search_memories(
) )
except Exception: except Exception:
continue continue
docs = source_drawers.get("documents") or [] docs = source_drawers.documents
metas_ = source_drawers.get("metadatas") or [] metas_ = source_drawers.metadatas
if len(docs) <= 1: if len(docs) <= 1:
continue continue
+3
View File
@@ -37,6 +37,9 @@ Repository = "https://github.com/MemPalace/mempalace"
[project.scripts] [project.scripts]
mempalace = "mempalace.cli:main" mempalace = "mempalace.cli:main"
[project.entry-points."mempalace.backends"]
chroma = "mempalace.backends.chroma:ChromaBackend"
[project.optional-dependencies] [project.optional-dependencies]
dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"] dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"]
spellcheck = ["autocorrect>=2.0"] spellcheck = ["autocorrect>=2.0"]
+247 -15
View File
@@ -3,12 +3,34 @@ import sqlite3
import chromadb import chromadb
import pytest import pytest
from mempalace.backends import (
GetResult,
PalaceRef,
QueryResult,
UnsupportedFilterError,
available_backends,
get_backend,
)
from mempalace.backends.chroma import ChromaBackend, ChromaCollection, _fix_blob_seq_ids from mempalace.backends.chroma import ChromaBackend, ChromaCollection, _fix_blob_seq_ids
class _FakeCollection: class _FakeCollection:
def __init__(self): """Stand-in for a chromadb.Collection returning raw chroma-shaped dicts."""
def __init__(self, query_response=None, get_response=None, count_value=7):
self.calls = [] self.calls = []
self._query_response = query_response or {
"ids": [["a", "b"]],
"documents": [["da", "db"]],
"metadatas": [[{"wing": "w1"}, {"wing": "w2"}]],
"distances": [[0.1, 0.2]],
}
self._get_response = get_response or {
"ids": ["a"],
"documents": ["da"],
"metadatas": [{"wing": "w1"}],
}
self._count_value = count_value
def add(self, **kwargs): def add(self, **kwargs):
self.calls.append(("add", kwargs)) self.calls.append(("add", kwargs))
@@ -16,41 +38,251 @@ class _FakeCollection:
def upsert(self, **kwargs): def upsert(self, **kwargs):
self.calls.append(("upsert", kwargs)) self.calls.append(("upsert", kwargs))
def update(self, **kwargs):
self.calls.append(("update", kwargs))
def query(self, **kwargs): def query(self, **kwargs):
self.calls.append(("query", kwargs)) self.calls.append(("query", kwargs))
return {"kind": "query"} return self._query_response
def get(self, **kwargs): def get(self, **kwargs):
self.calls.append(("get", kwargs)) self.calls.append(("get", kwargs))
return {"kind": "get"} return self._get_response
def delete(self, **kwargs): def delete(self, **kwargs):
self.calls.append(("delete", kwargs)) self.calls.append(("delete", kwargs))
def count(self): def count(self):
self.calls.append(("count", {})) self.calls.append(("count", {}))
return 7 return self._count_value
def test_chroma_collection_delegates_methods(): def test_chroma_collection_returns_typed_query_result():
fake = _FakeCollection()
collection = ChromaCollection(fake)
result = collection.query(query_texts=["q"])
assert isinstance(result, QueryResult)
assert result.ids == [["a", "b"]]
assert result.documents == [["da", "db"]]
assert result.metadatas == [[{"wing": "w1"}, {"wing": "w2"}]]
assert result.distances == [[0.1, 0.2]]
assert result.embeddings is None
def test_chroma_collection_returns_typed_get_result():
fake = _FakeCollection()
collection = ChromaCollection(fake)
result = collection.get(where={"wing": "w1"})
assert isinstance(result, GetResult)
assert result.ids == ["a"]
assert result.documents == ["da"]
assert result.metadatas == [{"wing": "w1"}]
def test_query_result_empty_preserves_outer_dimension():
empty = QueryResult.empty(num_queries=2)
assert empty.ids == [[], []]
assert empty.documents == [[], []]
assert empty.distances == [[], []]
assert empty.embeddings is None
def test_typed_results_support_dict_compat_access():
"""Transitional compat shim per base.py — retained until callers migrate to attrs."""
result = GetResult(ids=["a"], documents=["da"], metadatas=[{"w": 1}])
assert result["ids"] == ["a"]
assert result.get("documents") == ["da"]
assert result.get("missing", "default") == "default"
assert "ids" in result
assert "missing" not in result
def test_chroma_collection_query_empty_result_preserves_outer_shape():
fake = _FakeCollection(
query_response={"ids": [], "documents": [], "metadatas": [], "distances": []}
)
collection = ChromaCollection(fake)
result = collection.query(query_texts=["q1", "q2"])
assert result.ids == [[], []]
assert result.documents == [[], []]
assert result.distances == [[], []]
def test_chroma_collection_rejects_unknown_where_operator():
fake = _FakeCollection()
collection = ChromaCollection(fake)
with pytest.raises(UnsupportedFilterError):
collection.query(query_texts=["q"], where={"$regex": "foo"})
def test_chroma_collection_delegates_writes():
fake = _FakeCollection() fake = _FakeCollection()
collection = ChromaCollection(fake) collection = ChromaCollection(fake)
collection.add(documents=["d"], ids=["1"], metadatas=[{"wing": "w"}]) collection.add(documents=["d"], ids=["1"], metadatas=[{"wing": "w"}])
collection.upsert(documents=["u"], ids=["2"], metadatas=[{"room": "r"}]) collection.upsert(documents=["u"], ids=["2"], metadatas=[{"room": "r"}])
assert collection.query(query_texts=["q"]) == {"kind": "query"}
assert collection.get(where={"wing": "w"}) == {"kind": "get"}
collection.delete(ids=["1"]) collection.delete(ids=["1"])
assert collection.count() == 7 assert collection.count() == 7
assert fake.calls == [ kinds = [call[0] for call in fake.calls]
("add", {"documents": ["d"], "ids": ["1"], "metadatas": [{"wing": "w"}]}), assert kinds == ["add", "upsert", "delete", "count"]
("upsert", {"documents": ["u"], "ids": ["2"], "metadatas": [{"room": "r"}]}),
("query", {"query_texts": ["q"]}),
("get", {"where": {"wing": "w"}}), def test_registry_exposes_chroma_by_default():
("delete", {"ids": ["1"]}), names = available_backends()
("count", {}), assert "chroma" in names
] assert isinstance(get_backend("chroma"), ChromaBackend)
def test_registry_unknown_backend_raises():
with pytest.raises(KeyError):
get_backend("no-such-backend-exists")
def test_resolve_backend_priority_order(tmp_path):
from mempalace.backends import resolve_backend_for_palace
# explicit kwarg wins over everything
assert resolve_backend_for_palace(explicit="pg", config_value="lance") == "pg"
# config value wins over env / default
assert resolve_backend_for_palace(config_value="lance", env_value="qdrant") == "lance"
# env wins over default
assert resolve_backend_for_palace(env_value="qdrant", default="chroma") == "qdrant"
# falls back to default
assert resolve_backend_for_palace() == "chroma"
def test_chroma_detect_matches_palace_with_chroma_sqlite(tmp_path):
(tmp_path / "chroma.sqlite3").write_bytes(b"")
assert ChromaBackend.detect(str(tmp_path)) is True
assert ChromaBackend.detect(str(tmp_path.parent)) is False
def test_query_rejects_missing_input():
fake = _FakeCollection()
collection = ChromaCollection(fake)
with pytest.raises(ValueError):
collection.query()
def test_query_rejects_both_texts_and_embeddings():
fake = _FakeCollection()
collection = ChromaCollection(fake)
with pytest.raises(ValueError):
collection.query(query_texts=["q"], query_embeddings=[[0.1, 0.2]])
def test_query_rejects_empty_input_list():
fake = _FakeCollection()
collection = ChromaCollection(fake)
with pytest.raises(ValueError):
collection.query(query_texts=[])
def test_query_empty_preserves_embeddings_outer_shape_when_requested():
fake = _FakeCollection(
query_response={"ids": [], "documents": [], "metadatas": [], "distances": []}
)
collection = ChromaCollection(fake)
requested = collection.query(query_texts=["q1", "q2"], include=["documents", "embeddings"])
assert requested.embeddings == [[], []]
not_requested = collection.query(query_texts=["q1", "q2"], include=["documents"])
assert not_requested.embeddings is None
def test_chroma_cache_invalidates_when_db_file_missing(tmp_path):
"""A palace rebuild that removes chroma.sqlite3 must drop the stale cache.
Primes backend._clients/_freshness directly with a sentinel rather than
opening a real ``PersistentClient``: on Windows the sqlite file handle
would still be live and ``Path.unlink`` would raise ``PermissionError``,
making the test unable to exercise the branch we care about. The decision
logic under test is pure (no chromadb calls before the branch), so a
sentinel is sufficient.
"""
backend = ChromaBackend()
palace_path = tmp_path / "palace"
palace_path.mkdir()
db_file = palace_path / "chroma.sqlite3"
db_file.write_bytes(b"") # any file is enough for _db_stat to see it
st = db_file.stat()
sentinel = object()
backend._clients[str(palace_path)] = sentinel
backend._freshness[str(palace_path)] = (st.st_ino, st.st_mtime)
# Simulate a rebuild mid-flight: chroma.sqlite3 goes away. Safe to unlink
# because nothing in this test is holding an OS handle on the file.
db_file.unlink()
prior_freshness = (st.st_ino, st.st_mtime)
new_client = backend._client(str(palace_path))
# Cache was replaced (not the sentinel) and freshness reflects the post-
# rebuild stat (chromadb re-creates chroma.sqlite3 during PersistentClient
# construction; _client re-stats after the constructor so freshness is
# not frozen at the pre-rebuild value). The stale cached sentinel would
# have served wrong data if returned.
assert new_client is not sentinel
assert backend._freshness[str(palace_path)] != prior_freshness
def test_chroma_cache_picks_up_db_created_after_first_open(tmp_path):
"""The 0 → nonzero stat transition invalidates a cache built before the DB existed."""
backend = ChromaBackend()
palace_path = tmp_path / "palace"
palace_path.mkdir()
# Seed an entry in the caches as if a prior _client() call had opened the
# palace when chroma.sqlite3 did not exist yet. Freshness (0, 0.0) is the
# signal that the DB was absent at cache time.
sentinel = object()
backend._clients[str(palace_path)] = sentinel
backend._freshness[str(palace_path)] = (0, 0.0)
# The DB file now appears (real chromadb would have created it by now).
# Use a real chromadb call so _fix_blob_seq_ids and PersistentClient succeed.
import chromadb as _chromadb
_chromadb.PersistentClient(path=str(palace_path)).get_or_create_collection("seed")
assert (palace_path / "chroma.sqlite3").is_file()
# Next _client() call must detect the 0 → nonzero transition and rebuild.
refreshed = backend._client(str(palace_path))
assert refreshed is not sentinel
assert backend._freshness[str(palace_path)] != (0, 0.0)
def test_base_collection_update_default_rejects_mismatched_lengths():
"""The ABC default update() raises ValueError rather than silently misaligning."""
from mempalace.backends.base import BaseCollection
collection = ChromaCollection(_FakeCollection())
with pytest.raises(ValueError, match="documents length"):
BaseCollection.update(collection, ids=["1", "2"], documents=["only-one"])
with pytest.raises(ValueError, match="metadatas length"):
BaseCollection.update(collection, ids=["1", "2"], metadatas=[{"k": 9}])
def test_chroma_backend_accepts_palace_ref_kwarg(tmp_path):
palace_path = tmp_path / "palace"
backend = ChromaBackend()
collection = backend.get_collection(
palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)),
collection_name="mempalace_drawers",
create=True,
)
assert palace_path.is_dir()
assert isinstance(collection, ChromaCollection)
def test_chroma_backend_create_false_raises_without_creating_directory(tmp_path): def test_chroma_backend_create_false_raises_without_creating_directory(tmp_path):