Merge branch 'main' into fix/issue-347-codex-hook-message-counting

This commit is contained in:
Ben Sigman
2026-04-10 09:25:58 -07:00
committed by GitHub
6 changed files with 1087 additions and 1 deletions
+1 -1
View File
@@ -334,7 +334,7 @@ def mine_convos(
room_counts[chunk_room] += 1 room_counts[chunk_room] += 1
drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}" drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}"
try: try:
collection.add( collection.upsert(
documents=[chunk["content"]], documents=[chunk["content"]],
ids=[drawer_id], ids=[drawer_id],
metadatas=[ metadatas=[
+239
View File
@@ -0,0 +1,239 @@
"""
dedup.py — Detect and remove near-duplicate drawers
====================================================
When the same files are mined multiple times, near-identical drawers
accumulate. This module finds drawers from the same source_file that
are too similar (cosine distance < threshold), keeps the longest/richest
version, and deletes the rest.
No API calls — uses ChromaDB's built-in embedding similarity.
Usage (standalone):
python -m mempalace.dedup # dedup all
python -m mempalace.dedup --dry-run # preview only
python -m mempalace.dedup --threshold 0.10 # stricter (near-identical only)
python -m mempalace.dedup --threshold 0.35 # looser (catches paraphrased content)
python -m mempalace.dedup --wing my_project # scope to one wing
python -m mempalace.dedup --stats # stats only
python -m mempalace.dedup --source "my_project" # filter by source
Usage (from CLI):
mempalace dedup [--dry-run] [--threshold 0.15] [--stats]
"""
import argparse
import os
import time
from collections import defaultdict
import chromadb
COLLECTION_NAME = "mempalace_drawers"
# Cosine DISTANCE threshold (not similarity). Lower = stricter.
# 0.15 = ~85% cosine similarity — catches near-identical chunks.
# For looser dedup of paraphrased content, try 0.30.4.
DEFAULT_THRESHOLD = 0.15
MIN_DRAWERS_TO_CHECK = 5
def _get_palace_path():
"""Resolve palace path from config."""
try:
from .config import MempalaceConfig
return MempalaceConfig().palace_path
except Exception:
return os.path.join(os.path.expanduser("~"), ".mempalace", "palace")
def get_source_groups(col, min_count=MIN_DRAWERS_TO_CHECK, source_pattern=None, wing=None):
"""Group drawers by source_file, return groups with min_count+ entries.
If wing is specified, only considers drawers in that wing. This catches
cross-wing duplicates when the same source was mined into multiple wings.
"""
total = col.count()
groups = defaultdict(list)
offset = 0
batch_size = 1000
while offset < total:
kwargs = {"limit": batch_size, "offset": offset, "include": ["metadatas"]}
if wing:
kwargs["where"] = {"wing": wing}
batch = col.get(**kwargs)
if not batch["ids"]:
break
for did, meta in zip(batch["ids"], batch["metadatas"]):
src = meta.get("source_file", "unknown")
if source_pattern and source_pattern.lower() not in src.lower():
continue
groups[src].append(did)
offset += len(batch["ids"])
return {src: ids for src, ids in groups.items() if len(ids) >= min_count}
def dedup_source_group(col, drawer_ids, threshold=DEFAULT_THRESHOLD, dry_run=True):
"""Dedup drawers within one source_file group.
Greedy: sort by doc length (longest first), keep if not too similar
to any already-kept drawer. Returns (kept_ids, deleted_ids).
"""
data = col.get(ids=drawer_ids, include=["documents", "metadatas"])
items = list(zip(data["ids"], data["documents"], data["metadatas"]))
items.sort(key=lambda x: len(x[1] or ""), reverse=True)
kept = []
to_delete = []
for did, doc, meta in items:
if not doc or len(doc) < 20:
to_delete.append(did)
continue
if not kept:
kept.append((did, doc))
continue
try:
results = col.query(
query_texts=[doc],
n_results=min(len(kept), 5),
include=["distances"],
)
dists = results["distances"][0] if results["distances"] else []
kept_ids_set = {k[0] for k in kept}
is_dup = False
for rid, dist in zip(results["ids"][0], dists):
if rid in kept_ids_set and dist < threshold:
is_dup = True
break
if is_dup:
to_delete.append(did)
else:
kept.append((did, doc))
except Exception:
kept.append((did, doc))
if to_delete and not dry_run:
for i in range(0, len(to_delete), 500):
col.delete(ids=to_delete[i : i + 500])
return [k[0] for k in kept], to_delete
def show_stats(palace_path=None):
"""Show duplication statistics without making changes."""
palace_path = palace_path or _get_palace_path()
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection(COLLECTION_NAME)
groups = get_source_groups(col)
total_drawers = sum(len(ids) for ids in groups.values())
print(f"\n Sources with {MIN_DRAWERS_TO_CHECK}+ drawers: {len(groups)}")
print(f" Total drawers in those sources: {total_drawers:,}")
print("\n Top 15 by drawer count:")
sorted_groups = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True)
for src, ids in sorted_groups[:15]:
print(f" {len(ids):4d} {src[:65]}")
estimated_dups = sum(int(len(ids) * 0.4) for ids in groups.values() if len(ids) > 20)
print(f"\n Estimated duplicates (groups > 20): ~{estimated_dups:,}")
def dedup_palace(
palace_path=None,
threshold=DEFAULT_THRESHOLD,
dry_run=True,
source_pattern=None,
min_count=MIN_DRAWERS_TO_CHECK,
wing=None,
):
"""Main entry point: deduplicate near-identical drawers across the palace."""
palace_path = palace_path or _get_palace_path()
print(f"\n{'=' * 55}")
print(" MemPalace Deduplicator")
print(f"{'=' * 55}")
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection(COLLECTION_NAME)
print(f" Palace: {palace_path}")
print(f" Drawers: {col.count():,}")
print(f" Threshold: {threshold}")
print(f" Mode: {'DRY RUN' if dry_run else 'LIVE'}")
print(f"{'' * 55}")
if wing:
print(f" Wing: {wing}")
groups = get_source_groups(col, min_count, source_pattern, wing=wing)
print(f"\n Sources to check: {len(groups)}")
t0 = time.time()
total_kept = 0
total_deleted = 0
sorted_groups = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True)
for i, (src, drawer_ids) in enumerate(sorted_groups):
kept, deleted = dedup_source_group(col, drawer_ids, threshold, dry_run)
total_kept += len(kept)
total_deleted += len(deleted)
if deleted:
print(
f" [{i + 1:3d}/{len(groups)}] "
f"{src[:50]:50s} {len(drawer_ids):4d}{len(kept):4d} "
f"(-{len(deleted)})"
)
elapsed = time.time() - t0
print(f"\n{'' * 55}")
print(f" Done in {elapsed:.1f}s")
print(
f" Drawers: {total_kept + total_deleted:,}{total_kept:,} (-{total_deleted:,} removed)"
)
print(f" Palace after: {col.count():,} drawers")
if dry_run:
print("\n [DRY RUN] No changes written. Re-run without --dry-run to apply.")
print(f"{'=' * 55}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Deduplicate near-identical drawers")
parser.add_argument("--palace", default=None, help="Palace directory path")
parser.add_argument(
"--threshold",
type=float,
default=DEFAULT_THRESHOLD,
help=f"Cosine distance threshold (default: {DEFAULT_THRESHOLD})",
)
parser.add_argument("--dry-run", action="store_true", help="Preview without deleting")
parser.add_argument("--stats", action="store_true", help="Show stats only")
parser.add_argument("--wing", default=None, help="Scope dedup to a single wing")
parser.add_argument("--source", default=None, help="Filter by source file pattern")
args = parser.parse_args()
path = os.path.expanduser(args.palace) if args.palace else None
if args.stats:
show_stats(palace_path=path)
else:
dedup_palace(
palace_path=path,
threshold=args.threshold,
dry_run=args.dry_run,
source_pattern=args.source,
wing=args.wing,
)
+10
View File
@@ -436,6 +436,16 @@ def process_file(
print(f" [DRY RUN] {filepath.name} → room:{room} ({len(chunks)} drawers)") print(f" [DRY RUN] {filepath.name} → room:{room} ({len(chunks)} drawers)")
return len(chunks), room return len(chunks), room
# Purge stale drawers for this file before re-inserting the fresh chunks.
# Converts modified-file re-mines from upsert-over-existing-IDs (which hits
# hnswlib's thread-unsafe updatePoint path and can segfault on macOS ARM
# with chromadb 0.6.3) into a clean delete+insert, bypassing the update
# path entirely.
try:
collection.delete(where={"source_file": source_file})
except Exception:
pass
drawers_added = 0 drawers_added = 0
for chunk in chunks: for chunk in chunks:
added = add_drawer( added = add_drawer(
+299
View File
@@ -0,0 +1,299 @@
"""
repair.py — Scan, prune corrupt entries, and rebuild HNSW index
================================================================
When ChromaDB's HNSW index accumulates duplicate entries (from repeated
add() calls with the same ID), link_lists.bin can grow unbounded —
terabytes on large palaces — eventually causing segfaults.
This module provides three operations:
scan — find every corrupt/unfetchable ID in the palace
prune — delete only the corrupt IDs (surgical)
rebuild — extract all drawers, delete the collection, recreate with
correct HNSW settings, and upsert everything back
The rebuild backs up ONLY chroma.sqlite3 (the source of truth), not the
full palace directory — so it works even when link_lists.bin is bloated.
Usage (standalone):
python -m mempalace.repair scan [--wing X]
python -m mempalace.repair prune --confirm
python -m mempalace.repair rebuild
Usage (from CLI):
mempalace repair
mempalace repair-scan [--wing X]
mempalace repair-prune --confirm
"""
import argparse
import os
import shutil
import time
import chromadb
COLLECTION_NAME = "mempalace_drawers"
def _get_palace_path():
"""Resolve palace path from config."""
try:
from .config import MempalaceConfig
return MempalaceConfig().palace_path
except Exception:
default = os.path.join(os.path.expanduser("~"), ".mempalace", "palace")
return default
def _paginate_ids(col, where=None):
"""Pull all IDs in a collection using pagination."""
ids = []
page = 1000
offset = 0
while True:
try:
r = col.get(where=where, include=[], limit=page, offset=offset)
except Exception:
try:
r = col.get(where=where, include=[], limit=page)
new_ids = [i for i in r["ids"] if i not in set(ids)]
if not new_ids:
break
ids.extend(new_ids)
offset += len(new_ids)
continue
except Exception:
break
n = len(r["ids"]) if r["ids"] else 0
if n == 0:
break
ids.extend(r["ids"])
offset += n
if n < page:
break
return ids
def scan_palace(palace_path=None, only_wing=None):
"""Scan the palace for corrupt/unfetchable IDs.
Probes in batches of 100, falls back to per-ID on failure.
Writes corrupt_ids.txt to the palace directory for the prune step.
Returns (good_set, bad_set).
"""
palace_path = palace_path or _get_palace_path()
print(f"\n Palace: {palace_path}")
print(" Loading...")
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection(COLLECTION_NAME)
where = {"wing": only_wing} if only_wing else None
total = col.count()
print(f" Collection: {COLLECTION_NAME}, total: {total:,}")
if only_wing:
print(f" Scanning wing: {only_wing}")
print("\n Step 1: listing all IDs...")
t0 = time.time()
all_ids = _paginate_ids(col, where=where)
print(f" Found {len(all_ids):,} IDs in {time.time() - t0:.1f}s\n")
if not all_ids:
print(" Nothing to scan.")
return set(), set()
print(" Step 2: probing each ID (batches of 100)...")
t0 = time.time()
good_set = set()
bad_set = set()
batch = 100
for i in range(0, len(all_ids), batch):
chunk = all_ids[i : i + batch]
try:
r = col.get(ids=chunk, include=["documents"])
for got in r["ids"]:
good_set.add(got)
for mid in chunk:
if mid not in good_set:
bad_set.add(mid)
except Exception:
for sid in chunk:
try:
r = col.get(ids=[sid], include=["documents"])
if r["ids"]:
good_set.add(sid)
else:
bad_set.add(sid)
except Exception:
bad_set.add(sid)
if (i // batch) % 50 == 0:
elapsed = time.time() - t0
rate = (i + batch) / max(elapsed, 0.01)
eta = (len(all_ids) - i - batch) / max(rate, 0.01)
print(
f" {i + batch:>6}/{len(all_ids):>6} "
f"good={len(good_set):>6} bad={len(bad_set):>6} "
f"eta={eta:.0f}s"
)
print(f"\n Scan complete in {time.time() - t0:.1f}s")
print(f" GOOD: {len(good_set):,}")
print(f" BAD: {len(bad_set):,} ({len(bad_set) / max(len(all_ids), 1) * 100:.1f}%)")
bad_file = os.path.join(palace_path, "corrupt_ids.txt")
with open(bad_file, "w") as f:
for bid in sorted(bad_set):
f.write(bid + "\n")
print(f"\n Bad IDs written to: {bad_file}")
return good_set, bad_set
def prune_corrupt(palace_path=None, confirm=False):
"""Delete corrupt IDs listed in corrupt_ids.txt."""
palace_path = palace_path or _get_palace_path()
bad_file = os.path.join(palace_path, "corrupt_ids.txt")
if not os.path.exists(bad_file):
print(" No corrupt_ids.txt found — run scan first.")
return
with open(bad_file) as f:
bad_ids = [line.strip() for line in f if line.strip()]
print(f" {len(bad_ids):,} corrupt IDs queued for deletion")
if not confirm:
print("\n DRY RUN — no deletions performed.")
print(" Re-run with --confirm to actually delete.")
return
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection(COLLECTION_NAME)
before = col.count()
print(f" Collection size before: {before:,}")
batch = 100
deleted = 0
failed = 0
for i in range(0, len(bad_ids), batch):
chunk = bad_ids[i : i + batch]
try:
col.delete(ids=chunk)
deleted += len(chunk)
except Exception:
for sid in chunk:
try:
col.delete(ids=[sid])
deleted += 1
except Exception:
failed += 1
if (i // batch) % 20 == 0:
print(f" deleted {deleted}/{len(bad_ids)} (failed: {failed})")
after = col.count()
print(f"\n Deleted: {deleted:,}")
print(f" Failed: {failed:,}")
print(f" Collection size: {before:,}{after:,}")
def rebuild_index(palace_path=None):
"""Rebuild the HNSW index from scratch.
1. Extract all drawers via ChromaDB get()
2. Back up ONLY chroma.sqlite3 (not the bloated HNSW files)
3. Delete and recreate the collection with hnsw:space=cosine
4. Upsert all drawers back
"""
palace_path = palace_path or _get_palace_path()
if not os.path.isdir(palace_path):
print(f"\n No palace found at {palace_path}")
return
print(f"\n{'=' * 55}")
print(" MemPalace Repair — Index Rebuild")
print(f"{'=' * 55}\n")
print(f" Palace: {palace_path}")
client = chromadb.PersistentClient(path=palace_path)
try:
col = client.get_collection(COLLECTION_NAME)
total = col.count()
except Exception as e:
print(f" Error reading palace: {e}")
print(" Palace may need to be re-mined from source files.")
return
print(f" Drawers found: {total}")
if total == 0:
print(" Nothing to repair.")
return
# Extract all drawers in batches
print("\n Extracting drawers...")
batch_size = 5000
all_ids = []
all_docs = []
all_metas = []
offset = 0
while offset < total:
batch = col.get(limit=batch_size, offset=offset, include=["documents", "metadatas"])
if not batch["ids"]:
break
all_ids.extend(batch["ids"])
all_docs.extend(batch["documents"])
all_metas.extend(batch["metadatas"])
offset += len(batch["ids"])
print(f" Extracted {len(all_ids)} drawers")
# Back up ONLY the SQLite database, not the bloated HNSW files
sqlite_path = os.path.join(palace_path, "chroma.sqlite3")
if os.path.exists(sqlite_path):
backup_path = sqlite_path + ".backup"
print(f" Backing up chroma.sqlite3 ({os.path.getsize(sqlite_path) / 1e6:.0f} MB)...")
shutil.copy2(sqlite_path, backup_path)
print(f" Backup: {backup_path}")
# Rebuild with correct HNSW settings
print(" Rebuilding collection with hnsw:space=cosine...")
client.delete_collection(COLLECTION_NAME)
new_col = client.create_collection(COLLECTION_NAME, metadata={"hnsw:space": "cosine"})
filed = 0
for i in range(0, len(all_ids), batch_size):
batch_ids = all_ids[i : i + batch_size]
batch_docs = all_docs[i : i + batch_size]
batch_metas = all_metas[i : i + batch_size]
new_col.upsert(documents=batch_docs, ids=batch_ids, metadatas=batch_metas)
filed += len(batch_ids)
print(f" Re-filed {filed}/{len(all_ids)} drawers...")
print(f"\n Repair complete. {filed} drawers rebuilt.")
print(" HNSW index is now clean with cosine distance metric.")
print(f"\n{'=' * 55}\n")
if __name__ == "__main__":
p = argparse.ArgumentParser(description="MemPalace repair tools")
p.add_argument("command", choices=["scan", "prune", "rebuild"])
p.add_argument("--palace", default=None, help="Palace directory path")
p.add_argument("--wing", default=None, help="Scan only this wing")
p.add_argument("--confirm", action="store_true", help="Actually delete corrupt IDs")
args = p.parse_args()
path = os.path.expanduser(args.palace) if args.palace else None
if args.command == "scan":
scan_palace(palace_path=path, only_wing=args.wing)
elif args.command == "prune":
prune_corrupt(palace_path=path, confirm=args.confirm)
elif args.command == "rebuild":
rebuild_index(palace_path=path)
+272
View File
@@ -0,0 +1,272 @@
"""Tests for mempalace.dedup — near-duplicate drawer detection and removal."""
from unittest.mock import MagicMock, patch
from mempalace import dedup
# ── get_source_groups ─────────────────────────────────────────────────
def test_get_source_groups_basic():
col = MagicMock()
col.count.return_value = 5
col.get.side_effect = [
{
"ids": ["d1", "d2", "d3", "d4", "d5"],
"metadatas": [
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
],
},
{"ids": []},
]
groups = dedup.get_source_groups(col, min_count=5)
assert "a.txt" in groups
assert len(groups["a.txt"]) == 5
def test_get_source_groups_below_min():
col = MagicMock()
col.count.return_value = 2
col.get.side_effect = [
{
"ids": ["d1", "d2"],
"metadatas": [
{"source_file": "a.txt"},
{"source_file": "a.txt"},
],
},
{"ids": []},
]
groups = dedup.get_source_groups(col, min_count=5)
assert len(groups) == 0
def test_get_source_groups_source_filter():
col = MagicMock()
col.count.return_value = 6
col.get.side_effect = [
{
"ids": ["d1", "d2", "d3", "d4", "d5", "d6"],
"metadatas": [
{"source_file": "project_a.txt"},
{"source_file": "project_a.txt"},
{"source_file": "project_a.txt"},
{"source_file": "project_a.txt"},
{"source_file": "project_a.txt"},
{"source_file": "other.txt"},
],
},
{"ids": []},
]
groups = dedup.get_source_groups(col, min_count=5, source_pattern="project_a")
assert "project_a.txt" in groups
assert "other.txt" not in groups
def test_get_source_groups_wing_filter():
col = MagicMock()
col.count.return_value = 5
col.get.side_effect = [
{
"ids": ["d1", "d2", "d3", "d4", "d5"],
"metadatas": [
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
],
},
{"ids": []},
]
dedup.get_source_groups(col, min_count=5, wing="my_wing")
# Verify where filter was passed
first_call = col.get.call_args_list[0]
assert first_call.kwargs.get("where") == {"wing": "my_wing"}
def test_get_source_groups_missing_source_file():
col = MagicMock()
col.count.return_value = 5
col.get.side_effect = [
{
"ids": ["d1", "d2", "d3", "d4", "d5"],
"metadatas": [{}, {}, {}, {}, {}],
},
{"ids": []},
]
groups = dedup.get_source_groups(col, min_count=5)
assert "unknown" in groups
# ── dedup_source_group ────────────────────────────────────────────────
def test_dedup_source_group_all_unique():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": ["long document one content here", "different document two here"],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
col.query.return_value = {
"ids": [["d1"]],
"distances": [[0.8]], # far apart = unique
}
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
assert len(kept) == 2
assert len(deleted) == 0
def test_dedup_source_group_with_duplicate():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": [
"long document content that is fairly long",
"long document content that is fairly long",
],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
col.query.return_value = {
"ids": [["d1"]],
"distances": [[0.05]], # very close = duplicate
}
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
assert len(kept) == 1
assert len(deleted) == 1
def test_dedup_source_group_short_docs_deleted():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": ["long enough document to keep in the palace", "tiny"],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
assert "d2" in deleted # too short
def test_dedup_source_group_empty_doc_deleted():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": ["real document content here that is long enough", None],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
assert "d2" in deleted
def test_dedup_source_group_live_deletes():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": ["long document content here enough", "long document content here enough"],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
col.query.return_value = {
"ids": [["d1"]],
"distances": [[0.05]],
}
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=False)
col.delete.assert_called_once()
def test_dedup_source_group_query_failure_keeps():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": [
"long document one content here enough",
"long document two content here enough",
],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
col.query.side_effect = Exception("query failed")
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
assert len(kept) == 2 # both kept on error
# ── show_stats ────────────────────────────────────────────────────────
@patch("mempalace.dedup.chromadb")
def test_show_stats(mock_chromadb, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 5
mock_col.get.side_effect = [
{
"ids": ["d1", "d2", "d3", "d4", "d5"],
"metadatas": [
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
],
},
{"ids": []},
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
dedup.show_stats(palace_path=str(tmp_path)) # should not raise
# ── dedup_palace ──────────────────────────────────────────────────────
@patch("mempalace.dedup.dedup_source_group")
@patch("mempalace.dedup.get_source_groups")
@patch("mempalace.dedup.chromadb")
def test_dedup_palace_dry_run(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 10
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_groups.return_value = {"a.txt": ["d1", "d2", "d3", "d4", "d5"]}
mock_dedup_group.return_value = (["d1", "d2", "d3"], ["d4", "d5"])
dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True)
mock_dedup_group.assert_called_once()
@patch("mempalace.dedup.dedup_source_group")
@patch("mempalace.dedup.get_source_groups")
@patch("mempalace.dedup.chromadb")
def test_dedup_palace_with_wing(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 10
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_groups.return_value = {}
dedup.dedup_palace(palace_path=str(tmp_path), wing="test_wing", dry_run=True)
mock_groups.assert_called_once_with(mock_col, 5, None, wing="test_wing")
@patch("mempalace.dedup.dedup_source_group")
@patch("mempalace.dedup.get_source_groups")
@patch("mempalace.dedup.chromadb")
def test_dedup_palace_no_groups(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 3
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_groups.return_value = {}
dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True)
mock_dedup_group.assert_not_called()
+266
View File
@@ -0,0 +1,266 @@
"""Tests for mempalace.repair — scan, prune, and rebuild HNSW index."""
import os
from unittest.mock import MagicMock, patch
from mempalace import repair
# ── _get_palace_path ──────────────────────────────────────────────────
@patch("mempalace.repair.MempalaceConfig", create=True)
def test_get_palace_path_from_config(mock_config_cls):
mock_config_cls.return_value.palace_path = "/configured/palace"
with patch.dict("sys.modules", {}):
# Force reimport to pick up the mock
result = repair._get_palace_path()
assert isinstance(result, str)
def test_get_palace_path_fallback():
with patch("mempalace.repair._get_palace_path") as mock_get:
mock_get.return_value = os.path.join(os.path.expanduser("~"), ".mempalace", "palace")
result = mock_get()
assert ".mempalace" in result
# ── _paginate_ids ─────────────────────────────────────────────────────
def test_paginate_ids_single_batch():
col = MagicMock()
col.get.return_value = {"ids": ["id1", "id2", "id3"]}
ids = repair._paginate_ids(col)
assert ids == ["id1", "id2", "id3"]
def test_paginate_ids_empty():
col = MagicMock()
col.get.return_value = {"ids": []}
ids = repair._paginate_ids(col)
assert ids == []
def test_paginate_ids_with_where():
col = MagicMock()
col.get.return_value = {"ids": ["id1"]}
repair._paginate_ids(col, where={"wing": "test"})
col.get.assert_called_with(where={"wing": "test"}, include=[], limit=1000, offset=0)
def test_paginate_ids_offset_exception_fallback():
col = MagicMock()
# First call raises, fallback returns ids, second fallback returns empty
col.get.side_effect = [
Exception("offset bug"),
{"ids": ["id1", "id2"]},
Exception("offset bug"),
{"ids": ["id1", "id2"]}, # same ids = no new = break
]
ids = repair._paginate_ids(col)
assert "id1" in ids
# ── scan_palace ───────────────────────────────────────────────────────
@patch("mempalace.repair.chromadb")
def test_scan_palace_no_ids(mock_chromadb, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 0
mock_col.get.return_value = {"ids": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
good, bad = repair.scan_palace(palace_path=str(tmp_path))
assert good == set()
assert bad == set()
@patch("mempalace.repair.chromadb")
def test_scan_palace_all_good(mock_chromadb, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 2
# _paginate_ids call
mock_col.get.side_effect = [
{"ids": ["id1", "id2"]}, # paginate
{"ids": ["id1", "id2"]}, # probe batch — both returned
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
good, bad = repair.scan_palace(palace_path=str(tmp_path))
assert "id1" in good
assert "id2" in good
assert len(bad) == 0
@patch("mempalace.repair.chromadb")
def test_scan_palace_with_bad_ids(mock_chromadb, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 2
def get_side_effect(**kwargs):
ids = kwargs.get("ids", None)
if ids is None:
# paginate call
return {"ids": ["good1", "bad1"]}
if "bad1" in ids and len(ids) == 1:
raise Exception("corrupt")
if "good1" in ids and len(ids) == 1:
return {"ids": ["good1"]}
# batch probe — raise to force per-id
raise Exception("batch fail")
mock_col.get.side_effect = get_side_effect
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
good, bad = repair.scan_palace(palace_path=str(tmp_path))
assert "good1" in good
assert "bad1" in bad
@patch("mempalace.repair.chromadb")
def test_scan_palace_with_wing_filter(mock_chromadb, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 1
mock_col.get.side_effect = [
{"ids": ["id1"]}, # paginate
{"ids": ["id1"]}, # probe
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.scan_palace(palace_path=str(tmp_path), only_wing="test_wing")
# Verify where filter was passed
first_call = mock_col.get.call_args_list[0]
assert first_call.kwargs.get("where") == {"wing": "test_wing"}
# ── prune_corrupt ─────────────────────────────────────────────────────
@patch("mempalace.repair.chromadb")
def test_prune_corrupt_no_file(mock_chromadb, tmp_path):
# Should print message and return without error
repair.prune_corrupt(palace_path=str(tmp_path))
@patch("mempalace.repair.chromadb")
def test_prune_corrupt_dry_run(mock_chromadb, tmp_path):
bad_file = tmp_path / "corrupt_ids.txt"
bad_file.write_text("bad1\nbad2\n")
repair.prune_corrupt(palace_path=str(tmp_path), confirm=False)
# No chromadb calls in dry run
mock_chromadb.PersistentClient.assert_not_called()
@patch("mempalace.repair.chromadb")
def test_prune_corrupt_confirmed(mock_chromadb, tmp_path):
bad_file = tmp_path / "corrupt_ids.txt"
bad_file.write_text("bad1\nbad2\n")
mock_col = MagicMock()
mock_col.count.side_effect = [10, 8]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.prune_corrupt(palace_path=str(tmp_path), confirm=True)
mock_col.delete.assert_called_once()
@patch("mempalace.repair.chromadb")
def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path):
bad_file = tmp_path / "corrupt_ids.txt"
bad_file.write_text("bad1\nbad2\n")
mock_col = MagicMock()
mock_col.count.side_effect = [10, 8]
# Batch delete fails, per-id succeeds
mock_col.delete.side_effect = [Exception("batch fail"), None, None]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.prune_corrupt(palace_path=str(tmp_path), confirm=True)
assert mock_col.delete.call_count == 3 # 1 batch + 2 individual
# ── rebuild_index ─────────────────────────────────────────────────────
@patch("mempalace.repair.chromadb")
def test_rebuild_index_no_palace(mock_chromadb, tmp_path):
nonexistent = str(tmp_path / "nope")
repair.rebuild_index(palace_path=nonexistent)
mock_chromadb.PersistentClient.assert_not_called()
@patch("mempalace.repair.shutil")
@patch("mempalace.repair.chromadb")
def test_rebuild_index_empty_palace(mock_chromadb, mock_shutil, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 0
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.rebuild_index(palace_path=str(tmp_path))
mock_client.delete_collection.assert_not_called()
@patch("mempalace.repair.shutil")
@patch("mempalace.repair.chromadb")
def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path):
# Create a fake sqlite file
sqlite_path = tmp_path / "chroma.sqlite3"
sqlite_path.write_text("fake")
mock_col = MagicMock()
mock_col.count.return_value = 2
mock_col.get.return_value = {
"ids": ["id1", "id2"],
"documents": ["doc1", "doc2"],
"metadatas": [{"wing": "a"}, {"wing": "b"}],
}
mock_new_col = MagicMock()
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_client.create_collection.return_value = mock_new_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.rebuild_index(palace_path=str(tmp_path))
# Verify: backed up sqlite only (not copytree)
mock_shutil.copy2.assert_called_once()
assert "chroma.sqlite3" in str(mock_shutil.copy2.call_args)
# Verify: deleted and recreated with cosine
mock_client.delete_collection.assert_called_once_with("mempalace_drawers")
mock_client.create_collection.assert_called_once_with(
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
)
# Verify: used upsert not add
mock_new_col.upsert.assert_called_once()
mock_new_col.add.assert_not_called()
@patch("mempalace.repair.shutil")
@patch("mempalace.repair.chromadb")
def test_rebuild_index_error_reading(mock_chromadb, mock_shutil, tmp_path):
mock_client = MagicMock()
mock_client.get_collection.side_effect = Exception("corrupt")
mock_chromadb.PersistentClient.return_value = mock_client
repair.rebuild_index(palace_path=str(tmp_path))
mock_client.delete_collection.assert_not_called()