Merge pull request #995 from MemPalace/refactor/rfc-001-cleanup
refactor(backends): RFC 001 §10 cleanup — typed results, PalaceRef, registry
This commit is contained in:
@@ -1,6 +1,64 @@
|
|||||||
"""Storage backend implementations for MemPalace."""
|
"""Storage backend implementations for MemPalace (RFC 001).
|
||||||
|
|
||||||
from .base import BaseCollection
|
Public surface:
|
||||||
|
|
||||||
|
* :class:`BaseCollection` — per-collection read/write contract.
|
||||||
|
* :class:`BaseBackend` — per-palace factory contract.
|
||||||
|
* :class:`PalaceRef` — value object identifying a palace for a backend.
|
||||||
|
* :class:`QueryResult` / :class:`GetResult` — typed read returns.
|
||||||
|
* Error classes: :class:`PalaceNotFoundError`, :class:`BackendClosedError`,
|
||||||
|
:class:`UnsupportedFilterError`, :class:`DimensionMismatchError`,
|
||||||
|
:class:`EmbedderIdentityMismatchError`.
|
||||||
|
* Registry: :func:`get_backend`, :func:`register`, :func:`available_backends`,
|
||||||
|
:func:`resolve_backend_for_palace`.
|
||||||
|
* In-tree Chroma default: :class:`ChromaBackend`, :class:`ChromaCollection`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
BackendClosedError,
|
||||||
|
BackendError,
|
||||||
|
BaseBackend,
|
||||||
|
BaseCollection,
|
||||||
|
DimensionMismatchError,
|
||||||
|
EmbedderIdentityMismatchError,
|
||||||
|
GetResult,
|
||||||
|
HealthStatus,
|
||||||
|
PalaceNotFoundError,
|
||||||
|
PalaceRef,
|
||||||
|
QueryResult,
|
||||||
|
UnsupportedFilterError,
|
||||||
|
)
|
||||||
from .chroma import ChromaBackend, ChromaCollection
|
from .chroma import ChromaBackend, ChromaCollection
|
||||||
|
from .registry import (
|
||||||
|
available_backends,
|
||||||
|
get_backend,
|
||||||
|
get_backend_class,
|
||||||
|
register,
|
||||||
|
reset_backends,
|
||||||
|
resolve_backend_for_palace,
|
||||||
|
unregister,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["BaseCollection", "ChromaBackend", "ChromaCollection"]
|
__all__ = [
|
||||||
|
"BackendClosedError",
|
||||||
|
"BackendError",
|
||||||
|
"BaseBackend",
|
||||||
|
"BaseCollection",
|
||||||
|
"ChromaBackend",
|
||||||
|
"ChromaCollection",
|
||||||
|
"DimensionMismatchError",
|
||||||
|
"EmbedderIdentityMismatchError",
|
||||||
|
"GetResult",
|
||||||
|
"HealthStatus",
|
||||||
|
"PalaceNotFoundError",
|
||||||
|
"PalaceRef",
|
||||||
|
"QueryResult",
|
||||||
|
"UnsupportedFilterError",
|
||||||
|
"available_backends",
|
||||||
|
"get_backend",
|
||||||
|
"get_backend_class",
|
||||||
|
"register",
|
||||||
|
"reset_backends",
|
||||||
|
"resolve_backend_for_palace",
|
||||||
|
"unregister",
|
||||||
|
]
|
||||||
|
|||||||
+348
-27
@@ -1,49 +1,370 @@
|
|||||||
"""Abstract collection interface for MemPalace storage backends."""
|
"""Storage backend contract for MemPalace (RFC 001).
|
||||||
|
|
||||||
|
This module defines the surface every storage backend must implement:
|
||||||
|
|
||||||
|
* ``BaseCollection`` — the per-collection read/write interface, kwargs-only.
|
||||||
|
* ``BaseBackend`` — the per-palace factory, addressed by ``PalaceRef``.
|
||||||
|
* ``QueryResult`` / ``GetResult`` — typed result dataclasses that replace the
|
||||||
|
Chroma dict shape as the canonical return type.
|
||||||
|
* Error classes + ``HealthStatus`` — uniform across backends.
|
||||||
|
|
||||||
|
This is the v1 cleanup from RFC 001 §10: full typed results, ``PalaceRef``,
|
||||||
|
registry-ready ABC. Embedder injection, maintenance hooks, and the full
|
||||||
|
conformance suite land in follow-up PRs.
|
||||||
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional
|
from dataclasses import dataclass
|
||||||
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Errors
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class BackendError(Exception):
|
||||||
|
"""Base class for every storage-backend error raised by core."""
|
||||||
|
|
||||||
|
|
||||||
|
class PalaceNotFoundError(BackendError, FileNotFoundError):
|
||||||
|
"""Raised when ``get_collection(create=False)`` is called on a missing palace.
|
||||||
|
|
||||||
|
Subclass of ``FileNotFoundError`` so legacy callers that catch the latter
|
||||||
|
(pre-#413 seam) keep working unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BackendClosedError(BackendError):
|
||||||
|
"""Raised when a backend method is called after ``close()``."""
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedFilterError(BackendError):
|
||||||
|
"""Raised when a where-clause uses an operator the backend does not implement.
|
||||||
|
|
||||||
|
Silent dropping of unknown operators is forbidden by spec (RFC 001 §1.4).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DimensionMismatchError(BackendError):
|
||||||
|
"""Raised when the embedding dimension on write does not match the collection."""
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedderIdentityMismatchError(BackendError):
|
||||||
|
"""Raised when the stored embedder model name differs from the current one."""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Value objects
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PalaceRef:
|
||||||
|
"""A handle to a palace, consumed by backends.
|
||||||
|
|
||||||
|
``id`` is always present and is the key backends use to cache handles.
|
||||||
|
``local_path`` is populated for filesystem-rooted palaces.
|
||||||
|
``namespace`` is used by server-mode backends for tenant / prefix routing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
local_path: Optional[str] = None
|
||||||
|
namespace: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class HealthStatus:
|
||||||
|
ok: bool
|
||||||
|
detail: str = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def healthy(cls, detail: str = "") -> "HealthStatus":
|
||||||
|
return cls(ok=True, detail=detail)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unhealthy(cls, detail: str) -> "HealthStatus":
|
||||||
|
return cls(ok=False, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
|
_TYPED_RESULT_FIELDS = ("ids", "documents", "metadatas", "distances", "embeddings")
|
||||||
|
|
||||||
|
|
||||||
|
class _DictCompatMixin:
|
||||||
|
"""Transitional dict-protocol access for typed results.
|
||||||
|
|
||||||
|
RFC 001 §1.3 spec is attribute access (``result.ids``). The ``result["ids"]``
|
||||||
|
and ``result.get("ids")`` forms are retained as a migration shim for callers
|
||||||
|
that predate the typed interface and are scheduled for removal in a follow-
|
||||||
|
up cleanup. New code MUST use attribute access.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
if key in _TYPED_RESULT_FIELDS:
|
||||||
|
return getattr(self, key)
|
||||||
|
raise KeyError(key)
|
||||||
|
|
||||||
|
def get(self, key: str, default=None):
|
||||||
|
if key in _TYPED_RESULT_FIELDS:
|
||||||
|
val = getattr(self, key, default)
|
||||||
|
return default if val is None else val
|
||||||
|
return default
|
||||||
|
|
||||||
|
def __contains__(self, key: object) -> bool:
|
||||||
|
return key in _TYPED_RESULT_FIELDS and getattr(self, key, None) is not None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class QueryResult(_DictCompatMixin):
|
||||||
|
"""Typed return from ``BaseCollection.query``.
|
||||||
|
|
||||||
|
Outer list dimension = number of query vectors / texts.
|
||||||
|
Inner list dimension = hits per query (may be zero).
|
||||||
|
|
||||||
|
Fields not in ``include=`` at the call site are populated with empty lists
|
||||||
|
of the correct outer shape (never ``None``), except ``embeddings`` which
|
||||||
|
is ``None`` when not requested.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ids: list[list[str]]
|
||||||
|
documents: list[list[str]]
|
||||||
|
metadatas: list[list[dict]]
|
||||||
|
distances: list[list[float]]
|
||||||
|
embeddings: Optional[list[list[list[float]]]] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty(cls, num_queries: int = 1, embeddings_requested: bool = False) -> "QueryResult":
|
||||||
|
"""Construct an all-empty result preserving outer dimension.
|
||||||
|
|
||||||
|
When ``embeddings_requested`` is True, ``embeddings`` preserves the outer
|
||||||
|
query dimension with empty hit lists (matching the spec's rule that fields
|
||||||
|
requested via ``include=`` carry the outer shape even when empty). When
|
||||||
|
False, ``embeddings`` stays ``None`` to signal the field was not requested.
|
||||||
|
"""
|
||||||
|
empty_outer = [[] for _ in range(num_queries)]
|
||||||
|
return cls(
|
||||||
|
ids=[[] for _ in range(num_queries)],
|
||||||
|
documents=[[] for _ in range(num_queries)],
|
||||||
|
metadatas=[[] for _ in range(num_queries)],
|
||||||
|
distances=[[] for _ in range(num_queries)],
|
||||||
|
embeddings=empty_outer if embeddings_requested else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GetResult(_DictCompatMixin):
|
||||||
|
"""Typed return from ``BaseCollection.get``."""
|
||||||
|
|
||||||
|
ids: list[str]
|
||||||
|
documents: list[str]
|
||||||
|
metadatas: list[dict]
|
||||||
|
embeddings: Optional[list[list[float]]] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty(cls) -> "GetResult":
|
||||||
|
return cls(ids=[], documents=[], metadatas=[], embeddings=None)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Collection contract
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class BaseCollection(ABC):
|
class BaseCollection(ABC):
|
||||||
"""Smallest collection contract the rest of MemPalace relies on."""
|
"""Per-collection read/write surface every backend must implement."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
documents: List[str],
|
documents: list[str],
|
||||||
ids: List[str],
|
ids: list[str],
|
||||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
metadatas: Optional[list[dict]] = None,
|
||||||
) -> None:
|
embeddings: Optional[list[list[float]]] = None,
|
||||||
raise NotImplementedError
|
) -> None: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def upsert(
|
def upsert(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
documents: List[str],
|
documents: list[str],
|
||||||
ids: List[str],
|
ids: list[str],
|
||||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
metadatas: Optional[list[dict]] = None,
|
||||||
|
embeddings: Optional[list[list[float]]] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def query(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
query_texts: Optional[list[str]] = None,
|
||||||
|
query_embeddings: Optional[list[list[float]]] = None,
|
||||||
|
n_results: int = 10,
|
||||||
|
where: Optional[dict] = None,
|
||||||
|
where_document: Optional[dict] = None,
|
||||||
|
include: Optional[list[str]] = None,
|
||||||
|
) -> QueryResult: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ids: Optional[list[str]] = None,
|
||||||
|
where: Optional[dict] = None,
|
||||||
|
where_document: Optional[dict] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: Optional[int] = None,
|
||||||
|
include: Optional[list[str]] = None,
|
||||||
|
) -> GetResult: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ids: Optional[list[str]] = None,
|
||||||
|
where: Optional[dict] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def count(self) -> int: ...
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Optional methods with ABC defaults (spec §1.2)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def estimated_count(self) -> int:
|
||||||
|
return self.count()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def health(self) -> HealthStatus:
|
||||||
|
return HealthStatus.healthy()
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ids: list[str],
|
||||||
|
documents: Optional[list[str]] = None,
|
||||||
|
metadatas: Optional[list[dict]] = None,
|
||||||
|
embeddings: Optional[list[list[float]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
"""Default non-atomic update: get + merge + upsert.
|
||||||
|
|
||||||
|
Backends advertising ``supports_update`` MUST override with an atomic
|
||||||
|
single-round-trip implementation.
|
||||||
|
"""
|
||||||
|
if documents is None and metadatas is None and embeddings is None:
|
||||||
|
raise ValueError("update requires at least one of documents, metadatas, embeddings")
|
||||||
|
|
||||||
|
n = len(ids)
|
||||||
|
for label, value in (
|
||||||
|
("documents", documents),
|
||||||
|
("metadatas", metadatas),
|
||||||
|
("embeddings", embeddings),
|
||||||
|
):
|
||||||
|
if value is not None and len(value) != n:
|
||||||
|
raise ValueError(f"{label} length {len(value)} does not match ids length {n}")
|
||||||
|
|
||||||
|
existing = self.get(ids=ids, include=["documents", "metadatas"])
|
||||||
|
by_id = {
|
||||||
|
rid: (existing.documents[i], existing.metadatas[i])
|
||||||
|
for i, rid in enumerate(existing.ids)
|
||||||
|
}
|
||||||
|
merged_docs: list[str] = []
|
||||||
|
merged_metas: list[dict] = []
|
||||||
|
for i, rid in enumerate(ids):
|
||||||
|
prev_doc, prev_meta = by_id.get(rid, ("", {}))
|
||||||
|
merged_docs.append(documents[i] if documents is not None else prev_doc)
|
||||||
|
new_meta = dict(prev_meta or {})
|
||||||
|
if metadatas is not None:
|
||||||
|
new_meta.update(metadatas[i] or {})
|
||||||
|
merged_metas.append(new_meta)
|
||||||
|
self.upsert(
|
||||||
|
documents=merged_docs,
|
||||||
|
ids=list(ids),
|
||||||
|
metadatas=merged_metas,
|
||||||
|
embeddings=embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Backend contract
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class BaseBackend(ABC):
|
||||||
|
"""Long-lived factory serving many palaces (RFC 001 §2).
|
||||||
|
|
||||||
|
Instances are lightweight on construction — no I/O, no network. All
|
||||||
|
connection work is deferred to ``get_collection``. Instances are thread-
|
||||||
|
safe for concurrent ``get_collection`` calls across different palaces.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: ClassVar[str]
|
||||||
|
spec_version: ClassVar[str] = "1.0"
|
||||||
|
capabilities: ClassVar[frozenset[str]] = frozenset()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, **kwargs: Any) -> None:
|
def get_collection(
|
||||||
"""Update existing records. Must raise if any ID is missing."""
|
self,
|
||||||
raise NotImplementedError
|
*,
|
||||||
|
palace: PalaceRef,
|
||||||
|
collection_name: str,
|
||||||
|
create: bool = False,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
) -> BaseCollection: ...
|
||||||
|
|
||||||
@abstractmethod
|
def close_palace(self, palace: PalaceRef) -> None:
|
||||||
def query(self, **kwargs: Any) -> Dict[str, Any]:
|
"""Evict cached handles for a single palace. Default: no-op."""
|
||||||
raise NotImplementedError
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
def close(self) -> None:
|
||||||
def get(self, **kwargs: Any) -> Dict[str, Any]:
|
"""Shut down the entire backend. Default: no-op."""
|
||||||
raise NotImplementedError
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus:
|
||||||
def delete(self, **kwargs: Any) -> None:
|
return HealthStatus.healthy()
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
# Optional detection hint used by selection priority (RFC 001 §3.3 (4)):
|
||||||
def count(self) -> int:
|
@classmethod
|
||||||
raise NotImplementedError
|
def detect(cls, path: str) -> bool: # pragma: no cover - default hook
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Adapter utilities
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
# Keys the Chroma ``include=`` parameter accepts.
|
||||||
|
_VALID_INCLUDE_KEYS = frozenset({"documents", "metadatas", "distances", "embeddings"})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _IncludeSpec:
|
||||||
|
"""Resolve an ``include=`` parameter with spec-mandated defaults."""
|
||||||
|
|
||||||
|
documents: bool = True
|
||||||
|
metadatas: bool = True
|
||||||
|
distances: bool = True # only meaningful for query
|
||||||
|
embeddings: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resolve(
|
||||||
|
cls, include: Optional[list[str]], *, default_distances: bool = True
|
||||||
|
) -> "_IncludeSpec":
|
||||||
|
if include is None:
|
||||||
|
return cls(
|
||||||
|
documents=True,
|
||||||
|
metadatas=True,
|
||||||
|
distances=default_distances,
|
||||||
|
embeddings=False,
|
||||||
|
)
|
||||||
|
keys = {k for k in include if k in _VALID_INCLUDE_KEYS}
|
||||||
|
return cls(
|
||||||
|
documents="documents" in keys,
|
||||||
|
metadatas="metadatas" in keys,
|
||||||
|
distances="distances" in keys,
|
||||||
|
embeddings="embeddings" in keys,
|
||||||
|
)
|
||||||
|
|||||||
+443
-37
@@ -1,17 +1,54 @@
|
|||||||
"""ChromaDB-backed MemPalace collection adapter."""
|
"""ChromaDB-backed MemPalace storage backend (RFC 001 reference implementation)."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
|
|
||||||
from .base import BaseCollection
|
from .base import (
|
||||||
|
BaseBackend,
|
||||||
|
BaseCollection,
|
||||||
|
GetResult,
|
||||||
|
HealthStatus,
|
||||||
|
PalaceNotFoundError,
|
||||||
|
PalaceRef,
|
||||||
|
QueryResult,
|
||||||
|
UnsupportedFilterError,
|
||||||
|
_IncludeSpec,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _fix_blob_seq_ids(palace_path: str):
|
_REQUIRED_OPERATORS = frozenset({"$eq", "$ne", "$in", "$nin", "$and", "$or", "$contains"})
|
||||||
|
_OPTIONAL_OPERATORS = frozenset({"$gt", "$gte", "$lt", "$lte"})
|
||||||
|
_SUPPORTED_OPERATORS = _REQUIRED_OPERATORS | _OPTIONAL_OPERATORS
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_where(where: Optional[dict]) -> None:
|
||||||
|
"""Scan a where-clause for unknown operators and raise ``UnsupportedFilterError``.
|
||||||
|
|
||||||
|
Spec (RFC 001 §1.4): silent dropping of unknown operators is forbidden.
|
||||||
|
"""
|
||||||
|
if not where:
|
||||||
|
return
|
||||||
|
stack = [where]
|
||||||
|
while stack:
|
||||||
|
node = stack.pop()
|
||||||
|
if not isinstance(node, dict):
|
||||||
|
continue
|
||||||
|
for k, v in node.items():
|
||||||
|
if k.startswith("$") and k not in _SUPPORTED_OPERATORS:
|
||||||
|
raise UnsupportedFilterError(f"operator {k!r} not supported by chroma backend")
|
||||||
|
if isinstance(v, dict):
|
||||||
|
stack.append(v)
|
||||||
|
elif isinstance(v, list):
|
||||||
|
stack.extend(x for x in v if isinstance(x, dict))
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_blob_seq_ids(palace_path: str) -> None:
|
||||||
"""Fix ChromaDB 0.6.x -> 1.5.x migration bug: BLOB seq_ids -> INTEGER.
|
"""Fix ChromaDB 0.6.x -> 1.5.x migration bug: BLOB seq_ids -> INTEGER.
|
||||||
|
|
||||||
ChromaDB 0.6.x stored seq_id as big-endian 8-byte BLOBs. ChromaDB 1.5.x
|
ChromaDB 0.6.x stored seq_id as big-endian 8-byte BLOBs. ChromaDB 1.5.x
|
||||||
@@ -43,62 +80,328 @@ def _fix_blob_seq_ids(palace_path: str):
|
|||||||
logger.exception("Could not fix BLOB seq_ids in %s", db_path)
|
logger.exception("Could not fix BLOB seq_ids in %s", db_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Collection adapter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _as_list(v: Any) -> list:
|
||||||
|
"""Coerce possibly-None scalar-or-list into a list (defensive for chroma nulls)."""
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, list):
|
||||||
|
return v
|
||||||
|
return [v]
|
||||||
|
|
||||||
|
|
||||||
class ChromaCollection(BaseCollection):
|
class ChromaCollection(BaseCollection):
|
||||||
"""Thin adapter over a ChromaDB collection."""
|
"""Thin adapter translating ChromaDB dict returns into typed results."""
|
||||||
|
|
||||||
def __init__(self, collection):
|
def __init__(self, collection):
|
||||||
self._collection = collection
|
self._collection = collection
|
||||||
|
|
||||||
def add(self, *, documents, ids, metadatas=None):
|
# ------------------------------------------------------------------
|
||||||
self._collection.add(documents=documents, ids=ids, metadatas=metadatas)
|
# Writes
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def upsert(self, *, documents, ids, metadatas=None):
|
def add(self, *, documents, ids, metadatas=None, embeddings=None):
|
||||||
self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas)
|
kwargs: dict[str, Any] = {"documents": documents, "ids": ids}
|
||||||
|
if metadatas is not None:
|
||||||
|
kwargs["metadatas"] = metadatas
|
||||||
|
if embeddings is not None:
|
||||||
|
kwargs["embeddings"] = embeddings
|
||||||
|
self._collection.add(**kwargs)
|
||||||
|
|
||||||
def update(self, **kwargs):
|
def upsert(self, *, documents, ids, metadatas=None, embeddings=None):
|
||||||
|
kwargs: dict[str, Any] = {"documents": documents, "ids": ids}
|
||||||
|
if metadatas is not None:
|
||||||
|
kwargs["metadatas"] = metadatas
|
||||||
|
if embeddings is not None:
|
||||||
|
kwargs["embeddings"] = embeddings
|
||||||
|
self._collection.upsert(**kwargs)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ids,
|
||||||
|
documents=None,
|
||||||
|
metadatas=None,
|
||||||
|
embeddings=None,
|
||||||
|
):
|
||||||
|
if documents is None and metadatas is None and embeddings is None:
|
||||||
|
raise ValueError("update requires at least one of documents, metadatas, embeddings")
|
||||||
|
kwargs: dict[str, Any] = {"ids": ids}
|
||||||
|
if documents is not None:
|
||||||
|
kwargs["documents"] = documents
|
||||||
|
if metadatas is not None:
|
||||||
|
kwargs["metadatas"] = metadatas
|
||||||
|
if embeddings is not None:
|
||||||
|
kwargs["embeddings"] = embeddings
|
||||||
self._collection.update(**kwargs)
|
self._collection.update(**kwargs)
|
||||||
|
|
||||||
def query(self, **kwargs):
|
# ------------------------------------------------------------------
|
||||||
return self._collection.query(**kwargs)
|
# Reads
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def get(self, **kwargs):
|
def query(
|
||||||
return self._collection.get(**kwargs)
|
self,
|
||||||
|
*,
|
||||||
|
query_texts=None,
|
||||||
|
query_embeddings=None,
|
||||||
|
n_results=10,
|
||||||
|
where=None,
|
||||||
|
where_document=None,
|
||||||
|
include=None,
|
||||||
|
) -> QueryResult:
|
||||||
|
_validate_where(where)
|
||||||
|
_validate_where(where_document)
|
||||||
|
|
||||||
def delete(self, **kwargs):
|
if (query_texts is None) == (query_embeddings is None):
|
||||||
|
raise ValueError("query requires exactly one of query_texts or query_embeddings")
|
||||||
|
chosen = query_texts if query_texts is not None else query_embeddings
|
||||||
|
if not chosen:
|
||||||
|
raise ValueError("query input must be a non-empty list")
|
||||||
|
|
||||||
|
spec = _IncludeSpec.resolve(include, default_distances=True)
|
||||||
|
chroma_include: list[str] = []
|
||||||
|
if spec.documents:
|
||||||
|
chroma_include.append("documents")
|
||||||
|
if spec.metadatas:
|
||||||
|
chroma_include.append("metadatas")
|
||||||
|
if spec.distances:
|
||||||
|
chroma_include.append("distances")
|
||||||
|
if spec.embeddings:
|
||||||
|
chroma_include.append("embeddings")
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"n_results": n_results,
|
||||||
|
"include": chroma_include,
|
||||||
|
}
|
||||||
|
if query_texts is not None:
|
||||||
|
kwargs["query_texts"] = query_texts
|
||||||
|
if query_embeddings is not None:
|
||||||
|
kwargs["query_embeddings"] = query_embeddings
|
||||||
|
if where is not None:
|
||||||
|
kwargs["where"] = where
|
||||||
|
if where_document is not None:
|
||||||
|
kwargs["where_document"] = where_document
|
||||||
|
|
||||||
|
raw = self._collection.query(**kwargs)
|
||||||
|
|
||||||
|
num_queries = (
|
||||||
|
len(query_texts)
|
||||||
|
if query_texts is not None
|
||||||
|
else (len(query_embeddings) if query_embeddings is not None else 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
ids = raw.get("ids") or []
|
||||||
|
if not ids:
|
||||||
|
return QueryResult.empty(
|
||||||
|
num_queries=num_queries,
|
||||||
|
embeddings_requested=spec.embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = raw.get("documents") or [[] for _ in ids]
|
||||||
|
metadatas = raw.get("metadatas") or [[] for _ in ids]
|
||||||
|
distances = raw.get("distances") or [[] for _ in ids]
|
||||||
|
embeddings_raw = raw.get("embeddings") if spec.embeddings else None
|
||||||
|
|
||||||
|
def _none_list_to_empty(outer):
|
||||||
|
return [(inner or []) for inner in outer]
|
||||||
|
|
||||||
|
return QueryResult(
|
||||||
|
ids=_none_list_to_empty(ids),
|
||||||
|
documents=_none_list_to_empty(documents),
|
||||||
|
metadatas=_none_list_to_empty(metadatas),
|
||||||
|
distances=_none_list_to_empty(distances),
|
||||||
|
embeddings=(
|
||||||
|
[list(inner) for inner in embeddings_raw]
|
||||||
|
if spec.embeddings and embeddings_raw is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ids=None,
|
||||||
|
where=None,
|
||||||
|
where_document=None,
|
||||||
|
limit=None,
|
||||||
|
offset=None,
|
||||||
|
include=None,
|
||||||
|
) -> GetResult:
|
||||||
|
_validate_where(where)
|
||||||
|
_validate_where(where_document)
|
||||||
|
|
||||||
|
spec = _IncludeSpec.resolve(include, default_distances=False)
|
||||||
|
chroma_include: list[str] = []
|
||||||
|
if spec.documents:
|
||||||
|
chroma_include.append("documents")
|
||||||
|
if spec.metadatas:
|
||||||
|
chroma_include.append("metadatas")
|
||||||
|
if spec.embeddings:
|
||||||
|
chroma_include.append("embeddings")
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {"include": chroma_include}
|
||||||
|
if ids is not None:
|
||||||
|
kwargs["ids"] = ids
|
||||||
|
if where is not None:
|
||||||
|
kwargs["where"] = where
|
||||||
|
if where_document is not None:
|
||||||
|
kwargs["where_document"] = where_document
|
||||||
|
if limit is not None:
|
||||||
|
kwargs["limit"] = limit
|
||||||
|
if offset is not None:
|
||||||
|
kwargs["offset"] = offset
|
||||||
|
|
||||||
|
raw = self._collection.get(**kwargs)
|
||||||
|
out_ids = list(raw.get("ids") or [])
|
||||||
|
out_docs = list(raw.get("documents") or []) if spec.documents else []
|
||||||
|
out_metas = list(raw.get("metadatas") or []) if spec.metadatas else []
|
||||||
|
out_embeds = raw.get("embeddings") if spec.embeddings else None
|
||||||
|
|
||||||
|
# Pad doc/meta lists to match ids so downstream zipping is safe.
|
||||||
|
if spec.documents and len(out_docs) < len(out_ids):
|
||||||
|
out_docs = out_docs + [""] * (len(out_ids) - len(out_docs))
|
||||||
|
if spec.metadatas and len(out_metas) < len(out_ids):
|
||||||
|
out_metas = out_metas + [{}] * (len(out_ids) - len(out_metas))
|
||||||
|
|
||||||
|
return GetResult(
|
||||||
|
ids=out_ids,
|
||||||
|
documents=out_docs,
|
||||||
|
metadatas=out_metas,
|
||||||
|
embeddings=[list(v) for v in out_embeds] if out_embeds is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(self, *, ids=None, where=None):
|
||||||
|
_validate_where(where)
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if ids is not None:
|
||||||
|
kwargs["ids"] = ids
|
||||||
|
if where is not None:
|
||||||
|
kwargs["where"] = where
|
||||||
self._collection.delete(**kwargs)
|
self._collection.delete(**kwargs)
|
||||||
|
|
||||||
def count(self):
|
def count(self):
|
||||||
return self._collection.count()
|
return self._collection.count()
|
||||||
|
|
||||||
|
|
||||||
class ChromaBackend:
|
# ---------------------------------------------------------------------------
|
||||||
"""Factory for MemPalace's default ChromaDB backend."""
|
# Backend
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaBackend(BaseBackend):
|
||||||
|
"""MemPalace's default ChromaDB backend.
|
||||||
|
|
||||||
|
Maintains two caches:
|
||||||
|
|
||||||
|
* ``self._clients`` — ``palace_path -> PersistentClient`` for callers
|
||||||
|
using the ``PalaceRef`` / :meth:`get_collection` path.
|
||||||
|
* An inode+mtime freshness check absorbed from ``mcp_server._get_client``
|
||||||
|
(merged via #757) ensuring a palace rebuild on disk is detected on the
|
||||||
|
next :meth:`get_collection` call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "chroma"
|
||||||
|
capabilities = frozenset(
|
||||||
|
{
|
||||||
|
"supports_embeddings_in",
|
||||||
|
"supports_embeddings_passthrough",
|
||||||
|
"supports_embeddings_out",
|
||||||
|
"supports_metadata_filters",
|
||||||
|
"supports_contains_fast",
|
||||||
|
"local_mode",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Per-instance client cache: palace_path -> chromadb.PersistentClient
|
# palace_path -> PersistentClient
|
||||||
self._clients: dict = {}
|
self._clients: dict[str, Any] = {}
|
||||||
|
# palace_path -> (inode, mtime) of chroma.sqlite3 at cache time.
|
||||||
|
self._freshness: dict[str, tuple[int, float]] = {}
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _db_stat(palace_path: str) -> tuple[int, float]:
|
||||||
|
"""Return ``(inode, mtime)`` of ``chroma.sqlite3`` or ``(0, 0.0)`` if absent."""
|
||||||
|
db_path = os.path.join(palace_path, "chroma.sqlite3")
|
||||||
|
try:
|
||||||
|
st = os.stat(db_path)
|
||||||
|
return (st.st_ino, st.st_mtime)
|
||||||
|
except OSError:
|
||||||
|
return (0, 0.0)
|
||||||
|
|
||||||
def _client(self, palace_path: str):
|
def _client(self, palace_path: str):
|
||||||
"""Return a cached PersistentClient for *palace_path*, creating one if needed."""
|
"""Return a cached ``PersistentClient``, rebuilding on inode/mtime change.
|
||||||
if palace_path not in self._clients:
|
|
||||||
|
Handles the palace-rebuild case (repair/nuke/purge) by invalidating the
|
||||||
|
cache when ``chroma.sqlite3`` changes on disk. Mirrors the semantics of
|
||||||
|
``mcp_server._get_client`` (merged via #757):
|
||||||
|
|
||||||
|
* DB file missing while we hold a cached client → drop the cache so we
|
||||||
|
do not serve stale data after a rebuild that has not yet re-created
|
||||||
|
the DB.
|
||||||
|
* Transition 0 → nonzero stat (DB created after cache) counts as a
|
||||||
|
change, so the cached client is replaced with one that sees the DB.
|
||||||
|
* FAT/exFAT filesystems return inode 0; we never fire inode comparisons
|
||||||
|
when either side is 0 (safe fallback) but still honor mtime.
|
||||||
|
* Mtime change uses an epsilon (0.01 s) to tolerate FS timestamp
|
||||||
|
granularity without thrashing.
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
from .base import BackendClosedError # late import avoids cycles at module load
|
||||||
|
|
||||||
|
raise BackendClosedError("ChromaBackend has been closed")
|
||||||
|
|
||||||
|
cached = self._clients.get(palace_path)
|
||||||
|
cached_inode, cached_mtime = self._freshness.get(palace_path, (0, 0.0))
|
||||||
|
current_inode, current_mtime = self._db_stat(palace_path)
|
||||||
|
|
||||||
|
db_path = os.path.join(palace_path, "chroma.sqlite3")
|
||||||
|
# DB was present when cache was built but is now missing → invalidate.
|
||||||
|
if cached is not None and not os.path.isfile(db_path):
|
||||||
|
self._clients.pop(palace_path, None)
|
||||||
|
self._freshness.pop(palace_path, None)
|
||||||
|
cached = None
|
||||||
|
cached_inode, cached_mtime = 0, 0.0
|
||||||
|
|
||||||
|
inode_changed = current_inode != 0 and cached_inode != 0 and current_inode != cached_inode
|
||||||
|
# Transition from no-stat (0.0) to a real stat counts as a change so we
|
||||||
|
# pick up a DB that was created after the cache was built.
|
||||||
|
mtime_appeared = cached_mtime == 0.0 and current_mtime != 0.0
|
||||||
|
mtime_changed = (
|
||||||
|
current_mtime != 0.0
|
||||||
|
and cached_mtime != 0.0
|
||||||
|
and abs(current_mtime - cached_mtime) > 0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
if cached is None or inode_changed or mtime_changed or mtime_appeared:
|
||||||
_fix_blob_seq_ids(palace_path)
|
_fix_blob_seq_ids(palace_path)
|
||||||
self._clients[palace_path] = chromadb.PersistentClient(path=palace_path)
|
cached = chromadb.PersistentClient(path=palace_path)
|
||||||
return self._clients[palace_path]
|
self._clients[palace_path] = cached
|
||||||
|
# Re-stat after the client constructor runs: chromadb creates
|
||||||
|
# chroma.sqlite3 lazily, so the stat captured before the call
|
||||||
|
# may still be (0, 0.0) on first open.
|
||||||
|
self._freshness[palace_path] = self._db_stat(palace_path)
|
||||||
|
return cached
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Public static helpers (for callers that manage their own caching)
|
# Public static helpers (legacy; prefer :meth:`get_collection`)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_client(palace_path: str):
|
def make_client(palace_path: str):
|
||||||
"""Create and return a fresh PersistentClient (fix BLOB seq_ids first).
|
"""Create a fresh ``PersistentClient`` (fixes BLOB seq_ids first).
|
||||||
|
|
||||||
Intended for long-lived callers (e.g. mcp_server) that keep their own
|
Deprecated-ish: exposed for legacy long-lived callers that manage their
|
||||||
inode/mtime-based client cache.
|
own client cache. New code should obtain a collection through
|
||||||
|
:meth:`get_collection` which manages caching internally.
|
||||||
"""
|
"""
|
||||||
_fix_blob_seq_ids(palace_path)
|
_fix_blob_seq_ids(palace_path)
|
||||||
return chromadb.PersistentClient(path=palace_path)
|
return chromadb.PersistentClient(path=palace_path)
|
||||||
@@ -109,12 +412,31 @@ class ChromaBackend:
|
|||||||
return chromadb.__version__
|
return chromadb.__version__
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Collection lifecycle
|
# BaseBackend surface
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def get_collection(self, palace_path: str, collection_name: str, create: bool = False):
|
def get_collection(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> ChromaCollection:
|
||||||
|
"""Obtain a collection for a palace.
|
||||||
|
|
||||||
|
Supports two calling conventions during the RFC 001 transition:
|
||||||
|
|
||||||
|
* New (preferred): ``get_collection(palace=PalaceRef, collection_name=...,
|
||||||
|
create=False, options=None)``.
|
||||||
|
* Legacy: ``get_collection(palace_path, collection_name, create=False)``
|
||||||
|
— still used by callers not yet migrated.
|
||||||
|
"""
|
||||||
|
palace_ref, collection_name, create, options = _normalize_get_collection_args(args, kwargs)
|
||||||
|
|
||||||
|
palace_path = palace_ref.local_path
|
||||||
|
if palace_path is None:
|
||||||
|
raise PalaceNotFoundError("ChromaBackend requires PalaceRef.local_path")
|
||||||
|
|
||||||
if not create and not os.path.isdir(palace_path):
|
if not create and not os.path.isdir(palace_path):
|
||||||
raise FileNotFoundError(palace_path)
|
raise PalaceNotFoundError(palace_path)
|
||||||
|
|
||||||
if create:
|
if create:
|
||||||
os.makedirs(palace_path, exist_ok=True)
|
os.makedirs(palace_path, exist_ok=True)
|
||||||
@@ -124,29 +446,113 @@ class ChromaBackend:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
client = self._client(palace_path)
|
client = self._client(palace_path)
|
||||||
|
hnsw_space = "cosine"
|
||||||
|
if options and isinstance(options, dict):
|
||||||
|
hnsw_space = options.get("hnsw_space", hnsw_space)
|
||||||
|
|
||||||
if create:
|
if create:
|
||||||
collection = client.get_or_create_collection(
|
collection = client.get_or_create_collection(
|
||||||
collection_name, metadata={"hnsw:space": "cosine"}
|
collection_name, metadata={"hnsw:space": hnsw_space}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
collection = client.get_collection(collection_name)
|
collection = client.get_collection(collection_name)
|
||||||
return ChromaCollection(collection)
|
return ChromaCollection(collection)
|
||||||
|
|
||||||
def get_or_create_collection(
|
def close_palace(self, palace) -> None:
|
||||||
self, palace_path: str, collection_name: str
|
"""Drop cached handles for ``palace``. Accepts ``PalaceRef`` or legacy path str."""
|
||||||
) -> "ChromaCollection":
|
path = palace.local_path if isinstance(palace, PalaceRef) else palace
|
||||||
"""Shorthand for get_collection(..., create=True)."""
|
if path is None:
|
||||||
|
return
|
||||||
|
self._clients.pop(path, None)
|
||||||
|
self._freshness.pop(path, None)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._clients.clear()
|
||||||
|
self._freshness.clear()
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
def health(self, palace: Optional[PalaceRef] = None) -> HealthStatus:
|
||||||
|
if self._closed:
|
||||||
|
return HealthStatus.unhealthy("backend closed")
|
||||||
|
return HealthStatus.healthy()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect(cls, path: str) -> bool:
|
||||||
|
return os.path.isfile(os.path.join(path, "chroma.sqlite3"))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Legacy (pre-RFC 001) surface — retained while callers migrate.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_or_create_collection(self, palace_path: str, collection_name: str) -> ChromaCollection:
|
||||||
|
"""Legacy shim for ``get_collection(..., create=True)`` by path string."""
|
||||||
return self.get_collection(palace_path, collection_name, create=True)
|
return self.get_collection(palace_path, collection_name, create=True)
|
||||||
|
|
||||||
def delete_collection(self, palace_path: str, collection_name: str) -> None:
|
def delete_collection(self, palace_path: str, collection_name: str) -> None:
|
||||||
"""Delete *collection_name* from the palace at *palace_path*."""
|
"""Delete ``collection_name`` from the palace at ``palace_path``."""
|
||||||
self._client(palace_path).delete_collection(collection_name)
|
self._client(palace_path).delete_collection(collection_name)
|
||||||
|
|
||||||
def create_collection(
|
def create_collection(
|
||||||
self, palace_path: str, collection_name: str, hnsw_space: str = "cosine"
|
self, palace_path: str, collection_name: str, hnsw_space: str = "cosine"
|
||||||
) -> "ChromaCollection":
|
) -> ChromaCollection:
|
||||||
"""Create (not get-or-create) *collection_name* with cosine HNSW space."""
|
"""Create (not get-or-create) ``collection_name`` with the given HNSW space."""
|
||||||
collection = self._client(palace_path).create_collection(
|
collection = self._client(palace_path).create_collection(
|
||||||
collection_name, metadata={"hnsw:space": hnsw_space}
|
collection_name, metadata={"hnsw:space": hnsw_space}
|
||||||
)
|
)
|
||||||
return ChromaCollection(collection)
|
return ChromaCollection(collection)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_get_collection_args(args, kwargs):
|
||||||
|
"""Unify legacy positional ``(palace_path, collection_name, create)`` calls
|
||||||
|
with the new kwargs-only ``(palace=PalaceRef, collection_name=..., create=...)``.
|
||||||
|
|
||||||
|
Returns ``(PalaceRef, collection_name, create, options)``.
|
||||||
|
"""
|
||||||
|
# New-style: palace= kwarg with a PalaceRef (spec path).
|
||||||
|
if "palace" in kwargs:
|
||||||
|
palace_ref = kwargs.pop("palace")
|
||||||
|
if not isinstance(palace_ref, PalaceRef):
|
||||||
|
raise TypeError("palace= must be a PalaceRef instance")
|
||||||
|
collection_name = kwargs.pop("collection_name")
|
||||||
|
create = kwargs.pop("create", False)
|
||||||
|
options = kwargs.pop("options", None)
|
||||||
|
if kwargs:
|
||||||
|
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
|
||||||
|
if args:
|
||||||
|
raise TypeError("positional args not allowed with palace= kwarg")
|
||||||
|
return palace_ref, collection_name, create, options
|
||||||
|
|
||||||
|
# Legacy: first positional is a path string.
|
||||||
|
if args:
|
||||||
|
palace_path = args[0]
|
||||||
|
rest = list(args[1:])
|
||||||
|
collection_name = kwargs.pop("collection_name", None) or (rest.pop(0) if rest else None)
|
||||||
|
if collection_name is None:
|
||||||
|
raise TypeError("collection_name is required")
|
||||||
|
create = kwargs.pop("create", False)
|
||||||
|
if rest:
|
||||||
|
create = rest.pop(0)
|
||||||
|
if kwargs:
|
||||||
|
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
|
||||||
|
return (
|
||||||
|
PalaceRef(id=palace_path, local_path=palace_path),
|
||||||
|
collection_name,
|
||||||
|
bool(create),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Legacy kwargs-only (palace_path=..., collection_name=..., create=...)
|
||||||
|
if "palace_path" in kwargs:
|
||||||
|
palace_path = kwargs.pop("palace_path")
|
||||||
|
collection_name = kwargs.pop("collection_name")
|
||||||
|
create = kwargs.pop("create", False)
|
||||||
|
if kwargs:
|
||||||
|
raise TypeError(f"unexpected kwargs: {sorted(kwargs)}")
|
||||||
|
return (
|
||||||
|
PalaceRef(id=palace_path, local_path=palace_path),
|
||||||
|
collection_name,
|
||||||
|
bool(create),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise TypeError("get_collection requires palace= or a positional palace_path")
|
||||||
|
|||||||
@@ -0,0 +1,189 @@
|
|||||||
|
"""Backend registry + entry-point discovery (RFC 001 §3).
|
||||||
|
|
||||||
|
Third-party backends ship as installable packages that declare a
|
||||||
|
``mempalace.backends`` entry point::
|
||||||
|
|
||||||
|
# pyproject.toml of mempalace-postgres
|
||||||
|
[project.entry-points."mempalace.backends"]
|
||||||
|
postgres = "mempalace_postgres:PostgresBackend"
|
||||||
|
|
||||||
|
MemPalace discovers them at process start. In-tree tests and local development
|
||||||
|
can register manually via :func:`register`. Explicit registration wins on
|
||||||
|
name conflict (matches RFC 001 §3.2).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from importlib import metadata
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
from .base import BaseBackend
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ENTRY_POINT_GROUP = "mempalace.backends"
|
||||||
|
|
||||||
|
_registry: dict[str, Type[BaseBackend]] = {}
|
||||||
|
_instances: dict[str, BaseBackend] = {}
|
||||||
|
_explicit: set[str] = set()
|
||||||
|
_discovered = False
|
||||||
|
_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def register(name: str, backend_cls: Type[BaseBackend]) -> None:
|
||||||
|
"""Register ``backend_cls`` under ``name``.
|
||||||
|
|
||||||
|
Explicit registration wins over entry-point discovery on conflict
|
||||||
|
(RFC 001 §3.2).
|
||||||
|
"""
|
||||||
|
with _lock:
|
||||||
|
_registry[name] = backend_cls
|
||||||
|
_explicit.add(name)
|
||||||
|
# Invalidate any cached instance so the new class is used on next get.
|
||||||
|
_instances.pop(name, None)
|
||||||
|
|
||||||
|
|
||||||
|
def unregister(name: str) -> None:
|
||||||
|
"""Remove a backend registration (primarily for tests)."""
|
||||||
|
with _lock:
|
||||||
|
_registry.pop(name, None)
|
||||||
|
_explicit.discard(name)
|
||||||
|
_instances.pop(name, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _discover_entry_points() -> None:
|
||||||
|
"""Load entry-point-declared backends once per process."""
|
||||||
|
global _discovered
|
||||||
|
if _discovered:
|
||||||
|
return
|
||||||
|
with _lock:
|
||||||
|
if _discovered:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
eps = metadata.entry_points()
|
||||||
|
# Py ≥ 3.10 returns an EntryPoints object; older versions returned a dict.
|
||||||
|
group = (
|
||||||
|
eps.select(group=_ENTRY_POINT_GROUP)
|
||||||
|
if hasattr(eps, "select")
|
||||||
|
else eps.get(_ENTRY_POINT_GROUP, [])
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("entry-point discovery for %s failed", _ENTRY_POINT_GROUP)
|
||||||
|
group = []
|
||||||
|
for ep in group:
|
||||||
|
if ep.name in _explicit:
|
||||||
|
continue # explicit registration wins
|
||||||
|
try:
|
||||||
|
cls = ep.load()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("failed to load backend entry point %r", ep.name)
|
||||||
|
continue
|
||||||
|
if not isinstance(cls, type) or not issubclass(cls, BaseBackend):
|
||||||
|
logger.warning(
|
||||||
|
"entry point %r did not resolve to a BaseBackend subclass (got %r)",
|
||||||
|
ep.name,
|
||||||
|
cls,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
_registry.setdefault(ep.name, cls)
|
||||||
|
_discovered = True
|
||||||
|
|
||||||
|
|
||||||
|
def available_backends() -> list[str]:
|
||||||
|
"""Return sorted list of all registered backend names."""
|
||||||
|
_discover_entry_points()
|
||||||
|
return sorted(_registry.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend_class(name: str) -> Type[BaseBackend]:
|
||||||
|
"""Return the registered backend class for ``name``."""
|
||||||
|
_discover_entry_points()
|
||||||
|
try:
|
||||||
|
return _registry[name]
|
||||||
|
except KeyError as e:
|
||||||
|
raise KeyError(f"unknown backend {name!r}; available: {available_backends()}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend(name: str) -> BaseBackend:
|
||||||
|
"""Return a long-lived instance of the named backend.
|
||||||
|
|
||||||
|
Instances are cached per-name; repeated calls return the same object.
|
||||||
|
Call :func:`reset_backends` in tests that need isolation.
|
||||||
|
"""
|
||||||
|
_discover_entry_points()
|
||||||
|
with _lock:
|
||||||
|
inst = _instances.get(name)
|
||||||
|
if inst is not None:
|
||||||
|
return inst
|
||||||
|
cls = _registry.get(name)
|
||||||
|
if cls is None:
|
||||||
|
raise KeyError(f"unknown backend {name!r}; available: {sorted(_registry.keys())}")
|
||||||
|
inst = cls()
|
||||||
|
_instances[name] = inst
|
||||||
|
return inst
|
||||||
|
|
||||||
|
|
||||||
|
def reset_backends() -> None:
|
||||||
|
"""Close and drop all cached backend instances (primarily for tests)."""
|
||||||
|
with _lock:
|
||||||
|
for inst in _instances.values():
|
||||||
|
try:
|
||||||
|
inst.close()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("error closing backend during reset")
|
||||||
|
_instances.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_backend_for_palace(
|
||||||
|
*,
|
||||||
|
explicit: Optional[str] = None,
|
||||||
|
config_value: Optional[str] = None,
|
||||||
|
env_value: Optional[str] = None,
|
||||||
|
palace_path: Optional[str] = None,
|
||||||
|
default: str = "chroma",
|
||||||
|
) -> str:
|
||||||
|
"""Resolve the backend name for a palace per RFC 001 §3.3 priority order.
|
||||||
|
|
||||||
|
1. Explicit kwarg / CLI flag
|
||||||
|
2. Per-palace config value
|
||||||
|
3. ``MEMPALACE_BACKEND`` env var
|
||||||
|
4. Auto-detect from on-disk artifacts (migration/upgrade path only)
|
||||||
|
5. Default (``chroma``)
|
||||||
|
|
||||||
|
Auto-detection is strictly a migration aid: it fires only when a local path
|
||||||
|
is presented, no earlier rule has chosen a backend, AND the path already
|
||||||
|
contains backend-identifiable artifacts. For new palaces, (5) wins.
|
||||||
|
"""
|
||||||
|
for candidate in (explicit, config_value, env_value):
|
||||||
|
if candidate:
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
_discover_entry_points()
|
||||||
|
if palace_path:
|
||||||
|
for name, cls in _registry.items():
|
||||||
|
try:
|
||||||
|
if cls.detect(palace_path):
|
||||||
|
return name
|
||||||
|
except Exception:
|
||||||
|
logger.exception("detect() raised on backend %r", name)
|
||||||
|
continue
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Built-in registration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _register_builtins() -> None:
|
||||||
|
"""Register chroma as the in-tree default."""
|
||||||
|
from .chroma import ChromaBackend
|
||||||
|
|
||||||
|
# Use setdefault semantics so a caller that pre-registered for tests wins.
|
||||||
|
if "chroma" not in _registry:
|
||||||
|
_registry["chroma"] = ChromaBackend
|
||||||
|
|
||||||
|
|
||||||
|
_register_builtins()
|
||||||
+12
-12
@@ -30,15 +30,16 @@ class SearchError(Exception):
|
|||||||
_TOKEN_RE = re.compile(r"\w{2,}", re.UNICODE)
|
_TOKEN_RE = re.compile(r"\w{2,}", re.UNICODE)
|
||||||
|
|
||||||
|
|
||||||
def _first_or_empty(results: dict, key: str) -> list:
|
def _first_or_empty(results, key: str) -> list:
|
||||||
"""Return the first inner list of a ChromaDB query result, or [].
|
"""Return the first inner list of a query result field, or [].
|
||||||
|
|
||||||
ChromaDB returns shapes like ``{"documents": [["a", "b"]], ...}`` for a
|
Accepts both the typed :class:`QueryResult` (attribute access) and the
|
||||||
successful query, but ``{"documents": [], ...}`` (empty outer list) when
|
pre-typed chroma dict shape; this polymorphism is retained so test mocks
|
||||||
the collection is empty or the filter excludes everything. Indexing
|
still work and callers mid-migration do not crash. Preserves the empty-
|
||||||
``[0]`` blindly raises IndexError in that case (issue #195).
|
collection semantics from issue #195: when no queries returned hits, the
|
||||||
|
outer list may be empty and indexing ``[0]`` would raise.
|
||||||
"""
|
"""
|
||||||
outer = results.get(key)
|
outer = getattr(results, key, None) if not isinstance(results, dict) else results.get(key)
|
||||||
if not outer:
|
if not outer:
|
||||||
return []
|
return []
|
||||||
return outer[0] or []
|
return outer[0] or []
|
||||||
@@ -209,7 +210,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra
|
|||||||
return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None}
|
return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None}
|
||||||
|
|
||||||
indexed_docs = []
|
indexed_docs = []
|
||||||
for doc, meta in zip(neighbors.get("documents") or [], neighbors.get("metadatas") or []):
|
for doc, meta in zip(neighbors.documents, neighbors.metadatas):
|
||||||
ci = meta.get("chunk_index")
|
ci = meta.get("chunk_index")
|
||||||
if isinstance(ci, int):
|
if isinstance(ci, int):
|
||||||
indexed_docs.append((ci, doc))
|
indexed_docs.append((ci, doc))
|
||||||
@@ -224,8 +225,7 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra
|
|||||||
total_drawers = None
|
total_drawers = None
|
||||||
try:
|
try:
|
||||||
all_meta = drawers_col.get(where={"source_file": src}, include=["metadatas"])
|
all_meta = drawers_col.get(where={"source_file": src}, include=["metadatas"])
|
||||||
ids = all_meta.get("ids") or []
|
total_drawers = len(all_meta.ids) if all_meta.ids else None
|
||||||
total_drawers = len(ids) if ids else None
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -451,8 +451,8 @@ def search_memories(
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
docs = source_drawers.get("documents") or []
|
docs = source_drawers.documents
|
||||||
metas_ = source_drawers.get("metadatas") or []
|
metas_ = source_drawers.metadatas
|
||||||
if len(docs) <= 1:
|
if len(docs) <= 1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ Repository = "https://github.com/MemPalace/mempalace"
|
|||||||
[project.scripts]
|
[project.scripts]
|
||||||
mempalace = "mempalace.cli:main"
|
mempalace = "mempalace.cli:main"
|
||||||
|
|
||||||
|
[project.entry-points."mempalace.backends"]
|
||||||
|
chroma = "mempalace.backends.chroma:ChromaBackend"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"]
|
dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"]
|
||||||
spellcheck = ["autocorrect>=2.0"]
|
spellcheck = ["autocorrect>=2.0"]
|
||||||
|
|||||||
+247
-15
@@ -3,12 +3,34 @@ import sqlite3
|
|||||||
import chromadb
|
import chromadb
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from mempalace.backends import (
|
||||||
|
GetResult,
|
||||||
|
PalaceRef,
|
||||||
|
QueryResult,
|
||||||
|
UnsupportedFilterError,
|
||||||
|
available_backends,
|
||||||
|
get_backend,
|
||||||
|
)
|
||||||
from mempalace.backends.chroma import ChromaBackend, ChromaCollection, _fix_blob_seq_ids
|
from mempalace.backends.chroma import ChromaBackend, ChromaCollection, _fix_blob_seq_ids
|
||||||
|
|
||||||
|
|
||||||
class _FakeCollection:
|
class _FakeCollection:
|
||||||
def __init__(self):
|
"""Stand-in for a chromadb.Collection returning raw chroma-shaped dicts."""
|
||||||
|
|
||||||
|
def __init__(self, query_response=None, get_response=None, count_value=7):
|
||||||
self.calls = []
|
self.calls = []
|
||||||
|
self._query_response = query_response or {
|
||||||
|
"ids": [["a", "b"]],
|
||||||
|
"documents": [["da", "db"]],
|
||||||
|
"metadatas": [[{"wing": "w1"}, {"wing": "w2"}]],
|
||||||
|
"distances": [[0.1, 0.2]],
|
||||||
|
}
|
||||||
|
self._get_response = get_response or {
|
||||||
|
"ids": ["a"],
|
||||||
|
"documents": ["da"],
|
||||||
|
"metadatas": [{"wing": "w1"}],
|
||||||
|
}
|
||||||
|
self._count_value = count_value
|
||||||
|
|
||||||
def add(self, **kwargs):
|
def add(self, **kwargs):
|
||||||
self.calls.append(("add", kwargs))
|
self.calls.append(("add", kwargs))
|
||||||
@@ -16,41 +38,251 @@ class _FakeCollection:
|
|||||||
def upsert(self, **kwargs):
|
def upsert(self, **kwargs):
|
||||||
self.calls.append(("upsert", kwargs))
|
self.calls.append(("upsert", kwargs))
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
self.calls.append(("update", kwargs))
|
||||||
|
|
||||||
def query(self, **kwargs):
|
def query(self, **kwargs):
|
||||||
self.calls.append(("query", kwargs))
|
self.calls.append(("query", kwargs))
|
||||||
return {"kind": "query"}
|
return self._query_response
|
||||||
|
|
||||||
def get(self, **kwargs):
|
def get(self, **kwargs):
|
||||||
self.calls.append(("get", kwargs))
|
self.calls.append(("get", kwargs))
|
||||||
return {"kind": "get"}
|
return self._get_response
|
||||||
|
|
||||||
def delete(self, **kwargs):
|
def delete(self, **kwargs):
|
||||||
self.calls.append(("delete", kwargs))
|
self.calls.append(("delete", kwargs))
|
||||||
|
|
||||||
def count(self):
|
def count(self):
|
||||||
self.calls.append(("count", {}))
|
self.calls.append(("count", {}))
|
||||||
return 7
|
return self._count_value
|
||||||
|
|
||||||
|
|
||||||
def test_chroma_collection_delegates_methods():
|
def test_chroma_collection_returns_typed_query_result():
|
||||||
|
fake = _FakeCollection()
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
|
||||||
|
result = collection.query(query_texts=["q"])
|
||||||
|
|
||||||
|
assert isinstance(result, QueryResult)
|
||||||
|
assert result.ids == [["a", "b"]]
|
||||||
|
assert result.documents == [["da", "db"]]
|
||||||
|
assert result.metadatas == [[{"wing": "w1"}, {"wing": "w2"}]]
|
||||||
|
assert result.distances == [[0.1, 0.2]]
|
||||||
|
assert result.embeddings is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_collection_returns_typed_get_result():
|
||||||
|
fake = _FakeCollection()
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
|
||||||
|
result = collection.get(where={"wing": "w1"})
|
||||||
|
|
||||||
|
assert isinstance(result, GetResult)
|
||||||
|
assert result.ids == ["a"]
|
||||||
|
assert result.documents == ["da"]
|
||||||
|
assert result.metadatas == [{"wing": "w1"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_result_empty_preserves_outer_dimension():
|
||||||
|
empty = QueryResult.empty(num_queries=2)
|
||||||
|
assert empty.ids == [[], []]
|
||||||
|
assert empty.documents == [[], []]
|
||||||
|
assert empty.distances == [[], []]
|
||||||
|
assert empty.embeddings is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_typed_results_support_dict_compat_access():
|
||||||
|
"""Transitional compat shim per base.py — retained until callers migrate to attrs."""
|
||||||
|
result = GetResult(ids=["a"], documents=["da"], metadatas=[{"w": 1}])
|
||||||
|
assert result["ids"] == ["a"]
|
||||||
|
assert result.get("documents") == ["da"]
|
||||||
|
assert result.get("missing", "default") == "default"
|
||||||
|
assert "ids" in result
|
||||||
|
assert "missing" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_collection_query_empty_result_preserves_outer_shape():
|
||||||
|
fake = _FakeCollection(
|
||||||
|
query_response={"ids": [], "documents": [], "metadatas": [], "distances": []}
|
||||||
|
)
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
|
||||||
|
result = collection.query(query_texts=["q1", "q2"])
|
||||||
|
assert result.ids == [[], []]
|
||||||
|
assert result.documents == [[], []]
|
||||||
|
assert result.distances == [[], []]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_collection_rejects_unknown_where_operator():
|
||||||
|
fake = _FakeCollection()
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
|
||||||
|
with pytest.raises(UnsupportedFilterError):
|
||||||
|
collection.query(query_texts=["q"], where={"$regex": "foo"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_collection_delegates_writes():
|
||||||
fake = _FakeCollection()
|
fake = _FakeCollection()
|
||||||
collection = ChromaCollection(fake)
|
collection = ChromaCollection(fake)
|
||||||
|
|
||||||
collection.add(documents=["d"], ids=["1"], metadatas=[{"wing": "w"}])
|
collection.add(documents=["d"], ids=["1"], metadatas=[{"wing": "w"}])
|
||||||
collection.upsert(documents=["u"], ids=["2"], metadatas=[{"room": "r"}])
|
collection.upsert(documents=["u"], ids=["2"], metadatas=[{"room": "r"}])
|
||||||
assert collection.query(query_texts=["q"]) == {"kind": "query"}
|
|
||||||
assert collection.get(where={"wing": "w"}) == {"kind": "get"}
|
|
||||||
collection.delete(ids=["1"])
|
collection.delete(ids=["1"])
|
||||||
assert collection.count() == 7
|
assert collection.count() == 7
|
||||||
|
|
||||||
assert fake.calls == [
|
kinds = [call[0] for call in fake.calls]
|
||||||
("add", {"documents": ["d"], "ids": ["1"], "metadatas": [{"wing": "w"}]}),
|
assert kinds == ["add", "upsert", "delete", "count"]
|
||||||
("upsert", {"documents": ["u"], "ids": ["2"], "metadatas": [{"room": "r"}]}),
|
|
||||||
("query", {"query_texts": ["q"]}),
|
|
||||||
("get", {"where": {"wing": "w"}}),
|
def test_registry_exposes_chroma_by_default():
|
||||||
("delete", {"ids": ["1"]}),
|
names = available_backends()
|
||||||
("count", {}),
|
assert "chroma" in names
|
||||||
]
|
assert isinstance(get_backend("chroma"), ChromaBackend)
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_unknown_backend_raises():
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
get_backend("no-such-backend-exists")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_backend_priority_order(tmp_path):
|
||||||
|
from mempalace.backends import resolve_backend_for_palace
|
||||||
|
|
||||||
|
# explicit kwarg wins over everything
|
||||||
|
assert resolve_backend_for_palace(explicit="pg", config_value="lance") == "pg"
|
||||||
|
# config value wins over env / default
|
||||||
|
assert resolve_backend_for_palace(config_value="lance", env_value="qdrant") == "lance"
|
||||||
|
# env wins over default
|
||||||
|
assert resolve_backend_for_palace(env_value="qdrant", default="chroma") == "qdrant"
|
||||||
|
# falls back to default
|
||||||
|
assert resolve_backend_for_palace() == "chroma"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_detect_matches_palace_with_chroma_sqlite(tmp_path):
|
||||||
|
(tmp_path / "chroma.sqlite3").write_bytes(b"")
|
||||||
|
assert ChromaBackend.detect(str(tmp_path)) is True
|
||||||
|
assert ChromaBackend.detect(str(tmp_path.parent)) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_rejects_missing_input():
|
||||||
|
fake = _FakeCollection()
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
collection.query()
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_rejects_both_texts_and_embeddings():
|
||||||
|
fake = _FakeCollection()
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
collection.query(query_texts=["q"], query_embeddings=[[0.1, 0.2]])
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_rejects_empty_input_list():
|
||||||
|
fake = _FakeCollection()
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
collection.query(query_texts=[])
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_empty_preserves_embeddings_outer_shape_when_requested():
|
||||||
|
fake = _FakeCollection(
|
||||||
|
query_response={"ids": [], "documents": [], "metadatas": [], "distances": []}
|
||||||
|
)
|
||||||
|
collection = ChromaCollection(fake)
|
||||||
|
|
||||||
|
requested = collection.query(query_texts=["q1", "q2"], include=["documents", "embeddings"])
|
||||||
|
assert requested.embeddings == [[], []]
|
||||||
|
|
||||||
|
not_requested = collection.query(query_texts=["q1", "q2"], include=["documents"])
|
||||||
|
assert not_requested.embeddings is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_cache_invalidates_when_db_file_missing(tmp_path):
|
||||||
|
"""A palace rebuild that removes chroma.sqlite3 must drop the stale cache.
|
||||||
|
|
||||||
|
Primes backend._clients/_freshness directly with a sentinel rather than
|
||||||
|
opening a real ``PersistentClient``: on Windows the sqlite file handle
|
||||||
|
would still be live and ``Path.unlink`` would raise ``PermissionError``,
|
||||||
|
making the test unable to exercise the branch we care about. The decision
|
||||||
|
logic under test is pure (no chromadb calls before the branch), so a
|
||||||
|
sentinel is sufficient.
|
||||||
|
"""
|
||||||
|
backend = ChromaBackend()
|
||||||
|
palace_path = tmp_path / "palace"
|
||||||
|
palace_path.mkdir()
|
||||||
|
db_file = palace_path / "chroma.sqlite3"
|
||||||
|
db_file.write_bytes(b"") # any file is enough for _db_stat to see it
|
||||||
|
st = db_file.stat()
|
||||||
|
|
||||||
|
sentinel = object()
|
||||||
|
backend._clients[str(palace_path)] = sentinel
|
||||||
|
backend._freshness[str(palace_path)] = (st.st_ino, st.st_mtime)
|
||||||
|
|
||||||
|
# Simulate a rebuild mid-flight: chroma.sqlite3 goes away. Safe to unlink
|
||||||
|
# because nothing in this test is holding an OS handle on the file.
|
||||||
|
db_file.unlink()
|
||||||
|
|
||||||
|
prior_freshness = (st.st_ino, st.st_mtime)
|
||||||
|
new_client = backend._client(str(palace_path))
|
||||||
|
# Cache was replaced (not the sentinel) and freshness reflects the post-
|
||||||
|
# rebuild stat (chromadb re-creates chroma.sqlite3 during PersistentClient
|
||||||
|
# construction; _client re-stats after the constructor so freshness is
|
||||||
|
# not frozen at the pre-rebuild value). The stale cached sentinel would
|
||||||
|
# have served wrong data if returned.
|
||||||
|
assert new_client is not sentinel
|
||||||
|
assert backend._freshness[str(palace_path)] != prior_freshness
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_cache_picks_up_db_created_after_first_open(tmp_path):
|
||||||
|
"""The 0 → nonzero stat transition invalidates a cache built before the DB existed."""
|
||||||
|
backend = ChromaBackend()
|
||||||
|
palace_path = tmp_path / "palace"
|
||||||
|
palace_path.mkdir()
|
||||||
|
|
||||||
|
# Seed an entry in the caches as if a prior _client() call had opened the
|
||||||
|
# palace when chroma.sqlite3 did not exist yet. Freshness (0, 0.0) is the
|
||||||
|
# signal that the DB was absent at cache time.
|
||||||
|
sentinel = object()
|
||||||
|
backend._clients[str(palace_path)] = sentinel
|
||||||
|
backend._freshness[str(palace_path)] = (0, 0.0)
|
||||||
|
|
||||||
|
# The DB file now appears (real chromadb would have created it by now).
|
||||||
|
# Use a real chromadb call so _fix_blob_seq_ids and PersistentClient succeed.
|
||||||
|
import chromadb as _chromadb
|
||||||
|
|
||||||
|
_chromadb.PersistentClient(path=str(palace_path)).get_or_create_collection("seed")
|
||||||
|
assert (palace_path / "chroma.sqlite3").is_file()
|
||||||
|
|
||||||
|
# Next _client() call must detect the 0 → nonzero transition and rebuild.
|
||||||
|
refreshed = backend._client(str(palace_path))
|
||||||
|
assert refreshed is not sentinel
|
||||||
|
assert backend._freshness[str(palace_path)] != (0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_collection_update_default_rejects_mismatched_lengths():
|
||||||
|
"""The ABC default update() raises ValueError rather than silently misaligning."""
|
||||||
|
from mempalace.backends.base import BaseCollection
|
||||||
|
|
||||||
|
collection = ChromaCollection(_FakeCollection())
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="documents length"):
|
||||||
|
BaseCollection.update(collection, ids=["1", "2"], documents=["only-one"])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="metadatas length"):
|
||||||
|
BaseCollection.update(collection, ids=["1", "2"], metadatas=[{"k": 9}])
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_backend_accepts_palace_ref_kwarg(tmp_path):
|
||||||
|
palace_path = tmp_path / "palace"
|
||||||
|
backend = ChromaBackend()
|
||||||
|
collection = backend.get_collection(
|
||||||
|
palace=PalaceRef(id=str(palace_path), local_path=str(palace_path)),
|
||||||
|
collection_name="mempalace_drawers",
|
||||||
|
create=True,
|
||||||
|
)
|
||||||
|
assert palace_path.is_dir()
|
||||||
|
assert isinstance(collection, ChromaCollection)
|
||||||
|
|
||||||
|
|
||||||
def test_chroma_backend_create_false_raises_without_creating_directory(tmp_path):
|
def test_chroma_backend_create_false_raises_without_creating_directory(tmp_path):
|
||||||
|
|||||||
Reference in New Issue
Block a user