diff --git a/tests/conftest.py b/tests/conftest.py index fa08802..eb2b432 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,14 +37,19 @@ from mempalace.knowledge_graph import KnowledgeGraph # noqa: E402 @pytest.fixture(autouse=True) def _reset_mcp_cache(): """Reset the MCP server's cached ChromaDB client/collection between tests.""" - yield - try: - from mempalace import mcp_server - mcp_server._client_cache = None - mcp_server._collection_cache = None - except (ImportError, AttributeError): - pass + def _clear_cache(): + try: + from mempalace import mcp_server + + mcp_server._client_cache = None + mcp_server._collection_cache = None + except (ImportError, AttributeError): + pass + + _clear_cache() + yield + _clear_cache() @pytest.fixture(scope="session", autouse=True) diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index cf37a27..09a3c46 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -151,9 +151,10 @@ class TestReadTools: assert result["taxonomy"]["project"]["frontend"] == 1 assert result["taxonomy"]["notes"]["planning"] == 1 - def test_no_palace_returns_error(self, monkeypatch, config, kg): - config._file_config["palace_path"] = "/nonexistent/path" - _patch_mcp_server(monkeypatch, config, "/nonexistent/path", kg) + def test_no_palace_returns_error(self, monkeypatch, config, kg, tmp_path): + missing = str(tmp_path / "missing") + config._file_config["palace_path"] = missing + _patch_mcp_server(monkeypatch, config, missing, kg) from mempalace.mcp_server import tool_status result = tool_status() diff --git a/tests/test_searcher.py b/tests/test_searcher.py index 44a05aa..1c2687d 100644 --- a/tests/test_searcher.py +++ b/tests/test_searcher.py @@ -30,8 +30,8 @@ class TestSearchMemories: result = search_memories("code", palace_path, n_results=2) assert len(result["results"]) <= 2 - def test_no_palace_returns_error(self): - result = search_memories("anything", "/nonexistent/path") + def test_no_palace_returns_error(self, tmp_path): + result = search_memories("anything", str(tmp_path / "missing")) assert "error" in result def test_result_fields(self, palace_path, seeded_collection):