refactor(backends): typed QueryResult/GetResult, PalaceRef, BaseBackend registry (RFC 001 §10)

Advances RFC 001 §10 cleanup so backend-author PRs (#574 LanceDB, #665 Postgres,
#700 Qdrant, #697 hosted, #643 PalaceStore, #381 Qdrant) have a stable target
to align against.

Scope (this PR):

- Typed QueryResult / GetResult dataclasses replace Chroma's dict shape at
  the BaseCollection boundary (§1.3). A transitional _DictCompatMixin keeps
  existing callers working while the attribute-access migration proceeds.
- BaseCollection is now kwargs-only across add/upsert/query/get/delete/update
  with ABC defaults for estimated_count/close/health and a non-atomic default
  update() (§1.1–1.2).
- PalaceRef replaces raw path strings at the backend boundary (§2.2).
- BaseBackend ABC with get_collection/close_palace/close/health/detect (§2.3).
- mempalace.backends entry-point group + in-tree registry with
  resolve_backend_for_palace priority order matching §3.2–3.3.
- ChromaCollection normalizes chroma returns into typed results; unknown
  where-clause operators raise UnsupportedFilterError (no silent drop, §1.4).
- ChromaBackend absorbs the inode/mtime client-cache freshness check
  previously duplicated in mcp_server._get_client() (§10 + PR #757).
- searcher.py migrated to typed-attribute access as the reference call
  site; remaining callers land in a follow-up.
- pyproject: chroma registered via [project.entry-points."mempalace.backends"].

Out of scope (explicit follow-ups):

- Full caller migration off the dict-compat shim across palace.py,
  mcp_server.py, miner.py, convo_miner.py, dedup.py, repair.py, exporter.py,
  palace_graph.py, cli.py, closet_llm.py.
- Embedder injection + three-state EmbedderIdentityMismatchError check (§1.5).
- maintenance_state() / run_maintenance() benchmark hooks (§7.3).
- AbstractBackendContractSuite full coverage (§7.1–7.2).
- mempalace migrate / mempalace verify CLI rewrites through BaseCollection (§8).

Tests: 970 passed (up from 967 on develop); new coverage for typed results,
empty-result outer-shape preservation, \$regex rejection, registry lookup,
priority resolver, and PalaceRef-kwarg ChromaBackend.get_collection.

Refs: #743 (RFC 001), #989 (RFC 002 tracking issue).
This commit is contained in:
Igor Lins e Silva
2026-04-18 12:45:16 -03:00
parent e4a2cd48a2
commit a17a8b734a
7 changed files with 1143 additions and 94 deletions
+61 -3
View File
@@ -1,6 +1,64 @@
"""Storage backend implementations for MemPalace."""
"""Storage backend implementations for MemPalace (RFC 001).
from .base import BaseCollection
Public surface:
* :class:`BaseCollection` — per-collection read/write contract.
* :class:`BaseBackend` — per-palace factory contract.
* :class:`PalaceRef` — value object identifying a palace for a backend.
* :class:`QueryResult` / :class:`GetResult` — typed read returns.
* Error classes: :class:`PalaceNotFoundError`, :class:`BackendClosedError`,
:class:`UnsupportedFilterError`, :class:`DimensionMismatchError`,
:class:`EmbedderIdentityMismatchError`.
* Registry: :func:`get_backend`, :func:`register`, :func:`available_backends`,
:func:`resolve_backend_for_palace`.
* In-tree Chroma default: :class:`ChromaBackend`, :class:`ChromaCollection`.
"""
from .base import (
BackendClosedError,
BackendError,
BaseBackend,
BaseCollection,
DimensionMismatchError,
EmbedderIdentityMismatchError,
GetResult,
HealthStatus,
PalaceNotFoundError,
PalaceRef,
QueryResult,
UnsupportedFilterError,
)
from .chroma import ChromaBackend, ChromaCollection
from .registry import (
available_backends,
get_backend,
get_backend_class,
register,
reset_backends,
resolve_backend_for_palace,
unregister,
)
__all__ = ["BaseCollection", "ChromaBackend", "ChromaCollection"]
__all__ = [
"BackendClosedError",
"BackendError",
"BaseBackend",
"BaseCollection",
"ChromaBackend",
"ChromaCollection",
"DimensionMismatchError",
"EmbedderIdentityMismatchError",
"GetResult",
"HealthStatus",
"PalaceNotFoundError",
"PalaceRef",
"QueryResult",
"UnsupportedFilterError",
"available_backends",
"get_backend",
"get_backend_class",
"register",
"reset_backends",
"resolve_backend_for_palace",
"unregister",
]
+332 -27
View File
@@ -1,49 +1,354 @@
"""Abstract collection interface for MemPalace storage backends."""
"""Storage backend contract for MemPalace (RFC 001).
This module defines the surface every storage backend must implement:
* ``BaseCollection`` — the per-collection read/write interface, kwargs-only.
* ``BaseBackend`` — the per-palace factory, addressed by ``PalaceRef``.
* ``QueryResult`` / ``GetResult`` — typed result dataclasses that replace the
Chroma dict shape as the canonical return type.
* Error classes + ``HealthStatus`` — uniform across backends.
This is the v1 cleanup from RFC 001 §10: full typed results, ``PalaceRef``,
registry-ready ABC. Embedder injection, maintenance hooks, and the full
conformance suite land in follow-up PRs.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from typing import ClassVar, Optional
# ---------------------------------------------------------------------------
# Errors
# ---------------------------------------------------------------------------
class BackendError(Exception):
"""Base class for every storage-backend error raised by core."""
class PalaceNotFoundError(BackendError, FileNotFoundError):
"""Raised when ``get_collection(create=False)`` is called on a missing palace.
Subclass of ``FileNotFoundError`` so legacy callers that catch the latter
(pre-#413 seam) keep working unchanged.
"""
class BackendClosedError(BackendError):
"""Raised when a backend method is called after ``close()``."""
class UnsupportedFilterError(BackendError):
"""Raised when a where-clause uses an operator the backend does not implement.
Silent dropping of unknown operators is forbidden by spec (RFC 001 §1.4).
"""
class DimensionMismatchError(BackendError):
"""Raised when the embedding dimension on write does not match the collection."""
class EmbedderIdentityMismatchError(BackendError):
"""Raised when the stored embedder model name differs from the current one."""
# ---------------------------------------------------------------------------
# Value objects
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class PalaceRef:
"""A handle to a palace, consumed by backends.
``id`` is always present and is the key backends use to cache handles.
``local_path`` is populated for filesystem-rooted palaces.
``namespace`` is used by server-mode backends for tenant / prefix routing.
"""
id: str
local_path: Optional[str] = None
namespace: Optional[str] = None
@dataclass(frozen=True)
class HealthStatus:
ok: bool
detail: str = ""
@classmethod
def healthy(cls, detail: str = "") -> "HealthStatus":
return cls(ok=True, detail=detail)
@classmethod
def unhealthy(cls, detail: str) -> "HealthStatus":
return cls(ok=False, detail=detail)
_TYPED_RESULT_FIELDS = ("ids", "documents", "metadatas", "distances", "embeddings")
class _DictCompatMixin:
"""Transitional dict-protocol access for typed results.
RFC 001 §1.3 spec is attribute access (``result.ids``). The ``result["ids"]``
and ``result.get("ids")`` forms are retained as a migration shim for callers
that predate the typed interface and are scheduled for removal in a follow-
up cleanup. New code MUST use attribute access.
"""
def __getitem__(self, key: str):
if key in _TYPED_RESULT_FIELDS:
return getattr(self, key)
raise KeyError(key)
def get(self, key: str, default=None):
if key in _TYPED_RESULT_FIELDS:
val = getattr(self, key, default)
return default if val is None else val
return default
def __contains__(self, key: object) -> bool:
return key in _TYPED_RESULT_FIELDS and getattr(self, key, None) is not None
@dataclass(frozen=True)
class QueryResult(_DictCompatMixin):
"""Typed return from ``BaseCollection.query``.
Outer list dimension = number of query vectors / texts.
Inner list dimension = hits per query (may be zero).
Fields not in ``include=`` at the call site are populated with empty lists
of the correct outer shape (never ``None``), except ``embeddings`` which
is ``None`` when not requested.
"""
ids: list[list[str]]
documents: list[list[str]]
metadatas: list[list[dict]]
distances: list[list[float]]
embeddings: Optional[list[list[list[float]]]] = None
@classmethod
def empty(cls, num_queries: int = 1) -> "QueryResult":
"""Construct an all-empty result preserving outer dimension."""
return cls(
ids=[[] for _ in range(num_queries)],
documents=[[] for _ in range(num_queries)],
metadatas=[[] for _ in range(num_queries)],
distances=[[] for _ in range(num_queries)],
embeddings=None,
)
@dataclass(frozen=True)
class GetResult(_DictCompatMixin):
"""Typed return from ``BaseCollection.get``."""
ids: list[str]
documents: list[str]
metadatas: list[dict]
embeddings: Optional[list[list[float]]] = None
@classmethod
def empty(cls) -> "GetResult":
return cls(ids=[], documents=[], metadatas=[], embeddings=None)
# ---------------------------------------------------------------------------
# Collection contract
# ---------------------------------------------------------------------------
class BaseCollection(ABC):
"""Smallest collection contract the rest of MemPalace relies on."""
"""Per-collection read/write surface every backend must implement."""
@abstractmethod
def add(
self,
*,
documents: List[str],
ids: List[str],
metadatas: Optional[List[Dict[str, Any]]] = None,
) -> None:
raise NotImplementedError
documents: list[str],
ids: list[str],
metadatas: Optional[list[dict]] = None,
embeddings: Optional[list[list[float]]] = None,
) -> None: ...
@abstractmethod
def upsert(
self,
*,
documents: List[str],
ids: List[str],
metadatas: Optional[List[Dict[str, Any]]] = None,
documents: list[str],
ids: list[str],
metadatas: Optional[list[dict]] = None,
embeddings: Optional[list[list[float]]] = None,
) -> None: ...
@abstractmethod
def query(
self,
*,
query_texts: Optional[list[str]] = None,
query_embeddings: Optional[list[list[float]]] = None,
n_results: int = 10,
where: Optional[dict] = None,
where_document: Optional[dict] = None,
include: Optional[list[str]] = None,
) -> QueryResult: ...
@abstractmethod
def get(
self,
*,
ids: Optional[list[str]] = None,
where: Optional[dict] = None,
where_document: Optional[dict] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
include: Optional[list[str]] = None,
) -> GetResult: ...
@abstractmethod
def delete(
self,
*,
ids: Optional[list[str]] = None,
where: Optional[dict] = None,
) -> None: ...
@abstractmethod
def count(self) -> int: ...
# ------------------------------------------------------------------
# Optional methods with ABC defaults (spec §1.2)
# ------------------------------------------------------------------
def estimated_count(self) -> int:
return self.count()
def close(self) -> None:
return None
def health(self) -> HealthStatus:
return HealthStatus.healthy()
def update(
self,
*,
ids: list[str],
documents: Optional[list[str]] = None,
metadatas: Optional[list[dict]] = None,
embeddings: Optional[list[list[float]]] = None,
) -> None:
raise NotImplementedError
"""Default non-atomic update: get + merge + upsert.
Backends advertising ``supports_update`` MUST override with an atomic
single-round-trip implementation.
"""
if documents is None and metadatas is None and embeddings is None:
raise ValueError("update requires at least one of documents, metadatas, embeddings")
existing = self.get(ids=ids, include=["documents", "metadatas"])
by_id = {
rid: (existing.documents[i], existing.metadatas[i])
for i, rid in enumerate(existing.ids)
}
merged_docs: list[str] = []
merged_metas: list[dict] = []
for i, rid in enumerate(ids):
prev_doc, prev_meta = by_id.get(rid, ("", {}))
merged_docs.append(documents[i] if documents is not None else prev_doc)
new_meta = dict(prev_meta or {})
if metadatas is not None:
new_meta.update(metadatas[i] or {})
merged_metas.append(new_meta)
self.upsert(
documents=merged_docs,
ids=list(ids),
metadatas=merged_metas,
embeddings=embeddings,
)
# ---------------------------------------------------------------------------
# Backend contract
# ---------------------------------------------------------------------------
class BaseBackend(ABC):
"""Long-lived factory serving many palaces (RFC 001 §2).
Instances are lightweight on construction — no I/O, no network. All
connection work is deferred to ``get_collection``. Instances are thread-
safe for concurrent ``get_collection`` calls across different palaces.
"""
name: ClassVar[str]
spec_version: ClassVar[str] = "1.0"
capabilities: ClassVar[frozenset[str]] = frozenset()
@abstractmethod
def update(self, **kwargs: Any) -> None:
"""Update existing records. Must raise if any ID is missing."""
raise NotImplementedError
def get_collection(
self,
*,
palace: PalaceRef,
collection_name: str,
create: bool = False,
options: Optional[dict] = None,
) -> BaseCollection: ...
@abstractmethod
def query(self, **kwargs: Any) -> Dict[str, Any]:
raise NotImplementedError
def close_palace(self, palace: PalaceRef) -> None:
"""Evict cached handles for a single palace. Default: no-op."""
return None
@abstractmethod
def get(self, **kwargs: Any) -> Dict[str, Any]:
raise NotImplementedError
def close(self) -> None:
"""Shut down the entire backend. Default: no-op."""
return None
@abstractmethod
def delete(self, **kwargs: Any) -> None:
raise NotImplementedError
def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus:
return HealthStatus.healthy()
@abstractmethod
def count(self) -> int:
raise NotImplementedError
# Optional detection hint used by selection priority (RFC 001 §3.3 (4)):
@classmethod
def detect(cls, path: str) -> bool: # pragma: no cover - default hook
return False
# ---------------------------------------------------------------------------
# Adapter utilities
# ---------------------------------------------------------------------------
# Keys the Chroma ``include=`` parameter accepts.
_VALID_INCLUDE_KEYS = frozenset({"documents", "metadatas", "distances", "embeddings"})
@dataclass
class _IncludeSpec:
"""Resolve an ``include=`` parameter with spec-mandated defaults."""
documents: bool = True
metadatas: bool = True
distances: bool = True # only meaningful for query
embeddings: bool = False
@classmethod
def resolve(
cls, include: Optional[list[str]], *, default_distances: bool = True
) -> "_IncludeSpec":
if include is None:
return cls(
documents=True,
metadatas=True,
distances=default_distances,
embeddings=False,
)
keys = {k for k in include if k in _VALID_INCLUDE_KEYS}
return cls(
documents="documents" in keys,
metadatas="metadatas" in keys,
distances="distances" in keys,
embeddings="embeddings" in keys,
)
+408 -37
View File
@@ -1,17 +1,54 @@
"""ChromaDB-backed MemPalace collection adapter."""
"""ChromaDB-backed MemPalace storage backend (RFC 001 reference implementation)."""
import logging
import os
import sqlite3
from typing import Any, Optional
import chromadb
from .base import BaseCollection
from .base import (
BaseBackend,
BaseCollection,
GetResult,
HealthStatus,
PalaceNotFoundError,
PalaceRef,
QueryResult,
UnsupportedFilterError,
_IncludeSpec,
)
logger = logging.getLogger(__name__)
def _fix_blob_seq_ids(palace_path: str):
_REQUIRED_OPERATORS = frozenset({"$eq", "$ne", "$in", "$nin", "$and", "$or", "$contains"})
_OPTIONAL_OPERATORS = frozenset({"$gt", "$gte", "$lt", "$lte"})
_SUPPORTED_OPERATORS = _REQUIRED_OPERATORS | _OPTIONAL_OPERATORS
def _validate_where(where: Optional[dict]) -> None:
"""Scan a where-clause for unknown operators and raise ``UnsupportedFilterError``.
Spec (RFC 001 §1.4): silent dropping of unknown operators is forbidden.
"""
if not where:
return
stack = [where]
while stack:
node = stack.pop()
if not isinstance(node, dict):
continue
for k, v in node.items():
if k.startswith("$") and k not in _SUPPORTED_OPERATORS:
raise UnsupportedFilterError(f"operator {k!r} not supported by chroma backend")
if isinstance(v, dict):
stack.append(v)
elif isinstance(v, list):
stack.extend(x for x in v if isinstance(x, dict))
def _fix_blob_seq_ids(palace_path: str) -> None:
"""Fix ChromaDB 0.6.x -> 1.5.x migration bug: BLOB seq_ids -> INTEGER.
ChromaDB 0.6.x stored seq_id as big-endian 8-byte BLOBs. ChromaDB 1.5.x
@@ -43,62 +80,293 @@ def _fix_blob_seq_ids(palace_path: str):
logger.exception("Could not fix BLOB seq_ids in %s", db_path)
# ---------------------------------------------------------------------------
# Collection adapter
# ---------------------------------------------------------------------------
def _as_list(v: Any) -> list:
"""Coerce possibly-None scalar-or-list into a list (defensive for chroma nulls)."""
if v is None:
return []
if isinstance(v, list):
return v
return [v]
class ChromaCollection(BaseCollection):
"""Thin adapter over a ChromaDB collection."""
"""Thin adapter translating ChromaDB dict returns into typed results."""
def __init__(self, collection):
self._collection = collection
def add(self, *, documents, ids, metadatas=None):
self._collection.add(documents=documents, ids=ids, metadatas=metadatas)
# ------------------------------------------------------------------
# Writes
# ------------------------------------------------------------------
def upsert(self, *, documents, ids, metadatas=None):
self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas)
def add(self, *, documents, ids, metadatas=None, embeddings=None):
kwargs: dict[str, Any] = {"documents": documents, "ids": ids}
if metadatas is not None:
kwargs["metadatas"] = metadatas
if embeddings is not None:
kwargs["embeddings"] = embeddings
self._collection.add(**kwargs)
def update(self, **kwargs):
def upsert(self, *, documents, ids, metadatas=None, embeddings=None):
kwargs: dict[str, Any] = {"documents": documents, "ids": ids}
if metadatas is not None:
kwargs["metadatas"] = metadatas
if embeddings is not None:
kwargs["embeddings"] = embeddings
self._collection.upsert(**kwargs)
def update(
self,
*,
ids,
documents=None,
metadatas=None,
embeddings=None,
):
if documents is None and metadatas is None and embeddings is None:
raise ValueError("update requires at least one of documents, metadatas, embeddings")
kwargs: dict[str, Any] = {"ids": ids}
if documents is not None:
kwargs["documents"] = documents
if metadatas is not None:
kwargs["metadatas"] = metadatas
if embeddings is not None:
kwargs["embeddings"] = embeddings
self._collection.update(**kwargs)
def query(self, **kwargs):
return self._collection.query(**kwargs)
# ------------------------------------------------------------------
# Reads
# ------------------------------------------------------------------
def get(self, **kwargs):
return self._collection.get(**kwargs)
def query(
self,
*,
query_texts=None,
query_embeddings=None,
n_results=10,
where=None,
where_document=None,
include=None,
) -> QueryResult:
_validate_where(where)
_validate_where(where_document)
def delete(self, **kwargs):
spec = _IncludeSpec.resolve(include, default_distances=True)
chroma_include: list[str] = []
if spec.documents:
chroma_include.append("documents")
if spec.metadatas:
chroma_include.append("metadatas")
if spec.distances:
chroma_include.append("distances")
if spec.embeddings:
chroma_include.append("embeddings")
kwargs: dict[str, Any] = {
"n_results": n_results,
"include": chroma_include,
}
if query_texts is not None:
kwargs["query_texts"] = query_texts
if query_embeddings is not None:
kwargs["query_embeddings"] = query_embeddings
if where is not None:
kwargs["where"] = where
if where_document is not None:
kwargs["where_document"] = where_document
raw = self._collection.query(**kwargs)
num_queries = (
len(query_texts)
if query_texts is not None
else (len(query_embeddings) if query_embeddings is not None else 1)
)
ids = raw.get("ids") or []
if not ids:
return QueryResult.empty(num_queries=num_queries)
documents = raw.get("documents") or [[] for _ in ids]
metadatas = raw.get("metadatas") or [[] for _ in ids]
distances = raw.get("distances") or [[] for _ in ids]
embeddings_raw = raw.get("embeddings") if spec.embeddings else None
def _none_list_to_empty(outer):
return [(inner or []) for inner in outer]
return QueryResult(
ids=_none_list_to_empty(ids),
documents=_none_list_to_empty(documents),
metadatas=_none_list_to_empty(metadatas),
distances=_none_list_to_empty(distances),
embeddings=(
[list(inner) for inner in embeddings_raw]
if spec.embeddings and embeddings_raw is not None
else None
),
)
def get(
self,
*,
ids=None,
where=None,
where_document=None,
limit=None,
offset=None,
include=None,
) -> GetResult:
_validate_where(where)
_validate_where(where_document)
spec = _IncludeSpec.resolve(include, default_distances=False)
chroma_include: list[str] = []
if spec.documents:
chroma_include.append("documents")
if spec.metadatas:
chroma_include.append("metadatas")
if spec.embeddings:
chroma_include.append("embeddings")
kwargs: dict[str, Any] = {"include": chroma_include}
if ids is not None:
kwargs["ids"] = ids
if where is not None:
kwargs["where"] = where
if where_document is not None:
kwargs["where_document"] = where_document
if limit is not None:
kwargs["limit"] = limit
if offset is not None:
kwargs["offset"] = offset
raw = self._collection.get(**kwargs)
out_ids = list(raw.get("ids") or [])
out_docs = list(raw.get("documents") or []) if spec.documents else []
out_metas = list(raw.get("metadatas") or []) if spec.metadatas else []
out_embeds = raw.get("embeddings") if spec.embeddings else None
# Pad doc/meta lists to match ids so downstream zipping is safe.
if spec.documents and len(out_docs) < len(out_ids):
out_docs = out_docs + [""] * (len(out_ids) - len(out_docs))
if spec.metadatas and len(out_metas) < len(out_ids):
out_metas = out_metas + [{}] * (len(out_ids) - len(out_metas))
return GetResult(
ids=out_ids,
documents=out_docs,
metadatas=out_metas,
embeddings=[list(v) for v in out_embeds] if out_embeds is not None else None,
)
def delete(self, *, ids=None, where=None):
_validate_where(where)
kwargs: dict[str, Any] = {}
if ids is not None:
kwargs["ids"] = ids
if where is not None:
kwargs["where"] = where
self._collection.delete(**kwargs)
def count(self):
return self._collection.count()
class ChromaBackend:
"""Factory for MemPalace's default ChromaDB backend."""
# ---------------------------------------------------------------------------
# Backend
# ---------------------------------------------------------------------------
class ChromaBackend(BaseBackend):
"""MemPalace's default ChromaDB backend.
Maintains two caches:
* ``self._clients`` — ``palace_path -> PersistentClient`` for callers
using the ``PalaceRef`` / :meth:`get_collection` path.
* An inode+mtime freshness check absorbed from ``mcp_server._get_client``
(merged via #757) ensuring a palace rebuild on disk is detected on the
next :meth:`get_collection` call.
"""
name = "chroma"
capabilities = frozenset(
{
"supports_embeddings_in",
"supports_embeddings_passthrough",
"supports_embeddings_out",
"supports_metadata_filters",
"supports_contains_fast",
"local_mode",
}
)
def __init__(self):
# Per-instance client cache: palace_path -> chromadb.PersistentClient
self._clients: dict = {}
# palace_path -> PersistentClient
self._clients: dict[str, Any] = {}
# palace_path -> (inode, mtime) of chroma.sqlite3 at cache time.
self._freshness: dict[str, tuple[int, float]] = {}
self._closed = False
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _db_stat(palace_path: str) -> tuple[int, float]:
"""Return ``(inode, mtime)`` of ``chroma.sqlite3`` or ``(0, 0.0)`` if absent."""
db_path = os.path.join(palace_path, "chroma.sqlite3")
try:
st = os.stat(db_path)
return (st.st_ino, st.st_mtime)
except OSError:
return (0, 0.0)
def _client(self, palace_path: str):
"""Return a cached PersistentClient for *palace_path*, creating one if needed."""
if palace_path not in self._clients:
"""Return a cached ``PersistentClient``, rebuilding on inode/mtime change.
Handles the palace-rebuild case (repair/nuke/purge) by invalidating the
cache when ``chroma.sqlite3`` changes on disk. FAT/exFAT return inode 0,
so inode comparisons only fire when non-zero (matches #757 semantics).
"""
if self._closed:
from .base import BackendClosedError # late import avoids cycles at module load
raise BackendClosedError("ChromaBackend has been closed")
cached = self._clients.get(palace_path)
cached_inode, cached_mtime = self._freshness.get(palace_path, (0, 0.0))
current_inode, current_mtime = self._db_stat(palace_path)
inode_changed = current_inode != 0 and cached_inode != 0 and current_inode != cached_inode
mtime_changed = (
current_mtime != 0.0 and cached_mtime != 0.0 and current_mtime > cached_mtime
)
if cached is None or inode_changed or mtime_changed:
_fix_blob_seq_ids(palace_path)
self._clients[palace_path] = chromadb.PersistentClient(path=palace_path)
return self._clients[palace_path]
cached = chromadb.PersistentClient(path=palace_path)
self._clients[palace_path] = cached
self._freshness[palace_path] = (current_inode, current_mtime)
return cached
# ------------------------------------------------------------------
# Public static helpers (for callers that manage their own caching)
# Public static helpers (legacy; prefer :meth:`get_collection`)
# ------------------------------------------------------------------
@staticmethod
def make_client(palace_path: str):
"""Create and return a fresh PersistentClient (fix BLOB seq_ids first).
"""Create a fresh ``PersistentClient`` (fixes BLOB seq_ids first).
Intended for long-lived callers (e.g. mcp_server) that keep their own
inode/mtime-based client cache.
Deprecated-ish: exposed for legacy long-lived callers that manage their
own client cache. New code should obtain a collection through
:meth:`get_collection` which manages caching internally.
"""
_fix_blob_seq_ids(palace_path)
return chromadb.PersistentClient(path=palace_path)
@@ -109,12 +377,31 @@ class ChromaBackend:
return chromadb.__version__
# ------------------------------------------------------------------
# Collection lifecycle
# BaseBackend surface
# ------------------------------------------------------------------
def get_collection(self, palace_path: str, collection_name: str, create: bool = False):
def get_collection(
self,
*args,
**kwargs,
) -> ChromaCollection:
"""Obtain a collection for a palace.
Supports two calling conventions during the RFC 001 transition:
* New (preferred): ``get_collection(palace=PalaceRef, collection_name=...,
create=False, options=None)``.
* Legacy: ``get_collection(palace_path, collection_name, create=False)``
— still used by callers not yet migrated.
"""
palace_ref, collection_name, create, options = _normalize_get_collection_args(args, kwargs)
palace_path = palace_ref.local_path
if palace_path is None:
raise PalaceNotFoundError("ChromaBackend requires PalaceRef.local_path")
if not create and not os.path.isdir(palace_path):
raise FileNotFoundError(palace_path)
raise PalaceNotFoundError(palace_path)
if create:
os.makedirs(palace_path, exist_ok=True)
@@ -124,29 +411,113 @@ class ChromaBackend:
pass
client = self._client(palace_path)
hnsw_space = "cosine"
if options and isinstance(options, dict):
hnsw_space = options.get("hnsw_space", hnsw_space)
if create:
collection = client.get_or_create_collection(
collection_name, metadata={"hnsw:space": "cosine"}
collection_name, metadata={"hnsw:space": hnsw_space}
)
else:
collection = client.get_collection(collection_name)
return ChromaCollection(collection)
def get_or_create_collection(
self, palace_path: str, collection_name: str
) -> "ChromaCollection":
"""Shorthand for get_collection(..., create=True)."""
def close_palace(self, palace) -> None:
"""Drop cached handles for ``palace``. Accepts ``PalaceRef`` or legacy path str."""
path = palace.local_path if isinstance(palace, PalaceRef) else palace
if path is None:
return
self._clients.pop(path, None)
self._freshness.pop(path, None)
def close(self) -> None:
self._clients.clear()
self._freshness.clear()
self._closed = True
def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus:
if self._closed:
return HealthStatus.unhealthy("backend closed")
return HealthStatus.healthy()
@classmethod
def detect(cls, path: str) -> bool:
return os.path.isfile(os.path.join(path, "chroma.sqlite3"))
# ------------------------------------------------------------------
# Legacy (pre-RFC 001) surface — retained while callers migrate.
# ------------------------------------------------------------------
def get_or_create_collection(self, palace_path: str, collection_name: str) -> ChromaCollection:
"""Legacy shim for ``get_collection(..., create=True)`` by path string."""
return self.get_collection(palace_path, collection_name, create=True)
def delete_collection(self, palace_path: str, collection_name: str) -> None:
"""Delete *collection_name* from the palace at *palace_path*."""
"""Delete ``collection_name`` from the palace at ``palace_path``."""
self._client(palace_path).delete_collection(collection_name)
def create_collection(
self, palace_path: str, collection_name: str, hnsw_space: str = "cosine"
) -> "ChromaCollection":
"""Create (not get-or-create) *collection_name* with cosine HNSW space."""
) -> ChromaCollection:
"""Create (not get-or-create) ``collection_name`` with the given HNSW space."""
collection = self._client(palace_path).create_collection(
collection_name, metadata={"hnsw:space": hnsw_space}
)
return ChromaCollection(collection)
def _normalize_get_collection_args(args, kwargs):
"""Unify legacy positional ``(palace_path, collection_name, create)`` calls
with the new kwargs-only ``(palace=PalaceRef, collection_name=..., create=...)``.
Returns ``(PalaceRef, collection_name, create, options)``.
"""
# New-style: palace= kwarg with a PalaceRef (spec path).
if "palace" in kwargs:
palace_ref = kwargs.pop("palace")
if not isinstance(palace_ref, PalaceRef):
raise TypeError("palace= must be a PalaceRef instance")
collection_name = kwargs.pop("collection_name")
create = kwargs.pop("create", False)
options = kwargs.pop("options", None)
if kwargs:
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
if args:
raise TypeError("positional args not allowed with palace= kwarg")
return palace_ref, collection_name, create, options
# Legacy: first positional is a path string.
if args:
palace_path = args[0]
rest = list(args[1:])
collection_name = kwargs.pop("collection_name", None) or (rest.pop(0) if rest else None)
if collection_name is None:
raise TypeError("collection_name is required")
create = kwargs.pop("create", False)
if rest:
create = rest.pop(0)
if kwargs:
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
return (
PalaceRef(id=palace_path, local_path=palace_path),
collection_name,
bool(create),
None,
)
# Legacy kwargs-only (palace_path=..., collection_name=..., create=...)
if "palace_path" in kwargs:
palace_path = kwargs.pop("palace_path")
collection_name = kwargs.pop("collection_name")
create = kwargs.pop("create", False)
if kwargs:
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
return (
PalaceRef(id=palace_path, local_path=palace_path),
collection_name,
bool(create),
None,
)
raise TypeError("get_collection requires palace= or a positional palace_path")
+189
View File
@@ -0,0 +1,189 @@
"""Backend registry + entry-point discovery (RFC 001 §3).
Third-party backends ship as installable packages that declare a
``mempalace.backends`` entry point::
# pyproject.toml of mempalace-postgres
[project.entry-points."mempalace.backends"]
postgres = "mempalace_postgres:PostgresBackend"
MemPalace discovers them at process start. In-tree tests and local development
can register manually via :func:`register`. Explicit registration wins on
name conflict (matches RFC 001 §3.2).
"""
from __future__ import annotations
import logging
from importlib import metadata
from threading import Lock
from typing import Optional, Type
from .base import BaseBackend
logger = logging.getLogger(__name__)
_ENTRY_POINT_GROUP = "mempalace.backends"
_registry: dict[str, Type[BaseBackend]] = {}
_instances: dict[str, BaseBackend] = {}
_explicit: set[str] = set()
_discovered = False
_lock = Lock()
def register(name: str, backend_cls: Type[BaseBackend]) -> None:
"""Register ``backend_cls`` under ``name``.
Explicit registration wins over entry-point discovery on conflict
(RFC 001 §3.2).
"""
with _lock:
_registry[name] = backend_cls
_explicit.add(name)
# Invalidate any cached instance so the new class is used on next get.
_instances.pop(name, None)
def unregister(name: str) -> None:
"""Remove a backend registration (primarily for tests)."""
with _lock:
_registry.pop(name, None)
_explicit.discard(name)
_instances.pop(name, None)
def _discover_entry_points() -> None:
"""Load entry-point-declared backends once per process."""
global _discovered
if _discovered:
return
with _lock:
if _discovered:
return
try:
eps = metadata.entry_points()
# Py ≥ 3.10 returns an EntryPoints object; older versions returned a dict.
group = (
eps.select(group=_ENTRY_POINT_GROUP)
if hasattr(eps, "select")
else eps.get(_ENTRY_POINT_GROUP, [])
)
except Exception:
logger.exception("entry-point discovery for %s failed", _ENTRY_POINT_GROUP)
group = []
for ep in group:
if ep.name in _explicit:
continue # explicit registration wins
try:
cls = ep.load()
except Exception:
logger.exception("failed to load backend entry point %r", ep.name)
continue
if not isinstance(cls, type) or not issubclass(cls, BaseBackend):
logger.warning(
"entry point %r did not resolve to a BaseBackend subclass (got %r)",
ep.name,
cls,
)
continue
_registry.setdefault(ep.name, cls)
_discovered = True
def available_backends() -> list[str]:
"""Return sorted list of all registered backend names."""
_discover_entry_points()
return sorted(_registry.keys())
def get_backend_class(name: str) -> Type[BaseBackend]:
"""Return the registered backend class for ``name``."""
_discover_entry_points()
try:
return _registry[name]
except KeyError as e:
raise KeyError(f"unknown backend {name!r}; available: {available_backends()}") from e
def get_backend(name: str) -> BaseBackend:
"""Return a long-lived instance of the named backend.
Instances are cached per-name; repeated calls return the same object.
Call :func:`reset_backends` in tests that need isolation.
"""
_discover_entry_points()
with _lock:
inst = _instances.get(name)
if inst is not None:
return inst
cls = _registry.get(name)
if cls is None:
raise KeyError(f"unknown backend {name!r}; available: {sorted(_registry.keys())}")
inst = cls()
_instances[name] = inst
return inst
def reset_backends() -> None:
"""Close and drop all cached backend instances (primarily for tests)."""
with _lock:
for inst in _instances.values():
try:
inst.close()
except Exception:
logger.exception("error closing backend during reset")
_instances.clear()
def resolve_backend_for_palace(
*,
explicit: Optional[str] = None,
config_value: Optional[str] = None,
env_value: Optional[str] = None,
palace_path: Optional[str] = None,
default: str = "chroma",
) -> str:
"""Resolve the backend name for a palace per RFC 001 §3.3 priority order.
1. Explicit kwarg / CLI flag
2. Per-palace config value
3. ``MEMPALACE_BACKEND`` env var
4. Auto-detect from on-disk artifacts (migration/upgrade path only)
5. Default (``chroma``)
Auto-detection is strictly a migration aid: it fires only when a local path
is presented, no earlier rule has chosen a backend, AND the path already
contains backend-identifiable artifacts. For new palaces, (5) wins.
"""
for candidate in (explicit, config_value, env_value):
if candidate:
return candidate
_discover_entry_points()
if palace_path:
for name, cls in _registry.items():
try:
if cls.detect(palace_path):
return name
except Exception:
logger.exception("detect() raised on backend %r", name)
continue
return default
# ---------------------------------------------------------------------------
# Built-in registration
# ---------------------------------------------------------------------------
def _register_builtins() -> None:
"""Register chroma as the in-tree default."""
from .chroma import ChromaBackend
# Use setdefault semantics so a caller that pre-registered for tests wins.
if "chroma" not in _registry:
_registry["chroma"] = ChromaBackend
_register_builtins()
+12 -12
View File
@@ -30,15 +30,16 @@ class SearchError(Exception):
_TOKEN_RE = re.compile(r"\w{2,}", re.UNICODE)
def _first_or_empty(results: dict, key: str) -> list:
"""Return the first inner list of a ChromaDB query result, or [].
def _first_or_empty(results, key: str) -> list:
"""Return the first inner list of a query result field, or [].
ChromaDB returns shapes like ``{"documents": [["a", "b"]], ...}`` for a
successful query, but ``{"documents": [], ...}`` (empty outer list) when
the collection is empty or the filter excludes everything. Indexing
``[0]`` blindly raises IndexError in that case (issue #195).
Accepts both the typed :class:`QueryResult` (attribute access) and the
pre-typed chroma dict shape; this polymorphism is retained so test mocks
still work and callers mid-migration do not crash. Preserves the empty-
collection semantics from issue #195: when no queries returned hits, the
outer list may be empty and indexing ``[0]`` would raise.
"""
outer = results.get(key)
outer = getattr(results, key, None) if not isinstance(results, dict) else results.get(key)
if not outer:
return []
return outer[0] or []
@@ -209,7 +210,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra
return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None}
indexed_docs = []
for doc, meta in zip(neighbors.get("documents") or [], neighbors.get("metadatas") or []):
for doc, meta in zip(neighbors.documents, neighbors.metadatas):
ci = meta.get("chunk_index")
if isinstance(ci, int):
indexed_docs.append((ci, doc))
@@ -224,8 +225,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra
total_drawers = None
try:
all_meta = drawers_col.get(where={"source_file": src}, include=["metadatas"])
ids = all_meta.get("ids") or []
total_drawers = len(ids) if ids else None
total_drawers = len(all_meta.ids) if all_meta.ids else None
except Exception:
pass
@@ -451,8 +451,8 @@ def search_memories(
)
except Exception:
continue
docs = source_drawers.get("documents") or []
metas_ = source_drawers.get("metadatas") or []
docs = source_drawers.documents
metas_ = source_drawers.metadatas
if len(docs) <= 1:
continue
+3
View File
@@ -37,6 +37,9 @@ Repository = "https://github.com/MemPalace/mempalace"
[project.scripts]
mempalace = "mempalace.cli:main"
[project.entry-points."mempalace.backends"]
chroma = "mempalace.backends.chroma:ChromaBackend"
[project.optional-dependencies]
dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"]
spellcheck = ["autocorrect>=2.0"]
+138 -15
View File
@@ -3,12 +3,34 @@ import sqlite3
import chromadb
import pytest
from mempalace.backends import (
GetResult,
PalaceRef,
QueryResult,
UnsupportedFilterError,
available_backends,
get_backend,
)
from mempalace.backends.chroma import ChromaBackend, ChromaCollection, _fix_blob_seq_ids
class _FakeCollection:
def __init__(self):
"""Stand-in for a chromadb.Collection returning raw chroma-shaped dicts."""
def __init__(self, query_response=None, get_response=None, count_value=7):
self.calls = []
self._query_response = query_response or {
"ids": [["a", "b"]],
"documents": [["da", "db"]],
"metadatas": [[{"wing": "w1"}, {"wing": "w2"}]],
"distances": [[0.1, 0.2]],
}
self._get_response = get_response or {
"ids": ["a"],
"documents": ["da"],
"metadatas": [{"wing": "w1"}],
}
self._count_value = count_value
def add(self, **kwargs):
self.calls.append(("add", kwargs))
@@ -16,41 +38,142 @@ class _FakeCollection:
def upsert(self, **kwargs):
self.calls.append(("upsert", kwargs))
def update(self, **kwargs):
self.calls.append(("update", kwargs))
def query(self, **kwargs):
self.calls.append(("query", kwargs))
return {"kind": "query"}
return self._query_response
def get(self, **kwargs):
self.calls.append(("get", kwargs))
return {"kind": "get"}
return self._get_response
def delete(self, **kwargs):
self.calls.append(("delete", kwargs))
def count(self):
self.calls.append(("count", {}))
return 7
return self._count_value
def test_chroma_collection_delegates_methods():
def test_chroma_collection_returns_typed_query_result():
fake = _FakeCollection()
collection = ChromaCollection(fake)
result = collection.query(query_texts=["q"])
assert isinstance(result, QueryResult)
assert result.ids == [["a", "b"]]
assert result.documents == [["da", "db"]]
assert result.metadatas == [[{"wing": "w1"}, {"wing": "w2"}]]
assert result.distances == [[0.1, 0.2]]
assert result.embeddings is None
def test_chroma_collection_returns_typed_get_result():
fake = _FakeCollection()
collection = ChromaCollection(fake)
result = collection.get(where={"wing": "w1"})
assert isinstance(result, GetResult)
assert result.ids == ["a"]
assert result.documents == ["da"]
assert result.metadatas == [{"wing": "w1"}]
def test_query_result_empty_preserves_outer_dimension():
empty = QueryResult.empty(num_queries=2)
assert empty.ids == [[], []]
assert empty.documents == [[], []]
assert empty.distances == [[], []]
assert empty.embeddings is None
def test_typed_results_support_dict_compat_access():
"""Transitional compat shim per base.py — retained until callers migrate to attrs."""
result = GetResult(ids=["a"], documents=["da"], metadatas=[{"w": 1}])
assert result["ids"] == ["a"]
assert result.get("documents") == ["da"]
assert result.get("missing", "default") == "default"
assert "ids" in result
assert "missing" not in result
def test_chroma_collection_query_empty_result_preserves_outer_shape():
fake = _FakeCollection(
query_response={"ids": [], "documents": [], "metadatas": [], "distances": []}
)
collection = ChromaCollection(fake)
result = collection.query(query_texts=["q1", "q2"])
assert result.ids == [[], []]
assert result.documents == [[], []]
assert result.distances == [[], []]
def test_chroma_collection_rejects_unknown_where_operator():
fake = _FakeCollection()
collection = ChromaCollection(fake)
with pytest.raises(UnsupportedFilterError):
collection.query(query_texts=["q"], where={"$regex": "foo"})
def test_chroma_collection_delegates_writes():
fake = _FakeCollection()
collection = ChromaCollection(fake)
collection.add(documents=["d"], ids=["1"], metadatas=[{"wing": "w"}])
collection.upsert(documents=["u"], ids=["2"], metadatas=[{"room": "r"}])
assert collection.query(query_texts=["q"]) == {"kind": "query"}
assert collection.get(where={"wing": "w"}) == {"kind": "get"}
collection.delete(ids=["1"])
assert collection.count() == 7
assert fake.calls == [
("add", {"documents": ["d"], "ids": ["1"], "metadatas": [{"wing": "w"}]}),
("upsert", {"documents": ["u"], "ids": ["2"], "metadatas": [{"room": "r"}]}),
("query", {"query_texts": ["q"]}),
("get", {"where": {"wing": "w"}}),
("delete", {"ids": ["1"]}),
("count", {}),
]
kinds = [call[0] for call in fake.calls]
assert kinds == ["add", "upsert", "delete", "count"]
def test_registry_exposes_chroma_by_default():
names = available_backends()
assert "chroma" in names
assert isinstance(get_backend("chroma"), ChromaBackend)
def test_registry_unknown_backend_raises():
with pytest.raises(KeyError):
get_backend("no-such-backend-exists")
def test_resolve_backend_priority_order(tmp_path):
from mempalace.backends import resolve_backend_for_palace
# explicit kwarg wins over everything
assert resolve_backend_for_palace(explicit="pg", config_value="lance") == "pg"
# config value wins over env / default
assert resolve_backend_for_palace(config_value="lance", env_value="qdrant") == "lance"
# env wins over default
assert resolve_backend_for_palace(env_value="qdrant", default="chroma") == "qdrant"
# falls back to default
assert resolve_backend_for_palace() == "chroma"
def test_chroma_detect_matches_palace_with_chroma_sqlite(tmp_path):
(tmp_path / "chroma.sqlite3").write_bytes(b"")
assert ChromaBackend.detect(str(tmp_path)) is True
assert ChromaBackend.detect(str(tmp_path.parent)) is False
def test_chroma_backend_accepts_palace_ref_kwarg(tmp_path):
palace_path = tmp_path / "palace"
backend = ChromaBackend()
collection = backend.get_collection(
palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)),
collection_name="mempalace_drawers",
create=True,
)
assert palace_path.is_dir()
assert isinstance(collection, ChromaCollection)
def test_chroma_backend_create_false_raises_without_creating_directory(tmp_path):