diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index b447249..7d263a6 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -44,7 +44,8 @@ def _parse_args(): metavar="PATH", help="Path to the palace directory (overrides config file and env var)", ) - return parser.parse_args() + args, _ = parser.parse_known_args() + return args _args = _parse_args() @@ -283,19 +284,18 @@ def tool_add_drawer( if not col: return _no_palace() - # Duplicate check - dup = tool_check_duplicate(content, threshold=0.9) - if dup.get("is_duplicate"): - return { - "success": False, - "reason": "duplicate", - "matches": dup["matches"], - } + drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}" - drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((content[:100] + datetime.now().isoformat()).encode()).hexdigest()[:16]}" + # Idempotency: if the deterministic ID already exists, return success as a no-op. + try: + existing = col.get(ids=[drawer_id]) + if existing and existing["ids"]: + return {"success": True, "reason": "already_exists", "drawer_id": drawer_id} + except Exception: + pass try: - col.add( + col.upsert( ids=[drawer_id], documents=[content], metadatas=[ diff --git a/mempalace/miner.py b/mempalace/miner.py index 7b4e949..e29fb25 100644 --- a/mempalace/miner.py +++ b/mempalace/miner.py @@ -403,10 +403,22 @@ def get_collection(palace_path: str): def file_already_mined(collection, source_file: str) -> bool: - """Fast check: has this file been filed before?""" + """Fast check: has this file been filed before and is unchanged? + + Compares the stored mtime in drawer metadata against the file's current + mtime. Returns False (needs re-mining) when the file has been modified + since it was last mined, or when no mtime was stored. + """ try: results = collection.get(where={"source_file": source_file}, limit=1) - return len(results.get("ids", [])) > 0 + if not results.get("ids"): + return False + stored_meta = results["metadatas"][0] if results.get("metadatas") else {} + stored_mtime = stored_meta.get("source_mtime") + if stored_mtime is None: + return False + current_mtime = os.path.getmtime(source_file) + return float(stored_mtime) == current_mtime except Exception: return False @@ -417,24 +429,26 @@ def add_drawer( """Add one drawer to the palace.""" drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}" try: - collection.add( + metadata = { + "wing": wing, + "room": room, + "source_file": source_file, + "chunk_index": chunk_index, + "added_by": agent, + "filed_at": datetime.now().isoformat(), + } + # Store file mtime so we can detect modifications later. + try: + metadata["source_mtime"] = os.path.getmtime(source_file) + except OSError: + pass + collection.upsert( documents=[content], ids=[drawer_id], - metadatas=[ - { - "wing": wing, - "room": room, - "source_file": source_file, - "chunk_index": chunk_index, - "added_by": agent, - "filed_at": datetime.now().isoformat(), - } - ], + metadatas=[metadata], ) return True - except Exception as e: - if "already exists" in str(e).lower() or "duplicate" in str(e).lower(): - return False + except Exception: raise diff --git a/tests/conftest.py b/tests/conftest.py index eb2b432..7a3e55a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -102,7 +102,9 @@ def collection(palace_path): """A ChromaDB collection pre-seeded in the temp palace.""" client = chromadb.PersistentClient(path=palace_path) col = client.get_or_create_collection("mempalace_drawers") - return col + yield col + client.delete_collection("mempalace_drawers") + del client @pytest.fixture diff --git a/tests/test_hooks_cli.py b/tests/test_hooks_cli.py index 8eeffed..d6951e2 100644 --- a/tests/test_hooks_cli.py +++ b/tests/test_hooks_cli.py @@ -42,29 +42,43 @@ def _write_transcript(path: Path, entries: list[dict]): def test_count_human_messages_basic(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": "hello"}}, - {"message": {"role": "assistant", "content": "hi"}}, - {"message": {"role": "user", "content": "bye"}}, - ]) + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": "hello"}}, + {"message": {"role": "assistant", "content": "hi"}}, + {"message": {"role": "user", "content": "bye"}}, + ], + ) assert _count_human_messages(str(transcript)) == 2 def test_count_skips_command_messages(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": "status"}}, - {"message": {"role": "user", "content": "real question"}}, - ]) + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": "status"}}, + {"message": {"role": "user", "content": "real question"}}, + ], + ) assert _count_human_messages(str(transcript)) == 1 def test_count_handles_list_content(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}}, - {"message": {"role": "user", "content": [{"type": "text", "text": "x"}]}}, - ]) + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}}, + { + "message": { + "role": "user", + "content": [{"type": "text", "text": "x"}], + } + }, + ], + ) assert _count_human_messages(str(transcript)) == 1 @@ -90,6 +104,7 @@ def test_count_malformed_json_lines(tmp_path): def _capture_hook_output(hook_fn, data, harness="claude-code", state_dir=None): """Run a hook and capture its JSON stdout output.""" import io + buf = io.StringIO() patches = [patch("mempalace.hooks_cli._output", side_effect=lambda d: buf.write(json.dumps(d)))] if state_dir: @@ -123,10 +138,10 @@ def test_stop_hook_passthrough_when_active_string(tmp_path): def test_stop_hook_passthrough_below_interval(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(SAVE_INTERVAL - 1) - ]) + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL - 1)], + ) result = _capture_hook_output( hook_stop, {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, @@ -137,10 +152,10 @@ def test_stop_hook_passthrough_below_interval(tmp_path): def test_stop_hook_blocks_at_interval(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(SAVE_INTERVAL) - ]) + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) result = _capture_hook_output( hook_stop, {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, @@ -152,10 +167,10 @@ def test_stop_hook_blocks_at_interval(tmp_path): def test_stop_hook_tracks_save_point(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(SAVE_INTERVAL) - ]) + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) data = {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)} # First call blocks diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 09a3c46..24258a9 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -9,25 +9,26 @@ via monkeypatch to avoid touching real data. import json -def _patch_mcp_server(monkeypatch, config, palace_path, kg): +def _patch_mcp_server(monkeypatch, config, kg): """Patch the mcp_server module globals to use test fixtures.""" from mempalace import mcp_server - assert getattr(config, "palace_path", None) == palace_path, ( - f"config.palace_path ({getattr(config, 'palace_path', None)!r}) does not match palace_path fixture ({palace_path!r})" - ) monkeypatch.setattr(mcp_server, "_config", config) monkeypatch.setattr(mcp_server, "_kg", kg) def _get_collection(palace_path, create=False): - """Helper to get collection from test palace.""" + """Helper to get collection from test palace. + + Returns (client, collection) so callers can clean up the client + when they are done. + """ import chromadb client = chromadb.PersistentClient(path=palace_path) if create: - return client.get_or_create_collection("mempalace_drawers") - return client.get_collection("mempalace_drawers") + return client, client.get_or_create_collection("mempalace_drawers") + return client, client.get_collection("mempalace_drawers") # ── Protocol Layer ────────────────────────────────────────────────────── @@ -77,11 +78,12 @@ class TestHandleRequest: assert resp["error"]["code"] == -32601 def test_tools_call_dispatches(self, monkeypatch, config, palace_path, seeded_kg): - _patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) + _patch_mcp_server(monkeypatch, config, seeded_kg) from mempalace.mcp_server import handle_request # Create a collection so status works - _get_collection(palace_path, create=True) + _client, _col = _get_collection(palace_path, create=True) + del _client resp = handle_request( { @@ -100,8 +102,9 @@ class TestHandleRequest: class TestReadTools: def test_status_empty_palace(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) - _get_collection(palace_path, create=True) + _patch_mcp_server(monkeypatch, config, kg) + _client, _col = _get_collection(palace_path, create=True) + del _client from mempalace.mcp_server import tool_status result = tool_status() @@ -109,7 +112,7 @@ class TestReadTools: assert result["wings"] == {} def test_status_with_data(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_status result = tool_status() @@ -118,7 +121,7 @@ class TestReadTools: assert "notes" in result["wings"] def test_list_wings(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_list_wings result = tool_list_wings() @@ -126,7 +129,7 @@ class TestReadTools: assert result["wings"]["notes"] == 1 def test_list_rooms_all(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_list_rooms result = tool_list_rooms() @@ -135,7 +138,7 @@ class TestReadTools: assert "planning" in result["rooms"] def test_list_rooms_filtered(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_list_rooms result = tool_list_rooms(wing="project") @@ -143,7 +146,7 @@ class TestReadTools: assert "planning" not in result["rooms"] def test_get_taxonomy(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_get_taxonomy result = tool_get_taxonomy() @@ -151,10 +154,8 @@ class TestReadTools: assert result["taxonomy"]["project"]["frontend"] == 1 assert result["taxonomy"]["notes"]["planning"] == 1 - def test_no_palace_returns_error(self, monkeypatch, config, kg, tmp_path): - missing = str(tmp_path / "missing") - config._file_config["palace_path"] = missing - _patch_mcp_server(monkeypatch, config, missing, kg) + def test_no_palace_returns_error(self, monkeypatch, config, kg): + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_status result = tool_status() @@ -166,7 +167,7 @@ class TestReadTools: class TestSearchTool: def test_search_basic(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_search result = tool_search(query="JWT authentication tokens") @@ -177,14 +178,14 @@ class TestSearchTool: assert "JWT" in top["text"] or "authentication" in top["text"].lower() def test_search_with_wing_filter(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_search result = tool_search(query="planning", wing="notes") assert all(r["wing"] == "notes" for r in result["results"]) def test_search_with_room_filter(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_search result = tool_search(query="database", room="backend") @@ -196,8 +197,9 @@ class TestSearchTool: class TestWriteTools: def test_add_drawer(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) - _get_collection(palace_path, create=True) + _patch_mcp_server(monkeypatch, config, kg) + _client, _col = _get_collection(palace_path, create=True) + del _client from mempalace.mcp_server import tool_add_drawer result = tool_add_drawer( @@ -211,8 +213,9 @@ class TestWriteTools: assert result["drawer_id"].startswith("drawer_test_wing_test_room_") def test_add_drawer_duplicate_detection(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) - _get_collection(palace_path, create=True) + _patch_mcp_server(monkeypatch, config, kg) + _client, _col = _get_collection(palace_path, create=True) + del _client from mempalace.mcp_server import tool_add_drawer content = "This is a unique test memory about Rust ownership and borrowing." @@ -220,11 +223,11 @@ class TestWriteTools: assert result1["success"] is True result2 = tool_add_drawer(wing="w", room="r", content=content) - assert result2["success"] is False - assert result2["reason"] == "duplicate" + assert result2["success"] is True + assert result2["reason"] == "already_exists" def test_delete_drawer(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_delete_drawer result = tool_delete_drawer("drawer_proj_backend_aaa") @@ -232,14 +235,14 @@ class TestWriteTools: assert seeded_collection.count() == 3 def test_delete_drawer_not_found(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_delete_drawer result = tool_delete_drawer("nonexistent_drawer") assert result["success"] is False def test_check_duplicate(self, monkeypatch, config, palace_path, seeded_collection, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_check_duplicate # Exact match text from seeded_collection should be flagged @@ -263,7 +266,7 @@ class TestWriteTools: class TestKGTools: def test_kg_add(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) + _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_kg_add result = tool_kg_add( @@ -275,14 +278,14 @@ class TestKGTools: assert result["success"] is True def test_kg_query(self, monkeypatch, config, palace_path, seeded_kg): - _patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) + _patch_mcp_server(monkeypatch, config, seeded_kg) from mempalace.mcp_server import tool_kg_query result = tool_kg_query(entity="Max") assert result["count"] > 0 def test_kg_invalidate(self, monkeypatch, config, palace_path, seeded_kg): - _patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) + _patch_mcp_server(monkeypatch, config, seeded_kg) from mempalace.mcp_server import tool_kg_invalidate result = tool_kg_invalidate( @@ -294,14 +297,14 @@ class TestKGTools: assert result["success"] is True def test_kg_timeline(self, monkeypatch, config, palace_path, seeded_kg): - _patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) + _patch_mcp_server(monkeypatch, config, seeded_kg) from mempalace.mcp_server import tool_kg_timeline result = tool_kg_timeline(entity="Alice") assert result["count"] > 0 def test_kg_stats(self, monkeypatch, config, palace_path, seeded_kg): - _patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) + _patch_mcp_server(monkeypatch, config, seeded_kg) from mempalace.mcp_server import tool_kg_stats result = tool_kg_stats() @@ -313,8 +316,9 @@ class TestKGTools: class TestDiaryTools: def test_diary_write_and_read(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) - _get_collection(palace_path, create=True) + _patch_mcp_server(monkeypatch, config, kg) + _client, _col = _get_collection(palace_path, create=True) + del _client from mempalace.mcp_server import tool_diary_write, tool_diary_read w = tool_diary_write( @@ -331,8 +335,9 @@ class TestDiaryTools: assert "authentication" in r["entries"][0]["content"] def test_diary_read_empty(self, monkeypatch, config, palace_path, kg): - _patch_mcp_server(monkeypatch, config, palace_path, kg) - _get_collection(palace_path, create=True) + _patch_mcp_server(monkeypatch, config, kg) + _client, _col = _get_collection(palace_path, create=True) + del _client from mempalace.mcp_server import tool_diary_read r = tool_diary_read(agent_name="Nobody")