diff --git a/mempalace/backends/__init__.py b/mempalace/backends/__init__.py
new file mode 100644
index 0000000..cb5f14d
--- /dev/null
+++ b/mempalace/backends/__init__.py
@@ -0,0 +1,6 @@
+"""Storage backend implementations for MemPalace."""
+
+from .base import BaseCollection
+from .chroma import ChromaBackend, ChromaCollection
+
+__all__ = ["BaseCollection", "ChromaBackend", "ChromaCollection"]
diff --git a/mempalace/backends/base.py b/mempalace/backends/base.py
new file mode 100644
index 0000000..4685f51
--- /dev/null
+++ b/mempalace/backends/base.py
@@ -0,0 +1,44 @@
+"""Abstract collection interface for MemPalace storage backends."""
+
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional
+
+
+class BaseCollection(ABC):
+ """Smallest collection contract the rest of MemPalace relies on."""
+
+ @abstractmethod
+ def add(
+ self,
+ *,
+ documents: List[str],
+ ids: List[str],
+ metadatas: Optional[List[Dict[str, Any]]] = None,
+ ) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def upsert(
+ self,
+ *,
+ documents: List[str],
+ ids: List[str],
+ metadatas: Optional[List[Dict[str, Any]]] = None,
+ ) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def query(self, **kwargs: Any) -> Dict[str, Any]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get(self, **kwargs: Any) -> Dict[str, Any]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def delete(self, **kwargs: Any) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def count(self) -> int:
+ raise NotImplementedError
diff --git a/mempalace/backends/chroma.py b/mempalace/backends/chroma.py
new file mode 100644
index 0000000..0ac7501
--- /dev/null
+++ b/mempalace/backends/chroma.py
@@ -0,0 +1,54 @@
+"""ChromaDB-backed MemPalace collection adapter."""
+
+import os
+
+import chromadb
+
+from .base import BaseCollection
+
+
+class ChromaCollection(BaseCollection):
+ """Thin adapter over a ChromaDB collection."""
+
+ def __init__(self, collection):
+ self._collection = collection
+
+ def add(self, *, documents, ids, metadatas=None):
+ self._collection.add(documents=documents, ids=ids, metadatas=metadatas)
+
+ def upsert(self, *, documents, ids, metadatas=None):
+ self._collection.upsert(documents=documents, ids=ids, metadatas=metadatas)
+
+ def query(self, **kwargs):
+ return self._collection.query(**kwargs)
+
+ def get(self, **kwargs):
+ return self._collection.get(**kwargs)
+
+ def delete(self, **kwargs):
+ self._collection.delete(**kwargs)
+
+ def count(self):
+ return self._collection.count()
+
+
+class ChromaBackend:
+ """Factory for MemPalace's default ChromaDB backend."""
+
+ def get_collection(self, palace_path: str, collection_name: str, create: bool = False):
+ if not create and not os.path.isdir(palace_path):
+ raise FileNotFoundError(palace_path)
+
+ if create:
+ os.makedirs(palace_path, exist_ok=True)
+ try:
+ os.chmod(palace_path, 0o700)
+ except (OSError, NotImplementedError):
+ pass
+
+ client = chromadb.PersistentClient(path=palace_path)
+ if create:
+ collection = client.get_or_create_collection(collection_name)
+ else:
+ collection = client.get_collection(collection_name)
+ return ChromaCollection(collection)
diff --git a/mempalace/layers.py b/mempalace/layers.py
index 6abb99b..651818d 100644
--- a/mempalace/layers.py
+++ b/mempalace/layers.py
@@ -21,9 +21,8 @@ import sys
from pathlib import Path
from collections import defaultdict
-import chromadb
-
from .config import MempalaceConfig
+from .palace import get_collection as _get_collection
# ---------------------------------------------------------------------------
@@ -91,8 +90,7 @@ class Layer1:
def generate(self) -> str:
"""Pull top drawers from ChromaDB and format as compact L1 text."""
try:
- client = chromadb.PersistentClient(path=self.palace_path)
- col = client.get_collection("mempalace_drawers")
+ col = _get_collection(self.palace_path, create=False)
except Exception:
return "## L1 — No palace found. Run: mempalace mine
"
@@ -196,8 +194,7 @@ class Layer2:
def retrieve(self, wing: str = None, room: str = None, n_results: int = 10) -> str:
"""Retrieve drawers filtered by wing and/or room."""
try:
- client = chromadb.PersistentClient(path=self.palace_path)
- col = client.get_collection("mempalace_drawers")
+ col = _get_collection(self.palace_path, create=False)
except Exception:
return "No palace found."
@@ -260,8 +257,7 @@ class Layer3:
def search(self, query: str, wing: str = None, room: str = None, n_results: int = 5) -> str:
"""Semantic search, returns compact result text."""
try:
- client = chromadb.PersistentClient(path=self.palace_path)
- col = client.get_collection("mempalace_drawers")
+ col = _get_collection(self.palace_path, create=False)
except Exception:
return "No palace found."
@@ -316,8 +312,7 @@ class Layer3:
) -> list:
"""Return raw dicts instead of formatted text."""
try:
- client = chromadb.PersistentClient(path=self.palace_path)
- col = client.get_collection("mempalace_drawers")
+ col = _get_collection(self.palace_path, create=False)
except Exception:
return []
@@ -437,8 +432,7 @@ class MemoryStack:
# Count drawers
try:
- client = chromadb.PersistentClient(path=self.palace_path)
- col = client.get_collection("mempalace_drawers")
+ col = _get_collection(self.palace_path, create=False)
count = col.count()
result["total_drawers"] = count
except Exception:
diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py
index fce6f35..bf8a281 100644
--- a/mempalace/mcp_server.py
+++ b/mempalace/mcp_server.py
@@ -28,10 +28,10 @@ from pathlib import Path
from .config import MempalaceConfig, sanitize_name, sanitize_content
from .version import __version__
+from .palace import get_collection as _get_collection_from_palace
from .query_sanitizer import sanitize_query
from .searcher import search_memories
from .palace_graph import traverse, find_tunnels, graph_stats
-import chromadb
from .knowledge_graph import KnowledgeGraph
@@ -64,7 +64,6 @@ else:
_kg = KnowledgeGraph()
-_client_cache = None
_collection_cache = None
@@ -101,27 +100,16 @@ def _wal_log(operation: str, params: dict, result: dict = None):
logger.error(f"WAL write failed: {e}")
-_client_cache = None
-_collection_cache = None
-
-
-def _get_client():
- """Return a singleton ChromaDB PersistentClient."""
- global _client_cache
- if _client_cache is None:
- _client_cache = chromadb.PersistentClient(path=_config.palace_path)
- return _client_cache
-
-
def _get_collection(create=False):
- """Return the ChromaDB collection, caching the client between calls."""
+ """Return the configured collection, caching the wrapper between calls."""
global _collection_cache
try:
- client = _get_client()
- if create:
- _collection_cache = client.get_or_create_collection(_config.collection_name)
- elif _collection_cache is None:
- _collection_cache = client.get_collection(_config.collection_name)
+ if create or _collection_cache is None:
+ _collection_cache = _get_collection_from_palace(
+ _config.palace_path,
+ collection_name=_config.collection_name,
+ create=create,
+ )
return _collection_cache
except Exception:
return None
diff --git a/mempalace/miner.py b/mempalace/miner.py
index f342a2d..112de68 100644
--- a/mempalace/miner.py
+++ b/mempalace/miner.py
@@ -15,8 +15,6 @@ from pathlib import Path
from datetime import datetime
from collections import defaultdict
-import chromadb
-
from .palace import SKIP_DIRS, get_collection, file_already_mined
READABLE_EXTENSIONS = {
@@ -625,8 +623,7 @@ def mine(
def status(palace_path: str):
"""Show what's been filed in the palace."""
try:
- client = chromadb.PersistentClient(path=palace_path)
- col = client.get_collection("mempalace_drawers")
+ col = get_collection(palace_path, create=False)
except Exception:
print(f"\n No palace found at {palace_path}")
print(" Run: mempalace init then mempalace mine ")
diff --git a/mempalace/palace.py b/mempalace/palace.py
index 6ddf190..f01a64a 100644
--- a/mempalace/palace.py
+++ b/mempalace/palace.py
@@ -1,11 +1,12 @@
"""
palace.py — Shared palace operations.
-Consolidates ChromaDB access patterns used by both miners and the MCP server.
+Consolidates collection access patterns used by both miners and the MCP server.
"""
import os
-import chromadb
+
+from .backends.chroma import ChromaBackend
SKIP_DIRS = {
".git",
@@ -33,19 +34,20 @@ SKIP_DIRS = {
"target",
}
+_DEFAULT_BACKEND = ChromaBackend()
-def get_collection(palace_path: str, collection_name: str = "mempalace_drawers"):
- """Get or create the palace ChromaDB collection."""
- os.makedirs(palace_path, exist_ok=True)
- try:
- os.chmod(palace_path, 0o700)
- except (OSError, NotImplementedError):
- pass
- client = chromadb.PersistentClient(path=palace_path)
- try:
- return client.get_collection(collection_name)
- except Exception:
- return client.create_collection(collection_name)
+
+def get_collection(
+ palace_path: str,
+ collection_name: str = "mempalace_drawers",
+ create: bool = True,
+):
+ """Get the palace collection through the backend layer."""
+ return _DEFAULT_BACKEND.get_collection(
+ palace_path,
+ collection_name=collection_name,
+ create=create,
+ )
def file_already_mined(collection, source_file: str, check_mtime: bool = False) -> bool:
diff --git a/mempalace/palace_graph.py b/mempalace/palace_graph.py
index e4fda93..5e2e72e 100644
--- a/mempalace/palace_graph.py
+++ b/mempalace/palace_graph.py
@@ -16,16 +16,19 @@ No external graph DB needed — built from ChromaDB metadata.
"""
from collections import defaultdict, Counter
-from .config import MempalaceConfig
-import chromadb
+from .config import MempalaceConfig
+from .palace import get_collection as _get_palace_collection
def _get_collection(config=None):
config = config or MempalaceConfig()
try:
- client = chromadb.PersistentClient(path=config.palace_path)
- return client.get_collection(config.collection_name)
+ return _get_palace_collection(
+ config.palace_path,
+ collection_name=config.collection_name,
+ create=False,
+ )
except Exception:
return None
diff --git a/mempalace/searcher.py b/mempalace/searcher.py
index 163abd8..307820f 100644
--- a/mempalace/searcher.py
+++ b/mempalace/searcher.py
@@ -9,7 +9,7 @@ Returns verbatim text — the actual words, never summaries.
import logging
from pathlib import Path
-import chromadb
+from .palace import get_collection
logger = logging.getLogger("mempalace_mcp")
@@ -24,8 +24,7 @@ def search(query: str, palace_path: str, wing: str = None, room: str = None, n_r
Optionally filter by wing (project) or room (aspect).
"""
try:
- client = chromadb.PersistentClient(path=palace_path)
- col = client.get_collection("mempalace_drawers")
+ col = get_collection(palace_path, create=False)
except Exception:
print(f"\n No palace found at {palace_path}")
print(" Run: mempalace init then mempalace mine ")
@@ -98,8 +97,7 @@ def search_memories(
Used by the MCP server and other callers that need data.
"""
try:
- client = chromadb.PersistentClient(path=palace_path)
- col = client.get_collection("mempalace_drawers")
+ col = get_collection(palace_path, create=False)
except Exception as e:
logger.error("No palace found at %s: %s", palace_path, e)
return {
diff --git a/tests/test_backends.py b/tests/test_backends.py
new file mode 100644
index 0000000..4a97ef1
--- /dev/null
+++ b/tests/test_backends.py
@@ -0,0 +1,80 @@
+import chromadb
+import pytest
+
+from mempalace.backends.chroma import ChromaBackend, ChromaCollection
+
+
+class _FakeCollection:
+ def __init__(self):
+ self.calls = []
+
+ def add(self, **kwargs):
+ self.calls.append(("add", kwargs))
+
+ def upsert(self, **kwargs):
+ self.calls.append(("upsert", kwargs))
+
+ def query(self, **kwargs):
+ self.calls.append(("query", kwargs))
+ return {"kind": "query"}
+
+ def get(self, **kwargs):
+ self.calls.append(("get", kwargs))
+ return {"kind": "get"}
+
+ def delete(self, **kwargs):
+ self.calls.append(("delete", kwargs))
+
+ def count(self):
+ self.calls.append(("count", {}))
+ return 7
+
+
+def test_chroma_collection_delegates_methods():
+ fake = _FakeCollection()
+ collection = ChromaCollection(fake)
+
+ collection.add(documents=["d"], ids=["1"], metadatas=[{"wing": "w"}])
+ collection.upsert(documents=["u"], ids=["2"], metadatas=[{"room": "r"}])
+ assert collection.query(query_texts=["q"]) == {"kind": "query"}
+ assert collection.get(where={"wing": "w"}) == {"kind": "get"}
+ collection.delete(ids=["1"])
+ assert collection.count() == 7
+
+ assert fake.calls == [
+ ("add", {"documents": ["d"], "ids": ["1"], "metadatas": [{"wing": "w"}]}),
+ ("upsert", {"documents": ["u"], "ids": ["2"], "metadatas": [{"room": "r"}]}),
+ ("query", {"query_texts": ["q"]}),
+ ("get", {"where": {"wing": "w"}}),
+ ("delete", {"ids": ["1"]}),
+ ("count", {}),
+ ]
+
+
+def test_chroma_backend_create_false_raises_without_creating_directory(tmp_path):
+ palace_path = tmp_path / "missing-palace"
+
+ with pytest.raises(FileNotFoundError):
+ ChromaBackend().get_collection(
+ str(palace_path),
+ collection_name="mempalace_drawers",
+ create=False,
+ )
+
+ assert not palace_path.exists()
+
+
+def test_chroma_backend_create_true_creates_directory_and_collection(tmp_path):
+ palace_path = tmp_path / "palace"
+
+ collection = ChromaBackend().get_collection(
+ str(palace_path),
+ collection_name="mempalace_drawers",
+ create=True,
+ )
+
+ assert palace_path.is_dir()
+ assert isinstance(collection, ChromaCollection)
+
+ client = chromadb.PersistentClient(path=str(palace_path))
+ client.get_collection("mempalace_drawers")
diff --git a/tests/test_layers.py b/tests/test_layers.py
index 46b60e9..575183f 100644
--- a/tests/test_layers.py
+++ b/tests/test_layers.py
@@ -71,16 +71,14 @@ def test_layer0_default_path():
def _mock_chromadb_for_layer(docs, metas, monkeypatch=None):
- """Return a mock PersistentClient whose collection.get returns docs/metas."""
+ """Return a mock collection whose get() returns docs/metas."""
mock_col = MagicMock()
# First batch returns data, second batch returns empty (end of pagination)
mock_col.get.side_effect = [
{"documents": docs, "metadatas": metas},
{"documents": [], "metadatas": []},
]
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
- return mock_client
+ return mock_col
def test_layer1_no_palace():
@@ -101,11 +99,11 @@ def test_layer1_generates_essential_story():
{"room": "decisions", "source_file": "meeting.txt", "importance": 5},
{"room": "architecture", "source_file": "design.txt", "importance": 4},
]
- mock_client = _mock_chromadb_for_layer(docs, metas)
+ mock_col = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
@@ -118,12 +116,9 @@ def test_layer1_generates_essential_story():
def test_layer1_empty_palace():
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": []}
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
@@ -135,11 +130,11 @@ def test_layer1_empty_palace():
def test_layer1_with_wing_filter():
docs = ["Memory about project X"]
metas = [{"room": "general", "source_file": "x.txt", "importance": 3}]
- mock_client = _mock_chromadb_for_layer(docs, metas)
+ mock_col = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake", wing="project_x")
@@ -147,18 +142,18 @@ def test_layer1_with_wing_filter():
assert "ESSENTIAL STORY" in result
# Verify wing filter was passed
- call_kwargs = mock_client.get_collection.return_value.get.call_args_list[0][1]
+ call_kwargs = mock_col.get.call_args_list[0][1]
assert call_kwargs.get("where") == {"wing": "project_x"}
def test_layer1_truncates_long_snippets():
docs = ["A" * 300]
metas = [{"room": "general", "source_file": "long.txt"}]
- mock_client = _mock_chromadb_for_layer(docs, metas)
+ mock_col = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
@@ -171,11 +166,11 @@ def test_layer1_respects_max_chars():
"""L1 stops adding entries once MAX_CHARS is reached."""
docs = [f"Memory number {i} with substantial content padding here" for i in range(30)]
metas = [{"room": "general", "source_file": f"f{i}.txt", "importance": 5} for i in range(30)]
- mock_client = _mock_chromadb_for_layer(docs, metas)
+ mock_col = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
@@ -193,11 +188,11 @@ def test_layer1_importance_from_various_keys():
{"room": "r", "weight": 1},
{"room": "r"}, # no weight key, defaults to 3
]
- mock_client = _mock_chromadb_for_layer(docs, metas)
+ mock_col = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
@@ -213,12 +208,9 @@ def test_layer1_batch_exception_breaks():
{"documents": ["doc1"], "metadatas": [{"room": "r"}]},
RuntimeError("batch error"),
]
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
@@ -244,12 +236,9 @@ def test_layer2_retrieve_with_wing():
"documents": ["Some memory about the project"],
"metadatas": [{"room": "backend", "source_file": "notes.txt"}],
}
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
@@ -265,12 +254,9 @@ def test_layer2_retrieve_with_room():
"documents": ["Backend architecture notes"],
"metadatas": [{"room": "architecture", "source_file": "arch.txt"}],
}
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
@@ -285,12 +271,9 @@ def test_layer2_retrieve_wing_and_room():
"documents": ["Filtered result"],
"metadatas": [{"room": "backend", "source_file": "x.txt"}],
}
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
@@ -304,12 +287,9 @@ def test_layer2_retrieve_wing_and_room():
def test_layer2_retrieve_empty():
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": []}
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
@@ -321,12 +301,9 @@ def test_layer2_retrieve_empty():
def test_layer2_retrieve_no_filter():
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": []}
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
@@ -340,12 +317,9 @@ def test_layer2_retrieve_no_filter():
def test_layer2_retrieve_error():
mock_col = MagicMock()
mock_col.get.side_effect = RuntimeError("db error")
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
@@ -360,12 +334,9 @@ def test_layer2_truncates_long_snippets():
"documents": ["B" * 400],
"metadatas": [{"room": "r", "source_file": "s.txt"}],
}
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
@@ -408,12 +379,9 @@ def test_layer3_search_with_results():
[{"wing": "project", "room": "backend", "source_file": "notes.txt"}],
[0.2],
)
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
@@ -427,12 +395,9 @@ def test_layer3_search_with_results():
def test_layer3_search_no_results():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results([], [], [])
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
@@ -448,12 +413,9 @@ def test_layer3_search_with_wing_filter():
[{"wing": "proj", "room": "r"}],
[0.1],
)
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
@@ -470,12 +432,9 @@ def test_layer3_search_with_room_filter():
[{"wing": "w", "room": "backend"}],
[0.1],
)
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
@@ -492,12 +451,9 @@ def test_layer3_search_with_wing_and_room():
[{"wing": "proj", "room": "backend"}],
[0.1],
)
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
@@ -510,12 +466,9 @@ def test_layer3_search_with_wing_and_room():
def test_layer3_search_error():
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("search failed")
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
@@ -531,12 +484,9 @@ def test_layer3_search_truncates_long_docs():
[{"wing": "w", "room": "r", "source_file": "s.txt"}],
[0.1],
)
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
@@ -552,12 +502,9 @@ def test_layer3_search_raw_returns_dicts():
[{"wing": "proj", "room": "backend", "source_file": "f.txt"}],
[0.3],
)
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
@@ -577,12 +524,9 @@ def test_layer3_search_raw_with_filters():
[{"wing": "w", "room": "r"}],
[0.1],
)
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
@@ -595,12 +539,9 @@ def test_layer3_search_raw_with_filters():
def test_layer3_search_raw_error():
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("fail")
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
@@ -701,12 +642,9 @@ def test_memory_stack_status_with_palace(tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 42
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
-
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
- patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
+ patch("mempalace.layers._get_collection", return_value=mock_col),
):
mock_cfg.return_value.palace_path = "/fake"
stack = MemoryStack(
diff --git a/tests/test_miner.py b/tests/test_miner.py
index c013d7c..7851787 100644
--- a/tests/test_miner.py
+++ b/tests/test_miner.py
@@ -6,7 +6,7 @@ from pathlib import Path
import chromadb
import yaml
-from mempalace.miner import mine, scan_project
+from mempalace.miner import mine, scan_project, status
from mempalace.palace import file_already_mined
@@ -260,3 +260,13 @@ def test_file_already_mined_check_mtime():
# Release ChromaDB file handles before cleanup (required on Windows)
del col, client
shutil.rmtree(tmpdir, ignore_errors=True)
+
+
+def test_status_missing_palace_does_not_create_empty_collection(tmp_path, capsys):
+ palace_path = tmp_path / "missing-palace"
+
+ status(str(palace_path))
+
+ out = capsys.readouterr().out
+ assert "No palace found" in out
+ assert not palace_path.exists()
diff --git a/tests/test_searcher.py b/tests/test_searcher.py
index 94f22b4..244fbf3 100644
--- a/tests/test_searcher.py
+++ b/tests/test_searcher.py
@@ -56,10 +56,8 @@ class TestSearchMemories:
"""search_memories returns error dict when query raises."""
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("query failed")
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
- with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client):
+ with patch("mempalace.searcher.get_collection", return_value=mock_col):
result = search_memories("test", "/fake/path")
assert "error" in result
assert "query failed" in result["error"]
@@ -111,10 +109,8 @@ class TestSearchCLI:
"""search raises SearchError when query fails."""
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("boom")
- mock_client = MagicMock()
- mock_client.get_collection.return_value = mock_col
- with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client):
+ with patch("mempalace.searcher.get_collection", return_value=mock_col):
with pytest.raises(SearchError, match="Search error"):
search("test", "/fake/path")