refactor(backends): typed QueryResult/GetResult, PalaceRef, BaseBackend registry (RFC 001 §10)
Advances RFC 001 §10 cleanup so backend-author PRs (#574 LanceDB, #665 Postgres, #700 Qdrant, #697 hosted, #643 PalaceStore, #381 Qdrant) have a stable target to align against. Scope (this PR): - Typed QueryResult / GetResult dataclasses replace Chroma's dict shape at the BaseCollection boundary (§1.3). A transitional _DictCompatMixin keeps existing callers working while the attribute-access migration proceeds. - BaseCollection is now kwargs-only across add/upsert/query/get/delete/update with ABC defaults for estimated_count/close/health and a non-atomic default update() (§1.1–1.2). - PalaceRef replaces raw path strings at the backend boundary (§2.2). - BaseBackend ABC with get_collection/close_palace/close/health/detect (§2.3). - mempalace.backends entry-point group + in-tree registry with resolve_backend_for_palace priority order matching §3.2–3.3. - ChromaCollection normalizes chroma returns into typed results; unknown where-clause operators raise UnsupportedFilterError (no silent drop, §1.4). - ChromaBackend absorbs the inode/mtime client-cache freshness check previously duplicated in mcp_server._get_client() (§10 + PR #757). - searcher.py migrated to typed-attribute access as the reference call site; remaining callers land in a follow-up. - pyproject: chroma registered via [project.entry-points."mempalace.backends"]. Out of scope (explicit follow-ups): - Full caller migration off the dict-compat shim across palace.py, mcp_server.py, miner.py, convo_miner.py, dedup.py, repair.py, exporter.py, palace_graph.py, cli.py, closet_llm.py. - Embedder injection + three-state EmbedderIdentityMismatchError check (§1.5). - maintenance_state() / run_maintenance() benchmark hooks (§7.3). - AbstractBackendContractSuite full coverage (§7.1–7.2). - mempalace migrate / mempalace verify CLI rewrites through BaseCollection (§8). Tests: 970 passed (up from 967 on develop); new coverage for typed results, empty-result outer-shape preservation, \$regex rejection, registry lookup, priority resolver, and PalaceRef-kwarg ChromaBackend.get_collection. Refs: #743 (RFC 001), #989 (RFC 002 tracking issue).
This commit is contained in:
@@ -1,6 +1,64 @@
|
||||
"""Storage backend implementations for MemPalace."""
|
||||
"""Storage backend implementations for MemPalace (RFC 001).
|
||||
|
||||
from .base import BaseCollection
|
||||
Public surface:
|
||||
|
||||
* :class:`BaseCollection` — per-collection read/write contract.
|
||||
* :class:`BaseBackend` — per-palace factory contract.
|
||||
* :class:`PalaceRef` — value object identifying a palace for a backend.
|
||||
* :class:`QueryResult` / :class:`GetResult` — typed read returns.
|
||||
* Error classes: :class:`PalaceNotFoundError`, :class:`BackendClosedError`,
|
||||
:class:`UnsupportedFilterError`, :class:`DimensionMismatchError`,
|
||||
:class:`EmbedderIdentityMismatchError`.
|
||||
* Registry: :func:`get_backend`, :func:`register`, :func:`available_backends`,
|
||||
:func:`resolve_backend_for_palace`.
|
||||
* In-tree Chroma default: :class:`ChromaBackend`, :class:`ChromaCollection`.
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
BackendClosedError,
|
||||
BackendError,
|
||||
BaseBackend,
|
||||
BaseCollection,
|
||||
DimensionMismatchError,
|
||||
EmbedderIdentityMismatchError,
|
||||
GetResult,
|
||||
HealthStatus,
|
||||
PalaceNotFoundError,
|
||||
PalaceRef,
|
||||
QueryResult,
|
||||
UnsupportedFilterError,
|
||||
)
|
||||
from .chroma import ChromaBackend, ChromaCollection
|
||||
from .registry import (
|
||||
available_backends,
|
||||
get_backend,
|
||||
get_backend_class,
|
||||
register,
|
||||
reset_backends,
|
||||
resolve_backend_for_palace,
|
||||
unregister,
|
||||
)
|
||||
|
||||
__all__ = ["BaseCollection", "ChromaBackend", "ChromaCollection"]
|
||||
__all__ = [
|
||||
"BackendClosedError",
|
||||
"BackendError",
|
||||
"BaseBackend",
|
||||
"BaseCollection",
|
||||
"ChromaBackend",
|
||||
"ChromaCollection",
|
||||
"DimensionMismatchError",
|
||||
"EmbedderIdentityMismatchError",
|
||||
"GetResult",
|
||||
"HealthStatus",
|
||||
"PalaceNotFoundError",
|
||||
"PalaceRef",
|
||||
"QueryResult",
|
||||
"UnsupportedFilterError",
|
||||
"available_backends",
|
||||
"get_backend",
|
||||
"get_backend_class",
|
||||
"register",
|
||||
"reset_backends",
|
||||
"resolve_backend_for_palace",
|
||||
"unregister",
|
||||
]
|
||||
|
||||
+332
-27
@@ -1,49 +1,354 @@
|
||||
"""Abstract collection interface for MemPalace storage backends."""
|
||||
"""Storage backend contract for MemPalace (RFC 001).
|
||||
|
||||
This module defines the surface every storage backend must implement:
|
||||
|
||||
* ``BaseCollection`` — the per-collection read/write interface, kwargs-only.
|
||||
* ``BaseBackend`` — the per-palace factory, addressed by ``PalaceRef``.
|
||||
* ``QueryResult`` / ``GetResult`` — typed result dataclasses that replace the
|
||||
Chroma dict shape as the canonical return type.
|
||||
* Error classes + ``HealthStatus`` — uniform across backends.
|
||||
|
||||
This is the v1 cleanup from RFC 001 §10: full typed results, ``PalaceRef``,
|
||||
registry-ready ABC. Embedder injection, maintenance hooks, and the full
|
||||
conformance suite land in follow-up PRs.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BackendError(Exception):
|
||||
"""Base class for every storage-backend error raised by core."""
|
||||
|
||||
|
||||
class PalaceNotFoundError(BackendError, FileNotFoundError):
|
||||
"""Raised when ``get_collection(create=False)`` is called on a missing palace.
|
||||
|
||||
Subclass of ``FileNotFoundError`` so legacy callers that catch the latter
|
||||
(pre-#413 seam) keep working unchanged.
|
||||
"""
|
||||
|
||||
|
||||
class BackendClosedError(BackendError):
|
||||
"""Raised when a backend method is called after ``close()``."""
|
||||
|
||||
|
||||
class UnsupportedFilterError(BackendError):
|
||||
"""Raised when a where-clause uses an operator the backend does not implement.
|
||||
|
||||
Silent dropping of unknown operators is forbidden by spec (RFC 001 §1.4).
|
||||
"""
|
||||
|
||||
|
||||
class DimensionMismatchError(BackendError):
|
||||
"""Raised when the embedding dimension on write does not match the collection."""
|
||||
|
||||
|
||||
class EmbedderIdentityMismatchError(BackendError):
|
||||
"""Raised when the stored embedder model name differs from the current one."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Value objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PalaceRef:
|
||||
"""A handle to a palace, consumed by backends.
|
||||
|
||||
``id`` is always present and is the key backends use to cache handles.
|
||||
``local_path`` is populated for filesystem-rooted palaces.
|
||||
``namespace`` is used by server-mode backends for tenant / prefix routing.
|
||||
"""
|
||||
|
||||
id: str
|
||||
local_path: Optional[str] = None
|
||||
namespace: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HealthStatus:
|
||||
ok: bool
|
||||
detail: str = ""
|
||||
|
||||
@classmethod
|
||||
def healthy(cls, detail: str = "") -> "HealthStatus":
|
||||
return cls(ok=True, detail=detail)
|
||||
|
||||
@classmethod
|
||||
def unhealthy(cls, detail: str) -> "HealthStatus":
|
||||
return cls(ok=False, detail=detail)
|
||||
|
||||
|
||||
_TYPED_RESULT_FIELDS = ("ids", "documents", "metadatas", "distances", "embeddings")
|
||||
|
||||
|
||||
class _DictCompatMixin:
|
||||
"""Transitional dict-protocol access for typed results.
|
||||
|
||||
RFC 001 §1.3 spec is attribute access (``result.ids``). The ``result["ids"]``
|
||||
and ``result.get("ids")`` forms are retained as a migration shim for callers
|
||||
that predate the typed interface and are scheduled for removal in a follow-
|
||||
up cleanup. New code MUST use attribute access.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
if key in _TYPED_RESULT_FIELDS:
|
||||
return getattr(self, key)
|
||||
raise KeyError(key)
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
if key in _TYPED_RESULT_FIELDS:
|
||||
val = getattr(self, key, default)
|
||||
return default if val is None else val
|
||||
return default
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return key in _TYPED_RESULT_FIELDS and getattr(self, key, None) is not None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QueryResult(_DictCompatMixin):
|
||||
"""Typed return from ``BaseCollection.query``.
|
||||
|
||||
Outer list dimension = number of query vectors / texts.
|
||||
Inner list dimension = hits per query (may be zero).
|
||||
|
||||
Fields not in ``include=`` at the call site are populated with empty lists
|
||||
of the correct outer shape (never ``None``), except ``embeddings`` which
|
||||
is ``None`` when not requested.
|
||||
"""
|
||||
|
||||
ids: list[list[str]]
|
||||
documents: list[list[str]]
|
||||
metadatas: list[list[dict]]
|
||||
distances: list[list[float]]
|
||||
embeddings: Optional[list[list[list[float]]]] = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls, num_queries: int = 1) -> "QueryResult":
|
||||
"""Construct an all-empty result preserving outer dimension."""
|
||||
return cls(
|
||||
ids=[[] for _ in range(num_queries)],
|
||||
documents=[[] for _ in range(num_queries)],
|
||||
metadatas=[[] for _ in range(num_queries)],
|
||||
distances=[[] for _ in range(num_queries)],
|
||||
embeddings=None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetResult(_DictCompatMixin):
|
||||
"""Typed return from ``BaseCollection.get``."""
|
||||
|
||||
ids: list[str]
|
||||
documents: list[str]
|
||||
metadatas: list[dict]
|
||||
embeddings: Optional[list[list[float]]] = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "GetResult":
|
||||
return cls(ids=[], documents=[], metadatas=[], embeddings=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Collection contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BaseCollection(ABC):
|
||||
"""Smallest collection contract the rest of MemPalace relies on."""
|
||||
"""Per-collection read/write surface every backend must implement."""
|
||||
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
*,
|
||||
documents: List[str],
|
||||
ids: List[str],
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
documents: list[str],
|
||||
ids: list[str],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
embeddings: Optional[list[list[float]]] = None,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def upsert(
|
||||
self,
|
||||
*,
|
||||
documents: List[str],
|
||||
ids: List[str],
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
documents: list[str],
|
||||
ids: list[str],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
embeddings: Optional[list[list[float]]] = None,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
query_texts: Optional[list[str]] = None,
|
||||
query_embeddings: Optional[list[list[float]]] = None,
|
||||
n_results: int = 10,
|
||||
where: Optional[dict] = None,
|
||||
where_document: Optional[dict] = None,
|
||||
include: Optional[list[str]] = None,
|
||||
) -> QueryResult: ...
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
*,
|
||||
ids: Optional[list[str]] = None,
|
||||
where: Optional[dict] = None,
|
||||
where_document: Optional[dict] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
include: Optional[list[str]] = None,
|
||||
) -> GetResult: ...
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
*,
|
||||
ids: Optional[list[str]] = None,
|
||||
where: Optional[dict] = None,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def count(self) -> int: ...
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optional methods with ABC defaults (spec §1.2)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def estimated_count(self) -> int:
|
||||
return self.count()
|
||||
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
||||
def health(self) -> HealthStatus:
|
||||
return HealthStatus.healthy()
|
||||
|
||||
def update(
|
||||
self,
|
||||
*,
|
||||
ids: list[str],
|
||||
documents: Optional[list[str]] = None,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
embeddings: Optional[list[list[float]]] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
"""Default non-atomic update: get + merge + upsert.
|
||||
|
||||
Backends advertising ``supports_update`` MUST override with an atomic
|
||||
single-round-trip implementation.
|
||||
"""
|
||||
if documents is None and metadatas is None and embeddings is None:
|
||||
raise ValueError("update requires at least one of documents, metadatas, embeddings")
|
||||
|
||||
existing = self.get(ids=ids, include=["documents", "metadatas"])
|
||||
by_id = {
|
||||
rid: (existing.documents[i], existing.metadatas[i])
|
||||
for i, rid in enumerate(existing.ids)
|
||||
}
|
||||
merged_docs: list[str] = []
|
||||
merged_metas: list[dict] = []
|
||||
for i, rid in enumerate(ids):
|
||||
prev_doc, prev_meta = by_id.get(rid, ("", {}))
|
||||
merged_docs.append(documents[i] if documents is not None else prev_doc)
|
||||
new_meta = dict(prev_meta or {})
|
||||
if metadatas is not None:
|
||||
new_meta.update(metadatas[i] or {})
|
||||
merged_metas.append(new_meta)
|
||||
self.upsert(
|
||||
documents=merged_docs,
|
||||
ids=list(ids),
|
||||
metadatas=merged_metas,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BaseBackend(ABC):
|
||||
"""Long-lived factory serving many palaces (RFC 001 §2).
|
||||
|
||||
Instances are lightweight on construction — no I/O, no network. All
|
||||
connection work is deferred to ``get_collection``. Instances are thread-
|
||||
safe for concurrent ``get_collection`` calls across different palaces.
|
||||
"""
|
||||
|
||||
name: ClassVar[str]
|
||||
spec_version: ClassVar[str] = "1.0"
|
||||
capabilities: ClassVar[frozenset[str]] = frozenset()
|
||||
|
||||
@abstractmethod
|
||||
def update(self, **kwargs: Any) -> None:
|
||||
"""Update existing records. Must raise if any ID is missing."""
|
||||
raise NotImplementedError
|
||||
def get_collection(
|
||||
self,
|
||||
*,
|
||||
palace: PalaceRef,
|
||||
collection_name: str,
|
||||
create: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
) -> BaseCollection: ...
|
||||
|
||||
@abstractmethod
|
||||
def query(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
def close_palace(self, palace: PalaceRef) -> None:
|
||||
"""Evict cached handles for a single palace. Default: no-op."""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def get(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
def close(self) -> None:
|
||||
"""Shut down the entire backend. Default: no-op."""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, **kwargs: Any) -> None:
|
||||
raise NotImplementedError
|
||||
def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus:
|
||||
return HealthStatus.healthy()
|
||||
|
||||
@abstractmethod
|
||||
def count(self) -> int:
|
||||
raise NotImplementedError
|
||||
# Optional detection hint used by selection priority (RFC 001 §3.3 (4)):
|
||||
@classmethod
|
||||
def detect(cls, path: str) -> bool: # pragma: no cover - default hook
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Keys the Chroma ``include=`` parameter accepts.
|
||||
_VALID_INCLUDE_KEYS = frozenset({"documents", "metadatas", "distances", "embeddings"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class _IncludeSpec:
|
||||
"""Resolve an ``include=`` parameter with spec-mandated defaults."""
|
||||
|
||||
documents: bool = True
|
||||
metadatas: bool = True
|
||||
distances: bool = True # only meaningful for query
|
||||
embeddings: bool = False
|
||||
|
||||
@classmethod
|
||||
def resolve(
|
||||
cls, include: Optional[list[str]], *, default_distances: bool = True
|
||||
) -> "_IncludeSpec":
|
||||
if include is None:
|
||||
return cls(
|
||||
documents=True,
|
||||
metadatas=True,
|
||||
distances=default_distances,
|
||||
embeddings=False,
|
||||
)
|
||||
keys = {k for k in include if k in _VALID_INCLUDE_KEYS}
|
||||
return cls(
|
||||
documents="documents" in keys,
|
||||
metadatas="metadatas" in keys,
|
||||
distances="distances" in keys,
|
||||
embeddings="embeddings" in keys,
|
||||
)
|
||||
|
||||
+408
-37
@@ -1,17 +1,54 @@
|
||||
"""ChromaDB-backed MemPalace collection adapter."""
|
||||
"""ChromaDB-backed MemPalace storage backend (RFC 001 reference implementation)."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from typing import Any, Optional
|
||||
|
||||
import chromadb
|
||||
|
||||
from .base import BaseCollection
|
||||
from .base import (
|
||||
BaseBackend,
|
||||
BaseCollection,
|
||||
GetResult,
|
||||
HealthStatus,
|
||||
PalaceNotFoundError,
|
||||
PalaceRef,
|
||||
QueryResult,
|
||||
UnsupportedFilterError,
|
||||
_IncludeSpec,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _fix_blob_seq_ids(palace_path: str):
|
||||
_REQUIRED_OPERATORS = frozenset({"$eq", "$ne", "$in", "$nin", "$and", "$or", "$contains"})
|
||||
_OPTIONAL_OPERATORS = frozenset({"$gt", "$gte", "$lt", "$lte"})
|
||||
_SUPPORTED_OPERATORS = _REQUIRED_OPERATORS | _OPTIONAL_OPERATORS
|
||||
|
||||
|
||||
def _validate_where(where: Optional[dict]) -> None:
|
||||
"""Scan a where-clause for unknown operators and raise ``UnsupportedFilterError``.
|
||||
|
||||
Spec (RFC 001 §1.4): silent dropping of unknown operators is forbidden.
|
||||
"""
|
||||
if not where:
|
||||
return
|
||||
stack = [where]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
for k, v in node.items():
|
||||
if k.startswith("$") and k not in _SUPPORTED_OPERATORS:
|
||||
raise UnsupportedFilterError(f"operator {k!r} not supported by chroma backend")
|
||||
if isinstance(v, dict):
|
||||
stack.append(v)
|
||||
elif isinstance(v, list):
|
||||
stack.extend(x for x in v if isinstance(x, dict))
|
||||
|
||||
|
||||
def _fix_blob_seq_ids(palace_path: str) -> None:
|
||||
"""Fix ChromaDB 0.6.x -> 1.5.x migration bug: BLOB seq_ids -> INTEGER.
|
||||
|
||||
ChromaDB 0.6.x stored seq_id as big-endian 8-byte BLOBs. ChromaDB 1.5.x
|
||||
@@ -43,62 +80,293 @@ def _fix_blob_seq_ids(palace_path: str):
|
||||
logger.exception("Could not fix BLOB seq_ids in %s", db_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Collection adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _as_list(v: Any) -> list:
|
||||
"""Coerce possibly-None scalar-or-list into a list (defensive for chroma nulls)."""
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
return [v]
|
||||
|
||||
|
||||
class ChromaCollection(BaseCollection):
|
||||
"""Thin adapter over a ChromaDB collection."""
|
||||
"""Thin adapter translating ChromaDB dict returns into typed results."""
|
||||
|
||||
def __init__(self, collection):
|
||||
self._collection = collection
|
||||
|
||||
def add(self, *, documents, ids, metadatas=None):
|
||||
self._collection.add(documents=documents, ids=ids, metadatas=metadatas)
|
||||
# ------------------------------------------------------------------
|
||||
# Writes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def upsert(self, *, documents, ids, metadatas=None):
|
||||
self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas)
|
||||
def add(self, *, documents, ids, metadatas=None, embeddings=None):
|
||||
kwargs: dict[str, Any] = {"documents": documents, "ids": ids}
|
||||
if metadatas is not None:
|
||||
kwargs["metadatas"] = metadatas
|
||||
if embeddings is not None:
|
||||
kwargs["embeddings"] = embeddings
|
||||
self._collection.add(**kwargs)
|
||||
|
||||
def update(self, **kwargs):
|
||||
def upsert(self, *, documents, ids, metadatas=None, embeddings=None):
|
||||
kwargs: dict[str, Any] = {"documents": documents, "ids": ids}
|
||||
if metadatas is not None:
|
||||
kwargs["metadatas"] = metadatas
|
||||
if embeddings is not None:
|
||||
kwargs["embeddings"] = embeddings
|
||||
self._collection.upsert(**kwargs)
|
||||
|
||||
def update(
|
||||
self,
|
||||
*,
|
||||
ids,
|
||||
documents=None,
|
||||
metadatas=None,
|
||||
embeddings=None,
|
||||
):
|
||||
if documents is None and metadatas is None and embeddings is None:
|
||||
raise ValueError("update requires at least one of documents, metadatas, embeddings")
|
||||
kwargs: dict[str, Any] = {"ids": ids}
|
||||
if documents is not None:
|
||||
kwargs["documents"] = documents
|
||||
if metadatas is not None:
|
||||
kwargs["metadatas"] = metadatas
|
||||
if embeddings is not None:
|
||||
kwargs["embeddings"] = embeddings
|
||||
self._collection.update(**kwargs)
|
||||
|
||||
def query(self, **kwargs):
|
||||
return self._collection.query(**kwargs)
|
||||
# ------------------------------------------------------------------
|
||||
# Reads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get(self, **kwargs):
|
||||
return self._collection.get(**kwargs)
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
query_texts=None,
|
||||
query_embeddings=None,
|
||||
n_results=10,
|
||||
where=None,
|
||||
where_document=None,
|
||||
include=None,
|
||||
) -> QueryResult:
|
||||
_validate_where(where)
|
||||
_validate_where(where_document)
|
||||
|
||||
def delete(self, **kwargs):
|
||||
spec = _IncludeSpec.resolve(include, default_distances=True)
|
||||
chroma_include: list[str] = []
|
||||
if spec.documents:
|
||||
chroma_include.append("documents")
|
||||
if spec.metadatas:
|
||||
chroma_include.append("metadatas")
|
||||
if spec.distances:
|
||||
chroma_include.append("distances")
|
||||
if spec.embeddings:
|
||||
chroma_include.append("embeddings")
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"n_results": n_results,
|
||||
"include": chroma_include,
|
||||
}
|
||||
if query_texts is not None:
|
||||
kwargs["query_texts"] = query_texts
|
||||
if query_embeddings is not None:
|
||||
kwargs["query_embeddings"] = query_embeddings
|
||||
if where is not None:
|
||||
kwargs["where"] = where
|
||||
if where_document is not None:
|
||||
kwargs["where_document"] = where_document
|
||||
|
||||
raw = self._collection.query(**kwargs)
|
||||
|
||||
num_queries = (
|
||||
len(query_texts)
|
||||
if query_texts is not None
|
||||
else (len(query_embeddings) if query_embeddings is not None else 1)
|
||||
)
|
||||
|
||||
ids = raw.get("ids") or []
|
||||
if not ids:
|
||||
return QueryResult.empty(num_queries=num_queries)
|
||||
|
||||
documents = raw.get("documents") or [[] for _ in ids]
|
||||
metadatas = raw.get("metadatas") or [[] for _ in ids]
|
||||
distances = raw.get("distances") or [[] for _ in ids]
|
||||
embeddings_raw = raw.get("embeddings") if spec.embeddings else None
|
||||
|
||||
def _none_list_to_empty(outer):
|
||||
return [(inner or []) for inner in outer]
|
||||
|
||||
return QueryResult(
|
||||
ids=_none_list_to_empty(ids),
|
||||
documents=_none_list_to_empty(documents),
|
||||
metadatas=_none_list_to_empty(metadatas),
|
||||
distances=_none_list_to_empty(distances),
|
||||
embeddings=(
|
||||
[list(inner) for inner in embeddings_raw]
|
||||
if spec.embeddings and embeddings_raw is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
*,
|
||||
ids=None,
|
||||
where=None,
|
||||
where_document=None,
|
||||
limit=None,
|
||||
offset=None,
|
||||
include=None,
|
||||
) -> GetResult:
|
||||
_validate_where(where)
|
||||
_validate_where(where_document)
|
||||
|
||||
spec = _IncludeSpec.resolve(include, default_distances=False)
|
||||
chroma_include: list[str] = []
|
||||
if spec.documents:
|
||||
chroma_include.append("documents")
|
||||
if spec.metadatas:
|
||||
chroma_include.append("metadatas")
|
||||
if spec.embeddings:
|
||||
chroma_include.append("embeddings")
|
||||
|
||||
kwargs: dict[str, Any] = {"include": chroma_include}
|
||||
if ids is not None:
|
||||
kwargs["ids"] = ids
|
||||
if where is not None:
|
||||
kwargs["where"] = where
|
||||
if where_document is not None:
|
||||
kwargs["where_document"] = where_document
|
||||
if limit is not None:
|
||||
kwargs["limit"] = limit
|
||||
if offset is not None:
|
||||
kwargs["offset"] = offset
|
||||
|
||||
raw = self._collection.get(**kwargs)
|
||||
out_ids = list(raw.get("ids") or [])
|
||||
out_docs = list(raw.get("documents") or []) if spec.documents else []
|
||||
out_metas = list(raw.get("metadatas") or []) if spec.metadatas else []
|
||||
out_embeds = raw.get("embeddings") if spec.embeddings else None
|
||||
|
||||
# Pad doc/meta lists to match ids so downstream zipping is safe.
|
||||
if spec.documents and len(out_docs) < len(out_ids):
|
||||
out_docs = out_docs + [""] * (len(out_ids) - len(out_docs))
|
||||
if spec.metadatas and len(out_metas) < len(out_ids):
|
||||
out_metas = out_metas + [{}] * (len(out_ids) - len(out_metas))
|
||||
|
||||
return GetResult(
|
||||
ids=out_ids,
|
||||
documents=out_docs,
|
||||
metadatas=out_metas,
|
||||
embeddings=[list(v) for v in out_embeds] if out_embeds is not None else None,
|
||||
)
|
||||
|
||||
def delete(self, *, ids=None, where=None):
|
||||
_validate_where(where)
|
||||
kwargs: dict[str, Any] = {}
|
||||
if ids is not None:
|
||||
kwargs["ids"] = ids
|
||||
if where is not None:
|
||||
kwargs["where"] = where
|
||||
self._collection.delete(**kwargs)
|
||||
|
||||
def count(self):
|
||||
return self._collection.count()
|
||||
|
||||
|
||||
class ChromaBackend:
|
||||
"""Factory for MemPalace's default ChromaDB backend."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ChromaBackend(BaseBackend):
|
||||
"""MemPalace's default ChromaDB backend.
|
||||
|
||||
Maintains two caches:
|
||||
|
||||
* ``self._clients`` — ``palace_path -> PersistentClient`` for callers
|
||||
using the ``PalaceRef`` / :meth:`get_collection` path.
|
||||
* An inode+mtime freshness check absorbed from ``mcp_server._get_client``
|
||||
(merged via #757) ensuring a palace rebuild on disk is detected on the
|
||||
next :meth:`get_collection` call.
|
||||
"""
|
||||
|
||||
name = "chroma"
|
||||
capabilities = frozenset(
|
||||
{
|
||||
"supports_embeddings_in",
|
||||
"supports_embeddings_passthrough",
|
||||
"supports_embeddings_out",
|
||||
"supports_metadata_filters",
|
||||
"supports_contains_fast",
|
||||
"local_mode",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
# Per-instance client cache: palace_path -> chromadb.PersistentClient
|
||||
self._clients: dict = {}
|
||||
# palace_path -> PersistentClient
|
||||
self._clients: dict[str, Any] = {}
|
||||
# palace_path -> (inode, mtime) of chroma.sqlite3 at cache time.
|
||||
self._freshness: dict[str, tuple[int, float]] = {}
|
||||
self._closed = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _db_stat(palace_path: str) -> tuple[int, float]:
|
||||
"""Return ``(inode, mtime)`` of ``chroma.sqlite3`` or ``(0, 0.0)`` if absent."""
|
||||
db_path = os.path.join(palace_path, "chroma.sqlite3")
|
||||
try:
|
||||
st = os.stat(db_path)
|
||||
return (st.st_ino, st.st_mtime)
|
||||
except OSError:
|
||||
return (0, 0.0)
|
||||
|
||||
def _client(self, palace_path: str):
|
||||
"""Return a cached PersistentClient for *palace_path*, creating one if needed."""
|
||||
if palace_path not in self._clients:
|
||||
"""Return a cached ``PersistentClient``, rebuilding on inode/mtime change.
|
||||
|
||||
Handles the palace-rebuild case (repair/nuke/purge) by invalidating the
|
||||
cache when ``chroma.sqlite3`` changes on disk. FAT/exFAT return inode 0,
|
||||
so inode comparisons only fire when non-zero (matches #757 semantics).
|
||||
"""
|
||||
if self._closed:
|
||||
from .base import BackendClosedError # late import avoids cycles at module load
|
||||
|
||||
raise BackendClosedError("ChromaBackend has been closed")
|
||||
|
||||
cached = self._clients.get(palace_path)
|
||||
cached_inode, cached_mtime = self._freshness.get(palace_path, (0, 0.0))
|
||||
current_inode, current_mtime = self._db_stat(palace_path)
|
||||
|
||||
inode_changed = current_inode != 0 and cached_inode != 0 and current_inode != cached_inode
|
||||
mtime_changed = (
|
||||
current_mtime != 0.0 and cached_mtime != 0.0 and current_mtime > cached_mtime
|
||||
)
|
||||
|
||||
if cached is None or inode_changed or mtime_changed:
|
||||
_fix_blob_seq_ids(palace_path)
|
||||
self._clients[palace_path] = chromadb.PersistentClient(path=palace_path)
|
||||
return self._clients[palace_path]
|
||||
cached = chromadb.PersistentClient(path=palace_path)
|
||||
self._clients[palace_path] = cached
|
||||
self._freshness[palace_path] = (current_inode, current_mtime)
|
||||
return cached
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public static helpers (for callers that manage their own caching)
|
||||
# Public static helpers (legacy; prefer :meth:`get_collection`)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def make_client(palace_path: str):
|
||||
"""Create and return a fresh PersistentClient (fix BLOB seq_ids first).
|
||||
"""Create a fresh ``PersistentClient`` (fixes BLOB seq_ids first).
|
||||
|
||||
Intended for long-lived callers (e.g. mcp_server) that keep their own
|
||||
inode/mtime-based client cache.
|
||||
Deprecated-ish: exposed for legacy long-lived callers that manage their
|
||||
own client cache. New code should obtain a collection through
|
||||
:meth:`get_collection` which manages caching internally.
|
||||
"""
|
||||
_fix_blob_seq_ids(palace_path)
|
||||
return chromadb.PersistentClient(path=palace_path)
|
||||
@@ -109,12 +377,31 @@ class ChromaBackend:
|
||||
return chromadb.__version__
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Collection lifecycle
|
||||
# BaseBackend surface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_collection(self, palace_path: str, collection_name: str, create: bool = False):
|
||||
def get_collection(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> ChromaCollection:
|
||||
"""Obtain a collection for a palace.
|
||||
|
||||
Supports two calling conventions during the RFC 001 transition:
|
||||
|
||||
* New (preferred): ``get_collection(palace=PalaceRef, collection_name=...,
|
||||
create=False, options=None)``.
|
||||
* Legacy: ``get_collection(palace_path, collection_name, create=False)``
|
||||
— still used by callers not yet migrated.
|
||||
"""
|
||||
palace_ref, collection_name, create, options = _normalize_get_collection_args(args, kwargs)
|
||||
|
||||
palace_path = palace_ref.local_path
|
||||
if palace_path is None:
|
||||
raise PalaceNotFoundError("ChromaBackend requires PalaceRef.local_path")
|
||||
|
||||
if not create and not os.path.isdir(palace_path):
|
||||
raise FileNotFoundError(palace_path)
|
||||
raise PalaceNotFoundError(palace_path)
|
||||
|
||||
if create:
|
||||
os.makedirs(palace_path, exist_ok=True)
|
||||
@@ -124,29 +411,113 @@ class ChromaBackend:
|
||||
pass
|
||||
|
||||
client = self._client(palace_path)
|
||||
hnsw_space = "cosine"
|
||||
if options and isinstance(options, dict):
|
||||
hnsw_space = options.get("hnsw_space", hnsw_space)
|
||||
|
||||
if create:
|
||||
collection = client.get_or_create_collection(
|
||||
collection_name, metadata={"hnsw:space": "cosine"}
|
||||
collection_name, metadata={"hnsw:space": hnsw_space}
|
||||
)
|
||||
else:
|
||||
collection = client.get_collection(collection_name)
|
||||
return ChromaCollection(collection)
|
||||
|
||||
def get_or_create_collection(
|
||||
self, palace_path: str, collection_name: str
|
||||
) -> "ChromaCollection":
|
||||
"""Shorthand for get_collection(..., create=True)."""
|
||||
def close_palace(self, palace) -> None:
|
||||
"""Drop cached handles for ``palace``. Accepts ``PalaceRef`` or legacy path str."""
|
||||
path = palace.local_path if isinstance(palace, PalaceRef) else palace
|
||||
if path is None:
|
||||
return
|
||||
self._clients.pop(path, None)
|
||||
self._freshness.pop(path, None)
|
||||
|
||||
def close(self) -> None:
|
||||
self._clients.clear()
|
||||
self._freshness.clear()
|
||||
self._closed = True
|
||||
|
||||
def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus:
|
||||
if self._closed:
|
||||
return HealthStatus.unhealthy("backend closed")
|
||||
return HealthStatus.healthy()
|
||||
|
||||
@classmethod
|
||||
def detect(cls, path: str) -> bool:
|
||||
return os.path.isfile(os.path.join(path, "chroma.sqlite3"))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Legacy (pre-RFC 001) surface — retained while callers migrate.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_or_create_collection(self, palace_path: str, collection_name: str) -> ChromaCollection:
|
||||
"""Legacy shim for ``get_collection(..., create=True)`` by path string."""
|
||||
return self.get_collection(palace_path, collection_name, create=True)
|
||||
|
||||
def delete_collection(self, palace_path: str, collection_name: str) -> None:
|
||||
"""Delete *collection_name* from the palace at *palace_path*."""
|
||||
"""Delete ``collection_name`` from the palace at ``palace_path``."""
|
||||
self._client(palace_path).delete_collection(collection_name)
|
||||
|
||||
def create_collection(
|
||||
self, palace_path: str, collection_name: str, hnsw_space: str = "cosine"
|
||||
) -> "ChromaCollection":
|
||||
"""Create (not get-or-create) *collection_name* with cosine HNSW space."""
|
||||
) -> ChromaCollection:
|
||||
"""Create (not get-or-create) ``collection_name`` with the given HNSW space."""
|
||||
collection = self._client(palace_path).create_collection(
|
||||
collection_name, metadata={"hnsw:space": hnsw_space}
|
||||
)
|
||||
return ChromaCollection(collection)
|
||||
|
||||
|
||||
def _normalize_get_collection_args(args, kwargs):
|
||||
"""Unify legacy positional ``(palace_path, collection_name, create)`` calls
|
||||
with the new kwargs-only ``(palace=PalaceRef, collection_name=..., create=...)``.
|
||||
|
||||
Returns ``(PalaceRef, collection_name, create, options)``.
|
||||
"""
|
||||
# New-style: palace= kwarg with a PalaceRef (spec path).
|
||||
if "palace" in kwargs:
|
||||
palace_ref = kwargs.pop("palace")
|
||||
if not isinstance(palace_ref, PalaceRef):
|
||||
raise TypeError("palace= must be a PalaceRef instance")
|
||||
collection_name = kwargs.pop("collection_name")
|
||||
create = kwargs.pop("create", False)
|
||||
options = kwargs.pop("options", None)
|
||||
if kwargs:
|
||||
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
|
||||
if args:
|
||||
raise TypeError("positional args not allowed with palace= kwarg")
|
||||
return palace_ref, collection_name, create, options
|
||||
|
||||
# Legacy: first positional is a path string.
|
||||
if args:
|
||||
palace_path = args[0]
|
||||
rest = list(args[1:])
|
||||
collection_name = kwargs.pop("collection_name", None) or (rest.pop(0) if rest else None)
|
||||
if collection_name is None:
|
||||
raise TypeError("collection_name is required")
|
||||
create = kwargs.pop("create", False)
|
||||
if rest:
|
||||
create = rest.pop(0)
|
||||
if kwargs:
|
||||
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
|
||||
return (
|
||||
PalaceRef(id=palace_path, local_path=palace_path),
|
||||
collection_name,
|
||||
bool(create),
|
||||
None,
|
||||
)
|
||||
|
||||
# Legacy kwargs-only (palace_path=..., collection_name=..., create=...)
|
||||
if "palace_path" in kwargs:
|
||||
palace_path = kwargs.pop("palace_path")
|
||||
collection_name = kwargs.pop("collection_name")
|
||||
create = kwargs.pop("create", False)
|
||||
if kwargs:
|
||||
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
|
||||
return (
|
||||
PalaceRef(id=palace_path, local_path=palace_path),
|
||||
collection_name,
|
||||
bool(create),
|
||||
None,
|
||||
)
|
||||
|
||||
raise TypeError("get_collection requires palace= or a positional palace_path")
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Backend registry + entry-point discovery (RFC 001 §3).
|
||||
|
||||
Third-party backends ship as installable packages that declare a
|
||||
``mempalace.backends`` entry point::
|
||||
|
||||
# pyproject.toml of mempalace-postgres
|
||||
[project.entry-points."mempalace.backends"]
|
||||
postgres = "mempalace_postgres:PostgresBackend"
|
||||
|
||||
MemPalace discovers them at process start. In-tree tests and local development
|
||||
can register manually via :func:`register`. Explicit registration wins on
|
||||
name conflict (matches RFC 001 §3.2).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from importlib import metadata
|
||||
from threading import Lock
|
||||
from typing import Optional, Type
|
||||
|
||||
from .base import BaseBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ENTRY_POINT_GROUP = "mempalace.backends"
|
||||
|
||||
_registry: dict[str, Type[BaseBackend]] = {}
|
||||
_instances: dict[str, BaseBackend] = {}
|
||||
_explicit: set[str] = set()
|
||||
_discovered = False
|
||||
_lock = Lock()
|
||||
|
||||
|
||||
def register(name: str, backend_cls: Type[BaseBackend]) -> None:
|
||||
"""Register ``backend_cls`` under ``name``.
|
||||
|
||||
Explicit registration wins over entry-point discovery on conflict
|
||||
(RFC 001 §3.2).
|
||||
"""
|
||||
with _lock:
|
||||
_registry[name] = backend_cls
|
||||
_explicit.add(name)
|
||||
# Invalidate any cached instance so the new class is used on next get.
|
||||
_instances.pop(name, None)
|
||||
|
||||
|
||||
def unregister(name: str) -> None:
|
||||
"""Remove a backend registration (primarily for tests)."""
|
||||
with _lock:
|
||||
_registry.pop(name, None)
|
||||
_explicit.discard(name)
|
||||
_instances.pop(name, None)
|
||||
|
||||
|
||||
def _discover_entry_points() -> None:
|
||||
"""Load entry-point-declared backends once per process."""
|
||||
global _discovered
|
||||
if _discovered:
|
||||
return
|
||||
with _lock:
|
||||
if _discovered:
|
||||
return
|
||||
try:
|
||||
eps = metadata.entry_points()
|
||||
# Py ≥ 3.10 returns an EntryPoints object; older versions returned a dict.
|
||||
group = (
|
||||
eps.select(group=_ENTRY_POINT_GROUP)
|
||||
if hasattr(eps, "select")
|
||||
else eps.get(_ENTRY_POINT_GROUP, [])
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("entry-point discovery for %s failed", _ENTRY_POINT_GROUP)
|
||||
group = []
|
||||
for ep in group:
|
||||
if ep.name in _explicit:
|
||||
continue # explicit registration wins
|
||||
try:
|
||||
cls = ep.load()
|
||||
except Exception:
|
||||
logger.exception("failed to load backend entry point %r", ep.name)
|
||||
continue
|
||||
if not isinstance(cls, type) or not issubclass(cls, BaseBackend):
|
||||
logger.warning(
|
||||
"entry point %r did not resolve to a BaseBackend subclass (got %r)",
|
||||
ep.name,
|
||||
cls,
|
||||
)
|
||||
continue
|
||||
_registry.setdefault(ep.name, cls)
|
||||
_discovered = True
|
||||
|
||||
|
||||
def available_backends() -> list[str]:
|
||||
"""Return sorted list of all registered backend names."""
|
||||
_discover_entry_points()
|
||||
return sorted(_registry.keys())
|
||||
|
||||
|
||||
def get_backend_class(name: str) -> Type[BaseBackend]:
|
||||
"""Return the registered backend class for ``name``."""
|
||||
_discover_entry_points()
|
||||
try:
|
||||
return _registry[name]
|
||||
except KeyError as e:
|
||||
raise KeyError(f"unknown backend {name!r}; available: {available_backends()}") from e
|
||||
|
||||
|
||||
def get_backend(name: str) -> BaseBackend:
|
||||
"""Return a long-lived instance of the named backend.
|
||||
|
||||
Instances are cached per-name; repeated calls return the same object.
|
||||
Call :func:`reset_backends` in tests that need isolation.
|
||||
"""
|
||||
_discover_entry_points()
|
||||
with _lock:
|
||||
inst = _instances.get(name)
|
||||
if inst is not None:
|
||||
return inst
|
||||
cls = _registry.get(name)
|
||||
if cls is None:
|
||||
raise KeyError(f"unknown backend {name!r}; available: {sorted(_registry.keys())}")
|
||||
inst = cls()
|
||||
_instances[name] = inst
|
||||
return inst
|
||||
|
||||
|
||||
def reset_backends() -> None:
|
||||
"""Close and drop all cached backend instances (primarily for tests)."""
|
||||
with _lock:
|
||||
for inst in _instances.values():
|
||||
try:
|
||||
inst.close()
|
||||
except Exception:
|
||||
logger.exception("error closing backend during reset")
|
||||
_instances.clear()
|
||||
|
||||
|
||||
def resolve_backend_for_palace(
|
||||
*,
|
||||
explicit: Optional[str] = None,
|
||||
config_value: Optional[str] = None,
|
||||
env_value: Optional[str] = None,
|
||||
palace_path: Optional[str] = None,
|
||||
default: str = "chroma",
|
||||
) -> str:
|
||||
"""Resolve the backend name for a palace per RFC 001 §3.3 priority order.
|
||||
|
||||
1. Explicit kwarg / CLI flag
|
||||
2. Per-palace config value
|
||||
3. ``MEMPALACE_BACKEND`` env var
|
||||
4. Auto-detect from on-disk artifacts (migration/upgrade path only)
|
||||
5. Default (``chroma``)
|
||||
|
||||
Auto-detection is strictly a migration aid: it fires only when a local path
|
||||
is presented, no earlier rule has chosen a backend, AND the path already
|
||||
contains backend-identifiable artifacts. For new palaces, (5) wins.
|
||||
"""
|
||||
for candidate in (explicit, config_value, env_value):
|
||||
if candidate:
|
||||
return candidate
|
||||
|
||||
_discover_entry_points()
|
||||
if palace_path:
|
||||
for name, cls in _registry.items():
|
||||
try:
|
||||
if cls.detect(palace_path):
|
||||
return name
|
||||
except Exception:
|
||||
logger.exception("detect() raised on backend %r", name)
|
||||
continue
|
||||
return default
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _register_builtins() -> None:
|
||||
"""Register chroma as the in-tree default."""
|
||||
from .chroma import ChromaBackend
|
||||
|
||||
# Use setdefault semantics so a caller that pre-registered for tests wins.
|
||||
if "chroma" not in _registry:
|
||||
_registry["chroma"] = ChromaBackend
|
||||
|
||||
|
||||
_register_builtins()
|
||||
+12
-12
@@ -30,15 +30,16 @@ class SearchError(Exception):
|
||||
_TOKEN_RE = re.compile(r"\w{2,}", re.UNICODE)
|
||||
|
||||
|
||||
def _first_or_empty(results: dict, key: str) -> list:
|
||||
"""Return the first inner list of a ChromaDB query result, or [].
|
||||
def _first_or_empty(results, key: str) -> list:
|
||||
"""Return the first inner list of a query result field, or [].
|
||||
|
||||
ChromaDB returns shapes like ``{"documents": [["a", "b"]], ...}`` for a
|
||||
successful query, but ``{"documents": [], ...}`` (empty outer list) when
|
||||
the collection is empty or the filter excludes everything. Indexing
|
||||
``[0]`` blindly raises IndexError in that case (issue #195).
|
||||
Accepts both the typed :class:`QueryResult` (attribute access) and the
|
||||
pre-typed chroma dict shape; this polymorphism is retained so test mocks
|
||||
still work and callers mid-migration do not crash. Preserves the empty-
|
||||
collection semantics from issue #195: when no queries returned hits, the
|
||||
outer list may be empty and indexing ``[0]`` would raise.
|
||||
"""
|
||||
outer = results.get(key)
|
||||
outer = getattr(results, key, None) if not isinstance(results, dict) else results.get(key)
|
||||
if not outer:
|
||||
return []
|
||||
return outer[0] or []
|
||||
@@ -209,7 +210,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra
|
||||
return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None}
|
||||
|
||||
indexed_docs = []
|
||||
for doc, meta in zip(neighbors.get("documents") or [], neighbors.get("metadatas") or []):
|
||||
for doc, meta in zip(neighbors.documents, neighbors.metadatas):
|
||||
ci = meta.get("chunk_index")
|
||||
if isinstance(ci, int):
|
||||
indexed_docs.append((ci, doc))
|
||||
@@ -224,8 +225,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra
|
||||
total_drawers = None
|
||||
try:
|
||||
all_meta = drawers_col.get(where={"source_file": src}, include=["metadatas"])
|
||||
ids = all_meta.get("ids") or []
|
||||
total_drawers = len(ids) if ids else None
|
||||
total_drawers = len(all_meta.ids) if all_meta.ids else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -451,8 +451,8 @@ def search_memories(
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
docs = source_drawers.get("documents") or []
|
||||
metas_ = source_drawers.get("metadatas") or []
|
||||
docs = source_drawers.documents
|
||||
metas_ = source_drawers.metadatas
|
||||
if len(docs) <= 1:
|
||||
continue
|
||||
|
||||
|
||||
@@ -37,6 +37,9 @@ Repository = "https://github.com/MemPalace/mempalace"
|
||||
[project.scripts]
|
||||
mempalace = "mempalace.cli:main"
|
||||
|
||||
[project.entry-points."mempalace.backends"]
|
||||
chroma = "mempalace.backends.chroma:ChromaBackend"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"]
|
||||
spellcheck = ["autocorrect>=2.0"]
|
||||
|
||||
+138
-15
@@ -3,12 +3,34 @@ import sqlite3
|
||||
import chromadb
|
||||
import pytest
|
||||
|
||||
from mempalace.backends import (
|
||||
GetResult,
|
||||
PalaceRef,
|
||||
QueryResult,
|
||||
UnsupportedFilterError,
|
||||
available_backends,
|
||||
get_backend,
|
||||
)
|
||||
from mempalace.backends.chroma import ChromaBackend, ChromaCollection, _fix_blob_seq_ids
|
||||
|
||||
|
||||
class _FakeCollection:
|
||||
def __init__(self):
|
||||
"""Stand-in for a chromadb.Collection returning raw chroma-shaped dicts."""
|
||||
|
||||
def __init__(self, query_response=None, get_response=None, count_value=7):
|
||||
self.calls = []
|
||||
self._query_response = query_response or {
|
||||
"ids": [["a", "b"]],
|
||||
"documents": [["da", "db"]],
|
||||
"metadatas": [[{"wing": "w1"}, {"wing": "w2"}]],
|
||||
"distances": [[0.1, 0.2]],
|
||||
}
|
||||
self._get_response = get_response or {
|
||||
"ids": ["a"],
|
||||
"documents": ["da"],
|
||||
"metadatas": [{"wing": "w1"}],
|
||||
}
|
||||
self._count_value = count_value
|
||||
|
||||
def add(self, **kwargs):
|
||||
self.calls.append(("add", kwargs))
|
||||
@@ -16,41 +38,142 @@ class _FakeCollection:
|
||||
def upsert(self, **kwargs):
|
||||
self.calls.append(("upsert", kwargs))
|
||||
|
||||
def update(self, **kwargs):
|
||||
self.calls.append(("update", kwargs))
|
||||
|
||||
def query(self, **kwargs):
|
||||
self.calls.append(("query", kwargs))
|
||||
return {"kind": "query"}
|
||||
return self._query_response
|
||||
|
||||
def get(self, **kwargs):
|
||||
self.calls.append(("get", kwargs))
|
||||
return {"kind": "get"}
|
||||
return self._get_response
|
||||
|
||||
def delete(self, **kwargs):
|
||||
self.calls.append(("delete", kwargs))
|
||||
|
||||
def count(self):
|
||||
self.calls.append(("count", {}))
|
||||
return 7
|
||||
return self._count_value
|
||||
|
||||
|
||||
def test_chroma_collection_delegates_methods():
|
||||
def test_chroma_collection_returns_typed_query_result():
|
||||
fake = _FakeCollection()
|
||||
collection = ChromaCollection(fake)
|
||||
|
||||
result = collection.query(query_texts=["q"])
|
||||
|
||||
assert isinstance(result, QueryResult)
|
||||
assert result.ids == [["a", "b"]]
|
||||
assert result.documents == [["da", "db"]]
|
||||
assert result.metadatas == [[{"wing": "w1"}, {"wing": "w2"}]]
|
||||
assert result.distances == [[0.1, 0.2]]
|
||||
assert result.embeddings is None
|
||||
|
||||
|
||||
def test_chroma_collection_returns_typed_get_result():
|
||||
fake = _FakeCollection()
|
||||
collection = ChromaCollection(fake)
|
||||
|
||||
result = collection.get(where={"wing": "w1"})
|
||||
|
||||
assert isinstance(result, GetResult)
|
||||
assert result.ids == ["a"]
|
||||
assert result.documents == ["da"]
|
||||
assert result.metadatas == [{"wing": "w1"}]
|
||||
|
||||
|
||||
def test_query_result_empty_preserves_outer_dimension():
|
||||
empty = QueryResult.empty(num_queries=2)
|
||||
assert empty.ids == [[], []]
|
||||
assert empty.documents == [[], []]
|
||||
assert empty.distances == [[], []]
|
||||
assert empty.embeddings is None
|
||||
|
||||
|
||||
def test_typed_results_support_dict_compat_access():
|
||||
"""Transitional compat shim per base.py — retained until callers migrate to attrs."""
|
||||
result = GetResult(ids=["a"], documents=["da"], metadatas=[{"w": 1}])
|
||||
assert result["ids"] == ["a"]
|
||||
assert result.get("documents") == ["da"]
|
||||
assert result.get("missing", "default") == "default"
|
||||
assert "ids" in result
|
||||
assert "missing" not in result
|
||||
|
||||
|
||||
def test_chroma_collection_query_empty_result_preserves_outer_shape():
|
||||
fake = _FakeCollection(
|
||||
query_response={"ids": [], "documents": [], "metadatas": [], "distances": []}
|
||||
)
|
||||
collection = ChromaCollection(fake)
|
||||
|
||||
result = collection.query(query_texts=["q1", "q2"])
|
||||
assert result.ids == [[], []]
|
||||
assert result.documents == [[], []]
|
||||
assert result.distances == [[], []]
|
||||
|
||||
|
||||
def test_chroma_collection_rejects_unknown_where_operator():
|
||||
fake = _FakeCollection()
|
||||
collection = ChromaCollection(fake)
|
||||
|
||||
with pytest.raises(UnsupportedFilterError):
|
||||
collection.query(query_texts=["q"], where={"$regex": "foo"})
|
||||
|
||||
|
||||
def test_chroma_collection_delegates_writes():
|
||||
fake = _FakeCollection()
|
||||
collection = ChromaCollection(fake)
|
||||
|
||||
collection.add(documents=["d"], ids=["1"], metadatas=[{"wing": "w"}])
|
||||
collection.upsert(documents=["u"], ids=["2"], metadatas=[{"room": "r"}])
|
||||
assert collection.query(query_texts=["q"]) == {"kind": "query"}
|
||||
assert collection.get(where={"wing": "w"}) == {"kind": "get"}
|
||||
collection.delete(ids=["1"])
|
||||
assert collection.count() == 7
|
||||
|
||||
assert fake.calls == [
|
||||
("add", {"documents": ["d"], "ids": ["1"], "metadatas": [{"wing": "w"}]}),
|
||||
("upsert", {"documents": ["u"], "ids": ["2"], "metadatas": [{"room": "r"}]}),
|
||||
("query", {"query_texts": ["q"]}),
|
||||
("get", {"where": {"wing": "w"}}),
|
||||
("delete", {"ids": ["1"]}),
|
||||
("count", {}),
|
||||
]
|
||||
kinds = [call[0] for call in fake.calls]
|
||||
assert kinds == ["add", "upsert", "delete", "count"]
|
||||
|
||||
|
||||
def test_registry_exposes_chroma_by_default():
|
||||
names = available_backends()
|
||||
assert "chroma" in names
|
||||
assert isinstance(get_backend("chroma"), ChromaBackend)
|
||||
|
||||
|
||||
def test_registry_unknown_backend_raises():
|
||||
with pytest.raises(KeyError):
|
||||
get_backend("no-such-backend-exists")
|
||||
|
||||
|
||||
def test_resolve_backend_priority_order(tmp_path):
|
||||
from mempalace.backends import resolve_backend_for_palace
|
||||
|
||||
# explicit kwarg wins over everything
|
||||
assert resolve_backend_for_palace(explicit="pg", config_value="lance") == "pg"
|
||||
# config value wins over env / default
|
||||
assert resolve_backend_for_palace(config_value="lance", env_value="qdrant") == "lance"
|
||||
# env wins over default
|
||||
assert resolve_backend_for_palace(env_value="qdrant", default="chroma") == "qdrant"
|
||||
# falls back to default
|
||||
assert resolve_backend_for_palace() == "chroma"
|
||||
|
||||
|
||||
def test_chroma_detect_matches_palace_with_chroma_sqlite(tmp_path):
|
||||
(tmp_path / "chroma.sqlite3").write_bytes(b"")
|
||||
assert ChromaBackend.detect(str(tmp_path)) is True
|
||||
assert ChromaBackend.detect(str(tmp_path.parent)) is False
|
||||
|
||||
|
||||
def test_chroma_backend_accepts_palace_ref_kwarg(tmp_path):
|
||||
palace_path = tmp_path / "palace"
|
||||
backend = ChromaBackend()
|
||||
collection = backend.get_collection(
|
||||
palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)),
|
||||
collection_name="mempalace_drawers",
|
||||
create=True,
|
||||
)
|
||||
assert palace_path.is_dir()
|
||||
assert isinstance(collection, ChromaCollection)
|
||||
|
||||
|
||||
def test_chroma_backend_create_false_raises_without_creating_directory(tmp_path):
|
||||
|
||||
Reference in New Issue
Block a user