Merge pull request #852 from MemPalace/release/v4-prep

refactor: route all chromadb access through ChromaBackend (v4 prep)
This commit is contained in:
Igor Lins e Silva
2026-04-14 00:51:10 -03:00
committed by GitHub
11 changed files with 215 additions and 189 deletions
+5
View File
@@ -27,6 +27,11 @@ class BaseCollection(ABC):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def update(self, **kwargs: Any) -> None:
"""Update existing records. Must raise if any ID is missing."""
raise NotImplementedError
@abstractmethod @abstractmethod
def query(self, **kwargs: Any) -> Dict[str, Any]: def query(self, **kwargs: Any) -> Dict[str, Any]:
raise NotImplementedError raise NotImplementedError
+61 -2
View File
@@ -55,6 +55,9 @@ class ChromaCollection(BaseCollection):
def upsert(self, *, documents, ids, metadatas=None): def upsert(self, *, documents, ids, metadatas=None):
self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas) self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas)
def update(self, **kwargs):
self._collection.update(**kwargs)
def query(self, **kwargs): def query(self, **kwargs):
return self._collection.query(**kwargs) return self._collection.query(**kwargs)
@@ -71,6 +74,44 @@ class ChromaCollection(BaseCollection):
class ChromaBackend: class ChromaBackend:
"""Factory for MemPalace's default ChromaDB backend.""" """Factory for MemPalace's default ChromaDB backend."""
def __init__(self):
# Per-instance client cache: palace_path -> chromadb.PersistentClient
self._clients: dict = {}
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _client(self, palace_path: str):
"""Return a cached PersistentClient for *palace_path*, creating one if needed."""
if palace_path not in self._clients:
_fix_blob_seq_ids(palace_path)
self._clients[palace_path] = chromadb.PersistentClient(path=palace_path)
return self._clients[palace_path]
# ------------------------------------------------------------------
# Public static helpers (for callers that manage their own caching)
# ------------------------------------------------------------------
@staticmethod
def make_client(palace_path: str):
"""Create and return a fresh PersistentClient (fix BLOB seq_ids first).
Intended for long-lived callers (e.g. mcp_server) that keep their own
inode/mtime-based client cache.
"""
_fix_blob_seq_ids(palace_path)
return chromadb.PersistentClient(path=palace_path)
@staticmethod
def backend_version() -> str:
"""Return the installed chromadb package version string."""
return chromadb.__version__
# ------------------------------------------------------------------
# Collection lifecycle
# ------------------------------------------------------------------
def get_collection(self, palace_path: str, collection_name: str, create: bool = False): def get_collection(self, palace_path: str, collection_name: str, create: bool = False):
if not create and not os.path.isdir(palace_path): if not create and not os.path.isdir(palace_path):
raise FileNotFoundError(palace_path) raise FileNotFoundError(palace_path)
@@ -82,8 +123,7 @@ class ChromaBackend:
except (OSError, NotImplementedError): except (OSError, NotImplementedError):
pass pass
_fix_blob_seq_ids(palace_path) client = self._client(palace_path)
client = chromadb.PersistentClient(path=palace_path)
if create: if create:
collection = client.get_or_create_collection( collection = client.get_or_create_collection(
collection_name, metadata={"hnsw:space": "cosine"} collection_name, metadata={"hnsw:space": "cosine"}
@@ -91,3 +131,22 @@ class ChromaBackend:
else: else:
collection = client.get_collection(collection_name) collection = client.get_collection(collection_name)
return ChromaCollection(collection) return ChromaCollection(collection)
def get_or_create_collection(
self, palace_path: str, collection_name: str
) -> "ChromaCollection":
"""Shorthand for get_collection(..., create=True)."""
return self.get_collection(palace_path, collection_name, create=True)
def delete_collection(self, palace_path: str, collection_name: str) -> None:
"""Delete *collection_name* from the palace at *palace_path*."""
self._client(palace_path).delete_collection(collection_name)
def create_collection(
self, palace_path: str, collection_name: str, hnsw_space: str = "cosine"
) -> "ChromaCollection":
"""Create (not get-or-create) *collection_name* with cosine HNSW space."""
collection = self._client(palace_path).create_collection(
collection_name, metadata={"hnsw:space": hnsw_space}
)
return ChromaCollection(collection)
+10 -11
View File
@@ -172,8 +172,8 @@ def cmd_status(args):
def cmd_repair(args): def cmd_repair(args):
"""Rebuild palace vector index from SQLite metadata.""" """Rebuild palace vector index from SQLite metadata."""
import chromadb
import shutil import shutil
from .backends.chroma import ChromaBackend
from .migrate import confirm_destructive_action, contains_palace_database from .migrate import confirm_destructive_action, contains_palace_database
palace_path = os.path.abspath( palace_path = os.path.abspath(
@@ -193,10 +193,11 @@ def cmd_repair(args):
print(f"{'=' * 55}\n") print(f"{'=' * 55}\n")
print(f" Palace: {palace_path}") print(f" Palace: {palace_path}")
backend = ChromaBackend()
# Try to read existing drawers # Try to read existing drawers
try: try:
client = chromadb.PersistentClient(path=palace_path) col = backend.get_collection(palace_path, "mempalace_drawers")
col = client.get_collection("mempalace_drawers")
total = col.count() total = col.count()
print(f" Drawers found: {total}") print(f" Drawers found: {total}")
except Exception as e: except Exception as e:
@@ -243,8 +244,8 @@ def cmd_repair(args):
shutil.copytree(palace_path, backup_path) shutil.copytree(palace_path, backup_path)
print(" Rebuilding collection...") print(" Rebuilding collection...")
client.delete_collection("mempalace_drawers") backend.delete_collection(palace_path, "mempalace_drawers")
new_col = client.create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"}) new_col = backend.create_collection(palace_path, "mempalace_drawers")
filed = 0 filed = 0
for i in range(0, len(all_ids), batch_size): for i in range(0, len(all_ids), batch_size):
@@ -297,7 +298,7 @@ def cmd_mcp(args):
def cmd_compress(args): def cmd_compress(args):
"""Compress drawers in a wing using AAAK Dialect.""" """Compress drawers in a wing using AAAK Dialect."""
import chromadb from .backends.chroma import ChromaBackend
from .dialect import Dialect from .dialect import Dialect
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
@@ -317,9 +318,9 @@ def cmd_compress(args):
dialect = Dialect() dialect = Dialect()
# Connect to palace # Connect to palace
backend = ChromaBackend()
try: try:
client = chromadb.PersistentClient(path=palace_path) col = backend.get_collection(palace_path, "mempalace_drawers")
col = client.get_collection("mempalace_drawers")
except Exception: except Exception:
print(f"\n No palace found at {palace_path}") print(f"\n No palace found at {palace_path}")
print(" Run: mempalace init <dir> then mempalace mine <dir>") print(" Run: mempalace init <dir> then mempalace mine <dir>")
@@ -394,9 +395,7 @@ def cmd_compress(args):
# Store compressed versions (unless dry-run) # Store compressed versions (unless dry-run)
if not args.dry_run: if not args.dry_run:
try: try:
comp_col = client.get_or_create_collection( comp_col = backend.get_or_create_collection(palace_path, "mempalace_compressed")
"mempalace_compressed", metadata={"hnsw:space": "cosine"}
)
for doc_id, compressed, meta, stats in compressed_entries: for doc_id, compressed, meta, stats in compressed_entries:
comp_meta = dict(meta) comp_meta = dict(meta)
comp_meta["compression_ratio"] = round(stats["size_ratio"], 1) comp_meta["compression_ratio"] = round(stats["size_ratio"], 1)
+3 -5
View File
@@ -27,7 +27,7 @@ import os
import time import time
from collections import defaultdict from collections import defaultdict
import chromadb from .backends.chroma import ChromaBackend
COLLECTION_NAME = "mempalace_drawers" COLLECTION_NAME = "mempalace_drawers"
@@ -130,8 +130,7 @@ def dedup_source_group(col, drawer_ids, threshold=DEFAULT_THRESHOLD, dry_run=Tru
def show_stats(palace_path=None): def show_stats(palace_path=None):
"""Show duplication statistics without making changes.""" """Show duplication statistics without making changes."""
palace_path = palace_path or _get_palace_path() palace_path = palace_path or _get_palace_path()
client = chromadb.PersistentClient(path=palace_path) col = ChromaBackend().get_collection(palace_path, COLLECTION_NAME)
col = client.get_collection(COLLECTION_NAME)
groups = get_source_groups(col) groups = get_source_groups(col)
@@ -163,8 +162,7 @@ def dedup_palace(
print(" MemPalace Deduplicator") print(" MemPalace Deduplicator")
print(f"{'=' * 55}") print(f"{'=' * 55}")
client = chromadb.PersistentClient(path=palace_path) col = ChromaBackend().get_collection(palace_path, COLLECTION_NAME)
col = client.get_collection(COLLECTION_NAME)
print(f" Palace: {palace_path}") print(f" Palace: {palace_path}")
print(f" Drawers: {col.count():,}") print(f" Drawers: {col.count():,}")
+7 -5
View File
@@ -32,7 +32,7 @@ from pathlib import Path
from .config import MempalaceConfig, sanitize_name, sanitize_content from .config import MempalaceConfig, sanitize_name, sanitize_content
from .version import __version__ from .version import __version__
import chromadb from .backends.chroma import ChromaBackend, ChromaCollection
from .query_sanitizer import sanitize_query from .query_sanitizer import sanitize_query
from .searcher import search_memories from .searcher import search_memories
from .palace_graph import ( from .palace_graph import (
@@ -177,7 +177,7 @@ def _get_client():
mtime_changed = current_mtime != 0.0 and abs(current_mtime - _palace_db_mtime) > 0.01 mtime_changed = current_mtime != 0.0 and abs(current_mtime - _palace_db_mtime) > 0.01
if _client_cache is None or inode_changed or mtime_changed: if _client_cache is None or inode_changed or mtime_changed:
_client_cache = chromadb.PersistentClient(path=_config.palace_path) _client_cache = ChromaBackend.make_client(_config.palace_path)
_collection_cache = None _collection_cache = None
_metadata_cache = None _metadata_cache = None
_metadata_cache_time = 0 _metadata_cache_time = 0
@@ -192,13 +192,15 @@ def _get_collection(create=False):
try: try:
client = _get_client() client = _get_client()
if create: if create:
_collection_cache = client.get_or_create_collection( _collection_cache = ChromaCollection(
_config.collection_name, metadata={"hnsw:space": "cosine"} client.get_or_create_collection(
_config.collection_name, metadata={"hnsw:space": "cosine"}
)
) )
_metadata_cache = None _metadata_cache = None
_metadata_cache_time = 0 _metadata_cache_time = 0
elif _collection_cache is None: elif _collection_cache is None:
_collection_cache = client.get_collection(_config.collection_name) _collection_cache = ChromaCollection(client.get_collection(_config.collection_name))
_metadata_cache = None _metadata_cache = None
_metadata_cache_time = 0 _metadata_cache_time = 0
return _collection_cache return _collection_cache
+9 -9
View File
@@ -134,7 +134,7 @@ def confirm_destructive_action(
def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False): def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False):
"""Migrate a palace to the currently installed ChromaDB version.""" """Migrate a palace to the currently installed ChromaDB version."""
import chromadb from .backends.chroma import ChromaBackend
palace_path = os.path.abspath(os.path.expanduser(palace_path)) palace_path = os.path.abspath(os.path.expanduser(palace_path))
db_path = os.path.join(palace_path, "chroma.sqlite3") db_path = os.path.join(palace_path, "chroma.sqlite3")
@@ -152,19 +152,19 @@ def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False):
# Detect version # Detect version
source_version = detect_chromadb_version(db_path) source_version = detect_chromadb_version(db_path)
target_version = ChromaBackend.backend_version()
print(f" Source: ChromaDB {source_version}") print(f" Source: ChromaDB {source_version}")
print(f" Target: ChromaDB {chromadb.__version__}") print(f" Target: ChromaDB {target_version}")
# Try reading with current chromadb first # Try reading with current chromadb first
try: try:
client = chromadb.PersistentClient(path=palace_path) col = ChromaBackend().get_collection(palace_path, "mempalace_drawers")
col = client.get_collection("mempalace_drawers")
count = col.count() count = col.count()
print(f"\n Palace is already readable by chromadb {chromadb.__version__}.") print(f"\n Palace is already readable by chromadb {target_version}.")
print(f" {count} drawers found. No migration needed.") print(f" {count} drawers found. No migration needed.")
return True return True
except Exception: except Exception:
print(f"\n Palace is NOT readable by chromadb {chromadb.__version__}.") print(f"\n Palace is NOT readable by chromadb {target_version}.")
print(" Extracting from SQLite directly...") print(" Extracting from SQLite directly...")
# Extract all drawers via raw SQL # Extract all drawers via raw SQL
@@ -208,8 +208,8 @@ def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False):
temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_") temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_")
print(f" Creating fresh palace in {temp_palace}...") print(f" Creating fresh palace in {temp_palace}...")
client = chromadb.PersistentClient(path=temp_palace) fresh_backend = ChromaBackend()
col = client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"}) col = fresh_backend.get_or_create_collection(temp_palace, "mempalace_drawers")
# Re-import in batches # Re-import in batches
batch_size = 500 batch_size = 500
@@ -227,7 +227,7 @@ def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False):
# Verify before swapping # Verify before swapping
final_count = col.count() final_count = col.count()
del col del col
del client del fresh_backend
# Swap: remove old palace, move new one into place # Swap: remove old palace, move new one into place
print(" Swapping old palace for migrated version...") print(" Swapping old palace for migrated version...")
+7 -9
View File
@@ -32,7 +32,7 @@ import os
import shutil import shutil
import time import time
import chromadb from .backends.chroma import ChromaBackend
COLLECTION_NAME = "mempalace_drawers" COLLECTION_NAME = "mempalace_drawers"
@@ -90,8 +90,7 @@ def scan_palace(palace_path=None, only_wing=None):
print(f"\n Palace: {palace_path}") print(f"\n Palace: {palace_path}")
print(" Loading...") print(" Loading...")
client = chromadb.PersistentClient(path=palace_path) col = ChromaBackend().get_collection(palace_path, COLLECTION_NAME)
col = client.get_collection(COLLECTION_NAME)
where = {"wing": only_wing} if only_wing else None where = {"wing": only_wing} if only_wing else None
total = col.count() total = col.count()
@@ -174,8 +173,7 @@ def prune_corrupt(palace_path=None, confirm=False):
print(" Re-run with --confirm to actually delete.") print(" Re-run with --confirm to actually delete.")
return return
client = chromadb.PersistentClient(path=palace_path) col = ChromaBackend().get_collection(palace_path, COLLECTION_NAME)
col = client.get_collection(COLLECTION_NAME)
before = col.count() before = col.count()
print(f" Collection size before: {before:,}") print(f" Collection size before: {before:,}")
@@ -222,9 +220,9 @@ def rebuild_index(palace_path=None):
print(f"{'=' * 55}\n") print(f"{'=' * 55}\n")
print(f" Palace: {palace_path}") print(f" Palace: {palace_path}")
client = chromadb.PersistentClient(path=palace_path) backend = ChromaBackend()
try: try:
col = client.get_collection(COLLECTION_NAME) col = backend.get_collection(palace_path, COLLECTION_NAME)
total = col.count() total = col.count()
except Exception as e: except Exception as e:
print(f" Error reading palace: {e}") print(f" Error reading palace: {e}")
@@ -264,8 +262,8 @@ def rebuild_index(palace_path=None):
# Rebuild with correct HNSW settings # Rebuild with correct HNSW settings
print(" Rebuilding collection with hnsw:space=cosine...") print(" Rebuilding collection with hnsw:space=cosine...")
client.delete_collection(COLLECTION_NAME) backend.delete_collection(palace_path, COLLECTION_NAME)
new_col = client.create_collection(COLLECTION_NAME, metadata={"hnsw:space": "cosine"}) new_col = backend.create_collection(palace_path, COLLECTION_NAME)
filed = 0 filed = 0
for i in range(0, len(all_ids), batch_size): for i in range(0, len(all_ids), batch_size):
+41 -65
View File
@@ -412,12 +412,21 @@ def test_main_compress_dispatches():
# ── cmd_repair ───────────────────────────────────────────────────────── # ── cmd_repair ─────────────────────────────────────────────────────────
def _mock_backend_for(col=None, new_col=None):
"""Build a mock ChromaBackend whose get_collection/create_collection return *col* / *new_col*."""
mock_backend = MagicMock()
if col is not None:
mock_backend.get_collection.return_value = col
if new_col is not None:
mock_backend.create_collection.return_value = new_col
return mock_backend
@patch("mempalace.cli.MempalaceConfig") @patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_no_palace(mock_config_cls, tmp_path, capsys): def test_cmd_repair_no_palace(mock_config_cls, tmp_path, capsys):
mock_config_cls.return_value.palace_path = str(tmp_path / "nonexistent") mock_config_cls.return_value.palace_path = str(tmp_path / "nonexistent")
args = argparse.Namespace(palace=None) args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock() with patch("mempalace.backends.chroma.ChromaBackend"):
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args) cmd_repair(args)
out = capsys.readouterr().out out = capsys.readouterr().out
assert "No palace found" in out assert "No palace found" in out
@@ -429,8 +438,7 @@ def test_cmd_repair_requires_palace_database(mock_config_cls, tmp_path, capsys):
palace_dir.mkdir() palace_dir.mkdir()
mock_config_cls.return_value.palace_path = str(palace_dir) mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None) args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock() with patch("mempalace.backends.chroma.ChromaBackend"):
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args) cmd_repair(args)
out = capsys.readouterr().out out = capsys.readouterr().out
assert "No palace database found" in out assert "No palace database found" in out
@@ -443,11 +451,9 @@ def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys):
(palace_dir / "chroma.sqlite3").write_text("db") (palace_dir / "chroma.sqlite3").write_text("db")
mock_config_cls.return_value.palace_path = str(palace_dir) mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None) args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock() mock_backend = MagicMock()
mock_client = MagicMock() mock_backend.get_collection.side_effect = Exception("corrupt db")
mock_client.get_collection.side_effect = Exception("corrupt db") with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend):
mock_chromadb.PersistentClient.return_value = mock_client
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args) cmd_repair(args)
out = capsys.readouterr().out out = capsys.readouterr().out
assert "Error reading palace" in out assert "Error reading palace" in out
@@ -460,13 +466,10 @@ def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys):
(palace_dir / "chroma.sqlite3").write_text("db") (palace_dir / "chroma.sqlite3").write_text("db")
mock_config_cls.return_value.palace_path = str(palace_dir) mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None) args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 0 mock_col.count.return_value = 0
mock_client = MagicMock() mock_backend = _mock_backend_for(col=mock_col)
mock_client.get_collection.return_value = mock_col with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend):
mock_chromadb.PersistentClient.return_value = mock_client
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args) cmd_repair(args)
out = capsys.readouterr().out out = capsys.readouterr().out
assert "Nothing to repair" in out assert "Nothing to repair" in out
@@ -479,7 +482,6 @@ def test_cmd_repair_success(mock_config_cls, tmp_path, capsys):
(palace_dir / "chroma.sqlite3").write_text("db") (palace_dir / "chroma.sqlite3").write_text("db")
mock_config_cls.return_value.palace_path = str(palace_dir) mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None, yes=True) args = argparse.Namespace(palace=None, yes=True)
mock_chromadb = MagicMock()
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 2 mock_col.count.return_value = 2
mock_col.get.return_value = { mock_col.get.return_value = {
@@ -487,12 +489,9 @@ def test_cmd_repair_success(mock_config_cls, tmp_path, capsys):
"documents": ["doc1", "doc2"], "documents": ["doc1", "doc2"],
"metadatas": [{"wing": "a"}, {"wing": "b"}], "metadatas": [{"wing": "a"}, {"wing": "b"}],
} }
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_new_col = MagicMock() mock_new_col = MagicMock()
mock_client.create_collection.return_value = mock_new_col mock_backend = _mock_backend_for(col=mock_col, new_col=mock_new_col)
mock_chromadb.PersistentClient.return_value = mock_client with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend):
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args) cmd_repair(args)
out = capsys.readouterr().out out = capsys.readouterr().out
assert "Repair complete" in out assert "Repair complete" in out
@@ -506,20 +505,17 @@ def test_cmd_repair_aborts_without_confirmation(mock_config_cls, tmp_path, capsy
(palace_dir / "chroma.sqlite3").write_text("db") (palace_dir / "chroma.sqlite3").write_text("db")
mock_config_cls.return_value.palace_path = str(palace_dir) mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None) args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 1 mock_col.count.return_value = 1
mock_client = MagicMock() mock_backend = _mock_backend_for(col=mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
with ( with (
patch.dict("sys.modules", {"chromadb": mock_chromadb}), patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend),
patch("builtins.input", return_value="n"), patch("builtins.input", return_value="n"),
): ):
cmd_repair(args) cmd_repair(args)
out = capsys.readouterr().out out = capsys.readouterr().out
assert "Aborted." in out assert "Aborted." in out
mock_client.create_collection.assert_not_called() mock_backend.create_collection.assert_not_called()
# ── cmd_compress ─────────────────────────────────────────────────────── # ── cmd_compress ───────────────────────────────────────────────────────
@@ -529,10 +525,10 @@ def test_cmd_repair_aborts_without_confirmation(mock_config_cls, tmp_path, capsy
def test_cmd_compress_no_palace(mock_config_cls, capsys): def test_cmd_compress_no_palace(mock_config_cls, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace" mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None) args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None)
mock_chromadb = MagicMock() mock_backend = MagicMock()
mock_chromadb.PersistentClient.side_effect = Exception("no palace") mock_backend.get_collection.side_effect = Exception("no palace")
with ( with (
patch.dict("sys.modules", {"chromadb": mock_chromadb}), patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend),
pytest.raises(SystemExit), pytest.raises(SystemExit),
): ):
cmd_compress(args) cmd_compress(args)
@@ -542,13 +538,10 @@ def test_cmd_compress_no_palace(mock_config_cls, capsys):
def test_cmd_compress_no_drawers(mock_config_cls, capsys): def test_cmd_compress_no_drawers(mock_config_cls, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace" mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing="mywing", dry_run=False, config=None) args = argparse.Namespace(palace=None, wing="mywing", dry_run=False, config=None)
mock_chromadb = MagicMock()
mock_col = MagicMock() mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []} mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []}
mock_client = MagicMock() mock_backend = _mock_backend_for(col=mock_col)
mock_client.get_collection.return_value = mock_col with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend):
mock_chromadb.PersistentClient.return_value = mock_client
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_compress(args) cmd_compress(args)
out = capsys.readouterr().out out = capsys.readouterr().out
assert "No drawers found" in out assert "No drawers found" in out
@@ -567,7 +560,6 @@ def _make_mock_dialect_module(dialect_instance):
def test_cmd_compress_dry_run(mock_config_cls, capsys): def test_cmd_compress_dry_run(mock_config_cls, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace" mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=None) args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=None)
mock_chromadb = MagicMock()
mock_col = MagicMock() mock_col = MagicMock()
mock_col.get.side_effect = [ mock_col.get.side_effect = [
{ {
@@ -577,9 +569,7 @@ def test_cmd_compress_dry_run(mock_config_cls, capsys):
}, },
{"documents": [], "metadatas": [], "ids": []}, {"documents": [], "metadatas": [], "ids": []},
] ]
mock_client = MagicMock() mock_backend = _mock_backend_for(col=mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_dialect = MagicMock() mock_dialect = MagicMock()
mock_dialect.compress.return_value = "compressed" mock_dialect.compress.return_value = "compressed"
@@ -593,12 +583,9 @@ def test_cmd_compress_dry_run(mock_config_cls, capsys):
} }
mock_dialect_mod = _make_mock_dialect_module(mock_dialect) mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
with patch.dict( with (
"sys.modules", patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend),
{ patch.dict("sys.modules", {"mempalace.dialect": mock_dialect_mod}),
"chromadb": mock_chromadb,
"mempalace.dialect": mock_dialect_mod,
},
): ):
cmd_compress(args) cmd_compress(args)
out = capsys.readouterr().out out = capsys.readouterr().out
@@ -613,22 +600,16 @@ def test_cmd_compress_with_config(mock_config_cls, tmp_path, capsys):
config_file = tmp_path / "entities.json" config_file = tmp_path / "entities.json"
config_file.write_text('{"people": [], "projects": []}') config_file.write_text('{"people": [], "projects": []}')
args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=str(config_file)) args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=str(config_file))
mock_chromadb = MagicMock()
mock_col = MagicMock() mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []} mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []}
mock_client = MagicMock() mock_backend = _mock_backend_for(col=mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_dialect = MagicMock() mock_dialect = MagicMock()
mock_dialect_mod = _make_mock_dialect_module(mock_dialect) mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
with patch.dict( with (
"sys.modules", patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend),
{ patch.dict("sys.modules", {"mempalace.dialect": mock_dialect_mod}),
"chromadb": mock_chromadb,
"mempalace.dialect": mock_dialect_mod,
},
): ):
cmd_compress(args) cmd_compress(args)
out = capsys.readouterr().out out = capsys.readouterr().out
@@ -640,7 +621,6 @@ def test_cmd_compress_stores_results(mock_config_cls, capsys):
"""Non-dry-run compress stores to mempalace_compressed collection.""" """Non-dry-run compress stores to mempalace_compressed collection."""
mock_config_cls.return_value.palace_path = "/fake/palace" mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None) args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None)
mock_chromadb = MagicMock()
mock_col = MagicMock() mock_col = MagicMock()
mock_col.get.side_effect = [ mock_col.get.side_effect = [
{ {
@@ -650,11 +630,10 @@ def test_cmd_compress_stores_results(mock_config_cls, capsys):
}, },
{"documents": [], "metadatas": [], "ids": []}, {"documents": [], "metadatas": [], "ids": []},
] ]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_comp_col = MagicMock() mock_comp_col = MagicMock()
mock_client.get_or_create_collection.return_value = mock_comp_col mock_backend = MagicMock()
mock_chromadb.PersistentClient.return_value = mock_client mock_backend.get_collection.return_value = mock_col
mock_backend.get_or_create_collection.return_value = mock_comp_col
mock_dialect = MagicMock() mock_dialect = MagicMock()
mock_dialect.compress.return_value = "compressed" mock_dialect.compress.return_value = "compressed"
@@ -668,12 +647,9 @@ def test_cmd_compress_stores_results(mock_config_cls, capsys):
} }
mock_dialect_mod = _make_mock_dialect_module(mock_dialect) mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
with patch.dict( with (
"sys.modules", patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend),
{ patch.dict("sys.modules", {"mempalace.dialect": mock_dialect_mod}),
"chromadb": mock_chromadb,
"mempalace.dialect": mock_dialect_mod,
},
): ):
cmd_compress(args) cmd_compress(args)
out = capsys.readouterr().out out = capsys.readouterr().out
+19 -20
View File
@@ -198,8 +198,15 @@ def test_dedup_source_group_query_failure_keeps():
# ── show_stats ──────────────────────────────────────────────────────── # ── show_stats ────────────────────────────────────────────────────────
@patch("mempalace.dedup.chromadb") def _install_mock_backend(mock_backend_cls, collection):
def test_show_stats(mock_chromadb, tmp_path): mock_backend = MagicMock()
mock_backend.get_collection.return_value = collection
mock_backend_cls.return_value = mock_backend
return mock_backend
@patch("mempalace.dedup.ChromaBackend")
def test_show_stats(mock_backend_cls, tmp_path):
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 5 mock_col.count.return_value = 5
mock_col.get.side_effect = [ mock_col.get.side_effect = [
@@ -215,9 +222,7 @@ def test_show_stats(mock_chromadb, tmp_path):
}, },
{"ids": []}, {"ids": []},
] ]
mock_client = MagicMock() _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
dedup.show_stats(palace_path=str(tmp_path)) # should not raise dedup.show_stats(palace_path=str(tmp_path)) # should not raise
@@ -227,13 +232,11 @@ def test_show_stats(mock_chromadb, tmp_path):
@patch("mempalace.dedup.dedup_source_group") @patch("mempalace.dedup.dedup_source_group")
@patch("mempalace.dedup.get_source_groups") @patch("mempalace.dedup.get_source_groups")
@patch("mempalace.dedup.chromadb") @patch("mempalace.dedup.ChromaBackend")
def test_dedup_palace_dry_run(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): def test_dedup_palace_dry_run(mock_backend_cls, mock_groups, mock_dedup_group, tmp_path):
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 10 mock_col.count.return_value = 10
mock_client = MagicMock() _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_groups.return_value = {"a.txt": ["d1", "d2", "d3", "d4", "d5"]} mock_groups.return_value = {"a.txt": ["d1", "d2", "d3", "d4", "d5"]}
mock_dedup_group.return_value = (["d1", "d2", "d3"], ["d4", "d5"]) mock_dedup_group.return_value = (["d1", "d2", "d3"], ["d4", "d5"])
@@ -244,13 +247,11 @@ def test_dedup_palace_dry_run(mock_chromadb, mock_groups, mock_dedup_group, tmp_
@patch("mempalace.dedup.dedup_source_group") @patch("mempalace.dedup.dedup_source_group")
@patch("mempalace.dedup.get_source_groups") @patch("mempalace.dedup.get_source_groups")
@patch("mempalace.dedup.chromadb") @patch("mempalace.dedup.ChromaBackend")
def test_dedup_palace_with_wing(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): def test_dedup_palace_with_wing(mock_backend_cls, mock_groups, mock_dedup_group, tmp_path):
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 10 mock_col.count.return_value = 10
mock_client = MagicMock() _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_groups.return_value = {} mock_groups.return_value = {}
dedup.dedup_palace(palace_path=str(tmp_path), wing="test_wing", dry_run=True) dedup.dedup_palace(palace_path=str(tmp_path), wing="test_wing", dry_run=True)
@@ -259,13 +260,11 @@ def test_dedup_palace_with_wing(mock_chromadb, mock_groups, mock_dedup_group, tm
@patch("mempalace.dedup.dedup_source_group") @patch("mempalace.dedup.dedup_source_group")
@patch("mempalace.dedup.get_source_groups") @patch("mempalace.dedup.get_source_groups")
@patch("mempalace.dedup.chromadb") @patch("mempalace.dedup.ChromaBackend")
def test_dedup_palace_no_groups(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): def test_dedup_palace_no_groups(mock_backend_cls, mock_groups, mock_dedup_group, tmp_path):
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 3 mock_col.count.return_value = 3
mock_client = MagicMock() _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_groups.return_value = {} mock_groups.return_value = {}
dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True) dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True)
+52 -62
View File
@@ -66,22 +66,28 @@ def test_paginate_ids_offset_exception_fallback():
# ── scan_palace ─────────────────────────────────────────────────────── # ── scan_palace ───────────────────────────────────────────────────────
@patch("mempalace.repair.chromadb") def _install_mock_backend(mock_backend_cls, collection):
def test_scan_palace_no_ids(mock_chromadb, tmp_path): """Wire mock_backend_cls so ChromaBackend().get_collection(...) returns *collection*."""
mock_backend = MagicMock()
mock_backend.get_collection.return_value = collection
mock_backend_cls.return_value = mock_backend
return mock_backend
@patch("mempalace.repair.ChromaBackend")
def test_scan_palace_no_ids(mock_backend_cls, tmp_path):
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 0 mock_col.count.return_value = 0
mock_col.get.return_value = {"ids": []} mock_col.get.return_value = {"ids": []}
mock_client = MagicMock() _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
good, bad = repair.scan_palace(palace_path=str(tmp_path)) good, bad = repair.scan_palace(palace_path=str(tmp_path))
assert good == set() assert good == set()
assert bad == set() assert bad == set()
@patch("mempalace.repair.chromadb") @patch("mempalace.repair.ChromaBackend")
def test_scan_palace_all_good(mock_chromadb, tmp_path): def test_scan_palace_all_good(mock_backend_cls, tmp_path):
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 2 mock_col.count.return_value = 2
# _paginate_ids call # _paginate_ids call
@@ -89,9 +95,7 @@ def test_scan_palace_all_good(mock_chromadb, tmp_path):
{"ids": ["id1", "id2"]}, # paginate {"ids": ["id1", "id2"]}, # paginate
{"ids": ["id1", "id2"]}, # probe batch — both returned {"ids": ["id1", "id2"]}, # probe batch — both returned
] ]
mock_client = MagicMock() _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
good, bad = repair.scan_palace(palace_path=str(tmp_path)) good, bad = repair.scan_palace(palace_path=str(tmp_path))
assert "id1" in good assert "id1" in good
@@ -99,8 +103,8 @@ def test_scan_palace_all_good(mock_chromadb, tmp_path):
assert len(bad) == 0 assert len(bad) == 0
@patch("mempalace.repair.chromadb") @patch("mempalace.repair.ChromaBackend")
def test_scan_palace_with_bad_ids(mock_chromadb, tmp_path): def test_scan_palace_with_bad_ids(mock_backend_cls, tmp_path):
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 2 mock_col.count.return_value = 2
@@ -117,26 +121,22 @@ def test_scan_palace_with_bad_ids(mock_chromadb, tmp_path):
raise Exception("batch fail") raise Exception("batch fail")
mock_col.get.side_effect = get_side_effect mock_col.get.side_effect = get_side_effect
mock_client = MagicMock() _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
good, bad = repair.scan_palace(palace_path=str(tmp_path)) good, bad = repair.scan_palace(palace_path=str(tmp_path))
assert "good1" in good assert "good1" in good
assert "bad1" in bad assert "bad1" in bad
@patch("mempalace.repair.chromadb") @patch("mempalace.repair.ChromaBackend")
def test_scan_palace_with_wing_filter(mock_chromadb, tmp_path): def test_scan_palace_with_wing_filter(mock_backend_cls, tmp_path):
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 1 mock_col.count.return_value = 1
mock_col.get.side_effect = [ mock_col.get.side_effect = [
{"ids": ["id1"]}, # paginate {"ids": ["id1"]}, # paginate
{"ids": ["id1"]}, # probe {"ids": ["id1"]}, # probe
] ]
mock_client = MagicMock() _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.scan_palace(palace_path=str(tmp_path), only_wing="test_wing") repair.scan_palace(palace_path=str(tmp_path), only_wing="test_wing")
# Verify where filter was passed # Verify where filter was passed
@@ -147,38 +147,36 @@ def test_scan_palace_with_wing_filter(mock_chromadb, tmp_path):
# ── prune_corrupt ───────────────────────────────────────────────────── # ── prune_corrupt ─────────────────────────────────────────────────────
@patch("mempalace.repair.chromadb") @patch("mempalace.repair.ChromaBackend")
def test_prune_corrupt_no_file(mock_chromadb, tmp_path): def test_prune_corrupt_no_file(mock_backend_cls, tmp_path):
# Should print message and return without error # Should print message and return without error
repair.prune_corrupt(palace_path=str(tmp_path)) repair.prune_corrupt(palace_path=str(tmp_path))
@patch("mempalace.repair.chromadb") @patch("mempalace.repair.ChromaBackend")
def test_prune_corrupt_dry_run(mock_chromadb, tmp_path): def test_prune_corrupt_dry_run(mock_backend_cls, tmp_path):
bad_file = tmp_path / "corrupt_ids.txt" bad_file = tmp_path / "corrupt_ids.txt"
bad_file.write_text("bad1\nbad2\n") bad_file.write_text("bad1\nbad2\n")
repair.prune_corrupt(palace_path=str(tmp_path), confirm=False) repair.prune_corrupt(palace_path=str(tmp_path), confirm=False)
# No chromadb calls in dry run # No backend calls in dry run
mock_chromadb.PersistentClient.assert_not_called() mock_backend_cls.assert_not_called()
@patch("mempalace.repair.chromadb") @patch("mempalace.repair.ChromaBackend")
def test_prune_corrupt_confirmed(mock_chromadb, tmp_path): def test_prune_corrupt_confirmed(mock_backend_cls, tmp_path):
bad_file = tmp_path / "corrupt_ids.txt" bad_file = tmp_path / "corrupt_ids.txt"
bad_file.write_text("bad1\nbad2\n") bad_file.write_text("bad1\nbad2\n")
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.side_effect = [10, 8] mock_col.count.side_effect = [10, 8]
mock_client = MagicMock() _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.prune_corrupt(palace_path=str(tmp_path), confirm=True) repair.prune_corrupt(palace_path=str(tmp_path), confirm=True)
mock_col.delete.assert_called_once() mock_col.delete.assert_called_once()
@patch("mempalace.repair.chromadb") @patch("mempalace.repair.ChromaBackend")
def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path): def test_prune_corrupt_delete_failure_fallback(mock_backend_cls, tmp_path):
bad_file = tmp_path / "corrupt_ids.txt" bad_file = tmp_path / "corrupt_ids.txt"
bad_file.write_text("bad1\nbad2\n") bad_file.write_text("bad1\nbad2\n")
@@ -186,9 +184,7 @@ def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path):
mock_col.count.side_effect = [10, 8] mock_col.count.side_effect = [10, 8]
# Batch delete fails, per-id succeeds # Batch delete fails, per-id succeeds
mock_col.delete.side_effect = [Exception("batch fail"), None, None] mock_col.delete.side_effect = [Exception("batch fail"), None, None]
mock_client = MagicMock() _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.prune_corrupt(palace_path=str(tmp_path), confirm=True) repair.prune_corrupt(palace_path=str(tmp_path), confirm=True)
assert mock_col.delete.call_count == 3 # 1 batch + 2 individual assert mock_col.delete.call_count == 3 # 1 batch + 2 individual
@@ -197,29 +193,27 @@ def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path):
# ── rebuild_index ───────────────────────────────────────────────────── # ── rebuild_index ─────────────────────────────────────────────────────
@patch("mempalace.repair.chromadb") @patch("mempalace.repair.ChromaBackend")
def test_rebuild_index_no_palace(mock_chromadb, tmp_path): def test_rebuild_index_no_palace(mock_backend_cls, tmp_path):
nonexistent = str(tmp_path / "nope") nonexistent = str(tmp_path / "nope")
repair.rebuild_index(palace_path=nonexistent) repair.rebuild_index(palace_path=nonexistent)
mock_chromadb.PersistentClient.assert_not_called() mock_backend_cls.assert_not_called()
@patch("mempalace.repair.shutil") @patch("mempalace.repair.shutil")
@patch("mempalace.repair.chromadb") @patch("mempalace.repair.ChromaBackend")
def test_rebuild_index_empty_palace(mock_chromadb, mock_shutil, tmp_path): def test_rebuild_index_empty_palace(mock_backend_cls, mock_shutil, tmp_path):
mock_col = MagicMock() mock_col = MagicMock()
mock_col.count.return_value = 0 mock_col.count.return_value = 0
mock_client = MagicMock() mock_backend = _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.rebuild_index(palace_path=str(tmp_path)) repair.rebuild_index(palace_path=str(tmp_path))
mock_client.delete_collection.assert_not_called() mock_backend.delete_collection.assert_not_called()
@patch("mempalace.repair.shutil") @patch("mempalace.repair.shutil")
@patch("mempalace.repair.chromadb") @patch("mempalace.repair.ChromaBackend")
def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path): def test_rebuild_index_success(mock_backend_cls, mock_shutil, tmp_path):
# Create a fake sqlite file # Create a fake sqlite file
sqlite_path = tmp_path / "chroma.sqlite3" sqlite_path = tmp_path / "chroma.sqlite3"
sqlite_path.write_text("fake") sqlite_path.write_text("fake")
@@ -233,10 +227,8 @@ def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path):
} }
mock_new_col = MagicMock() mock_new_col = MagicMock()
mock_client = MagicMock() mock_backend = _install_mock_backend(mock_backend_cls, mock_col)
mock_client.get_collection.return_value = mock_col mock_backend.create_collection.return_value = mock_new_col
mock_client.create_collection.return_value = mock_new_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.rebuild_index(palace_path=str(tmp_path)) repair.rebuild_index(palace_path=str(tmp_path))
@@ -244,11 +236,9 @@ def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path):
mock_shutil.copy2.assert_called_once() mock_shutil.copy2.assert_called_once()
assert "chroma.sqlite3" in str(mock_shutil.copy2.call_args) assert "chroma.sqlite3" in str(mock_shutil.copy2.call_args)
# Verify: deleted and recreated with cosine # Verify: deleted and recreated (cosine is the backend default)
mock_client.delete_collection.assert_called_once_with("mempalace_drawers") mock_backend.delete_collection.assert_called_once_with(str(tmp_path), "mempalace_drawers")
mock_client.create_collection.assert_called_once_with( mock_backend.create_collection.assert_called_once_with(str(tmp_path), "mempalace_drawers")
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
)
# Verify: used upsert not add # Verify: used upsert not add
mock_new_col.upsert.assert_called_once() mock_new_col.upsert.assert_called_once()
@@ -256,11 +246,11 @@ def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path):
@patch("mempalace.repair.shutil") @patch("mempalace.repair.shutil")
@patch("mempalace.repair.chromadb") @patch("mempalace.repair.ChromaBackend")
def test_rebuild_index_error_reading(mock_chromadb, mock_shutil, tmp_path): def test_rebuild_index_error_reading(mock_backend_cls, mock_shutil, tmp_path):
mock_client = MagicMock() mock_backend = MagicMock()
mock_client.get_collection.side_effect = Exception("corrupt") mock_backend.get_collection.side_effect = Exception("corrupt")
mock_chromadb.PersistentClient.return_value = mock_client mock_backend_cls.return_value = mock_backend
repair.rebuild_index(palace_path=str(tmp_path)) repair.rebuild_index(palace_path=str(tmp_path))
mock_client.delete_collection.assert_not_called() mock_backend.delete_collection.assert_not_called()
Generated
+1 -1
View File
@@ -1239,7 +1239,7 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "autocorrect", marker = "extra == 'spellcheck'", specifier = ">=2.0" }, { name = "autocorrect", marker = "extra == 'spellcheck'", specifier = ">=2.0" },
{ name = "chromadb", specifier = ">=0.5.0,<0.7" }, { name = "chromadb", specifier = ">=0.5.0" },
{ name = "psutil", marker = "extra == 'dev'", specifier = ">=5.9" }, { name = "psutil", marker = "extra == 'dev'", specifier = ">=5.9" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },