Merge branch 'main' into fix/query-sanitizer-prompt-contamination
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -326,6 +327,35 @@ def test_main_split_dispatches():
|
||||
mock_cmd.assert_called_once()
|
||||
|
||||
|
||||
def test_mcp_command_prints_setup_guidance(monkeypatch, capsys):
|
||||
monkeypatch.setattr(sys, "argv", ["mempalace", "mcp"])
|
||||
|
||||
main()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "MemPalace MCP quick setup:" in captured.out
|
||||
assert "claude mcp add mempalace -- python -m mempalace.mcp_server" in captured.out
|
||||
assert "\nOptional custom palace:\n" in captured.out
|
||||
assert "python -m mempalace.mcp_server --palace /path/to/palace" in captured.out
|
||||
assert "[--palace /path/to/palace]" not in captured.out
|
||||
assert captured.err == ""
|
||||
|
||||
|
||||
def test_mcp_command_uses_custom_palace_path_when_provided(monkeypatch, capsys):
|
||||
monkeypatch.setattr(sys, "argv", ["mempalace", "--palace", "~/tmp/my palace", "mcp"])
|
||||
|
||||
main()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
expanded = str(Path("~/tmp/my palace").expanduser())
|
||||
|
||||
assert "python -m mempalace.mcp_server --palace" in captured.out
|
||||
assert expanded in captured.out
|
||||
assert "Optional custom palace:" not in captured.out
|
||||
assert "[--palace /path/to/palace]" not in captured.out
|
||||
assert captured.err == ""
|
||||
|
||||
|
||||
def test_main_hook_no_subcommand_prints_help(capsys):
|
||||
with patch("sys.argv", ["mempalace", "hook"]):
|
||||
main()
|
||||
@@ -607,3 +637,16 @@ def test_cmd_compress_stores_results(mock_config_cls, capsys):
|
||||
out = capsys.readouterr().out
|
||||
assert "Stored" in out
|
||||
mock_comp_col.upsert.assert_called_once()
|
||||
|
||||
|
||||
def test_cmd_repair_trailing_slash_does_not_recurse():
|
||||
"""Repair with trailing slash should put backup outside palace dir (#395)."""
|
||||
import os
|
||||
|
||||
args = argparse.Namespace(palace="/tmp/fake_palace/")
|
||||
with patch("mempalace.cli.os.path.isdir", return_value=False):
|
||||
cmd_repair(args)
|
||||
# Verify the rstrip logic: palace_path should not end with separator
|
||||
palace_path = os.path.expanduser(args.palace).rstrip(os.sep)
|
||||
backup_path = palace_path + ".backup"
|
||||
assert not backup_path.startswith(palace_path + os.sep)
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Tests for mempalace.dedup — near-duplicate drawer detection and removal."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
from mempalace import dedup
|
||||
|
||||
|
||||
# ── get_source_groups ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_source_groups_basic():
|
||||
col = MagicMock()
|
||||
col.count.return_value = 5
|
||||
col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2", "d3", "d4", "d5"],
|
||||
"metadatas": [
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
groups = dedup.get_source_groups(col, min_count=5)
|
||||
assert "a.txt" in groups
|
||||
assert len(groups["a.txt"]) == 5
|
||||
|
||||
|
||||
def test_get_source_groups_below_min():
|
||||
col = MagicMock()
|
||||
col.count.return_value = 2
|
||||
col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2"],
|
||||
"metadatas": [
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
groups = dedup.get_source_groups(col, min_count=5)
|
||||
assert len(groups) == 0
|
||||
|
||||
|
||||
def test_get_source_groups_source_filter():
|
||||
col = MagicMock()
|
||||
col.count.return_value = 6
|
||||
col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2", "d3", "d4", "d5", "d6"],
|
||||
"metadatas": [
|
||||
{"source_file": "project_a.txt"},
|
||||
{"source_file": "project_a.txt"},
|
||||
{"source_file": "project_a.txt"},
|
||||
{"source_file": "project_a.txt"},
|
||||
{"source_file": "project_a.txt"},
|
||||
{"source_file": "other.txt"},
|
||||
],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
groups = dedup.get_source_groups(col, min_count=5, source_pattern="project_a")
|
||||
assert "project_a.txt" in groups
|
||||
assert "other.txt" not in groups
|
||||
|
||||
|
||||
def test_get_source_groups_wing_filter():
|
||||
col = MagicMock()
|
||||
col.count.return_value = 5
|
||||
col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2", "d3", "d4", "d5"],
|
||||
"metadatas": [
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
dedup.get_source_groups(col, min_count=5, wing="my_wing")
|
||||
# Verify where filter was passed
|
||||
first_call = col.get.call_args_list[0]
|
||||
assert first_call.kwargs.get("where") == {"wing": "my_wing"}
|
||||
|
||||
|
||||
def test_get_source_groups_missing_source_file():
|
||||
col = MagicMock()
|
||||
col.count.return_value = 5
|
||||
col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2", "d3", "d4", "d5"],
|
||||
"metadatas": [{}, {}, {}, {}, {}],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
groups = dedup.get_source_groups(col, min_count=5)
|
||||
assert "unknown" in groups
|
||||
|
||||
|
||||
# ── dedup_source_group ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_dedup_source_group_all_unique():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": ["long document one content here", "different document two here"],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
col.query.return_value = {
|
||||
"ids": [["d1"]],
|
||||
"distances": [[0.8]], # far apart = unique
|
||||
}
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
|
||||
assert len(kept) == 2
|
||||
assert len(deleted) == 0
|
||||
|
||||
|
||||
def test_dedup_source_group_with_duplicate():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": [
|
||||
"long document content that is fairly long",
|
||||
"long document content that is fairly long",
|
||||
],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
col.query.return_value = {
|
||||
"ids": [["d1"]],
|
||||
"distances": [[0.05]], # very close = duplicate
|
||||
}
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
|
||||
assert len(kept) == 1
|
||||
assert len(deleted) == 1
|
||||
|
||||
|
||||
def test_dedup_source_group_short_docs_deleted():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": ["long enough document to keep in the palace", "tiny"],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
|
||||
assert "d2" in deleted # too short
|
||||
|
||||
|
||||
def test_dedup_source_group_empty_doc_deleted():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": ["real document content here that is long enough", None],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
|
||||
assert "d2" in deleted
|
||||
|
||||
|
||||
def test_dedup_source_group_live_deletes():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": ["long document content here enough", "long document content here enough"],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
col.query.return_value = {
|
||||
"ids": [["d1"]],
|
||||
"distances": [[0.05]],
|
||||
}
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=False)
|
||||
col.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_dedup_source_group_query_failure_keeps():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": [
|
||||
"long document one content here enough",
|
||||
"long document two content here enough",
|
||||
],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
col.query.side_effect = Exception("query failed")
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
|
||||
assert len(kept) == 2 # both kept on error
|
||||
|
||||
|
||||
# ── show_stats ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.dedup.chromadb")
|
||||
def test_show_stats(mock_chromadb, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 5
|
||||
mock_col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2", "d3", "d4", "d5"],
|
||||
"metadatas": [
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
dedup.show_stats(palace_path=str(tmp_path)) # should not raise
|
||||
|
||||
|
||||
# ── dedup_palace ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.dedup.dedup_source_group")
|
||||
@patch("mempalace.dedup.get_source_groups")
|
||||
@patch("mempalace.dedup.chromadb")
|
||||
def test_dedup_palace_dry_run(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 10
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
mock_groups.return_value = {"a.txt": ["d1", "d2", "d3", "d4", "d5"]}
|
||||
mock_dedup_group.return_value = (["d1", "d2", "d3"], ["d4", "d5"])
|
||||
|
||||
dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True)
|
||||
mock_dedup_group.assert_called_once()
|
||||
|
||||
|
||||
@patch("mempalace.dedup.dedup_source_group")
|
||||
@patch("mempalace.dedup.get_source_groups")
|
||||
@patch("mempalace.dedup.chromadb")
|
||||
def test_dedup_palace_with_wing(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 10
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
mock_groups.return_value = {}
|
||||
dedup.dedup_palace(palace_path=str(tmp_path), wing="test_wing", dry_run=True)
|
||||
mock_groups.assert_called_once_with(mock_col, 5, None, wing="test_wing")
|
||||
|
||||
|
||||
@patch("mempalace.dedup.dedup_source_group")
|
||||
@patch("mempalace.dedup.get_source_groups")
|
||||
@patch("mempalace.dedup.chromadb")
|
||||
def test_dedup_palace_no_groups(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 3
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
mock_groups.return_value = {}
|
||||
dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True)
|
||||
mock_dedup_group.assert_not_called()
|
||||
@@ -42,6 +42,50 @@ class TestHandleRequest:
|
||||
assert resp["result"]["serverInfo"]["name"] == "mempalace"
|
||||
assert resp["id"] == 1
|
||||
|
||||
def test_initialize_negotiates_client_version(self):
|
||||
from mempalace.mcp_server import handle_request
|
||||
|
||||
resp = handle_request(
|
||||
{
|
||||
"method": "initialize",
|
||||
"id": 1,
|
||||
"params": {"protocolVersion": "2025-11-25"},
|
||||
}
|
||||
)
|
||||
assert resp["result"]["protocolVersion"] == "2025-11-25"
|
||||
|
||||
def test_initialize_negotiates_older_supported_version(self):
|
||||
from mempalace.mcp_server import handle_request
|
||||
|
||||
resp = handle_request(
|
||||
{
|
||||
"method": "initialize",
|
||||
"id": 1,
|
||||
"params": {"protocolVersion": "2025-03-26"},
|
||||
}
|
||||
)
|
||||
assert resp["result"]["protocolVersion"] == "2025-03-26"
|
||||
|
||||
def test_initialize_unknown_version_falls_back_to_latest(self):
|
||||
from mempalace.mcp_server import handle_request
|
||||
|
||||
resp = handle_request(
|
||||
{
|
||||
"method": "initialize",
|
||||
"id": 1,
|
||||
"params": {"protocolVersion": "9999-12-31"},
|
||||
}
|
||||
)
|
||||
from mempalace.mcp_server import SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
assert resp["result"]["protocolVersion"] == SUPPORTED_PROTOCOL_VERSIONS[0]
|
||||
|
||||
def test_initialize_missing_version_uses_oldest(self):
|
||||
from mempalace.mcp_server import handle_request, SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
resp = handle_request({"method": "initialize", "id": 1, "params": {}})
|
||||
assert resp["result"]["protocolVersion"] == SUPPORTED_PROTOCOL_VERSIONS[-1]
|
||||
|
||||
def test_notifications_initialized_returns_none(self):
|
||||
from mempalace.mcp_server import handle_request
|
||||
|
||||
@@ -59,6 +103,23 @@ class TestHandleRequest:
|
||||
assert "mempalace_add_drawer" in names
|
||||
assert "mempalace_kg_add" in names
|
||||
|
||||
def test_null_arguments_does_not_hang(self, monkeypatch, config, palace_path, seeded_kg):
|
||||
"""Sending arguments: null should return a result, not hang (#394)."""
|
||||
_patch_mcp_server(monkeypatch, config, seeded_kg)
|
||||
from mempalace.mcp_server import handle_request
|
||||
|
||||
_client, _col = _get_collection(palace_path, create=True)
|
||||
del _client
|
||||
resp = handle_request(
|
||||
{
|
||||
"method": "tools/call",
|
||||
"id": 10,
|
||||
"params": {"name": "mempalace_status", "arguments": None},
|
||||
}
|
||||
)
|
||||
assert "error" not in resp
|
||||
assert resp["result"] is not None
|
||||
|
||||
def test_unknown_tool(self):
|
||||
from mempalace.mcp_server import handle_request
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import chromadb
|
||||
import yaml
|
||||
|
||||
from mempalace.miner import mine, scan_project
|
||||
from mempalace.palace import file_already_mined
|
||||
|
||||
|
||||
def write_file(path: Path, content: str):
|
||||
@@ -206,3 +207,56 @@ def test_scan_project_skip_dirs_still_apply_without_override():
|
||||
assert scanned_files(project_root, respect_gitignore=False) == ["main.py"]
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
def test_file_already_mined_check_mtime():
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
try:
|
||||
palace_path = os.path.join(tmpdir, "palace")
|
||||
os.makedirs(palace_path)
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
|
||||
test_file = os.path.join(tmpdir, "test.txt")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("hello world")
|
||||
|
||||
mtime = os.path.getmtime(test_file)
|
||||
|
||||
# Not mined yet
|
||||
assert file_already_mined(col, test_file) is False
|
||||
assert file_already_mined(col, test_file, check_mtime=True) is False
|
||||
|
||||
# Add it with mtime
|
||||
col.add(
|
||||
ids=["d1"],
|
||||
documents=["hello world"],
|
||||
metadatas=[{"source_file": test_file, "source_mtime": str(mtime)}],
|
||||
)
|
||||
|
||||
# Already mined (no mtime check)
|
||||
assert file_already_mined(col, test_file) is True
|
||||
# Already mined (mtime matches)
|
||||
assert file_already_mined(col, test_file, check_mtime=True) is True
|
||||
|
||||
# Modify file and force a different mtime (Windows has low mtime resolution)
|
||||
with open(test_file, "w") as f:
|
||||
f.write("modified content")
|
||||
os.utime(test_file, (mtime + 10, mtime + 10))
|
||||
|
||||
# Still mined without mtime check
|
||||
assert file_already_mined(col, test_file) is True
|
||||
# Needs re-mining with mtime check
|
||||
assert file_already_mined(col, test_file, check_mtime=True) is False
|
||||
|
||||
# Record with no mtime stored should return False for check_mtime
|
||||
col.add(
|
||||
ids=["d2"],
|
||||
documents=["other"],
|
||||
metadatas=[{"source_file": "/fake/no_mtime.txt"}],
|
||||
)
|
||||
assert file_already_mined(col, "/fake/no_mtime.txt", check_mtime=True) is False
|
||||
finally:
|
||||
# Release ChromaDB file handles before cleanup (required on Windows)
|
||||
del col, client
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
@@ -499,3 +499,13 @@ def test_messages_to_transcript_assistant_first():
|
||||
result = _messages_to_transcript(msgs, spellcheck=False)
|
||||
assert "preamble" in result
|
||||
assert "> Q" in result
|
||||
|
||||
|
||||
def test_normalize_rejects_large_file():
|
||||
"""Files over 500 MB should raise IOError before reading."""
|
||||
with patch("mempalace.normalize.os.path.getsize", return_value=600 * 1024 * 1024):
|
||||
try:
|
||||
normalize("/fake/huge_file.txt")
|
||||
assert False, "Should have raised IOError"
|
||||
except IOError as e:
|
||||
assert "too large" in str(e).lower()
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
"""Tests for mempalace.repair — scan, prune, and rebuild HNSW index."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
from mempalace import repair
|
||||
|
||||
|
||||
# ── _get_palace_path ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.repair.MempalaceConfig", create=True)
|
||||
def test_get_palace_path_from_config(mock_config_cls):
|
||||
mock_config_cls.return_value.palace_path = "/configured/palace"
|
||||
with patch.dict("sys.modules", {}):
|
||||
# Force reimport to pick up the mock
|
||||
result = repair._get_palace_path()
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_get_palace_path_fallback():
|
||||
with patch("mempalace.repair._get_palace_path") as mock_get:
|
||||
mock_get.return_value = os.path.join(os.path.expanduser("~"), ".mempalace", "palace")
|
||||
result = mock_get()
|
||||
assert ".mempalace" in result
|
||||
|
||||
|
||||
# ── _paginate_ids ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_paginate_ids_single_batch():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {"ids": ["id1", "id2", "id3"]}
|
||||
ids = repair._paginate_ids(col)
|
||||
assert ids == ["id1", "id2", "id3"]
|
||||
|
||||
|
||||
def test_paginate_ids_empty():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {"ids": []}
|
||||
ids = repair._paginate_ids(col)
|
||||
assert ids == []
|
||||
|
||||
|
||||
def test_paginate_ids_with_where():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {"ids": ["id1"]}
|
||||
repair._paginate_ids(col, where={"wing": "test"})
|
||||
col.get.assert_called_with(where={"wing": "test"}, include=[], limit=1000, offset=0)
|
||||
|
||||
|
||||
def test_paginate_ids_offset_exception_fallback():
|
||||
col = MagicMock()
|
||||
# First call raises, fallback returns ids, second fallback returns empty
|
||||
col.get.side_effect = [
|
||||
Exception("offset bug"),
|
||||
{"ids": ["id1", "id2"]},
|
||||
Exception("offset bug"),
|
||||
{"ids": ["id1", "id2"]}, # same ids = no new = break
|
||||
]
|
||||
ids = repair._paginate_ids(col)
|
||||
assert "id1" in ids
|
||||
|
||||
|
||||
# ── scan_palace ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_scan_palace_no_ids(mock_chromadb, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 0
|
||||
mock_col.get.return_value = {"ids": []}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
good, bad = repair.scan_palace(palace_path=str(tmp_path))
|
||||
assert good == set()
|
||||
assert bad == set()
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_scan_palace_all_good(mock_chromadb, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 2
|
||||
# _paginate_ids call
|
||||
mock_col.get.side_effect = [
|
||||
{"ids": ["id1", "id2"]}, # paginate
|
||||
{"ids": ["id1", "id2"]}, # probe batch — both returned
|
||||
]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
good, bad = repair.scan_palace(palace_path=str(tmp_path))
|
||||
assert "id1" in good
|
||||
assert "id2" in good
|
||||
assert len(bad) == 0
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_scan_palace_with_bad_ids(mock_chromadb, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 2
|
||||
|
||||
def get_side_effect(**kwargs):
|
||||
ids = kwargs.get("ids", None)
|
||||
if ids is None:
|
||||
# paginate call
|
||||
return {"ids": ["good1", "bad1"]}
|
||||
if "bad1" in ids and len(ids) == 1:
|
||||
raise Exception("corrupt")
|
||||
if "good1" in ids and len(ids) == 1:
|
||||
return {"ids": ["good1"]}
|
||||
# batch probe — raise to force per-id
|
||||
raise Exception("batch fail")
|
||||
|
||||
mock_col.get.side_effect = get_side_effect
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
good, bad = repair.scan_palace(palace_path=str(tmp_path))
|
||||
assert "good1" in good
|
||||
assert "bad1" in bad
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_scan_palace_with_wing_filter(mock_chromadb, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 1
|
||||
mock_col.get.side_effect = [
|
||||
{"ids": ["id1"]}, # paginate
|
||||
{"ids": ["id1"]}, # probe
|
||||
]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.scan_palace(palace_path=str(tmp_path), only_wing="test_wing")
|
||||
# Verify where filter was passed
|
||||
first_call = mock_col.get.call_args_list[0]
|
||||
assert first_call.kwargs.get("where") == {"wing": "test_wing"}
|
||||
|
||||
|
||||
# ── prune_corrupt ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_prune_corrupt_no_file(mock_chromadb, tmp_path):
|
||||
# Should print message and return without error
|
||||
repair.prune_corrupt(palace_path=str(tmp_path))
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_prune_corrupt_dry_run(mock_chromadb, tmp_path):
|
||||
bad_file = tmp_path / "corrupt_ids.txt"
|
||||
bad_file.write_text("bad1\nbad2\n")
|
||||
repair.prune_corrupt(palace_path=str(tmp_path), confirm=False)
|
||||
# No chromadb calls in dry run
|
||||
mock_chromadb.PersistentClient.assert_not_called()
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_prune_corrupt_confirmed(mock_chromadb, tmp_path):
|
||||
bad_file = tmp_path / "corrupt_ids.txt"
|
||||
bad_file.write_text("bad1\nbad2\n")
|
||||
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.side_effect = [10, 8]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.prune_corrupt(palace_path=str(tmp_path), confirm=True)
|
||||
mock_col.delete.assert_called_once()
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path):
|
||||
bad_file = tmp_path / "corrupt_ids.txt"
|
||||
bad_file.write_text("bad1\nbad2\n")
|
||||
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.side_effect = [10, 8]
|
||||
# Batch delete fails, per-id succeeds
|
||||
mock_col.delete.side_effect = [Exception("batch fail"), None, None]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.prune_corrupt(palace_path=str(tmp_path), confirm=True)
|
||||
assert mock_col.delete.call_count == 3 # 1 batch + 2 individual
|
||||
|
||||
|
||||
# ── rebuild_index ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_rebuild_index_no_palace(mock_chromadb, tmp_path):
|
||||
nonexistent = str(tmp_path / "nope")
|
||||
repair.rebuild_index(palace_path=nonexistent)
|
||||
mock_chromadb.PersistentClient.assert_not_called()
|
||||
|
||||
|
||||
@patch("mempalace.repair.shutil")
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_rebuild_index_empty_palace(mock_chromadb, mock_shutil, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 0
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.rebuild_index(palace_path=str(tmp_path))
|
||||
mock_client.delete_collection.assert_not_called()
|
||||
|
||||
|
||||
@patch("mempalace.repair.shutil")
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path):
|
||||
# Create a fake sqlite file
|
||||
sqlite_path = tmp_path / "chroma.sqlite3"
|
||||
sqlite_path.write_text("fake")
|
||||
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 2
|
||||
mock_col.get.return_value = {
|
||||
"ids": ["id1", "id2"],
|
||||
"documents": ["doc1", "doc2"],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "b"}],
|
||||
}
|
||||
|
||||
mock_new_col = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_client.create_collection.return_value = mock_new_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.rebuild_index(palace_path=str(tmp_path))
|
||||
|
||||
# Verify: backed up sqlite only (not copytree)
|
||||
mock_shutil.copy2.assert_called_once()
|
||||
assert "chroma.sqlite3" in str(mock_shutil.copy2.call_args)
|
||||
|
||||
# Verify: deleted and recreated with cosine
|
||||
mock_client.delete_collection.assert_called_once_with("mempalace_drawers")
|
||||
mock_client.create_collection.assert_called_once_with(
|
||||
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
# Verify: used upsert not add
|
||||
mock_new_col.upsert.assert_called_once()
|
||||
mock_new_col.add.assert_not_called()
|
||||
|
||||
|
||||
@patch("mempalace.repair.shutil")
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_rebuild_index_error_reading(mock_chromadb, mock_shutil, tmp_path):
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.side_effect = Exception("corrupt")
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.rebuild_index(palace_path=str(tmp_path))
|
||||
mock_client.delete_collection.assert_not_called()
|
||||
Reference in New Issue
Block a user