diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index 2169255..81c9b50 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -39,13 +39,21 @@ logger = logging.getLogger("mempalace_mcp") _config = MempalaceConfig() +_client_cache = None +_collection_cache = None + + def _get_collection(create=False): - """Return the ChromaDB collection, or None on failure.""" + """Return the ChromaDB collection, caching the client between calls.""" + global _client_cache, _collection_cache try: - client = chromadb.PersistentClient(path=_config.palace_path) + if _client_cache is None: + _client_cache = chromadb.PersistentClient(path=_config.palace_path) if create: - return client.get_or_create_collection(_config.collection_name) - return client.get_collection(_config.collection_name) + _collection_cache = _client_cache.get_or_create_collection(_config.collection_name) + elif _collection_cache is None: + _collection_cache = _client_cache.get_collection(_config.collection_name) + return _collection_cache except Exception: return None diff --git a/tests/conftest.py b/tests/conftest.py index 22b5e42..eb2b432 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,6 +34,24 @@ from mempalace.config import MempalaceConfig # noqa: E402 from mempalace.knowledge_graph import KnowledgeGraph # noqa: E402 +@pytest.fixture(autouse=True) +def _reset_mcp_cache(): + """Reset the MCP server's cached ChromaDB client/collection between tests.""" + + def _clear_cache(): + try: + from mempalace import mcp_server + + mcp_server._client_cache = None + mcp_server._collection_cache = None + except (ImportError, AttributeError): + pass + + _clear_cache() + yield + _clear_cache() + + @pytest.fixture(scope="session", autouse=True) def _isolate_home(): """Ensure HOME points to a temp dir for the entire test session. diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index cf37a27..09a3c46 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -151,9 +151,10 @@ class TestReadTools: assert result["taxonomy"]["project"]["frontend"] == 1 assert result["taxonomy"]["notes"]["planning"] == 1 - def test_no_palace_returns_error(self, monkeypatch, config, kg): - config._file_config["palace_path"] = "/nonexistent/path" - _patch_mcp_server(monkeypatch, config, "/nonexistent/path", kg) + def test_no_palace_returns_error(self, monkeypatch, config, kg, tmp_path): + missing = str(tmp_path / "missing") + config._file_config["palace_path"] = missing + _patch_mcp_server(monkeypatch, config, missing, kg) from mempalace.mcp_server import tool_status result = tool_status() diff --git a/tests/test_searcher.py b/tests/test_searcher.py index 44a05aa..1c2687d 100644 --- a/tests/test_searcher.py +++ b/tests/test_searcher.py @@ -30,8 +30,8 @@ class TestSearchMemories: result = search_memories("code", palace_path, n_results=2) assert len(result["results"]) <= 2 - def test_no_palace_returns_error(self): - result = search_memories("anything", "/nonexistent/path") + def test_no_palace_returns_error(self, tmp_path): + result = search_memories("anything", str(tmp_path / "missing")) assert "error" in result def test_result_fields(self, palace_path, seeded_collection):