diff --git a/mempalace/convo_miner.py b/mempalace/convo_miner.py index 7879f96..3bb4a89 100644 --- a/mempalace/convo_miner.py +++ b/mempalace/convo_miner.py @@ -334,7 +334,7 @@ def mine_convos( room_counts[chunk_room] += 1 drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}" try: - collection.add( + collection.upsert( documents=[chunk["content"]], ids=[drawer_id], metadatas=[ diff --git a/mempalace/dedup.py b/mempalace/dedup.py new file mode 100644 index 0000000..c2f9f6b --- /dev/null +++ b/mempalace/dedup.py @@ -0,0 +1,239 @@ +""" +dedup.py — Detect and remove near-duplicate drawers +==================================================== + +When the same files are mined multiple times, near-identical drawers +accumulate. This module finds drawers from the same source_file that +are too similar (cosine distance < threshold), keeps the longest/richest +version, and deletes the rest. + +No API calls — uses ChromaDB's built-in embedding similarity. + +Usage (standalone): + python -m mempalace.dedup # dedup all + python -m mempalace.dedup --dry-run # preview only + python -m mempalace.dedup --threshold 0.10 # stricter (near-identical only) + python -m mempalace.dedup --threshold 0.35 # looser (catches paraphrased content) + python -m mempalace.dedup --wing my_project # scope to one wing + python -m mempalace.dedup --stats # stats only + python -m mempalace.dedup --source "my_project" # filter by source + +Usage (from CLI): + mempalace dedup [--dry-run] [--threshold 0.15] [--stats] +""" + +import argparse +import os +import time +from collections import defaultdict + +import chromadb + + +COLLECTION_NAME = "mempalace_drawers" +# Cosine DISTANCE threshold (not similarity). Lower = stricter. +# 0.15 = ~85% cosine similarity — catches near-identical chunks. +# For looser dedup of paraphrased content, try 0.3–0.4. +DEFAULT_THRESHOLD = 0.15 +MIN_DRAWERS_TO_CHECK = 5 + + +def _get_palace_path(): + """Resolve palace path from config.""" + try: + from .config import MempalaceConfig + + return MempalaceConfig().palace_path + except Exception: + return os.path.join(os.path.expanduser("~"), ".mempalace", "palace") + + +def get_source_groups(col, min_count=MIN_DRAWERS_TO_CHECK, source_pattern=None, wing=None): + """Group drawers by source_file, return groups with min_count+ entries. + + If wing is specified, only considers drawers in that wing. This catches + cross-wing duplicates when the same source was mined into multiple wings. + """ + total = col.count() + groups = defaultdict(list) + + offset = 0 + batch_size = 1000 + while offset < total: + kwargs = {"limit": batch_size, "offset": offset, "include": ["metadatas"]} + if wing: + kwargs["where"] = {"wing": wing} + batch = col.get(**kwargs) + if not batch["ids"]: + break + for did, meta in zip(batch["ids"], batch["metadatas"]): + src = meta.get("source_file", "unknown") + if source_pattern and source_pattern.lower() not in src.lower(): + continue + groups[src].append(did) + offset += len(batch["ids"]) + + return {src: ids for src, ids in groups.items() if len(ids) >= min_count} + + +def dedup_source_group(col, drawer_ids, threshold=DEFAULT_THRESHOLD, dry_run=True): + """Dedup drawers within one source_file group. + + Greedy: sort by doc length (longest first), keep if not too similar + to any already-kept drawer. Returns (kept_ids, deleted_ids). + """ + data = col.get(ids=drawer_ids, include=["documents", "metadatas"]) + items = list(zip(data["ids"], data["documents"], data["metadatas"])) + items.sort(key=lambda x: len(x[1] or ""), reverse=True) + + kept = [] + to_delete = [] + + for did, doc, meta in items: + if not doc or len(doc) < 20: + to_delete.append(did) + continue + + if not kept: + kept.append((did, doc)) + continue + + try: + results = col.query( + query_texts=[doc], + n_results=min(len(kept), 5), + include=["distances"], + ) + dists = results["distances"][0] if results["distances"] else [] + kept_ids_set = {k[0] for k in kept} + + is_dup = False + for rid, dist in zip(results["ids"][0], dists): + if rid in kept_ids_set and dist < threshold: + is_dup = True + break + + if is_dup: + to_delete.append(did) + else: + kept.append((did, doc)) + except Exception: + kept.append((did, doc)) + + if to_delete and not dry_run: + for i in range(0, len(to_delete), 500): + col.delete(ids=to_delete[i : i + 500]) + + return [k[0] for k in kept], to_delete + + +def show_stats(palace_path=None): + """Show duplication statistics without making changes.""" + palace_path = palace_path or _get_palace_path() + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection(COLLECTION_NAME) + + groups = get_source_groups(col) + + total_drawers = sum(len(ids) for ids in groups.values()) + print(f"\n Sources with {MIN_DRAWERS_TO_CHECK}+ drawers: {len(groups)}") + print(f" Total drawers in those sources: {total_drawers:,}") + + print("\n Top 15 by drawer count:") + sorted_groups = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True) + for src, ids in sorted_groups[:15]: + print(f" {len(ids):4d} {src[:65]}") + + estimated_dups = sum(int(len(ids) * 0.4) for ids in groups.values() if len(ids) > 20) + print(f"\n Estimated duplicates (groups > 20): ~{estimated_dups:,}") + + +def dedup_palace( + palace_path=None, + threshold=DEFAULT_THRESHOLD, + dry_run=True, + source_pattern=None, + min_count=MIN_DRAWERS_TO_CHECK, + wing=None, +): + """Main entry point: deduplicate near-identical drawers across the palace.""" + palace_path = palace_path or _get_palace_path() + + print(f"\n{'=' * 55}") + print(" MemPalace Deduplicator") + print(f"{'=' * 55}") + + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection(COLLECTION_NAME) + + print(f" Palace: {palace_path}") + print(f" Drawers: {col.count():,}") + print(f" Threshold: {threshold}") + print(f" Mode: {'DRY RUN' if dry_run else 'LIVE'}") + print(f"{'─' * 55}") + + if wing: + print(f" Wing: {wing}") + groups = get_source_groups(col, min_count, source_pattern, wing=wing) + print(f"\n Sources to check: {len(groups)}") + + t0 = time.time() + total_kept = 0 + total_deleted = 0 + + sorted_groups = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True) + + for i, (src, drawer_ids) in enumerate(sorted_groups): + kept, deleted = dedup_source_group(col, drawer_ids, threshold, dry_run) + total_kept += len(kept) + total_deleted += len(deleted) + + if deleted: + print( + f" [{i + 1:3d}/{len(groups)}] " + f"{src[:50]:50s} {len(drawer_ids):4d} → {len(kept):4d} " + f"(-{len(deleted)})" + ) + + elapsed = time.time() - t0 + + print(f"\n{'─' * 55}") + print(f" Done in {elapsed:.1f}s") + print( + f" Drawers: {total_kept + total_deleted:,} → {total_kept:,} (-{total_deleted:,} removed)" + ) + print(f" Palace after: {col.count():,} drawers") + + if dry_run: + print("\n [DRY RUN] No changes written. Re-run without --dry-run to apply.") + + print(f"{'=' * 55}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Deduplicate near-identical drawers") + parser.add_argument("--palace", default=None, help="Palace directory path") + parser.add_argument( + "--threshold", + type=float, + default=DEFAULT_THRESHOLD, + help=f"Cosine distance threshold (default: {DEFAULT_THRESHOLD})", + ) + parser.add_argument("--dry-run", action="store_true", help="Preview without deleting") + parser.add_argument("--stats", action="store_true", help="Show stats only") + parser.add_argument("--wing", default=None, help="Scope dedup to a single wing") + parser.add_argument("--source", default=None, help="Filter by source file pattern") + args = parser.parse_args() + + path = os.path.expanduser(args.palace) if args.palace else None + + if args.stats: + show_stats(palace_path=path) + else: + dedup_palace( + palace_path=path, + threshold=args.threshold, + dry_run=args.dry_run, + source_pattern=args.source, + wing=args.wing, + ) diff --git a/mempalace/hooks_cli.py b/mempalace/hooks_cli.py index 3f3fc09..b6d2290 100644 --- a/mempalace/hooks_cli.py +++ b/mempalace/hooks_cli.py @@ -63,6 +63,14 @@ def _count_human_messages(transcript_path: str) -> int: if "" in text: continue count += 1 + # Also handle Codex CLI transcript format + # {"type": "event_msg", "payload": {"type": "user_message", "message": "..."}} + elif entry.get("type") == "event_msg": + payload = entry.get("payload", {}) + if isinstance(payload, dict) and payload.get("type") == "user_message": + msg_text = payload.get("message", "") + if isinstance(msg_text, str) and "" not in msg_text: + count += 1 except (json.JSONDecodeError, AttributeError): pass except OSError: diff --git a/mempalace/miner.py b/mempalace/miner.py index b52e6f7..f342a2d 100644 --- a/mempalace/miner.py +++ b/mempalace/miner.py @@ -436,6 +436,16 @@ def process_file( print(f" [DRY RUN] {filepath.name} → room:{room} ({len(chunks)} drawers)") return len(chunks), room + # Purge stale drawers for this file before re-inserting the fresh chunks. + # Converts modified-file re-mines from upsert-over-existing-IDs (which hits + # hnswlib's thread-unsafe updatePoint path and can segfault on macOS ARM + # with chromadb 0.6.3) into a clean delete+insert, bypassing the update + # path entirely. + try: + collection.delete(where={"source_file": source_file}) + except Exception: + pass + drawers_added = 0 for chunk in chunks: added = add_drawer( diff --git a/mempalace/repair.py b/mempalace/repair.py new file mode 100644 index 0000000..d51be60 --- /dev/null +++ b/mempalace/repair.py @@ -0,0 +1,299 @@ +""" +repair.py — Scan, prune corrupt entries, and rebuild HNSW index +================================================================ + +When ChromaDB's HNSW index accumulates duplicate entries (from repeated +add() calls with the same ID), link_lists.bin can grow unbounded — +terabytes on large palaces — eventually causing segfaults. + +This module provides three operations: + + scan — find every corrupt/unfetchable ID in the palace + prune — delete only the corrupt IDs (surgical) + rebuild — extract all drawers, delete the collection, recreate with + correct HNSW settings, and upsert everything back + +The rebuild backs up ONLY chroma.sqlite3 (the source of truth), not the +full palace directory — so it works even when link_lists.bin is bloated. + +Usage (standalone): + python -m mempalace.repair scan [--wing X] + python -m mempalace.repair prune --confirm + python -m mempalace.repair rebuild + +Usage (from CLI): + mempalace repair + mempalace repair-scan [--wing X] + mempalace repair-prune --confirm +""" + +import argparse +import os +import shutil +import time + +import chromadb + + +COLLECTION_NAME = "mempalace_drawers" + + +def _get_palace_path(): + """Resolve palace path from config.""" + try: + from .config import MempalaceConfig + + return MempalaceConfig().palace_path + except Exception: + default = os.path.join(os.path.expanduser("~"), ".mempalace", "palace") + return default + + +def _paginate_ids(col, where=None): + """Pull all IDs in a collection using pagination.""" + ids = [] + page = 1000 + offset = 0 + while True: + try: + r = col.get(where=where, include=[], limit=page, offset=offset) + except Exception: + try: + r = col.get(where=where, include=[], limit=page) + new_ids = [i for i in r["ids"] if i not in set(ids)] + if not new_ids: + break + ids.extend(new_ids) + offset += len(new_ids) + continue + except Exception: + break + n = len(r["ids"]) if r["ids"] else 0 + if n == 0: + break + ids.extend(r["ids"]) + offset += n + if n < page: + break + return ids + + +def scan_palace(palace_path=None, only_wing=None): + """Scan the palace for corrupt/unfetchable IDs. + + Probes in batches of 100, falls back to per-ID on failure. + Writes corrupt_ids.txt to the palace directory for the prune step. + + Returns (good_set, bad_set). + """ + palace_path = palace_path or _get_palace_path() + print(f"\n Palace: {palace_path}") + print(" Loading...") + + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection(COLLECTION_NAME) + + where = {"wing": only_wing} if only_wing else None + total = col.count() + print(f" Collection: {COLLECTION_NAME}, total: {total:,}") + if only_wing: + print(f" Scanning wing: {only_wing}") + + print("\n Step 1: listing all IDs...") + t0 = time.time() + all_ids = _paginate_ids(col, where=where) + print(f" Found {len(all_ids):,} IDs in {time.time() - t0:.1f}s\n") + + if not all_ids: + print(" Nothing to scan.") + return set(), set() + + print(" Step 2: probing each ID (batches of 100)...") + t0 = time.time() + good_set = set() + bad_set = set() + batch = 100 + + for i in range(0, len(all_ids), batch): + chunk = all_ids[i : i + batch] + try: + r = col.get(ids=chunk, include=["documents"]) + for got in r["ids"]: + good_set.add(got) + for mid in chunk: + if mid not in good_set: + bad_set.add(mid) + except Exception: + for sid in chunk: + try: + r = col.get(ids=[sid], include=["documents"]) + if r["ids"]: + good_set.add(sid) + else: + bad_set.add(sid) + except Exception: + bad_set.add(sid) + + if (i // batch) % 50 == 0: + elapsed = time.time() - t0 + rate = (i + batch) / max(elapsed, 0.01) + eta = (len(all_ids) - i - batch) / max(rate, 0.01) + print( + f" {i + batch:>6}/{len(all_ids):>6} " + f"good={len(good_set):>6} bad={len(bad_set):>6} " + f"eta={eta:.0f}s" + ) + + print(f"\n Scan complete in {time.time() - t0:.1f}s") + print(f" GOOD: {len(good_set):,}") + print(f" BAD: {len(bad_set):,} ({len(bad_set) / max(len(all_ids), 1) * 100:.1f}%)") + + bad_file = os.path.join(palace_path, "corrupt_ids.txt") + with open(bad_file, "w") as f: + for bid in sorted(bad_set): + f.write(bid + "\n") + print(f"\n Bad IDs written to: {bad_file}") + return good_set, bad_set + + +def prune_corrupt(palace_path=None, confirm=False): + """Delete corrupt IDs listed in corrupt_ids.txt.""" + palace_path = palace_path or _get_palace_path() + bad_file = os.path.join(palace_path, "corrupt_ids.txt") + + if not os.path.exists(bad_file): + print(" No corrupt_ids.txt found — run scan first.") + return + + with open(bad_file) as f: + bad_ids = [line.strip() for line in f if line.strip()] + print(f" {len(bad_ids):,} corrupt IDs queued for deletion") + + if not confirm: + print("\n DRY RUN — no deletions performed.") + print(" Re-run with --confirm to actually delete.") + return + + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection(COLLECTION_NAME) + before = col.count() + print(f" Collection size before: {before:,}") + + batch = 100 + deleted = 0 + failed = 0 + for i in range(0, len(bad_ids), batch): + chunk = bad_ids[i : i + batch] + try: + col.delete(ids=chunk) + deleted += len(chunk) + except Exception: + for sid in chunk: + try: + col.delete(ids=[sid]) + deleted += 1 + except Exception: + failed += 1 + if (i // batch) % 20 == 0: + print(f" deleted {deleted}/{len(bad_ids)} (failed: {failed})") + + after = col.count() + print(f"\n Deleted: {deleted:,}") + print(f" Failed: {failed:,}") + print(f" Collection size: {before:,} → {after:,}") + + +def rebuild_index(palace_path=None): + """Rebuild the HNSW index from scratch. + + 1. Extract all drawers via ChromaDB get() + 2. Back up ONLY chroma.sqlite3 (not the bloated HNSW files) + 3. Delete and recreate the collection with hnsw:space=cosine + 4. Upsert all drawers back + """ + palace_path = palace_path or _get_palace_path() + + if not os.path.isdir(palace_path): + print(f"\n No palace found at {palace_path}") + return + + print(f"\n{'=' * 55}") + print(" MemPalace Repair — Index Rebuild") + print(f"{'=' * 55}\n") + print(f" Palace: {palace_path}") + + client = chromadb.PersistentClient(path=palace_path) + try: + col = client.get_collection(COLLECTION_NAME) + total = col.count() + except Exception as e: + print(f" Error reading palace: {e}") + print(" Palace may need to be re-mined from source files.") + return + + print(f" Drawers found: {total}") + + if total == 0: + print(" Nothing to repair.") + return + + # Extract all drawers in batches + print("\n Extracting drawers...") + batch_size = 5000 + all_ids = [] + all_docs = [] + all_metas = [] + offset = 0 + while offset < total: + batch = col.get(limit=batch_size, offset=offset, include=["documents", "metadatas"]) + if not batch["ids"]: + break + all_ids.extend(batch["ids"]) + all_docs.extend(batch["documents"]) + all_metas.extend(batch["metadatas"]) + offset += len(batch["ids"]) + print(f" Extracted {len(all_ids)} drawers") + + # Back up ONLY the SQLite database, not the bloated HNSW files + sqlite_path = os.path.join(palace_path, "chroma.sqlite3") + if os.path.exists(sqlite_path): + backup_path = sqlite_path + ".backup" + print(f" Backing up chroma.sqlite3 ({os.path.getsize(sqlite_path) / 1e6:.0f} MB)...") + shutil.copy2(sqlite_path, backup_path) + print(f" Backup: {backup_path}") + + # Rebuild with correct HNSW settings + print(" Rebuilding collection with hnsw:space=cosine...") + client.delete_collection(COLLECTION_NAME) + new_col = client.create_collection(COLLECTION_NAME, metadata={"hnsw:space": "cosine"}) + + filed = 0 + for i in range(0, len(all_ids), batch_size): + batch_ids = all_ids[i : i + batch_size] + batch_docs = all_docs[i : i + batch_size] + batch_metas = all_metas[i : i + batch_size] + new_col.upsert(documents=batch_docs, ids=batch_ids, metadatas=batch_metas) + filed += len(batch_ids) + print(f" Re-filed {filed}/{len(all_ids)} drawers...") + + print(f"\n Repair complete. {filed} drawers rebuilt.") + print(" HNSW index is now clean with cosine distance metric.") + print(f"\n{'=' * 55}\n") + + +if __name__ == "__main__": + p = argparse.ArgumentParser(description="MemPalace repair tools") + p.add_argument("command", choices=["scan", "prune", "rebuild"]) + p.add_argument("--palace", default=None, help="Palace directory path") + p.add_argument("--wing", default=None, help="Scan only this wing") + p.add_argument("--confirm", action="store_true", help="Actually delete corrupt IDs") + args = p.parse_args() + + path = os.path.expanduser(args.palace) if args.palace else None + + if args.command == "scan": + scan_palace(palace_path=path, only_wing=args.wing) + elif args.command == "prune": + prune_corrupt(palace_path=path, confirm=args.confirm) + elif args.command == "rebuild": + rebuild_index(palace_path=path) diff --git a/tests/test_dedup.py b/tests/test_dedup.py new file mode 100644 index 0000000..2ddffb3 --- /dev/null +++ b/tests/test_dedup.py @@ -0,0 +1,272 @@ +"""Tests for mempalace.dedup — near-duplicate drawer detection and removal.""" + +from unittest.mock import MagicMock, patch + + +from mempalace import dedup + + +# ── get_source_groups ───────────────────────────────────────────────── + + +def test_get_source_groups_basic(): + col = MagicMock() + col.count.return_value = 5 + col.get.side_effect = [ + { + "ids": ["d1", "d2", "d3", "d4", "d5"], + "metadatas": [ + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + ], + }, + {"ids": []}, + ] + groups = dedup.get_source_groups(col, min_count=5) + assert "a.txt" in groups + assert len(groups["a.txt"]) == 5 + + +def test_get_source_groups_below_min(): + col = MagicMock() + col.count.return_value = 2 + col.get.side_effect = [ + { + "ids": ["d1", "d2"], + "metadatas": [ + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + ], + }, + {"ids": []}, + ] + groups = dedup.get_source_groups(col, min_count=5) + assert len(groups) == 0 + + +def test_get_source_groups_source_filter(): + col = MagicMock() + col.count.return_value = 6 + col.get.side_effect = [ + { + "ids": ["d1", "d2", "d3", "d4", "d5", "d6"], + "metadatas": [ + {"source_file": "project_a.txt"}, + {"source_file": "project_a.txt"}, + {"source_file": "project_a.txt"}, + {"source_file": "project_a.txt"}, + {"source_file": "project_a.txt"}, + {"source_file": "other.txt"}, + ], + }, + {"ids": []}, + ] + groups = dedup.get_source_groups(col, min_count=5, source_pattern="project_a") + assert "project_a.txt" in groups + assert "other.txt" not in groups + + +def test_get_source_groups_wing_filter(): + col = MagicMock() + col.count.return_value = 5 + col.get.side_effect = [ + { + "ids": ["d1", "d2", "d3", "d4", "d5"], + "metadatas": [ + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + ], + }, + {"ids": []}, + ] + dedup.get_source_groups(col, min_count=5, wing="my_wing") + # Verify where filter was passed + first_call = col.get.call_args_list[0] + assert first_call.kwargs.get("where") == {"wing": "my_wing"} + + +def test_get_source_groups_missing_source_file(): + col = MagicMock() + col.count.return_value = 5 + col.get.side_effect = [ + { + "ids": ["d1", "d2", "d3", "d4", "d5"], + "metadatas": [{}, {}, {}, {}, {}], + }, + {"ids": []}, + ] + groups = dedup.get_source_groups(col, min_count=5) + assert "unknown" in groups + + +# ── dedup_source_group ──────────────────────────────────────────────── + + +def test_dedup_source_group_all_unique(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": ["long document one content here", "different document two here"], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + col.query.return_value = { + "ids": [["d1"]], + "distances": [[0.8]], # far apart = unique + } + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True) + assert len(kept) == 2 + assert len(deleted) == 0 + + +def test_dedup_source_group_with_duplicate(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": [ + "long document content that is fairly long", + "long document content that is fairly long", + ], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + col.query.return_value = { + "ids": [["d1"]], + "distances": [[0.05]], # very close = duplicate + } + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True) + assert len(kept) == 1 + assert len(deleted) == 1 + + +def test_dedup_source_group_short_docs_deleted(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": ["long enough document to keep in the palace", "tiny"], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True) + assert "d2" in deleted # too short + + +def test_dedup_source_group_empty_doc_deleted(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": ["real document content here that is long enough", None], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True) + assert "d2" in deleted + + +def test_dedup_source_group_live_deletes(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": ["long document content here enough", "long document content here enough"], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + col.query.return_value = { + "ids": [["d1"]], + "distances": [[0.05]], + } + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=False) + col.delete.assert_called_once() + + +def test_dedup_source_group_query_failure_keeps(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": [ + "long document one content here enough", + "long document two content here enough", + ], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + col.query.side_effect = Exception("query failed") + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True) + assert len(kept) == 2 # both kept on error + + +# ── show_stats ──────────────────────────────────────────────────────── + + +@patch("mempalace.dedup.chromadb") +def test_show_stats(mock_chromadb, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 5 + mock_col.get.side_effect = [ + { + "ids": ["d1", "d2", "d3", "d4", "d5"], + "metadatas": [ + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + ], + }, + {"ids": []}, + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + dedup.show_stats(palace_path=str(tmp_path)) # should not raise + + +# ── dedup_palace ────────────────────────────────────────────────────── + + +@patch("mempalace.dedup.dedup_source_group") +@patch("mempalace.dedup.get_source_groups") +@patch("mempalace.dedup.chromadb") +def test_dedup_palace_dry_run(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 10 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_groups.return_value = {"a.txt": ["d1", "d2", "d3", "d4", "d5"]} + mock_dedup_group.return_value = (["d1", "d2", "d3"], ["d4", "d5"]) + + dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True) + mock_dedup_group.assert_called_once() + + +@patch("mempalace.dedup.dedup_source_group") +@patch("mempalace.dedup.get_source_groups") +@patch("mempalace.dedup.chromadb") +def test_dedup_palace_with_wing(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 10 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_groups.return_value = {} + dedup.dedup_palace(palace_path=str(tmp_path), wing="test_wing", dry_run=True) + mock_groups.assert_called_once_with(mock_col, 5, None, wing="test_wing") + + +@patch("mempalace.dedup.dedup_source_group") +@patch("mempalace.dedup.get_source_groups") +@patch("mempalace.dedup.chromadb") +def test_dedup_palace_no_groups(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 3 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_groups.return_value = {} + dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True) + mock_dedup_group.assert_not_called() diff --git a/tests/test_repair.py b/tests/test_repair.py new file mode 100644 index 0000000..604b0fb --- /dev/null +++ b/tests/test_repair.py @@ -0,0 +1,266 @@ +"""Tests for mempalace.repair — scan, prune, and rebuild HNSW index.""" + +import os +from unittest.mock import MagicMock, patch + + +from mempalace import repair + + +# ── _get_palace_path ────────────────────────────────────────────────── + + +@patch("mempalace.repair.MempalaceConfig", create=True) +def test_get_palace_path_from_config(mock_config_cls): + mock_config_cls.return_value.palace_path = "/configured/palace" + with patch.dict("sys.modules", {}): + # Force reimport to pick up the mock + result = repair._get_palace_path() + assert isinstance(result, str) + + +def test_get_palace_path_fallback(): + with patch("mempalace.repair._get_palace_path") as mock_get: + mock_get.return_value = os.path.join(os.path.expanduser("~"), ".mempalace", "palace") + result = mock_get() + assert ".mempalace" in result + + +# ── _paginate_ids ───────────────────────────────────────────────────── + + +def test_paginate_ids_single_batch(): + col = MagicMock() + col.get.return_value = {"ids": ["id1", "id2", "id3"]} + ids = repair._paginate_ids(col) + assert ids == ["id1", "id2", "id3"] + + +def test_paginate_ids_empty(): + col = MagicMock() + col.get.return_value = {"ids": []} + ids = repair._paginate_ids(col) + assert ids == [] + + +def test_paginate_ids_with_where(): + col = MagicMock() + col.get.return_value = {"ids": ["id1"]} + repair._paginate_ids(col, where={"wing": "test"}) + col.get.assert_called_with(where={"wing": "test"}, include=[], limit=1000, offset=0) + + +def test_paginate_ids_offset_exception_fallback(): + col = MagicMock() + # First call raises, fallback returns ids, second fallback returns empty + col.get.side_effect = [ + Exception("offset bug"), + {"ids": ["id1", "id2"]}, + Exception("offset bug"), + {"ids": ["id1", "id2"]}, # same ids = no new = break + ] + ids = repair._paginate_ids(col) + assert "id1" in ids + + +# ── scan_palace ─────────────────────────────────────────────────────── + + +@patch("mempalace.repair.chromadb") +def test_scan_palace_no_ids(mock_chromadb, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 0 + mock_col.get.return_value = {"ids": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + good, bad = repair.scan_palace(palace_path=str(tmp_path)) + assert good == set() + assert bad == set() + + +@patch("mempalace.repair.chromadb") +def test_scan_palace_all_good(mock_chromadb, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 2 + # _paginate_ids call + mock_col.get.side_effect = [ + {"ids": ["id1", "id2"]}, # paginate + {"ids": ["id1", "id2"]}, # probe batch — both returned + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + good, bad = repair.scan_palace(palace_path=str(tmp_path)) + assert "id1" in good + assert "id2" in good + assert len(bad) == 0 + + +@patch("mempalace.repair.chromadb") +def test_scan_palace_with_bad_ids(mock_chromadb, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 2 + + def get_side_effect(**kwargs): + ids = kwargs.get("ids", None) + if ids is None: + # paginate call + return {"ids": ["good1", "bad1"]} + if "bad1" in ids and len(ids) == 1: + raise Exception("corrupt") + if "good1" in ids and len(ids) == 1: + return {"ids": ["good1"]} + # batch probe — raise to force per-id + raise Exception("batch fail") + + mock_col.get.side_effect = get_side_effect + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + good, bad = repair.scan_palace(palace_path=str(tmp_path)) + assert "good1" in good + assert "bad1" in bad + + +@patch("mempalace.repair.chromadb") +def test_scan_palace_with_wing_filter(mock_chromadb, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 1 + mock_col.get.side_effect = [ + {"ids": ["id1"]}, # paginate + {"ids": ["id1"]}, # probe + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + repair.scan_palace(palace_path=str(tmp_path), only_wing="test_wing") + # Verify where filter was passed + first_call = mock_col.get.call_args_list[0] + assert first_call.kwargs.get("where") == {"wing": "test_wing"} + + +# ── prune_corrupt ───────────────────────────────────────────────────── + + +@patch("mempalace.repair.chromadb") +def test_prune_corrupt_no_file(mock_chromadb, tmp_path): + # Should print message and return without error + repair.prune_corrupt(palace_path=str(tmp_path)) + + +@patch("mempalace.repair.chromadb") +def test_prune_corrupt_dry_run(mock_chromadb, tmp_path): + bad_file = tmp_path / "corrupt_ids.txt" + bad_file.write_text("bad1\nbad2\n") + repair.prune_corrupt(palace_path=str(tmp_path), confirm=False) + # No chromadb calls in dry run + mock_chromadb.PersistentClient.assert_not_called() + + +@patch("mempalace.repair.chromadb") +def test_prune_corrupt_confirmed(mock_chromadb, tmp_path): + bad_file = tmp_path / "corrupt_ids.txt" + bad_file.write_text("bad1\nbad2\n") + + mock_col = MagicMock() + mock_col.count.side_effect = [10, 8] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + repair.prune_corrupt(palace_path=str(tmp_path), confirm=True) + mock_col.delete.assert_called_once() + + +@patch("mempalace.repair.chromadb") +def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path): + bad_file = tmp_path / "corrupt_ids.txt" + bad_file.write_text("bad1\nbad2\n") + + mock_col = MagicMock() + mock_col.count.side_effect = [10, 8] + # Batch delete fails, per-id succeeds + mock_col.delete.side_effect = [Exception("batch fail"), None, None] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + repair.prune_corrupt(palace_path=str(tmp_path), confirm=True) + assert mock_col.delete.call_count == 3 # 1 batch + 2 individual + + +# ── rebuild_index ───────────────────────────────────────────────────── + + +@patch("mempalace.repair.chromadb") +def test_rebuild_index_no_palace(mock_chromadb, tmp_path): + nonexistent = str(tmp_path / "nope") + repair.rebuild_index(palace_path=nonexistent) + mock_chromadb.PersistentClient.assert_not_called() + + +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.chromadb") +def test_rebuild_index_empty_palace(mock_chromadb, mock_shutil, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 0 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + repair.rebuild_index(palace_path=str(tmp_path)) + mock_client.delete_collection.assert_not_called() + + +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.chromadb") +def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path): + # Create a fake sqlite file + sqlite_path = tmp_path / "chroma.sqlite3" + sqlite_path.write_text("fake") + + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + + mock_new_col = MagicMock() + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_client.create_collection.return_value = mock_new_col + mock_chromadb.PersistentClient.return_value = mock_client + + repair.rebuild_index(palace_path=str(tmp_path)) + + # Verify: backed up sqlite only (not copytree) + mock_shutil.copy2.assert_called_once() + assert "chroma.sqlite3" in str(mock_shutil.copy2.call_args) + + # Verify: deleted and recreated with cosine + mock_client.delete_collection.assert_called_once_with("mempalace_drawers") + mock_client.create_collection.assert_called_once_with( + "mempalace_drawers", metadata={"hnsw:space": "cosine"} + ) + + # Verify: used upsert not add + mock_new_col.upsert.assert_called_once() + mock_new_col.add.assert_not_called() + + +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.chromadb") +def test_rebuild_index_error_reading(mock_chromadb, mock_shutil, tmp_path): + mock_client = MagicMock() + mock_client.get_collection.side_effect = Exception("corrupt") + mock_chromadb.PersistentClient.return_value = mock_client + + repair.rebuild_index(palace_path=str(tmp_path)) + mock_client.delete_collection.assert_not_called()