diff --git a/mempalace/cli.py b/mempalace/cli.py index 740de96..ac00283 100644 --- a/mempalace/cli.py +++ b/mempalace/cli.py @@ -663,8 +663,10 @@ def cmd_repair(args): check_extraction_safety, ) + config = MempalaceConfig() + collection_name = config.collection_name palace_path = os.path.abspath( - os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path + os.path.expanduser(args.palace) if args.palace else config.palace_path ) if getattr(args, "mode", "legacy") == "max-seq-id": @@ -749,7 +751,7 @@ def cmd_repair(args): # Try to read existing drawers try: - col = backend.get_collection(palace_path, "mempalace_drawers") + col = backend.get_collection(palace_path, collection_name) total = col.count() print(f" Drawers found: {total}") except Exception as e: @@ -784,6 +786,7 @@ def cmd_repair(args): palace_path, len(all_ids), confirm_truncation_ok=getattr(args, "confirm_truncation_ok", False), + collection_name=collection_name, ) except TruncationDetected as e: print(e.message) @@ -810,6 +813,7 @@ def cmd_repair(args): all_docs, all_metas, batch_size, + collection_name=collection_name, progress=print, ) except RebuildCollectionError as e: diff --git a/mempalace/config.py b/mempalace/config.py index 2252a49..fd32a17 100644 --- a/mempalace/config.py +++ b/mempalace/config.py @@ -7,6 +7,7 @@ Priority: env vars > config file (~/.mempalace/config.json) > defaults import json import os import re +from functools import lru_cache from pathlib import Path @@ -127,6 +128,13 @@ def sanitize_content(value: str, max_length: int = 100_000) -> str: DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace") DEFAULT_COLLECTION_NAME = "mempalace_drawers" + +@lru_cache(maxsize=1) +def get_configured_collection_name() -> str: + """Return the configured drawer collection name without repeated config-file reads.""" + return MempalaceConfig().collection_name + + DEFAULT_TOPIC_WINGS = [ "emotions", "consciousness", diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index bbb9c93..521cb07 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -193,7 +193,7 @@ def _refresh_vector_disabled_flag() -> None: """ global _vector_disabled, _vector_disabled_reason, _vector_capacity_status try: - info = hnsw_capacity_status(_config.palace_path, "mempalace_drawers") + info = hnsw_capacity_status(_config.palace_path, _config.collection_name) except Exception: logger.debug("HNSW capacity probe raised", exc_info=True) return @@ -490,6 +490,7 @@ def _tool_status_via_sqlite() -> dict: db_path = os.path.join(_config.palace_path, "chroma.sqlite3") if not os.path.isfile(db_path): return _no_palace() + collection_name = _config.collection_name wings: dict = {} rooms: dict = {} @@ -503,8 +504,9 @@ def _tool_status_via_sqlite() -> dict: FROM embeddings e JOIN segments s ON e.segment_id = s.id JOIN collections c ON s.collection = c.id - WHERE c.name = 'mempalace_drawers' - """ + WHERE c.name = ? + """, + (collection_name,), ).fetchone() total = int(row[0]) if row and row[0] is not None else 0 for key, target in (("wing", wings), ("room", rooms)): @@ -515,12 +517,12 @@ def _tool_status_via_sqlite() -> dict: JOIN embeddings e ON em.id = e.id JOIN segments s ON e.segment_id = s.id JOIN collections c ON s.collection = c.id - WHERE c.name = 'mempalace_drawers' + WHERE c.name = ? AND em.key = ? AND em.string_value IS NOT NULL GROUP BY em.string_value """, - (key,), + (collection_name, key), ): target[value] = count finally: @@ -720,6 +722,7 @@ def tool_search( n_results=limit, max_distance=dist, vector_disabled=_vector_disabled, + collection_name=_config.collection_name, ) if _vector_disabled: result["vector_disabled"] = True @@ -922,8 +925,8 @@ def tool_add_drawer( # Idempotency: if the deterministic ID already exists, return success as a no-op. try: - existing = col.get(ids=[drawer_id]) - if existing and existing["ids"]: + existing = col.get(ids=[drawer_id], include=[]) + if existing.ids: return {"success": True, "reason": "already_exists", "drawer_id": drawer_id} except Exception: logger.debug("Idempotency pre-check failed for %s", drawer_id, exc_info=True) @@ -943,6 +946,12 @@ def tool_add_drawer( } ], ) + inserted = col.get(ids=[drawer_id], include=[]) + if not inserted.ids: + raise RuntimeError( + "Drawer write was acknowledged but the new ID is not readable. " + "The palace index may be stale; run reconnect or repair." + ) _metadata_cache = None logger.info(f"Filed drawer: {drawer_id} → {wing}/{room}") return {"success": True, "drawer_id": drawer_id, "wing": wing, "room": room} @@ -1506,6 +1515,30 @@ def tool_reconnect(): _palace_db_mtime, \ _vector_disabled, \ _vector_disabled_reason + from . import palace as palace_module + + close_errors = [] + try: + palace_module._DEFAULT_BACKEND.close_palace(_config.palace_path) + except Exception as exc: + logger.debug("Failed to close shared palace backend during reconnect", exc_info=True) + close_errors.append(f"backend close_palace failed: {exc}") + try: + from chromadb.api.client import SharedSystemClient + + clear_system_cache = getattr(SharedSystemClient, "clear_system_cache", None) + if callable(clear_system_cache): + clear_system_cache() + else: + logger.debug( + "SharedSystemClient.clear_system_cache is unavailable; skipping shared Chroma cache clear during reconnect" + ) + except Exception as exc: + logger.debug( + "Failed to clear Chroma shared system cache during reconnect", + exc_info=True, + ) + close_errors.append(f"shared Chroma cache clear failed: {exc}") _client_cache = None _collection_cache = None _palace_db_inode = 0 @@ -1527,12 +1560,24 @@ def tool_reconnect(): try: col = _get_collection() if col is None: - return { + result = { "success": False, "message": "No palace found after reconnect", "drawers": 0, "vector_disabled": _vector_disabled, } + if close_errors: + result["error"] = "; ".join(close_errors) + return result + if close_errors: + return { + "success": False, + "message": "Reconnect reopened the palace but failed to fully reset cached handles", + "drawers": col.count(), + "vector_disabled": _vector_disabled, + "vector_disabled_reason": _vector_disabled_reason, + "error": "; ".join(close_errors), + } return { "success": True, "message": "Reconnected to palace", diff --git a/mempalace/palace.py b/mempalace/palace.py index e5f6411..dee5c8f 100644 --- a/mempalace/palace.py +++ b/mempalace/palace.py @@ -10,6 +10,7 @@ import logging import os import re import threading +from typing import Optional from .backends.chroma import ChromaBackend @@ -56,10 +57,14 @@ NORMALIZE_VERSION = 2 def get_collection( palace_path: str, - collection_name: str = "mempalace_drawers", + collection_name: Optional[str] = None, create: bool = True, ): """Get the palace collection through the backend layer.""" + if collection_name is None: + from .config import get_configured_collection_name + + collection_name = get_configured_collection_name() return _DEFAULT_BACKEND.get_collection( palace_path, collection_name=collection_name, diff --git a/mempalace/repair.py b/mempalace/repair.py index 34d165c..b47bcd6 100644 --- a/mempalace/repair.py +++ b/mempalace/repair.py @@ -181,10 +181,12 @@ def _rebuild_collection_via_temp( all_docs, all_metas, batch_size: int, + collection_name: Optional[str] = None, progress=print, ) -> int: expected = len(all_ids) - temp_name = REPAIR_TEMP_COLLECTION + collection_name = collection_name or _drawers_collection_name() + temp_name = f"{collection_name}__repair_tmp" live_replaced = False try: @@ -203,9 +205,9 @@ def _rebuild_collection_via_temp( _verify_collection_count(temp_col, expected, "temporary rebuild") progress(" Rebuilding live collection...") - backend.delete_collection(palace_path, COLLECTION_NAME) + backend.delete_collection(palace_path, collection_name) live_replaced = True - new_col = backend.create_collection(palace_path, COLLECTION_NAME) + new_col = backend.create_collection(palace_path, collection_name) rebuilt = 0 for i in range(0, expected, batch_size): @@ -230,7 +232,7 @@ def _rebuild_collection_via_temp( raise RebuildCollectionError(str(exc), live_replaced=live_replaced) from exc -def scan_palace(palace_path=None, only_wing=None): +def scan_palace(palace_path=None, only_wing=None, collection_name: Optional[str] = None): """Scan the palace for corrupt/unfetchable IDs. Probes in batches of 100, falls back to per-ID on failure. @@ -239,14 +241,15 @@ def scan_palace(palace_path=None, only_wing=None): Returns (good_set, bad_set). """ palace_path = palace_path or _get_palace_path() + collection_name = collection_name or _drawers_collection_name() print(f"\n Palace: {palace_path}") print(" Loading...") - col = ChromaBackend().get_collection(palace_path, COLLECTION_NAME) + col = ChromaBackend().get_collection(palace_path, collection_name) where = {"wing": only_wing} if only_wing else None total = col.count() - print(f" Collection: {COLLECTION_NAME}, total: {total:,}") + print(f" Collection: {collection_name}, total: {total:,}") if only_wing: print(f" Scanning wing: {only_wing}") @@ -307,9 +310,10 @@ def scan_palace(palace_path=None, only_wing=None): return good_set, bad_set -def prune_corrupt(palace_path=None, confirm=False): +def prune_corrupt(palace_path=None, confirm=False, collection_name: Optional[str] = None): """Delete corrupt IDs listed in corrupt_ids.txt.""" palace_path = palace_path or _get_palace_path() + collection_name = collection_name or _drawers_collection_name() bad_file = os.path.join(palace_path, "corrupt_ids.txt") if not os.path.exists(bad_file): @@ -325,7 +329,7 @@ def prune_corrupt(palace_path=None, confirm=False): print(" Re-run with --confirm to actually delete.") return - col = ChromaBackend().get_collection(palace_path, COLLECTION_NAME) + col = ChromaBackend().get_collection(palace_path, collection_name) before = col.count() print(f" Collection size before: {before:,}") @@ -379,7 +383,10 @@ class TruncationDetected(Exception): def check_extraction_safety( - palace_path: str, extracted: int, confirm_truncation_ok: bool = False + palace_path: str, + extracted: int, + confirm_truncation_ok: bool = False, + collection_name: Optional[str] = None, ) -> None: """Cross-check that ``extracted`` matches the SQLite ground truth. @@ -401,7 +408,8 @@ def check_extraction_safety( if confirm_truncation_ok: return - sqlite_count = sqlite_drawer_count(palace_path) + collection_name = collection_name or _drawers_collection_name() + sqlite_count = sqlite_drawer_count(palace_path, collection_name) cap_signal = extracted == CHROMADB_DEFAULT_GET_LIMIT if sqlite_count is not None and sqlite_count > extracted: @@ -437,7 +445,7 @@ def check_extraction_safety( raise TruncationDetected(message, sqlite_count, extracted) -def sqlite_drawer_count(palace_path: str) -> "int | None": +def sqlite_drawer_count(palace_path: str, collection_name: Optional[str] = None) -> "int | None": """Count rows in ``chroma.sqlite3.embeddings`` for the drawers collection. Used as an independent ground-truth check against the chromadb @@ -449,6 +457,7 @@ def sqlite_drawer_count(palace_path: str) -> "int | None": drift, missing tables, locked file). Callers treat ``None`` as "unknown" and fall back to the cap-detection check. """ + collection_name = collection_name or _drawers_collection_name() sqlite_path = os.path.join(palace_path, "chroma.sqlite3") if not os.path.exists(sqlite_path): return None @@ -465,7 +474,7 @@ def sqlite_drawer_count(palace_path: str) -> "int | None": JOIN collections c ON s.collection = c.id WHERE c.name = ? """, - (COLLECTION_NAME,), + (collection_name,), ).fetchone() return int(row[0]) if row and row[0] is not None else None finally: @@ -477,7 +486,11 @@ def sqlite_drawer_count(palace_path: str) -> "int | None": return None -def rebuild_index(palace_path=None, confirm_truncation_ok: bool = False): +def rebuild_index( + palace_path=None, + confirm_truncation_ok: bool = False, + collection_name: Optional[str] = None, +): """Rebuild the HNSW index from scratch. 1. Extract all drawers via ChromaDB get() @@ -492,6 +505,7 @@ def rebuild_index(palace_path=None, confirm_truncation_ok: bool = False): (typically only a concern for palaces sized at exactly 10 000 rows). """ palace_path = palace_path or _get_palace_path() + collection_name = collection_name or _drawers_collection_name() if not os.path.isdir(palace_path): print(f"\n No palace found at {palace_path}") @@ -504,7 +518,7 @@ def rebuild_index(palace_path=None, confirm_truncation_ok: bool = False): backend = ChromaBackend() try: - col = backend.get_collection(palace_path, COLLECTION_NAME) + col = backend.get_collection(palace_path, collection_name) total = col.count() except Exception as e: print(f" Error reading palace: {e}") @@ -528,7 +542,12 @@ def rebuild_index(palace_path=None, confirm_truncation_ok: bool = False): # short of the SQLite ground truth (or when extraction == chromadb # default get() cap and the SQLite check couldn't run). try: - check_extraction_safety(palace_path, len(all_ids), confirm_truncation_ok) + check_extraction_safety( + palace_path, + len(all_ids), + confirm_truncation_ok, + collection_name=collection_name, + ) except TruncationDetected as e: print(e.message) return @@ -551,6 +570,7 @@ def rebuild_index(palace_path=None, confirm_truncation_ok: bool = False): all_docs, all_metas, batch_size, + collection_name=collection_name, progress=print, ) except RebuildCollectionError as e: @@ -560,7 +580,7 @@ def rebuild_index(palace_path=None, confirm_truncation_ok: bool = False): print(f" Restoring from backup: {backup_path}") try: _close_chroma_handles(palace_path, backend=backend) - _delete_collection_if_exists(backend, palace_path, COLLECTION_NAME) + _delete_collection_if_exists(backend, palace_path, collection_name) shutil.copy2(backup_path, sqlite_path) print(" Backup restored. Palace is back to pre-repair state.") except Exception as restore_error: @@ -950,7 +970,7 @@ def rebuild_from_sqlite( backend.close() -def status(palace_path=None) -> dict: +def status(palace_path=None, collection_name: Optional[str] = None) -> dict: """Read-only health check: compare sqlite vs HNSW element counts. Catches the #1222 failure mode where chromadb's HNSW segment freezes @@ -968,6 +988,7 @@ def status(palace_path=None) -> dict: ``status="unknown"`` when no palace exists at the given path. """ palace_path = palace_path or _get_palace_path() + collection_name = collection_name or _drawers_collection_name() print(f"\n{'=' * 55}") print(" MemPalace Repair — Status") print(f"{'=' * 55}\n") @@ -977,8 +998,8 @@ def status(palace_path=None) -> dict: print(" No palace found.\n") return {"status": "unknown", "message": "no palace at path"} - drawers = hnsw_capacity_status(palace_path, "mempalace_drawers") - closets = hnsw_capacity_status(palace_path, "mempalace_closets") + drawers = hnsw_capacity_status(palace_path, collection_name) + closets = hnsw_capacity_status(palace_path, CLOSETS_COLLECTION_NAME) for label, info in (("drawers", drawers), ("closets", closets)): print(f"\n [{label}]") diff --git a/mempalace/searcher.py b/mempalace/searcher.py index b318d99..f644fda 100644 --- a/mempalace/searcher.py +++ b/mempalace/searcher.py @@ -382,6 +382,7 @@ def _bm25_only_via_sqlite( n_results: int = 5, max_candidates: int = 500, _include_internal: bool = False, + collection_name: str = None, ) -> dict: """BM25-only search reading drawers directly from chroma.sqlite3. @@ -405,6 +406,10 @@ def _bm25_only_via_sqlite( "error": "No palace found", "hint": "Run: mempalace init && mempalace mine ", } + if collection_name is None: + from .config import get_configured_collection_name + + collection_name = get_configured_collection_name() def _metadata_filter_sql(row_id_expr: str) -> tuple[str, list[str]]: clauses = [] @@ -441,35 +446,43 @@ def _bm25_only_via_sqlite( # shorter than 3 chars (trigram tokenizer can't match them). tokens = [t for t in _tokenize(query) if len(t) >= 3] candidate_ids: list[int] = [] + use_recency_fallback = not tokens if tokens: fts_query = " OR ".join(tokens) filter_sql, filter_params = _metadata_filter_sql("embedding_fulltext_search.rowid") try: rows = conn.execute( f""" - SELECT rowid + SELECT embedding_fulltext_search.rowid FROM embedding_fulltext_search + JOIN embeddings e ON e.id = embedding_fulltext_search.rowid + JOIN segments s ON e.segment_id = s.id + JOIN collections c ON s.collection = c.id WHERE embedding_fulltext_search MATCH ? + AND c.name = ? {filter_sql} LIMIT ? """, - (fts_query, *filter_params, max_candidates), + (fts_query, collection_name, *filter_params, max_candidates), ).fetchall() candidate_ids = [r[0] for r in rows] except sqlite3.Error: # FTS5 tokenizer mismatch or syntax error — fall through # to the recency-window selector below. logger.debug("FTS5 MATCH failed; using recency fallback", exc_info=True) + use_recency_fallback = True - if not candidate_ids: - # No FTS hits (or no usable tokens) — pull the most recent - # rows for the drawers segment so we can BM25-rank something - # rather than return empty-handed. Wrapped in try/except - # because the schema may differ on legacy palaces (older - # chromadb without ``created_at``, missing ``segments`` - # rows after partial restore, etc.); on schema mismatch we - # fall back to ordering by primary-key id and finally to an - # empty result rather than letting search raise. + if not candidate_ids and use_recency_fallback: + # No usable FTS tokens, or FTS itself failed — pull the most + # recent rows for the drawers segment so we can BM25-rank + # something rather than return empty-handed. A clean FTS miss + # must stay empty, especially after wing/room filtering, because + # recency fallback would return unrelated scoped drawers. + # Wrapped in try/except because the schema may differ on legacy + # palaces (older chromadb without ``created_at``, missing + # ``segments`` rows after partial restore, etc.); on schema + # mismatch we fall back to ordering by primary-key id and finally + # to an empty result rather than letting search raise. try: filter_sql, filter_params = _metadata_filter_sql("e.id") rows = conn.execute( @@ -478,12 +491,12 @@ def _bm25_only_via_sqlite( FROM embeddings e JOIN segments s ON e.segment_id = s.id JOIN collections c ON s.collection = c.id - WHERE c.name = 'mempalace_drawers' + WHERE c.name = ? {filter_sql} ORDER BY e.created_at DESC LIMIT ? """, - (*filter_params, max_candidates), + (collection_name, *filter_params, max_candidates), ).fetchall() candidate_ids = [r[0] for r in rows] except sqlite3.Error: @@ -499,12 +512,12 @@ def _bm25_only_via_sqlite( FROM embeddings e JOIN segments s ON e.segment_id = s.id JOIN collections c ON s.collection = c.id - WHERE c.name = 'mempalace_drawers' + WHERE c.name = ? {filter_sql} ORDER BY e.id DESC LIMIT ? """, - (*filter_params, max_candidates), + (collection_name, *filter_params, max_candidates), ).fetchall() candidate_ids = [r[0] for r in rows] except sqlite3.Error: @@ -720,6 +733,7 @@ def search_memories( max_distance: float = 0.0, vector_disabled: bool = False, candidate_strategy: str = "vector", + collection_name: str = None, ) -> dict: """Programmatic search — returns a dict instead of printing. @@ -770,10 +784,11 @@ def search_memories( wing=wing, room=room, n_results=n_results, + collection_name=collection_name, ) try: - drawers_col = get_collection(palace_path, create=False) + drawers_col = get_collection(palace_path, collection_name=collection_name, create=False) except Exception as e: logger.error("No palace found at %s: %s", palace_path, e) return { diff --git a/tests/test_backends.py b/tests/test_backends.py index 4cd9480..06625fc 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1194,3 +1194,26 @@ def test_chroma_backend_requarantines_after_inode_replacement(tmp_path, monkeypa ("invalid", str(palace)), ("stale", str(palace)), ] + + +def test_palace_get_collection_uses_configured_collection_name(monkeypatch): + from mempalace import palace + + captured = {} + + def fake_get_collection(palace_path, collection_name=None, create=False): + captured["palace_path"] = palace_path + captured["collection_name"] = collection_name + captured["create"] = create + return object() + + monkeypatch.setattr(palace._DEFAULT_BACKEND, "get_collection", fake_get_collection) + monkeypatch.setattr("mempalace.config.get_configured_collection_name", lambda: "custom_drawers") + + palace.get_collection("/palace", create=False) + + assert captured == { + "palace_path": "/palace", + "collection_name": "custom_drawers", + "create": False, + } diff --git a/tests/test_cli.py b/tests/test_cli.py index de00664..0b61b0c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -776,6 +776,7 @@ def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys): palace_dir.mkdir() (palace_dir / "chroma.sqlite3").write_text("db") mock_config_cls.return_value.palace_path = str(palace_dir) + mock_config_cls.return_value.collection_name = "mempalace_drawers" args = argparse.Namespace(palace=None) mock_backend = MagicMock() mock_backend.get_collection.side_effect = Exception("corrupt db") @@ -791,6 +792,7 @@ def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys): palace_dir.mkdir() (palace_dir / "chroma.sqlite3").write_text("db") mock_config_cls.return_value.palace_path = str(palace_dir) + mock_config_cls.return_value.collection_name = "mempalace_drawers" args = argparse.Namespace(palace=None) mock_col = MagicMock() mock_col.count.return_value = 0 @@ -807,6 +809,7 @@ def test_cmd_repair_success(mock_config_cls, tmp_path, capsys): palace_dir.mkdir() (palace_dir / "chroma.sqlite3").write_text("db") mock_config_cls.return_value.palace_path = str(palace_dir) + mock_config_cls.return_value.collection_name = "mempalace_drawers" args = argparse.Namespace(palace=None, yes=True) mock_col = MagicMock() mock_col.count.return_value = 2 @@ -836,12 +839,52 @@ def test_cmd_repair_success(mock_config_cls, tmp_path, capsys): mock_new_col.add.assert_not_called() +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_repair_uses_configured_collection(mock_config_cls, tmp_path, capsys): + palace_dir = tmp_path / "palace" + palace_dir.mkdir() + (palace_dir / "chroma.sqlite3").write_text("db") + mock_config_cls.return_value.palace_path = str(palace_dir) + mock_config_cls.return_value.collection_name = "custom_drawers" + args = argparse.Namespace(palace=None, yes=True) + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 2 + mock_new_col = MagicMock() + mock_new_col.count.return_value = 2 + mock_backend = _mock_backend_for(col=mock_col, new_col=mock_new_col) + mock_backend.create_collection.side_effect = [mock_temp_col, mock_new_col] + + with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend): + cmd_repair(args) + + out = capsys.readouterr().out + assert "Repair complete" in out + mock_backend.get_collection.assert_called_once_with(str(palace_dir), "custom_drawers") + assert mock_backend.create_collection.call_args_list == [ + call(str(palace_dir), "custom_drawers__repair_tmp"), + call(str(palace_dir), "custom_drawers"), + ] + assert mock_backend.delete_collection.call_args_list == [ + call(str(palace_dir), "custom_drawers__repair_tmp"), + call(str(palace_dir), "custom_drawers"), + call(str(palace_dir), "custom_drawers__repair_tmp"), + ] + + @patch("mempalace.cli.MempalaceConfig") def test_cmd_repair_restores_backup_on_live_rebuild_failure(mock_config_cls, tmp_path, capsys): palace_dir = tmp_path / "palace" palace_dir.mkdir() (palace_dir / "chroma.sqlite3").write_text("db") mock_config_cls.return_value.palace_path = str(palace_dir) + mock_config_cls.return_value.collection_name = "mempalace_drawers" args = argparse.Namespace(palace=None, yes=True) mock_col = MagicMock() mock_col.count.return_value = 2 @@ -875,6 +918,7 @@ def test_cmd_repair_aborts_without_confirmation(mock_config_cls, tmp_path, capsy palace_dir.mkdir() (palace_dir / "chroma.sqlite3").write_text("db") mock_config_cls.return_value.palace_path = str(palace_dir) + mock_config_cls.return_value.collection_name = "mempalace_drawers" args = argparse.Namespace(palace=None) mock_col = MagicMock() mock_col.count.return_value = 1 diff --git a/tests/test_hnsw_capacity.py b/tests/test_hnsw_capacity.py index 912def8..53775b0 100644 --- a/tests/test_hnsw_capacity.py +++ b/tests/test_hnsw_capacity.py @@ -260,6 +260,7 @@ def test_mcp_probe_does_not_disable_vectors_for_unflushed_metadata(tmp_path, mon class _Cfg: palace_path = str(tmp_path) + collection_name = "mempalace_drawers" monkeypatch.setattr(mcp_server, "_config", _Cfg()) monkeypatch.setattr(mcp_server, "_vector_disabled", True) @@ -625,6 +626,7 @@ def test_tool_status_via_sqlite_returns_breakdown(palace_with_drawers, monkeypat # MempalaceConfig. class _Cfg: palace_path = str(palace_with_drawers) + collection_name = "mempalace_drawers" monkeypatch.setattr(mcp_server, "_config", _Cfg()) monkeypatch.setattr(mcp_server, "_vector_disabled", True) diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index ae20bf3..1f47192 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -484,6 +484,26 @@ class TestWriteTools: assert result2["success"] is True assert result2["reason"] == "already_exists" + def test_add_drawer_fails_when_readback_misses(self, monkeypatch, config, kg): + _patch_mcp_server(monkeypatch, config, kg) + from mempalace import mcp_server + + class _FakeGetResult: + ids = [] + + class _FakeCol: + def get(self, **kwargs): + return _FakeGetResult() + + def upsert(self, **kwargs): + return None + + monkeypatch.setattr(mcp_server, "_get_collection", lambda create=False: _FakeCol()) + + result = mcp_server.tool_add_drawer("w", "r", "content") + assert result["success"] is False + assert "not readable" in result["error"] + def test_add_drawer_shared_header_no_collision(self, monkeypatch, config, palace_path, kg): """Documents sharing a >100-char header must get distinct IDs (full-content hash).""" _patch_mcp_server(monkeypatch, config, kg) @@ -1158,6 +1178,25 @@ class TestCacheInvalidation: assert "Reconnected" in result["message"] assert isinstance(result["drawers"], int) + def test_reconnect_closes_shared_backend(self, monkeypatch, config, kg): + _patch_mcp_server(monkeypatch, config, kg) + from unittest.mock import MagicMock + + from mempalace import mcp_server, palace + + close_palace = MagicMock() + monkeypatch.setattr(palace._DEFAULT_BACKEND, "close_palace", close_palace) + + class _FakeCol: + def count(self): + return 7 + + monkeypatch.setattr(mcp_server, "_get_collection", lambda create=False: _FakeCol()) + + result = mcp_server.tool_reconnect() + assert result["success"] is True + close_palace.assert_called_once_with(config.palace_path) + def test_get_collection_create_true_avoids_get_or_create_on_reopen( self, monkeypatch, config, palace_path, kg ): diff --git a/tests/test_repair.py b/tests/test_repair.py index a60836a..8e9f95b 100644 --- a/tests/test_repair.py +++ b/tests/test_repair.py @@ -28,6 +28,16 @@ def test_get_palace_path_fallback(): assert ".mempalace" in result +def test_get_collection_name_from_config(): + from mempalace.config import get_configured_collection_name + + get_configured_collection_name.cache_clear() + with patch("mempalace.config.MempalaceConfig") as mock_config_cls: + mock_config_cls.return_value.collection_name = "custom_drawers" + assert repair._drawers_collection_name() == "custom_drawers" + get_configured_collection_name.cache_clear() + + # ── _paginate_ids ───────────────────────────────────────────────────── @@ -330,6 +340,21 @@ def test_check_extraction_safety_passes_when_counts_match(tmp_path): repair.check_extraction_safety(str(tmp_path), 500) +def test_check_extraction_safety_uses_configured_collection(tmp_path): + with patch("mempalace.repair.sqlite_drawer_count", return_value=500) as count: + repair.check_extraction_safety(str(tmp_path), 500, collection_name="custom_drawers") + count.assert_called_once_with(str(tmp_path), "custom_drawers") + + +def test_check_extraction_safety_default_uses_configured_collection(tmp_path): + with ( + patch("mempalace.repair._drawers_collection_name", return_value="custom_drawers"), + patch("mempalace.repair.sqlite_drawer_count", return_value=500) as count, + ): + repair.check_extraction_safety(str(tmp_path), 500) + count.assert_called_once_with(str(tmp_path), "custom_drawers") + + def test_check_extraction_safety_passes_when_sqlite_unreadable_and_under_cap(tmp_path): """SQLite check fails (None) but extraction is well under the cap → safe.""" with patch("mempalace.repair.sqlite_drawer_count", return_value=None): @@ -384,6 +409,73 @@ def test_sqlite_drawer_count_returns_none_on_unreadable_schema(tmp_path): assert repair.sqlite_drawer_count(str(tmp_path)) is None +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.ChromaBackend") +def test_rebuild_index_default_uses_configured_collection(mock_backend_cls, mock_shutil, tmp_path): + sqlite_path = tmp_path / "chroma.sqlite3" + sqlite_path.write_text("fake") + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + mock_temp_col = MagicMock() + mock_temp_col.count.return_value = 2 + mock_new_col = MagicMock() + mock_new_col.count.return_value = 2 + mock_backend = _install_mock_backend(mock_backend_cls, mock_col) + mock_backend.create_collection.side_effect = [mock_temp_col, mock_new_col] + + with ( + patch("mempalace.repair._drawers_collection_name", return_value="custom_drawers"), + patch("mempalace.repair.sqlite_drawer_count", return_value=2) as count, + ): + repair.rebuild_index(palace_path=str(tmp_path)) + + mock_backend.get_collection.assert_called_once_with(str(tmp_path), "custom_drawers") + count.assert_called_once_with(str(tmp_path), "custom_drawers") + assert mock_backend.create_collection.call_args_list == [ + call(str(tmp_path), "custom_drawers__repair_tmp"), + call(str(tmp_path), "custom_drawers"), + ] + assert mock_backend.delete_collection.call_args_list == [ + call(str(tmp_path), "custom_drawers__repair_tmp"), + call(str(tmp_path), "custom_drawers"), + call(str(tmp_path), "custom_drawers__repair_tmp"), + ] + + +def test_status_default_uses_configured_drawer_collection(tmp_path): + with ( + patch("mempalace.repair._drawers_collection_name", return_value="custom_drawers"), + patch("mempalace.repair.hnsw_capacity_status") as capacity_status, + ): + capacity_status.side_effect = [ + { + "sqlite_count": 1, + "hnsw_count": 1, + "divergence": 0, + "diverged": False, + "status": "ok", + "message": "", + }, + { + "sqlite_count": 0, + "hnsw_count": 0, + "divergence": 0, + "diverged": False, + "status": "ok", + "message": "", + }, + ] + repair.status(palace_path=str(tmp_path)) + + assert capacity_status.call_args_list[0].args == (str(tmp_path), "custom_drawers") + assert capacity_status.call_args_list[1].args == (str(tmp_path), "mempalace_closets") + + @patch("mempalace.repair.shutil") @patch("mempalace.repair.ChromaBackend") def test_rebuild_index_aborts_on_truncation_signal(mock_backend_cls, mock_shutil, tmp_path): diff --git a/tests/test_searcher.py b/tests/test_searcher.py index 4f0b4c0..f4d46a0 100644 --- a/tests/test_searcher.py +++ b/tests/test_searcher.py @@ -84,6 +84,24 @@ class TestSearchMemories: assert "error" in result assert "query failed" in result["error"] + def test_search_memories_vector_path_uses_explicit_collection_name(self): + mock_col = MagicMock() + mock_col.query.return_value = { + "documents": [[]], + "metadatas": [[]], + "distances": [[]], + "ids": [[]], + } + + with patch("mempalace.searcher.get_collection", return_value=mock_col) as get_collection: + search_memories("test", "/fake/path", collection_name="custom_drawers") + + get_collection.assert_called_once_with( + "/fake/path", + collection_name="custom_drawers", + create=False, + ) + def test_search_memories_filters_in_result(self, palace_path, seeded_collection): result = search_memories("test", palace_path, wing="project", room="backend") assert result["filters"]["wing"] == "project" @@ -102,7 +120,7 @@ class TestSearchMemories: "ids": [["d1", "d2"]], } - def mock_get_collection(path, create=False): + def mock_get_collection(path, collection_name=None, create=False): # First call: drawers. Second call: closets — raise so hybrid # degrades to pure drawer search (the catch block covers it). if not hasattr(mock_get_collection, "_called"):