diff --git a/mempalace/backends/__init__.py b/mempalace/backends/__init__.py new file mode 100644 index 0000000..cb5f14d --- /dev/null +++ b/mempalace/backends/__init__.py @@ -0,0 +1,6 @@ +"""Storage backend implementations for MemPalace.""" + +from .base import BaseCollection +from .chroma import ChromaBackend, ChromaCollection + +__all__ = ["BaseCollection", "ChromaBackend", "ChromaCollection"] diff --git a/mempalace/backends/base.py b/mempalace/backends/base.py new file mode 100644 index 0000000..4685f51 --- /dev/null +++ b/mempalace/backends/base.py @@ -0,0 +1,44 @@ +"""Abstract collection interface for MemPalace storage backends.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + + +class BaseCollection(ABC): + """Smallest collection contract the rest of MemPalace relies on.""" + + @abstractmethod + def add( + self, + *, + documents: List[str], + ids: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def upsert( + self, + *, + documents: List[str], + ids: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def query(self, **kwargs: Any) -> Dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def get(self, **kwargs: Any) -> Dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def delete(self, **kwargs: Any) -> None: + raise NotImplementedError + + @abstractmethod + def count(self) -> int: + raise NotImplementedError diff --git a/mempalace/backends/chroma.py b/mempalace/backends/chroma.py new file mode 100644 index 0000000..0ac7501 --- /dev/null +++ b/mempalace/backends/chroma.py @@ -0,0 +1,54 @@ +"""ChromaDB-backed MemPalace collection adapter.""" + +import os + +import chromadb + +from .base import BaseCollection + + +class ChromaCollection(BaseCollection): + """Thin adapter over a ChromaDB collection.""" + + def __init__(self, collection): + self._collection = collection + + def add(self, *, documents, ids, metadatas=None): + self._collection.add(documents=documents, ids=ids, metadatas=metadatas) + + def upsert(self, *, documents, ids, metadatas=None): + self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas) + + def query(self, **kwargs): + return self._collection.query(**kwargs) + + def get(self, **kwargs): + return self._collection.get(**kwargs) + + def delete(self, **kwargs): + self._collection.delete(**kwargs) + + def count(self): + return self._collection.count() + + +class ChromaBackend: + """Factory for MemPalace's default ChromaDB backend.""" + + def get_collection(self, palace_path: str, collection_name: str, create: bool = False): + if not create and not os.path.isdir(palace_path): + raise FileNotFoundError(palace_path) + + if create: + os.makedirs(palace_path, exist_ok=True) + try: + os.chmod(palace_path, 0o700) + except (OSError, NotImplementedError): + pass + + client = chromadb.PersistentClient(path=palace_path) + if create: + collection = client.get_or_create_collection(collection_name) + else: + collection = client.get_collection(collection_name) + return ChromaCollection(collection) diff --git a/mempalace/layers.py b/mempalace/layers.py index 6abb99b..651818d 100644 --- a/mempalace/layers.py +++ b/mempalace/layers.py @@ -21,9 +21,8 @@ import sys from pathlib import Path from collections import defaultdict -import chromadb - from .config import MempalaceConfig +from .palace import get_collection as _get_collection # --------------------------------------------------------------------------- @@ -91,8 +90,7 @@ class Layer1: def generate(self) -> str: """Pull top drawers from ChromaDB and format as compact L1 text.""" try: - client = chromadb.PersistentClient(path=self.palace_path) - col = client.get_collection("mempalace_drawers") + col = _get_collection(self.palace_path, create=False) except Exception: return "## L1 — No palace found. Run: mempalace mine " @@ -196,8 +194,7 @@ class Layer2: def retrieve(self, wing: str = None, room: str = None, n_results: int = 10) -> str: """Retrieve drawers filtered by wing and/or room.""" try: - client = chromadb.PersistentClient(path=self.palace_path) - col = client.get_collection("mempalace_drawers") + col = _get_collection(self.palace_path, create=False) except Exception: return "No palace found." @@ -260,8 +257,7 @@ class Layer3: def search(self, query: str, wing: str = None, room: str = None, n_results: int = 5) -> str: """Semantic search, returns compact result text.""" try: - client = chromadb.PersistentClient(path=self.palace_path) - col = client.get_collection("mempalace_drawers") + col = _get_collection(self.palace_path, create=False) except Exception: return "No palace found." @@ -316,8 +312,7 @@ class Layer3: ) -> list: """Return raw dicts instead of formatted text.""" try: - client = chromadb.PersistentClient(path=self.palace_path) - col = client.get_collection("mempalace_drawers") + col = _get_collection(self.palace_path, create=False) except Exception: return [] @@ -437,8 +432,7 @@ class MemoryStack: # Count drawers try: - client = chromadb.PersistentClient(path=self.palace_path) - col = client.get_collection("mempalace_drawers") + col = _get_collection(self.palace_path, create=False) count = col.count() result["total_drawers"] = count except Exception: diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index fce6f35..bf8a281 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -28,10 +28,10 @@ from pathlib import Path from .config import MempalaceConfig, sanitize_name, sanitize_content from .version import __version__ +from .palace import get_collection as _get_collection_from_palace from .query_sanitizer import sanitize_query from .searcher import search_memories from .palace_graph import traverse, find_tunnels, graph_stats -import chromadb from .knowledge_graph import KnowledgeGraph @@ -64,7 +64,6 @@ else: _kg = KnowledgeGraph() -_client_cache = None _collection_cache = None @@ -101,27 +100,16 @@ def _wal_log(operation: str, params: dict, result: dict = None): logger.error(f"WAL write failed: {e}") -_client_cache = None -_collection_cache = None - - -def _get_client(): - """Return a singleton ChromaDB PersistentClient.""" - global _client_cache - if _client_cache is None: - _client_cache = chromadb.PersistentClient(path=_config.palace_path) - return _client_cache - - def _get_collection(create=False): - """Return the ChromaDB collection, caching the client between calls.""" + """Return the configured collection, caching the wrapper between calls.""" global _collection_cache try: - client = _get_client() - if create: - _collection_cache = client.get_or_create_collection(_config.collection_name) - elif _collection_cache is None: - _collection_cache = client.get_collection(_config.collection_name) + if create or _collection_cache is None: + _collection_cache = _get_collection_from_palace( + _config.palace_path, + collection_name=_config.collection_name, + create=create, + ) return _collection_cache except Exception: return None diff --git a/mempalace/miner.py b/mempalace/miner.py index f342a2d..112de68 100644 --- a/mempalace/miner.py +++ b/mempalace/miner.py @@ -15,8 +15,6 @@ from pathlib import Path from datetime import datetime from collections import defaultdict -import chromadb - from .palace import SKIP_DIRS, get_collection, file_already_mined READABLE_EXTENSIONS = { @@ -625,8 +623,7 @@ def mine( def status(palace_path: str): """Show what's been filed in the palace.""" try: - client = chromadb.PersistentClient(path=palace_path) - col = client.get_collection("mempalace_drawers") + col = get_collection(palace_path, create=False) except Exception: print(f"\n No palace found at {palace_path}") print(" Run: mempalace init then mempalace mine ") diff --git a/mempalace/palace.py b/mempalace/palace.py index 6ddf190..f01a64a 100644 --- a/mempalace/palace.py +++ b/mempalace/palace.py @@ -1,11 +1,12 @@ """ palace.py — Shared palace operations. -Consolidates ChromaDB access patterns used by both miners and the MCP server. +Consolidates collection access patterns used by both miners and the MCP server. """ import os -import chromadb + +from .backends.chroma import ChromaBackend SKIP_DIRS = { ".git", @@ -33,19 +34,20 @@ SKIP_DIRS = { "target", } +_DEFAULT_BACKEND = ChromaBackend() -def get_collection(palace_path: str, collection_name: str = "mempalace_drawers"): - """Get or create the palace ChromaDB collection.""" - os.makedirs(palace_path, exist_ok=True) - try: - os.chmod(palace_path, 0o700) - except (OSError, NotImplementedError): - pass - client = chromadb.PersistentClient(path=palace_path) - try: - return client.get_collection(collection_name) - except Exception: - return client.create_collection(collection_name) + +def get_collection( + palace_path: str, + collection_name: str = "mempalace_drawers", + create: bool = True, +): + """Get the palace collection through the backend layer.""" + return _DEFAULT_BACKEND.get_collection( + palace_path, + collection_name=collection_name, + create=create, + ) def file_already_mined(collection, source_file: str, check_mtime: bool = False) -> bool: diff --git a/mempalace/palace_graph.py b/mempalace/palace_graph.py index e4fda93..5e2e72e 100644 --- a/mempalace/palace_graph.py +++ b/mempalace/palace_graph.py @@ -16,16 +16,19 @@ No external graph DB needed — built from ChromaDB metadata. """ from collections import defaultdict, Counter -from .config import MempalaceConfig -import chromadb +from .config import MempalaceConfig +from .palace import get_collection as _get_palace_collection def _get_collection(config=None): config = config or MempalaceConfig() try: - client = chromadb.PersistentClient(path=config.palace_path) - return client.get_collection(config.collection_name) + return _get_palace_collection( + config.palace_path, + collection_name=config.collection_name, + create=False, + ) except Exception: return None diff --git a/mempalace/searcher.py b/mempalace/searcher.py index 163abd8..307820f 100644 --- a/mempalace/searcher.py +++ b/mempalace/searcher.py @@ -9,7 +9,7 @@ Returns verbatim text — the actual words, never summaries. import logging from pathlib import Path -import chromadb +from .palace import get_collection logger = logging.getLogger("mempalace_mcp") @@ -24,8 +24,7 @@ def search(query: str, palace_path: str, wing: str = None, room: str = None, n_r Optionally filter by wing (project) or room (aspect). """ try: - client = chromadb.PersistentClient(path=palace_path) - col = client.get_collection("mempalace_drawers") + col = get_collection(palace_path, create=False) except Exception: print(f"\n No palace found at {palace_path}") print(" Run: mempalace init then mempalace mine ") @@ -98,8 +97,7 @@ def search_memories( Used by the MCP server and other callers that need data. """ try: - client = chromadb.PersistentClient(path=palace_path) - col = client.get_collection("mempalace_drawers") + col = get_collection(palace_path, create=False) except Exception as e: logger.error("No palace found at %s: %s", palace_path, e) return { diff --git a/tests/test_backends.py b/tests/test_backends.py new file mode 100644 index 0000000..4a97ef1 --- /dev/null +++ b/tests/test_backends.py @@ -0,0 +1,80 @@ +import chromadb +import pytest + +from mempalace.backends.chroma import ChromaBackend, ChromaCollection + + +class _FakeCollection: + def __init__(self): + self.calls = [] + + def add(self, **kwargs): + self.calls.append(("add", kwargs)) + + def upsert(self, **kwargs): + self.calls.append(("upsert", kwargs)) + + def query(self, **kwargs): + self.calls.append(("query", kwargs)) + return {"kind": "query"} + + def get(self, **kwargs): + self.calls.append(("get", kwargs)) + return {"kind": "get"} + + def delete(self, **kwargs): + self.calls.append(("delete", kwargs)) + + def count(self): + self.calls.append(("count", {})) + return 7 + + +def test_chroma_collection_delegates_methods(): + fake = _FakeCollection() + collection = ChromaCollection(fake) + + collection.add(documents=["d"], ids=["1"], metadatas=[{"wing": "w"}]) + collection.upsert(documents=["u"], ids=["2"], metadatas=[{"room": "r"}]) + assert collection.query(query_texts=["q"]) == {"kind": "query"} + assert collection.get(where={"wing": "w"}) == {"kind": "get"} + collection.delete(ids=["1"]) + assert collection.count() == 7 + + assert fake.calls == [ + ("add", {"documents": ["d"], "ids": ["1"], "metadatas": [{"wing": "w"}]}), + ("upsert", {"documents": ["u"], "ids": ["2"], "metadatas": [{"room": "r"}]}), + ("query", {"query_texts": ["q"]}), + ("get", {"where": {"wing": "w"}}), + ("delete", {"ids": ["1"]}), + ("count", {}), + ] + + +def test_chroma_backend_create_false_raises_without_creating_directory(tmp_path): + palace_path = tmp_path / "missing-palace" + + with pytest.raises(FileNotFoundError): + ChromaBackend().get_collection( + str(palace_path), + collection_name="mempalace_drawers", + create=False, + ) + + assert not palace_path.exists() + + +def test_chroma_backend_create_true_creates_directory_and_collection(tmp_path): + palace_path = tmp_path / "palace" + + collection = ChromaBackend().get_collection( + str(palace_path), + collection_name="mempalace_drawers", + create=True, + ) + + assert palace_path.is_dir() + assert isinstance(collection, ChromaCollection) + + client = chromadb.PersistentClient(path=str(palace_path)) + client.get_collection("mempalace_drawers") diff --git a/tests/test_layers.py b/tests/test_layers.py index 46b60e9..575183f 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -71,16 +71,14 @@ def test_layer0_default_path(): def _mock_chromadb_for_layer(docs, metas, monkeypatch=None): - """Return a mock PersistentClient whose collection.get returns docs/metas.""" + """Return a mock collection whose get() returns docs/metas.""" mock_col = MagicMock() # First batch returns data, second batch returns empty (end of pagination) mock_col.get.side_effect = [ {"documents": docs, "metadatas": metas}, {"documents": [], "metadatas": []}, ] - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - return mock_client + return mock_col def test_layer1_no_palace(): @@ -101,11 +99,11 @@ def test_layer1_generates_essential_story(): {"room": "decisions", "source_file": "meeting.txt", "importance": 5}, {"room": "architecture", "source_file": "design.txt", "importance": 4}, ] - mock_client = _mock_chromadb_for_layer(docs, metas) + mock_col = _mock_chromadb_for_layer(docs, metas) with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer1(palace_path="/fake") @@ -118,12 +116,9 @@ def test_layer1_generates_essential_story(): def test_layer1_empty_palace(): mock_col = MagicMock() mock_col.get.return_value = {"documents": [], "metadatas": []} - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer1(palace_path="/fake") @@ -135,11 +130,11 @@ def test_layer1_empty_palace(): def test_layer1_with_wing_filter(): docs = ["Memory about project X"] metas = [{"room": "general", "source_file": "x.txt", "importance": 3}] - mock_client = _mock_chromadb_for_layer(docs, metas) + mock_col = _mock_chromadb_for_layer(docs, metas) with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer1(palace_path="/fake", wing="project_x") @@ -147,18 +142,18 @@ def test_layer1_with_wing_filter(): assert "ESSENTIAL STORY" in result # Verify wing filter was passed - call_kwargs = mock_client.get_collection.return_value.get.call_args_list[0][1] + call_kwargs = mock_col.get.call_args_list[0][1] assert call_kwargs.get("where") == {"wing": "project_x"} def test_layer1_truncates_long_snippets(): docs = ["A" * 300] metas = [{"room": "general", "source_file": "long.txt"}] - mock_client = _mock_chromadb_for_layer(docs, metas) + mock_col = _mock_chromadb_for_layer(docs, metas) with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer1(palace_path="/fake") @@ -171,11 +166,11 @@ def test_layer1_respects_max_chars(): """L1 stops adding entries once MAX_CHARS is reached.""" docs = [f"Memory number {i} with substantial content padding here" for i in range(30)] metas = [{"room": "general", "source_file": f"f{i}.txt", "importance": 5} for i in range(30)] - mock_client = _mock_chromadb_for_layer(docs, metas) + mock_col = _mock_chromadb_for_layer(docs, metas) with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer1(palace_path="/fake") @@ -193,11 +188,11 @@ def test_layer1_importance_from_various_keys(): {"room": "r", "weight": 1}, {"room": "r"}, # no weight key, defaults to 3 ] - mock_client = _mock_chromadb_for_layer(docs, metas) + mock_col = _mock_chromadb_for_layer(docs, metas) with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer1(palace_path="/fake") @@ -213,12 +208,9 @@ def test_layer1_batch_exception_breaks(): {"documents": ["doc1"], "metadatas": [{"room": "r"}]}, RuntimeError("batch error"), ] - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer1(palace_path="/fake") @@ -244,12 +236,9 @@ def test_layer2_retrieve_with_wing(): "documents": ["Some memory about the project"], "metadatas": [{"room": "backend", "source_file": "notes.txt"}], } - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer2(palace_path="/fake") @@ -265,12 +254,9 @@ def test_layer2_retrieve_with_room(): "documents": ["Backend architecture notes"], "metadatas": [{"room": "architecture", "source_file": "arch.txt"}], } - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer2(palace_path="/fake") @@ -285,12 +271,9 @@ def test_layer2_retrieve_wing_and_room(): "documents": ["Filtered result"], "metadatas": [{"room": "backend", "source_file": "x.txt"}], } - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer2(palace_path="/fake") @@ -304,12 +287,9 @@ def test_layer2_retrieve_wing_and_room(): def test_layer2_retrieve_empty(): mock_col = MagicMock() mock_col.get.return_value = {"documents": [], "metadatas": []} - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer2(palace_path="/fake") @@ -321,12 +301,9 @@ def test_layer2_retrieve_empty(): def test_layer2_retrieve_no_filter(): mock_col = MagicMock() mock_col.get.return_value = {"documents": [], "metadatas": []} - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer2(palace_path="/fake") @@ -340,12 +317,9 @@ def test_layer2_retrieve_no_filter(): def test_layer2_retrieve_error(): mock_col = MagicMock() mock_col.get.side_effect = RuntimeError("db error") - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer2(palace_path="/fake") @@ -360,12 +334,9 @@ def test_layer2_truncates_long_snippets(): "documents": ["B" * 400], "metadatas": [{"room": "r", "source_file": "s.txt"}], } - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer2(palace_path="/fake") @@ -408,12 +379,9 @@ def test_layer3_search_with_results(): [{"wing": "project", "room": "backend", "source_file": "notes.txt"}], [0.2], ) - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer3(palace_path="/fake") @@ -427,12 +395,9 @@ def test_layer3_search_with_results(): def test_layer3_search_no_results(): mock_col = MagicMock() mock_col.query.return_value = _mock_query_results([], [], []) - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer3(palace_path="/fake") @@ -448,12 +413,9 @@ def test_layer3_search_with_wing_filter(): [{"wing": "proj", "room": "r"}], [0.1], ) - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer3(palace_path="/fake") @@ -470,12 +432,9 @@ def test_layer3_search_with_room_filter(): [{"wing": "w", "room": "backend"}], [0.1], ) - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer3(palace_path="/fake") @@ -492,12 +451,9 @@ def test_layer3_search_with_wing_and_room(): [{"wing": "proj", "room": "backend"}], [0.1], ) - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer3(palace_path="/fake") @@ -510,12 +466,9 @@ def test_layer3_search_with_wing_and_room(): def test_layer3_search_error(): mock_col = MagicMock() mock_col.query.side_effect = RuntimeError("search failed") - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer3(palace_path="/fake") @@ -531,12 +484,9 @@ def test_layer3_search_truncates_long_docs(): [{"wing": "w", "room": "r", "source_file": "s.txt"}], [0.1], ) - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer3(palace_path="/fake") @@ -552,12 +502,9 @@ def test_layer3_search_raw_returns_dicts(): [{"wing": "proj", "room": "backend", "source_file": "f.txt"}], [0.3], ) - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer3(palace_path="/fake") @@ -577,12 +524,9 @@ def test_layer3_search_raw_with_filters(): [{"wing": "w", "room": "r"}], [0.1], ) - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer3(palace_path="/fake") @@ -595,12 +539,9 @@ def test_layer3_search_raw_with_filters(): def test_layer3_search_raw_error(): mock_col = MagicMock() mock_col.query.side_effect = RuntimeError("fail") - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" layer = Layer3(palace_path="/fake") @@ -701,12 +642,9 @@ def test_memory_stack_status_with_palace(tmp_path): mock_col = MagicMock() mock_col.count.return_value = 42 - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with ( patch("mempalace.layers.MempalaceConfig") as mock_cfg, - patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + patch("mempalace.layers._get_collection", return_value=mock_col), ): mock_cfg.return_value.palace_path = "/fake" stack = MemoryStack( diff --git a/tests/test_miner.py b/tests/test_miner.py index c013d7c..7851787 100644 --- a/tests/test_miner.py +++ b/tests/test_miner.py @@ -6,7 +6,7 @@ from pathlib import Path import chromadb import yaml -from mempalace.miner import mine, scan_project +from mempalace.miner import mine, scan_project, status from mempalace.palace import file_already_mined @@ -260,3 +260,13 @@ def test_file_already_mined_check_mtime(): # Release ChromaDB file handles before cleanup (required on Windows) del col, client shutil.rmtree(tmpdir, ignore_errors=True) + + +def test_status_missing_palace_does_not_create_empty_collection(tmp_path, capsys): + palace_path = tmp_path / "missing-palace" + + status(str(palace_path)) + + out = capsys.readouterr().out + assert "No palace found" in out + assert not palace_path.exists() diff --git a/tests/test_searcher.py b/tests/test_searcher.py index 94f22b4..244fbf3 100644 --- a/tests/test_searcher.py +++ b/tests/test_searcher.py @@ -56,10 +56,8 @@ class TestSearchMemories: """search_memories returns error dict when query raises.""" mock_col = MagicMock() mock_col.query.side_effect = RuntimeError("query failed") - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client): + with patch("mempalace.searcher.get_collection", return_value=mock_col): result = search_memories("test", "/fake/path") assert "error" in result assert "query failed" in result["error"] @@ -111,10 +109,8 @@ class TestSearchCLI: """search raises SearchError when query fails.""" mock_col = MagicMock() mock_col.query.side_effect = RuntimeError("boom") - mock_client = MagicMock() - mock_client.get_collection.return_value = mock_col - with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client): + with patch("mempalace.searcher.get_collection", return_value=mock_col): with pytest.raises(SearchError, match="Search error"): search("test", "/fake/path")