From f935e85ead2c647a28f33b84d1287164b1469a52 Mon Sep 17 00:00:00 2001 From: MSL <232237854+milla-jovovich@users.noreply.github.com> Date: Mon, 13 Apr 2026 01:47:19 -0700 Subject: [PATCH 01/12] feat: entity metadata + diary ingest + BM25 hybrid search Three features that close the gap between the architecture docs and the actual codebase: 1. Entity metadata on drawers and closets - _extract_entities_for_metadata() pulls names from known_entities.json + proper nouns appearing 2+ times - Stamped as "entities" field in ChromaDB metadata - Enables filterable search by person/project name 2. Day-based diary ingest (diary_ingest.py) - ONE drawer per day, upserted as the day grows - Closets pack topics atomically, never split mid-topic - Tracks entry count in state file, only processes new entries - Usage: python -m mempalace.diary_ingest --dir ~/summaries 3. BM25 hybrid search in searcher.py - _bm25_score() keyword matching complements vector similarity - _hybrid_rank() combines both signals (60% vector, 40% BM25) - Catches exact name/term matches that embeddings miss - Applied to both closet-first and direct drawer search paths 689/689 tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) --- mempalace/diary_ingest.py | 173 ++++++++++++++++++++++++++++++++++++++ mempalace/miner.py | 49 ++++++++++- mempalace/searcher.py | 64 +++++++++++++- 3 files changed, 282 insertions(+), 4 deletions(-) create mode 100644 mempalace/diary_ingest.py diff --git a/mempalace/diary_ingest.py b/mempalace/diary_ingest.py new file mode 100644 index 0000000..e64e139 --- /dev/null +++ b/mempalace/diary_ingest.py @@ -0,0 +1,173 @@ +""" +diary_ingest.py — Ingest daily summary files into the palace. + +Architecture: +- ONE drawer per day — full verbatim content, upserted as the day grows +- Closets pack topics up to 1500 chars, never split mid-topic +- Only new entries are processed (tracks entry count in state file) +- Entities extracted and stamped on metadata for filterable search + +Usage: + python -m mempalace.diary_ingest --dir ~/daily_summaries --palace ~/.mempalace/palace + python -m mempalace.diary_ingest --dir ~/daily_summaries --palace ~/.mempalace/palace --force +""" + +import hashlib +import json +import os +import re +from datetime import datetime, timezone +from pathlib import Path + +from .palace import ( + get_collection, + get_closets_collection, + build_closet_lines, + upsert_closet_lines, + CLOSET_CHAR_LIMIT, +) +from .miner import _extract_entities_for_metadata + + +DIARY_ENTRY_RE = re.compile(r"^## .+", re.MULTILINE) + + +def _split_entries(text): + """Split diary text into (header, body) pairs per ## entry.""" + parts = DIARY_ENTRY_RE.split(text) + headers = DIARY_ENTRY_RE.findall(text) + entries = [] + for i, header in enumerate(headers): + body = parts[i + 1] if i + 1 < len(parts) else "" + entries.append((header.strip(), body.strip())) + return entries + + +def ingest_diaries( + diary_dir, + palace_path, + wing="diary", + force=False, +): + """Ingest daily summary files into the palace. + + Each date file gets ONE drawer (upserted as day grows) and + closets that pack topics atomically up to 1500 chars. + """ + diary_dir = Path(diary_dir).expanduser().resolve() + if not diary_dir.exists(): + print(f"Diary directory not found: {diary_dir}") + return + + diary_files = sorted(diary_dir.glob("*.md")) + if not diary_files: + print(f"No .md files in {diary_dir}") + return + + # State tracks which entries have been closeted per file + state_file = diary_dir / ".diary_ingest_state.json" + state = {} if force else ( + json.loads(state_file.read_text()) if state_file.exists() else {} + ) + + drawers_col = get_collection(palace_path) + closets_col = get_closets_collection(palace_path) + + days_updated = 0 + closets_created = 0 + + for diary_path in diary_files: + text = diary_path.read_text(encoding="utf-8", errors="replace") + if len(text.strip()) < 50: + continue + + date_match = re.match(r"(\d{4}-\d{2}-\d{2})", diary_path.stem) + if not date_match: + continue + date_str = date_match.group(1) + + # Skip if content hasn't changed + prev_size = state.get(diary_path.name, {}).get("size", 0) + curr_size = len(text) + if curr_size == prev_size and not force: + continue + + now_iso = datetime.now(timezone.utc).isoformat() + drawer_id = f"drawer_diary_{date_str}" + + # Extract entities from full day text + entities = _extract_entities_for_metadata(text) + + # UPSERT the day's drawer (full verbatim, replaces as day grows) + drawer_meta = { + "date": date_str, + "wing": wing, + "room": "daily", + "source_file": str(diary_path), + "source_session": "daily_diary", + "filed_at": now_iso, + } + if entities: + drawer_meta["entities"] = entities + drawers_col.upsert( + documents=[text], + ids=[drawer_id], + metadatas=[drawer_meta], + ) + + # Split into entries and find new ones + entries = _split_entries(text) + prev_entry_count = state.get(diary_path.name, {}).get("entry_count", 0) + new_entries = entries[prev_entry_count:] if not force else entries + + if new_entries: + # Build closet lines from new entries + all_lines = [] + for header, body in new_entries: + entry_text = f"{header}\n{body}" + entry_lines = build_closet_lines( + str(diary_path), [drawer_id], entry_text, wing, "daily" + ) + all_lines.extend(entry_lines) + + if all_lines: + closet_id_base = f"closet_diary_{date_str}" + closet_meta = { + "date": date_str, + "wing": wing, + "room": "daily", + "source_file": str(diary_path), + "filed_at": now_iso, + } + if entities: + closet_meta["entities"] = entities + n = upsert_closet_lines( + closets_col, closet_id_base, all_lines, closet_meta + ) + closets_created += n + + state[diary_path.name] = { + "size": curr_size, + "entry_count": len(entries), + "ingested_at": now_iso, + } + days_updated += 1 + + state_file.write_text(json.dumps(state, indent=2)) + if days_updated: + print(f"Diary: {days_updated} days updated, {closets_created} new closets") + + return {"days_updated": days_updated, "closets_created": closets_created} + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Ingest daily summaries into the palace") + parser.add_argument("--dir", required=True, help="Path to daily_summaries directory") + parser.add_argument("--palace", default=os.path.expanduser("~/.mempalace/palace")) + parser.add_argument("--wing", default="diary") + parser.add_argument("--force", action="store_true") + args = parser.parse_args() + + ingest_diaries(args.dir, args.palace, wing=args.wing, force=args.force) diff --git a/mempalace/miner.py b/mempalace/miner.py index 37e507a..e2f6528 100644 --- a/mempalace/miner.py +++ b/mempalace/miner.py @@ -371,6 +371,43 @@ def chunk_text(content: str, source_file: str) -> list: # ============================================================================= +def _extract_entities_for_metadata(content: str) -> str: + """Extract entity names from content for metadata tagging. + + Returns semicolon-separated string of entity names found in the text, + suitable for ChromaDB metadata filtering. + """ + import re + # Load known entities from registry if available + known_names = set() + registry_path = os.path.join(os.path.expanduser("~"), ".mempalace", "known_entities.json") + if os.path.exists(registry_path): + try: + import json + kd = json.loads(open(registry_path).read()) + for cat in kd.values(): + if isinstance(cat, list): + known_names.update(cat) + except Exception: + pass + + matched = set() + # Match known entities + for name in known_names: + if re.search(r'(?= 2 and len(w) > 2: + matched.add(w) + + return ";".join(sorted(matched))[:500] if matched else "" + + def add_drawer( collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str ): @@ -390,6 +427,10 @@ def add_drawer( metadata["source_mtime"] = os.path.getmtime(source_file) except OSError: pass + # Tag with entity names for filterable search + entities = _extract_entities_for_metadata(content) + if entities: + metadata["entities"] = entities collection.upsert( documents=[content], ids=[drawer_id], @@ -479,13 +520,17 @@ def process_file( ] closet_lines = build_closet_lines(source_file, drawer_ids, content, wing, room) closet_id_base = f"closet_{wing}_{room}_{hashlib.sha256(source_file.encode()).hexdigest()[:24]}" - upsert_closet_lines(closets_col, closet_id_base, closet_lines, { + entities = _extract_entities_for_metadata(content) + closet_meta = { "wing": wing, "room": room, "source_file": source_file, "drawer_count": drawers_added, "filed_at": datetime.now().isoformat(), - }) + } + if entities: + closet_meta["entities"] = entities + upsert_closet_lines(closets_col, closet_id_base, closet_lines, closet_meta) return drawers_added, room diff --git a/mempalace/searcher.py b/mempalace/searcher.py index 70fd615..37795fc 100644 --- a/mempalace/searcher.py +++ b/mempalace/searcher.py @@ -2,11 +2,14 @@ """ searcher.py — Find anything. Exact words. -Semantic search against the palace. -Returns verbatim text — the actual words, never summaries. +Hybrid search: BM25 keyword matching + vector semantic similarity. +Searches closets first (fast index), then hydrates full drawer content. +Falls back to direct drawer search for palaces without closets. """ import logging +import math +import re from pathlib import Path from .palace import get_collection, get_closets_collection @@ -18,6 +21,59 @@ class SearchError(Exception): """Raised when search cannot proceed (e.g. no palace found).""" +def _bm25_score(query: str, document: str, k1: float = 1.5, b: float = 0.75, avg_dl: float = 500) -> float: + """Simple BM25 score for a single document against a query. + + This is a lightweight keyword-matching signal that complements vector + similarity. It catches exact matches that embeddings might miss + (e.g., specific names, project codes, error messages). + """ + query_terms = set(re.findall(r'\w{2,}', query.lower())) + doc_terms = re.findall(r'\w{2,}', document.lower()) + if not query_terms or not doc_terms: + return 0.0 + doc_len = len(doc_terms) + term_freq = {} + for t in doc_terms: + term_freq[t] = term_freq.get(t, 0) + 1 + + score = 0.0 + for term in query_terms: + tf = term_freq.get(term, 0) + if tf > 0: + # Simplified IDF — treat each query term as moderately rare + idf = math.log(2.0) + numerator = tf * (k1 + 1) + denominator = tf + k1 * (1 - b + b * doc_len / avg_dl) + score += idf * numerator / denominator + return score + + +def _hybrid_rank(vector_results, query: str, vector_weight: float = 0.6, bm25_weight: float = 0.4): + """Re-rank results using both vector distance and BM25 keyword score. + + Returns results sorted by combined score (higher = better). + """ + if not vector_results: + return vector_results + + # Normalize vector distances to 0-1 similarity + max_dist = max(r.get("distance", 1.0) for r in vector_results) or 1.0 + for r in vector_results: + vec_sim = max(0.0, 1 - r.get("distance", 1.0) / max(max_dist, 0.001)) + bm25 = _bm25_score(query, r.get("text", "")) + # Normalize BM25 to roughly 0-1 range + bm25_norm = min(bm25 / 3.0, 1.0) + r["_hybrid_score"] = vector_weight * vec_sim + bm25_weight * bm25_norm + r["bm25_score"] = round(bm25, 3) + + vector_results.sort(key=lambda r: r["_hybrid_score"], reverse=True) + # Clean up internal field + for r in vector_results: + del r["_hybrid_score"] + return vector_results + + def build_where_filter(wing: str = None, room: str = None) -> dict: """Build ChromaDB where filter for wing/room filtering.""" if wing and room: @@ -186,6 +242,8 @@ def search_memories( break if hits: + # Re-rank with BM25 hybrid scoring + hits = _hybrid_rank(hits, query) return { "query": query, "filters": {"wing": wing, "room": room}, @@ -227,6 +285,8 @@ def search_memories( } ) + # Re-rank with BM25 hybrid scoring + hits = _hybrid_rank(hits, query) return { "query": query, "filters": {"wing": wing, "room": room}, From f72ffbbcb2766b04be64e33a8b2e66788d4b22eb Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Mon, 13 Apr 2026 07:42:25 -0300 Subject: [PATCH 02/12] test: add tests for mine_lock, closets, entity metadata, BM25, diary Trimmed version of Milla's omnibus test_closets.py to only cover features present in this PR stack (#784 lock, #788 closets, this PR's entity/BM25/diary). Strip-noise tests will land with #785; tunnel tests will land with the tunnels PR. 16/16 pass. Co-Authored-By: MSL <232237854+milla-jovovich@users.noreply.github.com> --- tests/test_closets.py | 201 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 tests/test_closets.py diff --git a/tests/test_closets.py b/tests/test_closets.py new file mode 100644 index 0000000..b365102 --- /dev/null +++ b/tests/test_closets.py @@ -0,0 +1,201 @@ +"""Tests for the closet layer, mine_lock, entity metadata, BM25 hybrid search, +and diary ingest. + +Content derived from Milla's omnibus test file; trimmed to only the features +present in this PR stack (#784 lock, #788 closets, this PR's entity/BM25/diary). +Strip-noise tests live with #785; tunnel tests live with the tunnels PR. +""" + +import os +import tempfile +import threading +import time + +from mempalace.palace import ( + CLOSET_CHAR_LIMIT, + build_closet_lines, + get_closets_collection, + get_collection, + mine_lock, + upsert_closet_lines, +) +from mempalace.miner import _extract_entities_for_metadata +from mempalace.searcher import _bm25_score, _hybrid_rank + + +# ── mine_lock ──────────────────────────────────────────────────────────── + + +class TestMineLock: + def test_lock_acquires_and_releases(self): + with mine_lock("/tmp/test_lock_file.txt"): + lock_dir = os.path.expanduser("~/.mempalace/locks") + assert os.path.isdir(lock_dir) + + def test_lock_blocks_concurrent_access(self): + results = [] + + def worker(name): + start = time.time() + with mine_lock("/tmp/same_file_lock_test.txt"): + results.append((name, time.time() - start)) + time.sleep(0.2) + + t1 = threading.Thread(target=worker, args=("a",)) + t2 = threading.Thread(target=worker, args=("b",)) + t1.start() + time.sleep(0.05) + t2.start() + t1.join() + t2.join() + + # Second thread should have waited + wait_times = sorted(results, key=lambda x: x[1]) + assert wait_times[1][1] > 0.1, "Second thread should block" + + +# ── closet lines ───────────────────────────────────────────────────────── + + +class TestBuildClosetLines: + def test_returns_list_of_lines(self): + lines = build_closet_lines( + "/tmp/test.py", ["drawer_001"], "We built the auth system", "code", "general" + ) + assert isinstance(lines, list) + assert len(lines) >= 1 + + def test_each_line_has_pointer(self): + lines = build_closet_lines( + "/tmp/test.py", + ["drawer_001", "drawer_002"], + "We built the auth system and tested the login flow", + "code", + "general", + ) + for line in lines: + assert "→" in line, f"Line missing pointer: {line}" + + def test_fallback_when_no_topics(self): + lines = build_closet_lines( + "/tmp/test.py", ["drawer_001"], "short text", "wing", "room" + ) + assert len(lines) >= 1 + assert "→" in lines[0] + + +# ── upsert_closet_lines ───────────────────────────────────────────────── + + +class TestUpsertClosetLines: + def test_writes_closets(self): + with tempfile.TemporaryDirectory() as tmpdir: + col = get_closets_collection(tmpdir) + lines = [ + "topic one|Entity1|→drawer_001", + "topic two|Entity2|→drawer_002", + ] + n = upsert_closet_lines(col, "test_closet", lines, {"wing": "test"}) + assert n >= 1 + assert col.count() >= 1 + + def test_never_splits_mid_topic(self): + with tempfile.TemporaryDirectory() as tmpdir: + col = get_closets_collection(tmpdir) + # Create lines that together exceed CLOSET_CHAR_LIMIT + lines = [f"topic_{i}|{'x' * 200}|→drawer_{i}" for i in range(20)] + n = upsert_closet_lines(col, "test_closet", lines, {"wing": "test"}) + assert n >= 2, "Should create multiple closets" + + # Verify each closet has complete lines + all_data = col.get(include=["documents"]) + for doc in all_data["documents"]: + for line in doc.strip().split("\n"): + assert "→" in line, f"Split topic found: {line}" + + def test_respects_char_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + col = get_closets_collection(tmpdir) + lines = [f"topic_{i}|entities|→drawer_{i}" for i in range(50)] + upsert_closet_lines(col, "test_closet", lines, {"wing": "test"}) + + all_data = col.get(include=["documents"]) + for doc in all_data["documents"]: + assert len(doc) <= CLOSET_CHAR_LIMIT + 100 # small buffer for existing content + + +# ── entity metadata ────────────────────────────────────────────────────── + + +class TestEntityMetadata: + def test_extracts_capitalized_names(self): + text = "Ben reviewed the code. Ben approved it. Igor flagged two issues. Igor fixed them." + entities = _extract_entities_for_metadata(text) + assert "Ben" in entities + assert "Igor" in entities + + def test_empty_for_no_entities(self): + text = "this is all lowercase with no proper nouns at all" + entities = _extract_entities_for_metadata(text) + assert entities == "" + + def test_semicolon_separated(self): + text = "Alice and Bob met Charlie. Alice said hello. Bob agreed. Charlie laughed." + entities = _extract_entities_for_metadata(text) + assert ";" in entities + + +# ── BM25 hybrid search ────────────────────────────────────────────────── + + +class TestBM25: + def test_bm25_score_positive_for_match(self): + score = _bm25_score("database migration", "We migrated the database to Postgres") + assert score > 0 + + def test_bm25_score_zero_for_no_match(self): + score = _bm25_score("quantum physics", "We built a web application in React") + assert score == 0.0 + + def test_hybrid_rank_reorders(self): + results = [ + {"text": "database schema design for Postgres", "distance": 0.5}, + {"text": "unrelated topic about cooking", "distance": 0.3}, + ] + ranked = _hybrid_rank(results, "database Postgres schema") + # The database result should rank higher despite worse vector distance + assert "database" in ranked[0]["text"] + + +# ── diary ingest ───────────────────────────────────────────────────────── + + +class TestDiaryIngest: + def test_ingest_creates_drawers_and_closets(self): + with tempfile.TemporaryDirectory() as palace_dir: + diary_dir = tempfile.mkdtemp() + # Write a test diary + with open(os.path.join(diary_dir, "2026-04-13.md"), "w") as f: + f.write("# 2026-04-13\n\n## 10:00 PDT — Test\n\nBuilt the auth system.\n") + + from mempalace.diary_ingest import ingest_diaries + + result = ingest_diaries(diary_dir, palace_dir, force=True) + assert result["days_updated"] >= 1 + + # Check drawer exists + drawers = get_collection(palace_dir) + count = drawers.count() + assert count >= 1 + + def test_ingest_skips_unchanged(self): + with tempfile.TemporaryDirectory() as palace_dir: + diary_dir = tempfile.mkdtemp() + with open(os.path.join(diary_dir, "2026-04-13.md"), "w") as f: + f.write("# 2026-04-13\n\n## 10:00 — Test\n\nContent.\n") + + from mempalace.diary_ingest import ingest_diaries + + ingest_diaries(diary_dir, palace_dir, force=True) + result = ingest_diaries(diary_dir, palace_dir) # second run, no force + assert result["days_updated"] == 0 From 1b4ce0b1f8956436d7c6e9e5bb1ef314550db83d Mon Sep 17 00:00:00 2001 From: MSL <232237854+milla-jovovich@users.noreply.github.com> Date: Mon, 13 Apr 2026 02:05:55 -0700 Subject: [PATCH 03/12] feat: explicit cross-wing tunnels for multi-project agents MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds active tunnel creation alongside passive tunnel discovery. Passive tunnels (existing): rooms with the same name across wings. Explicit tunnels (new): agent-created links between specific locations. "This API design in project_api relates to the database schema in project_database." New functions in palace_graph.py: - create_tunnel() — link two wing/room pairs with a label - list_tunnels() — list all explicit tunnels, filter by wing - delete_tunnel() — remove a tunnel by ID - follow_tunnels() — from a room, find all connected rooms in other wings with drawer content previews New MCP tools: - mempalace_create_tunnel - mempalace_list_tunnels - mempalace_delete_tunnel - mempalace_follow_tunnels Tunnels stored in ~/.mempalace/tunnels.json (persists across palace rebuilds). Deduplicated by endpoint pair. 689/689 tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) --- mempalace/mcp_server.py | 109 ++++++++++++++++++++++++- mempalace/palace_graph.py | 162 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 270 insertions(+), 1 deletion(-) diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index 4e21426..89b74f7 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -35,7 +35,7 @@ from .version import __version__ import chromadb from .query_sanitizer import sanitize_query from .searcher import search_memories -from .palace_graph import traverse, find_tunnels, graph_stats +from .palace_graph import traverse, find_tunnels, graph_stats, create_tunnel, list_tunnels, delete_tunnel, follow_tunnels from .knowledge_graph import KnowledgeGraph @@ -496,6 +496,63 @@ def tool_graph_stats(): return graph_stats(col=col) +def tool_create_tunnel( + source_wing: str, + source_room: str, + target_wing: str, + target_room: str, + label: str = "", + source_drawer_id: str = None, + target_drawer_id: str = None, +): + """Create an explicit cross-wing tunnel between two palace locations. + + Use when you notice content in one project relates to another project. + Example: an API design discussion in project_api connects to the + database schema in project_database. + """ + try: + source_wing = sanitize_name(source_wing, "source_wing") + source_room = sanitize_name(source_room, "source_room") + target_wing = sanitize_name(target_wing, "target_wing") + target_room = sanitize_name(target_room, "target_room") + except ValueError as e: + return {"error": str(e)} + return create_tunnel( + source_wing, source_room, target_wing, target_room, + label=label, + source_drawer_id=source_drawer_id, + target_drawer_id=target_drawer_id, + ) + + +def tool_list_tunnels(wing: str = None): + """List all explicit cross-wing tunnels, optionally filtered by wing.""" + try: + wing = _sanitize_optional_name(wing, "wing") + except ValueError as e: + return {"error": str(e)} + return list_tunnels(wing) + + +def tool_delete_tunnel(tunnel_id: str): + """Delete an explicit tunnel by its ID.""" + if not tunnel_id or not isinstance(tunnel_id, str): + return {"error": "tunnel_id is required"} + return delete_tunnel(tunnel_id) + + +def tool_follow_tunnels(wing: str, room: str): + """Follow explicit tunnels from a room to see connected drawers in other wings.""" + try: + wing = sanitize_name(wing, "wing") + room = sanitize_name(room, "room") + except ValueError as e: + return {"error": str(e)} + col = _get_collection() + return follow_tunnels(wing, room, col=col) + + # ==================== WRITE TOOLS ==================== @@ -1181,6 +1238,56 @@ TOOLS = { "input_schema": {"type": "object", "properties": {}}, "handler": tool_graph_stats, }, + "mempalace_create_tunnel": { + "description": "Create a cross-wing tunnel linking two palace locations. Use when content in one project relates to another — e.g., an API design in project_api connects to a database schema in project_database.", + "input_schema": { + "type": "object", + "properties": { + "source_wing": {"type": "string", "description": "Wing of the source"}, + "source_room": {"type": "string", "description": "Room in the source wing"}, + "target_wing": {"type": "string", "description": "Wing of the target"}, + "target_room": {"type": "string", "description": "Room in the target wing"}, + "label": {"type": "string", "description": "Description of the connection"}, + "source_drawer_id": {"type": "string", "description": "Optional specific drawer ID"}, + "target_drawer_id": {"type": "string", "description": "Optional specific drawer ID"}, + }, + "required": ["source_wing", "source_room", "target_wing", "target_room"], + }, + "handler": tool_create_tunnel, + }, + "mempalace_list_tunnels": { + "description": "List all explicit cross-wing tunnels. Optionally filter by wing.", + "input_schema": { + "type": "object", + "properties": { + "wing": {"type": "string", "description": "Filter tunnels by wing (shows tunnels where wing is source or target)"}, + }, + }, + "handler": tool_list_tunnels, + }, + "mempalace_delete_tunnel": { + "description": "Delete an explicit tunnel by its ID.", + "input_schema": { + "type": "object", + "properties": { + "tunnel_id": {"type": "string", "description": "Tunnel ID to delete"}, + }, + "required": ["tunnel_id"], + }, + "handler": tool_delete_tunnel, + }, + "mempalace_follow_tunnels": { + "description": "Follow tunnels from a room to see what it connects to in other wings. Returns connected rooms with drawer previews.", + "input_schema": { + "type": "object", + "properties": { + "wing": {"type": "string", "description": "Wing to start from"}, + "room": {"type": "string", "description": "Room to follow tunnels from"}, + }, + "required": ["wing", "room"], + }, + "handler": tool_follow_tunnels, + }, "mempalace_search": { "description": "Semantic search. Returns verbatim drawer content with similarity scores. IMPORTANT: 'query' must contain ONLY search keywords. Use 'context' for background. Results with cosine distance > max_distance are filtered out.", "input_schema": { diff --git a/mempalace/palace_graph.py b/mempalace/palace_graph.py index 5e2e72e..2792d99 100644 --- a/mempalace/palace_graph.py +++ b/mempalace/palace_graph.py @@ -15,7 +15,11 @@ Enables queries like: No external graph DB needed — built from ChromaDB metadata. """ +import hashlib +import json +import os from collections import defaultdict, Counter +from datetime import datetime from .config import MempalaceConfig from .palace import get_collection as _get_palace_collection @@ -228,3 +232,161 @@ def _fuzzy_match(query: str, nodes: dict, n: int = 5): scored.append((room, 0.5)) scored.sort(key=lambda x: -x[1]) return [r for r, _ in scored[:n]] + + +# ============================================================================= +# EXPLICIT TUNNELS — agent-created cross-wing links +# ============================================================================= +# Passive tunnels are discovered from shared room names across wings. +# Explicit tunnels are created by agents when they notice a connection +# between two specific drawers or rooms in different wings/projects. +# +# Stored as a JSON file at ~/.mempalace/tunnels.json so they persist +# across palace rebuilds (not in ChromaDB which can be recreated). + + +_TUNNEL_FILE = os.path.join(os.path.expanduser("~"), ".mempalace", "tunnels.json") + + +def _load_tunnels(): + """Load explicit tunnels from disk.""" + if os.path.exists(_TUNNEL_FILE): + try: + return json.loads(open(_TUNNEL_FILE).read()) + except Exception: + pass + return [] + + +def _save_tunnels(tunnels): + """Save explicit tunnels to disk.""" + os.makedirs(os.path.dirname(_TUNNEL_FILE), exist_ok=True) + with open(_TUNNEL_FILE, "w") as f: + json.dump(tunnels, f, indent=2) + + +def create_tunnel( + source_wing: str, + source_room: str, + target_wing: str, + target_room: str, + label: str = "", + source_drawer_id: str = None, + target_drawer_id: str = None, +): + """Create an explicit tunnel between two locations in the palace. + + Use when an agent notices a connection between two projects/wings + that wouldn't be found by passive room-name matching. + + Args: + source_wing: Wing of the source (e.g., "project_api") + source_room: Room in the source wing + target_wing: Wing of the target (e.g., "project_database") + target_room: Room in the target wing + label: Description of the connection + source_drawer_id: Optional specific drawer ID + target_drawer_id: Optional specific drawer ID + + Returns: + The created tunnel dict. + """ + tunnel_id = hashlib.sha256( + f"{source_wing}/{source_room}↔{target_wing}/{target_room}".encode() + ).hexdigest()[:16] + + tunnel = { + "id": tunnel_id, + "source": {"wing": source_wing, "room": source_room}, + "target": {"wing": target_wing, "room": target_room}, + "label": label, + "created_at": datetime.now().isoformat(), + } + if source_drawer_id: + tunnel["source"]["drawer_id"] = source_drawer_id + if target_drawer_id: + tunnel["target"]["drawer_id"] = target_drawer_id + + tunnels = _load_tunnels() + + # Dedup — don't create if same endpoints already linked + for existing in tunnels: + if existing.get("id") == tunnel_id: + existing.update(tunnel) # update label/drawers + _save_tunnels(tunnels) + return existing + + tunnels.append(tunnel) + _save_tunnels(tunnels) + return tunnel + + +def list_tunnels(wing: str = None): + """List all explicit tunnels, optionally filtered by wing. + + Returns tunnels where the wing appears as either source or target. + """ + tunnels = _load_tunnels() + if wing: + tunnels = [ + t for t in tunnels + if t["source"]["wing"] == wing or t["target"]["wing"] == wing + ] + return tunnels + + +def delete_tunnel(tunnel_id: str): + """Delete an explicit tunnel by ID.""" + tunnels = _load_tunnels() + tunnels = [t for t in tunnels if t.get("id") != tunnel_id] + _save_tunnels(tunnels) + return {"deleted": tunnel_id} + + +def follow_tunnels(wing: str, room: str, col=None, config=None): + """Follow explicit tunnels from a room — returns connected drawers. + + Given a location (wing/room), finds all tunnels leading from or to it, + and optionally fetches the connected drawer content. + """ + tunnels = _load_tunnels() + connections = [] + + for t in tunnels: + src = t["source"] + tgt = t["target"] + + if src["wing"] == wing and src["room"] == room: + connections.append({ + "direction": "outgoing", + "connected_wing": tgt["wing"], + "connected_room": tgt["room"], + "label": t.get("label", ""), + "drawer_id": tgt.get("drawer_id"), + "tunnel_id": t["id"], + }) + elif tgt["wing"] == wing and tgt["room"] == room: + connections.append({ + "direction": "incoming", + "connected_wing": src["wing"], + "connected_room": src["room"], + "label": t.get("label", ""), + "drawer_id": src.get("drawer_id"), + "tunnel_id": t["id"], + }) + + # If we have a collection, fetch drawer content for connected items + if col and connections: + drawer_ids = [c["drawer_id"] for c in connections if c.get("drawer_id")] + if drawer_ids: + try: + results = col.get(ids=drawer_ids, include=["documents", "metadatas"]) + drawer_map = dict(zip(results["ids"], results["documents"])) + for c in connections: + did = c.get("drawer_id") + if did and did in drawer_map: + c["drawer_preview"] = drawer_map[did][:300] + except Exception: + pass + + return connections From e2a9bb05d37712af2dc488769713f97c7059f8e5 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Mon, 13 Apr 2026 07:44:32 -0300 Subject: [PATCH 04/12] test: add TestTunnels for cross-wing tunnel operations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Appended from Milla's omnibus test_closets.py — covers create, list, delete, dedup, and follow_tunnels behavior. 21/21 pass. Co-Authored-By: MSL <232237854+milla-jovovich@users.noreply.github.com> --- tests/test_closets.py | 61 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/test_closets.py b/tests/test_closets.py index b365102..57c989d 100644 --- a/tests/test_closets.py +++ b/tests/test_closets.py @@ -21,6 +21,13 @@ from mempalace.palace import ( ) from mempalace.miner import _extract_entities_for_metadata from mempalace.searcher import _bm25_score, _hybrid_rank +from mempalace.palace_graph import ( + create_tunnel, + list_tunnels, + delete_tunnel, + follow_tunnels, + _TUNNEL_FILE, +) # ── mine_lock ──────────────────────────────────────────────────────────── @@ -199,3 +206,57 @@ class TestDiaryIngest: ingest_diaries(diary_dir, palace_dir, force=True) result = ingest_diaries(diary_dir, palace_dir) # second run, no force assert result["days_updated"] == 0 + + +# ── tunnels ────────────────────────────────────────────────────────────── + + +class TestTunnels: + def setup_method(self): + # Use temp tunnel file + self._orig = _TUNNEL_FILE + import mempalace.palace_graph as pg + self._tmpdir = tempfile.mkdtemp() + pg._TUNNEL_FILE = os.path.join(self._tmpdir, "tunnels.json") + + def teardown_method(self): + import mempalace.palace_graph as pg + pg._TUNNEL_FILE = self._orig + + def test_create_tunnel(self): + t = create_tunnel("wing_api", "auth", "wing_db", "users", label="auth uses users table") + assert t["id"] + assert t["source"]["wing"] == "wing_api" + assert t["target"]["wing"] == "wing_db" + assert t["label"] == "auth uses users table" + + def test_list_tunnels(self): + create_tunnel("wing_a", "room1", "wing_b", "room2") + create_tunnel("wing_a", "room3", "wing_c", "room4") + all_t = list_tunnels() + assert len(all_t) == 2 + filtered = list_tunnels("wing_a") + assert len(filtered) == 2 + filtered_c = list_tunnels("wing_c") + assert len(filtered_c) == 1 + + def test_delete_tunnel(self): + t = create_tunnel("wing_x", "r1", "wing_y", "r2") + delete_tunnel(t["id"]) + assert len(list_tunnels()) == 0 + + def test_dedup_same_endpoints(self): + create_tunnel("wing_a", "r1", "wing_b", "r2", label="first") + create_tunnel("wing_a", "r1", "wing_b", "r2", label="updated") + tunnels = list_tunnels() + assert len(tunnels) == 1 + assert tunnels[0]["label"] == "updated" + + def test_follow_tunnels(self): + create_tunnel("wing_api", "auth", "wing_db", "users") + create_tunnel("wing_api", "auth", "wing_frontend", "login") + connections = follow_tunnels("wing_api", "auth") + assert len(connections) == 2 + wings = {c["connected_wing"] for c in connections} + assert "wing_db" in wings + assert "wing_frontend" in wings From 971b92da5d879e444a5f44210a6f2b084c75bf9a Mon Sep 17 00:00:00 2001 From: MSL <232237854+milla-jovovich@users.noreply.github.com> Date: Mon, 13 Apr 2026 07:46:07 -0300 Subject: [PATCH 05/12] feat(search): drawer-grep returns best-matching chunk + neighbors When a closet hit leads to a source file with many drawers, grep each chunk for query terms and return the BEST-MATCHING chunk + 1 neighbor on each side, instead of dumping the whole file truncated at MAX_HYDRATION_CHARS. Result now includes drawer_index and total_drawers so callers can request adjacent drawers explicitly. Extracted from Milla's commit 935f657 which bundled drawer-grep with closet_llm (deferred pending LLM_ENDPOINT refactor) and fact_checker (separate PR). Ported only the searcher.py change. Co-Authored-By: Claude Opus 4.6 (1M context) --- mempalace/searcher.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/mempalace/searcher.py b/mempalace/searcher.py index 37795fc..19b07f4 100644 --- a/mempalace/searcher.py +++ b/mempalace/searcher.py @@ -205,6 +205,8 @@ def search_memories( pass # no closets yet — fall through to direct drawer search # If closets found results, hydrate the referenced drawers + MAX_HYDRATION_CHARS = 10000 # cap to prevent blowup on large source files + if closet_hits: import re seen_sources = set() @@ -215,18 +217,39 @@ def search_memories( continue seen_sources.add(source) - # Find drawers for this source file + # Find drawers for this source file, grep for most relevant chunk try: drawer_results = drawers_col.get( where={"source_file": source}, include=["documents", "metadatas"], ) if drawer_results.get("ids"): - # Combine all drawer content for this file - full_text = "\n\n".join(drawer_results["documents"]) - meta = drawer_results["metadatas"][0] + # Drawer-grep: score each chunk against the query, + # return the best-matching chunk first + surrounding context + query_terms = set(re.findall(r'\w{2,}', query.lower())) + best_idx = 0 + best_score = -1 + for idx, doc in enumerate(drawer_results["documents"]): + doc_lower = doc.lower() + score = sum(1 for t in query_terms if t in doc_lower) + if score > best_score: + best_score = score + best_idx = idx + + # Build result: best chunk first, then neighbors + docs = drawer_results["documents"] + n_docs = len(docs) + # Include best chunk + 1 before + 1 after for context + start = max(0, best_idx - 1) + end = min(n_docs, best_idx + 2) + relevant_text = "\n\n".join(docs[start:end]) + + if len(relevant_text) > MAX_HYDRATION_CHARS: + relevant_text = relevant_text[:MAX_HYDRATION_CHARS] + f"\n\n[...truncated. {n_docs} total drawers. Use mempalace_get_drawer for full content.]" + + meta = drawer_results["metadatas"][best_idx] hits.append({ - "text": full_text, + "text": relevant_text, "wing": meta.get("wing", "unknown"), "room": meta.get("room", "unknown"), "source_file": Path(source).name, @@ -234,6 +257,8 @@ def search_memories( "distance": round(closet_dist, 4), "matched_via": "closet", "closet_preview": closet_doc[:200], + "drawer_index": best_idx, + "total_drawers": n_docs, }) except Exception: pass From 4a6147f903a95c9e573ee98e4cb3d624eb3ff8fc Mon Sep 17 00:00:00 2001 From: MSL <232237854+milla-jovovich@users.noreply.github.com> Date: Mon, 13 Apr 2026 07:47:40 -0300 Subject: [PATCH 06/12] feat: offline fact checker against entity registry + knowledge graph MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fact_checker.py verifies text for contradictions against locally stored entities and KG facts. Catches similar-name confusion (Bob vs Bobby), relationship mismatches (KG says husband, text says brother), and stale facts (KG valid_from/valid_to). No hardcoded facts. No network calls. Reads: - ~/.mempalace/known_entities.json - KnowledgeGraph SQLite Usage: from mempalace.fact_checker import check_text issues = check_text("Bob is Alice's brother", palace_path) # CLI python -m mempalace.fact_checker "text" --palace ~/.mempalace/palace Extracted from Milla's commit 935f657 which bundled this with closet_llm (deferred) and drawer-grep (PR #791). Ported only fact_checker.py — verified no network / API imports. Co-Authored-By: Claude Opus 4.6 (1M context) --- mempalace/fact_checker.py | 177 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 mempalace/fact_checker.py diff --git a/mempalace/fact_checker.py b/mempalace/fact_checker.py new file mode 100644 index 0000000..281f117 --- /dev/null +++ b/mempalace/fact_checker.py @@ -0,0 +1,177 @@ +""" +fact_checker.py — Verify text against known facts in the palace. + +Checks AI responses, diary entries, and new content against the +entity registry and knowledge graph for contradictions. Catches: + - Wrong names (similar but different entities) + - Wrong relationships (calling someone the wrong role) + - Stale facts (things that changed — KG has valid_from/valid_to) + +Uses the entity_registry and knowledge_graph — no hardcoded facts. + +Usage: + from mempalace.fact_checker import check_text + issues = check_text("Bob is Alice's brother", palace_path) + # → [{"type": "relationship_mismatch", "detail": "KG says Bob is Alice's husband"}] + + # CLI + python -m mempalace.fact_checker "Bob is Alice's brother" --palace ~/.mempalace/palace +""" + +import os +import re +from pathlib import Path + + +def check_text(text, palace_path=None, config=None): + """Check text for contradictions against known facts. + + Returns list of issues found. Empty list = no contradictions. + """ + if config is None: + from .config import MempalaceConfig + config = MempalaceConfig() + if palace_path is None: + palace_path = config.palace_path + + issues = [] + + # Load known entities + entity_names = _load_known_entities() + + # Check entity name confusion (similar names that might be mixed up) + issues.extend(_check_entity_confusion(text, entity_names)) + + # Check against knowledge graph facts + issues.extend(_check_kg_facts(text, palace_path)) + + return issues + + +def _load_known_entities(): + """Load entity names from the registry.""" + import json + registry_path = os.path.expanduser("~/.mempalace/known_entities.json") + if not os.path.exists(registry_path): + return {} + try: + return json.loads(open(registry_path).read()) + except Exception: + return {} + + +def _check_entity_confusion(text, entity_names): + """Check if text confuses similar entity names.""" + issues = [] + all_names = set() + for cat in entity_names.values(): + if isinstance(cat, list): + all_names.update(cat) + elif isinstance(cat, dict): + all_names.update(cat.keys()) + + # Find names mentioned in text + mentioned = set() + for name in all_names: + if re.search(r'\b' + re.escape(name) + r'\b', text, re.IGNORECASE): + mentioned.add(name) + + # Check for names that are very similar but different (edit distance 1-2) + name_list = sorted(all_names) + for i, name_a in enumerate(name_list): + for name_b in name_list[i + 1:]: + if _edit_distance(name_a.lower(), name_b.lower()) <= 2: + if name_a in mentioned or name_b in mentioned: + if name_a in text and name_b not in text: + issues.append({ + "type": "similar_name", + "detail": f"'{name_a}' mentioned — did you mean '{name_b}'? (similar names in registry)", + "names": [name_a, name_b], + }) + return issues + + +def _check_kg_facts(text, palace_path): + """Check text against knowledge graph for contradictions.""" + issues = [] + try: + from .knowledge_graph import KnowledgeGraph + kg = KnowledgeGraph(palace_path=palace_path) + + # Extract relationship claims from text + # Pattern: "X is Y's Z" or "X's Z is Y" + patterns = [ + (r"(\w+)\s+is\s+(\w+)'s\s+(\w+)", "subject", "possessor", "role"), + (r"(\w+)'s\s+(\w+)\s+is\s+(\w+)", "possessor", "role", "subject"), + ] + + for pattern, *roles in patterns: + for match in re.finditer(pattern, text, re.IGNORECASE): + groups = match.groups() + subject = groups[0] + # Query KG for this entity + try: + facts = kg.query(subject) + if facts: + for fact in facts: + # Check if the claim contradicts a known fact + if fact.get("valid_to") is None: # current fact + kg_pred = fact.get("predicate", "").lower() + claim = match.group(0).lower() + if kg_pred in claim and fact.get("object", "").lower() not in claim: + issues.append({ + "type": "relationship_mismatch", + "detail": f"Text says '{match.group(0)}' but KG says: {subject} {kg_pred} {fact.get('object')}", + "entity": subject, + }) + except Exception: + pass + except Exception: + pass # KG not available — skip + + return issues + + +def _edit_distance(s1, s2): + """Simple Levenshtein distance.""" + if len(s1) < len(s2): + return _edit_distance(s2, s1) + if len(s2) == 0: + return len(s1) + prev = list(range(len(s2) + 1)) + for i, c1 in enumerate(s1): + curr = [i + 1] + for j, c2 in enumerate(s2): + curr.append(min( + prev[j + 1] + 1, + curr[j] + 1, + prev[j] + (0 if c1 == c2 else 1), + )) + prev = curr + return prev[-1] + + +if __name__ == "__main__": + import argparse + import json + + parser = argparse.ArgumentParser(description="Check text against known facts") + parser.add_argument("text", nargs="?", help="Text to check") + parser.add_argument("--palace", default=os.path.expanduser("~/.mempalace/palace")) + parser.add_argument("--stdin", action="store_true", help="Read from stdin") + args = parser.parse_args() + + if args.stdin: + import sys + text = sys.stdin.read() + elif args.text: + text = args.text + else: + print("Provide text as argument or use --stdin") + exit(1) + + issues = check_text(text, palace_path=args.palace) + if issues: + print(json.dumps(issues, indent=2)) + else: + print("No contradictions found.") From 4d581cbb730b26d78e29e4e85115b948e0c0603e Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Mon, 13 Apr 2026 07:51:46 -0300 Subject: [PATCH 07/12] =?UTF-8?q?feat:=20optional=20LLM-based=20closet=20r?= =?UTF-8?q?egeneration=20=E2=80=94=20bring-your-own=20endpoint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds mempalace/closet_llm.py as an OPTIONAL path for richer closet generation. Regex closets remain the default and cover the local-first promise; users who want LLM-quality topics can bring their own endpoint. Configuration (env or CLI flag): LLM_ENDPOINT — OpenAI-compatible base URL (required) LLM_KEY — bearer token (optional; local inference skips this) LLM_MODEL — model name (required) Works with Ollama, vLLM, llama.cpp servers, OpenAI, OpenRouter, and any other provider that speaks OpenAI-compatible /chat/completions. Zero new dependencies — uses stdlib urllib. Replaces the original Anthropic-SDK-hardcoded version of this module from Milla's branch (commit 935f657). Same prompt, same parsing, same regenerate_closets flow; only the transport was generalised so the feature doesn't lock users into a specific vendor or require API keys for core memory operations (CLAUDE.md, "Local-first, zero API"). Includes 13 unit tests covering config resolution, request shape, auth-header omission when no key is set, code-fence stripping, and missing-config error path. All mocked — zero network calls in tests. Co-Authored-By: MSL <232237854+milla-jovovich@users.noreply.github.com> --- mempalace/closet_llm.py | 345 +++++++++++++++++++++++++++++++++++++++ tests/test_closet_llm.py | 222 +++++++++++++++++++++++++ 2 files changed, 567 insertions(+) create mode 100644 mempalace/closet_llm.py create mode 100644 tests/test_closet_llm.py diff --git a/mempalace/closet_llm.py b/mempalace/closet_llm.py new file mode 100644 index 0000000..35ec6d6 --- /dev/null +++ b/mempalace/closet_llm.py @@ -0,0 +1,345 @@ +""" +closet_llm.py — Generate closets via a user-configured LLM for richer indexing. + +The regex-based closet extraction catches action verbs, headers, and proper +nouns — but misses implicit topics, foreign-language content, and contextual +references. An LLM reads everything and produces better closets. + +This module is **OPTIONAL and opt-in**. Regex closets are always created by +the miner; this path regenerates them afterward using whatever LLM the user +chooses. Core memory operations remain API-free by design (see CLAUDE.md, +"Local-first, zero API"). + +## Bring-your-own-LLM configuration + +The endpoint is any OpenAI-compatible Chat Completions URL: + + LLM_ENDPOINT=http://localhost:11434/v1 # Ollama + LLM_ENDPOINT=http://localhost:8000/v1 # vLLM, llama.cpp + LLM_ENDPOINT=https://api.openai.com/v1 + LLM_ENDPOINT=https://openrouter.ai/api/v1 + LLM_ENDPOINT=https://api.anthropic.com/v1 # when proxied through a compat layer + +Set: + LLM_ENDPOINT — base URL (required) + LLM_KEY — bearer token (optional; local inference usually doesn't need it) + LLM_MODEL — model name (required), e.g. "gpt-4o-mini", "llama3:8b", "qwen2.5:7b" + +Or pass flags on the CLI (flags win over env): + + python -m mempalace.closet_llm \\ + --palace ~/.mempalace/palace \\ + --endpoint http://localhost:11434/v1 \\ + --model llama3:8b + +No vendor lock-in. No hidden dependency on any specific provider. Zero deps +added to pyproject — uses stdlib urllib. +""" + +import json +import os +import re +import time +import urllib.request +import urllib.error +from datetime import datetime +from typing import Optional + +from .palace import get_collection, get_closets_collection, upsert_closet_lines + +MAX_CONTENT_CHARS = 30000 +MAX_OUTPUT_TOKENS = 1500 +HTTP_TIMEOUT_S = 60 + +PROMPT_TEMPLATE = """You are reading content filed in a memory palace. Generate a +topic-dense index that will be used to find this content later when someone searches. + +Source: {source_file} +Wing: {wing} | Room: {room} + +CONTENT: +{content} + +--- + +Output a JSON object with EXACTLY these fields: + +{{ + "topics": ["distinctive_word_or_phrase_1", "topic_2", ...], + "quotes": ["[Speaker] verbatim quote", ...], + "summary": "2-3 sentences describing what this content is about." +}} + +RULES: +- Topics: 8-15 entries. Include proper nouns (names, places, projects), + distinctive technical terms, and key concepts. NOT generic words like + "conversation" or "discussion". +- Quotes: 2-5 entries. EXACT verbatim from the content, not paraphrased. + Attribute with [Speaker] prefix if speaker is identifiable. +- Summary: mention WHO, WHAT, and WHY. No filler. +- Write in the same language as the content. +- Output valid JSON only. No code fences. No commentary. +""" + + +class LLMConfig: + """Resolved LLM connection config. CLI flags > env vars.""" + + def __init__( + self, + endpoint: Optional[str] = None, + key: Optional[str] = None, + model: Optional[str] = None, + ): + self.endpoint = (endpoint or os.environ.get("LLM_ENDPOINT", "")).rstrip("/") + self.key = key or os.environ.get("LLM_KEY", "") + self.model = model or os.environ.get("LLM_MODEL", "") + + def missing(self) -> list: + missing = [] + if not self.endpoint: + missing.append("LLM_ENDPOINT (or --endpoint)") + if not self.model: + missing.append("LLM_MODEL (or --model)") + # key is optional — local inference servers (Ollama, vLLM) often don't require one + return missing + + +def _call_llm(cfg: LLMConfig, source_file: str, wing: str, room: str, content: str): + """Single LLM call via OpenAI-compatible /chat/completions. + + Returns (parsed_json_dict_or_None, usage_dict_or_None). + """ + try: + from mempalace.i18n import t + + lang_instruction = t("aaak.instruction") + except Exception: + lang_instruction = "" + + prompt = PROMPT_TEMPLATE.format( + source_file=source_file[:100], + wing=wing, + room=room, + content=content[:MAX_CONTENT_CHARS], + ) + if lang_instruction and "english" not in lang_instruction.lower(): + prompt += f"\n\nLanguage instruction: {lang_instruction}" + + body = json.dumps( + { + "model": cfg.model, + "max_tokens": MAX_OUTPUT_TOKENS, + "messages": [{"role": "user", "content": prompt}], + } + ).encode("utf-8") + + headers = {"Content-Type": "application/json"} + if cfg.key: + headers["Authorization"] = f"Bearer {cfg.key}" + + url = f"{cfg.endpoint}/chat/completions" + + for attempt in range(3): + try: + req = urllib.request.Request(url, data=body, headers=headers, method="POST") + with urllib.request.urlopen(req, timeout=HTTP_TIMEOUT_S) as resp: + raw = resp.read().decode("utf-8") + payload = json.loads(raw) + + text = payload["choices"][0]["message"]["content"].strip() + text = re.sub(r"^```(?:json)?\s*", "", text) + text = re.sub(r"\s*```$", "", text) + parsed = json.loads(text) + return parsed, payload.get("usage") + except json.JSONDecodeError: + return None, None + except urllib.error.HTTPError as e: + # 429 / 503 = retry with backoff + if e.code in (429, 503) and attempt < 2: + time.sleep(2 ** attempt) + continue + return None, None + except Exception as e: + if "rate" in str(e).lower() and attempt < 2: + time.sleep(2 ** attempt) + continue + return None, None + return None, None + + +def _parsed_to_closet_lines(parsed, drawer_ids, entities_str): + """Convert LLM's JSON output to closet pointer lines.""" + lines = [] + drawer_ref = ",".join(drawer_ids[:3]) + + for topic in parsed.get("topics", [])[:15]: + lines.append(f"{topic}|{entities_str}|→{drawer_ref}") + for quote in parsed.get("quotes", [])[:5]: + lines.append(f'{quote}|{entities_str}|→{drawer_ref}') + summary = parsed.get("summary", "") + if summary: + lines.append(f"{summary[:200]}|{entities_str}|→{drawer_ref}") + + return lines + + +def regenerate_closets( + palace_path, + wing=None, + sample=0, + dry_run=False, + cfg: Optional[LLMConfig] = None, +): + """Regenerate closets using a configured LLM for richer topic extraction. + + Reads existing drawers, sends content to the configured endpoint, + replaces regex closets with LLM-generated ones. Regex closets remain + as the fallback whenever the call fails. + """ + if cfg is None: + cfg = LLMConfig() + missing = cfg.missing() + if missing: + print("Error: missing configuration: " + ", ".join(missing)) + print("Set env vars LLM_ENDPOINT / LLM_MODEL (and optionally LLM_KEY),") + print("or pass --endpoint / --model / --key on the CLI.") + return {"error": "missing-config", "missing": missing} + + drawers_col = get_collection(palace_path, create=False) + closets_col = get_closets_collection(palace_path) + + total = drawers_col.count() + if total == 0: + print("No drawers in palace.") + return {"processed": 0} + + all_data = drawers_col.get(limit=total, include=["documents", "metadatas"]) + by_source = {} + for doc_id, doc, meta in zip(all_data["ids"], all_data["documents"], all_data["metadatas"]): + source = meta.get("source_file", "unknown") + w = meta.get("wing", "") + if wing and w != wing: + continue + if source not in by_source: + by_source[source] = {"drawer_ids": [], "content": [], "meta": meta} + by_source[source]["drawer_ids"].append(doc_id) + by_source[source]["content"].append(doc) + + sources = list(by_source.keys()) + if sample > 0: + sources = sources[:sample] + + print(f"Regenerating closets for {len(sources)} source files via {cfg.endpoint} ({cfg.model})...") + if dry_run: + print("DRY RUN — no changes will be written") + + processed = 0 + failed = 0 + total_input = 0 + total_output = 0 + + for i, source in enumerate(sources, 1): + data = by_source[source] + content = "\n\n".join(data["content"]) + meta = data["meta"] + w = meta.get("wing", "") + r = meta.get("room", "") + entities = meta.get("entities", "") + + if dry_run: + print(f" [{i}/{len(sources)}] {os.path.basename(source)} ({len(content)} chars)") + continue + + parsed, usage = _call_llm(cfg, source, w, r, content) + if not parsed: + failed += 1 + print(f" [{i}/{len(sources)}] ✗ {os.path.basename(source)} — LLM failed") + continue + + if usage: + total_input += usage.get("prompt_tokens", 0) + total_output += usage.get("completion_tokens", 0) + + lines = _parsed_to_closet_lines(parsed, data["drawer_ids"], entities) + closet_id_base = f"closet_{w}_{r}_{source.split('/')[-1][:30]}" + + # Delete old regex closets for this source before writing LLM ones + try: + old_ids = closets_col.get( + where={"source_file": source}, include=[] + ).get("ids", []) + if old_ids: + closets_col.delete(ids=old_ids) + except Exception: + pass + + upsert_closet_lines( + closets_col, + closet_id_base, + lines, + { + "wing": w, + "room": r, + "source_file": source, + "generated_by": f"llm:{cfg.model}", + "filed_at": datetime.now().isoformat(), + "entities": entities, + }, + ) + + processed += 1 + n_topics = len(parsed.get("topics", [])) + print(f" [{i}/{len(sources)}] ✓ {os.path.basename(source)} — {n_topics} topics") + + print(f"\nDone. {processed} regenerated, {failed} failed.") + if total_input or total_output: + print(f"Tokens: {total_input:,} in + {total_output:,} out (cost depends on provider)") + + return { + "processed": processed, + "failed": failed, + "input_tokens": total_input, + "output_tokens": total_output, + } + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Regenerate closets via a user-configured LLM (OpenAI-compatible API)" + ) + parser.add_argument( + "--palace", + default=os.path.expanduser("~/.mempalace/palace"), + help="Path to the palace", + ) + parser.add_argument("--wing", default=None, help="Limit to one wing") + parser.add_argument( + "--sample", type=int, default=0, help="Only process first N source files" + ) + parser.add_argument( + "--dry-run", action="store_true", help="List work without calling the LLM" + ) + parser.add_argument( + "--endpoint", + default=None, + help="LLM base URL (overrides $LLM_ENDPOINT), e.g. http://localhost:11434/v1", + ) + parser.add_argument( + "--key", + default=None, + help="LLM bearer token (overrides $LLM_KEY). Optional for local inference.", + ) + parser.add_argument( + "--model", + default=None, + help='LLM model name (overrides $LLM_MODEL), e.g. "gpt-4o-mini" or "llama3:8b"', + ) + args = parser.parse_args() + + cfg = LLMConfig(endpoint=args.endpoint, key=args.key, model=args.model) + regenerate_closets( + args.palace, wing=args.wing, sample=args.sample, dry_run=args.dry_run, cfg=cfg + ) diff --git a/tests/test_closet_llm.py b/tests/test_closet_llm.py new file mode 100644 index 0000000..762e16d --- /dev/null +++ b/tests/test_closet_llm.py @@ -0,0 +1,222 @@ +"""Unit tests for the optional LLM-based closet regeneration. + +These tests don't hit the network. They mock urllib to verify: +- LLMConfig correctly reads env vars and CLI overrides +- missing config is reported cleanly +- the OpenAI-compatible request shape is correct +- response parsing handles the standard chat-completions payload +""" + +import io +import json +import os +import tempfile +from unittest.mock import patch + +import pytest + +from mempalace.closet_llm import ( + LLMConfig, + _call_llm, + _parsed_to_closet_lines, + regenerate_closets, +) + + +# ── LLMConfig ───────────────────────────────────────────────────────────── + + +class TestLLMConfig: + def test_reads_env_vars(self, monkeypatch): + monkeypatch.setenv("LLM_ENDPOINT", "http://localhost:11434/v1") + monkeypatch.setenv("LLM_KEY", "sk-abc") + monkeypatch.setenv("LLM_MODEL", "llama3:8b") + c = LLMConfig() + assert c.endpoint == "http://localhost:11434/v1" + assert c.key == "sk-abc" + assert c.model == "llama3:8b" + + def test_cli_flags_override_env(self, monkeypatch): + monkeypatch.setenv("LLM_ENDPOINT", "http://env-endpoint/v1") + monkeypatch.setenv("LLM_MODEL", "env-model") + c = LLMConfig(endpoint="http://flag-endpoint/v1", model="flag-model") + assert c.endpoint == "http://flag-endpoint/v1" + assert c.model == "flag-model" + + def test_trailing_slash_stripped(self): + c = LLMConfig(endpoint="http://foo/v1/", model="m") + assert c.endpoint == "http://foo/v1" + + def test_missing_reports_required(self, monkeypatch): + monkeypatch.delenv("LLM_ENDPOINT", raising=False) + monkeypatch.delenv("LLM_KEY", raising=False) + monkeypatch.delenv("LLM_MODEL", raising=False) + c = LLMConfig() + missing = c.missing() + assert any("ENDPOINT" in m for m in missing) + assert any("MODEL" in m for m in missing) + # key is optional + assert not any("KEY" in m for m in missing) + + def test_key_is_optional(self, monkeypatch): + monkeypatch.delenv("LLM_KEY", raising=False) + c = LLMConfig(endpoint="http://local/v1", model="m") + assert c.missing() == [] + + +# ── _parsed_to_closet_lines ────────────────────────────────────────────── + + +class TestParsedToLines: + def test_topics_become_pointers(self): + parsed = {"topics": ["authentication", "jwt tokens"], "quotes": [], "summary": ""} + lines = _parsed_to_closet_lines(parsed, ["d1", "d2"], "Alice;Bob") + assert len(lines) == 2 + assert "authentication|Alice;Bob|→d1,d2" in lines + assert "jwt tokens|Alice;Bob|→d1,d2" in lines + + def test_quotes_and_summary_included(self): + parsed = { + "topics": ["t1"], + "quotes": ["[Igor] we ship Friday"], + "summary": "Release planning discussion", + } + lines = _parsed_to_closet_lines(parsed, ["d1"], "") + joined = "\n".join(lines) + assert "we ship Friday" in joined + assert "Release planning discussion" in joined + + def test_caps_topics_at_15(self): + parsed = {"topics": [f"t{i}" for i in range(20)], "quotes": [], "summary": ""} + lines = _parsed_to_closet_lines(parsed, ["d1"], "") + assert len(lines) == 15 + + +# ── _call_llm (HTTP mocked) ────────────────────────────────────────────── + + +class _FakeResp: + """Mimics urlopen's context-manager response.""" + + def __init__(self, payload: dict, status: int = 200): + self._body = json.dumps(payload).encode("utf-8") + self.status = status + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + + def read(self): + return self._body + + +class TestCallLLM: + def _make_cfg(self): + return LLMConfig( + endpoint="http://localhost:11434/v1", key="sk-test", model="llama3:8b" + ) + + def test_request_shape_and_parsing(self): + cfg = self._make_cfg() + captured = {} + + def fake_urlopen(req, timeout=None): + captured["url"] = req.full_url + captured["headers"] = dict(req.header_items()) + captured["body"] = json.loads(req.data.decode("utf-8")) + return _FakeResp( + { + "choices": [ + { + "message": { + "content": json.dumps( + { + "topics": ["postgres"], + "quotes": ["[Igor] migrate now"], + "summary": "db migration", + } + ) + } + } + ], + "usage": {"prompt_tokens": 42, "completion_tokens": 17}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + parsed, usage = _call_llm(cfg, "/tmp/test.md", "w", "r", "content body") + + assert parsed["topics"] == ["postgres"] + assert usage["prompt_tokens"] == 42 + assert captured["url"] == "http://localhost:11434/v1/chat/completions" + # Authorization header is stored capitalized-then-lowercase depending on urllib version + auth_vals = {v for k, v in captured["headers"].items() if k.lower() == "authorization"} + assert "Bearer sk-test" in auth_vals + assert captured["body"]["model"] == "llama3:8b" + assert captured["body"]["messages"][0]["role"] == "user" + + def test_omits_auth_header_when_no_key(self): + cfg = LLMConfig(endpoint="http://localhost:11434/v1", model="llama3:8b") + captured_headers = {} + + def fake_urlopen(req, timeout=None): + captured_headers.update({k.lower(): v for k, v in req.header_items()}) + return _FakeResp( + { + "choices": [ + {"message": {"content": '{"topics":[],"quotes":[],"summary":""}'}} + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + _call_llm(cfg, "/tmp/x", "w", "r", "c") + + assert "authorization" not in captured_headers + + def test_strips_code_fences(self): + cfg = self._make_cfg() + fenced = '```json\n{"topics":["t1"],"quotes":[],"summary":""}\n```' + + def fake_urlopen(req, timeout=None): + return _FakeResp( + { + "choices": [{"message": {"content": fenced}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + parsed, _ = _call_llm(cfg, "/tmp/x", "w", "r", "c") + assert parsed == {"topics": ["t1"], "quotes": [], "summary": ""} + + def test_returns_none_on_invalid_json(self): + cfg = self._make_cfg() + + def fake_urlopen(req, timeout=None): + return _FakeResp( + { + "choices": [{"message": {"content": "not json at all"}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + parsed, usage = _call_llm(cfg, "/tmp/x", "w", "r", "c") + assert parsed is None + + +# ── regenerate_closets error paths ─────────────────────────────────────── + + +class TestRegenerateClosets: + def test_missing_config_returns_error(self, monkeypatch): + monkeypatch.delenv("LLM_ENDPOINT", raising=False) + monkeypatch.delenv("LLM_MODEL", raising=False) + with tempfile.TemporaryDirectory() as palace: + result = regenerate_closets(palace) + assert result["error"] == "missing-config" + assert any("ENDPOINT" in m for m in result["missing"]) From 8e446f904ce00f58347fa5469ae1dadfa1278637 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Mon, 13 Apr 2026 08:43:54 -0300 Subject: [PATCH 08/12] =?UTF-8?q?fix(search):=20hybrid=20closet+drawer=20r?= =?UTF-8?q?etrieval=20=E2=80=94=20closets=20boost,=20never=20gate=20(#795)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mempalace/searcher.py | 240 +++++++++++++++++++----------------- tests/test_hybrid_search.py | 141 +++++++++++++++++++++ 2 files changed, 270 insertions(+), 111 deletions(-) create mode 100644 tests/test_hybrid_search.py diff --git a/mempalace/searcher.py b/mempalace/searcher.py index 19b07f4..06806aa 100644 --- a/mempalace/searcher.py +++ b/mempalace/searcher.py @@ -183,138 +183,156 @@ def search_memories( where = build_where_filter(wing, room) - # Try closet-first search: search the compact index, then hydrate drawers - closet_hits = [] + # Hybrid retrieval: always query drawers directly (the floor), then use + # closet hits to boost rankings. Closets are a ranking SIGNAL, never a + # GATE — direct drawer search is always the baseline. + # + # This avoids the "weak-closets regression" where narrative content + # produces low-signal closets (regex extraction matches few topics) + # and closet-first routing hides drawers that direct search would find. + try: + dkwargs = { + "query_texts": [query], + "n_results": n_results * 3, # over-fetch for re-ranking + "include": ["documents", "metadatas", "distances"], + } + if where: + dkwargs["where"] = where + drawer_results = drawers_col.query(**dkwargs) + except Exception as e: + return {"error": f"Search error: {e}"} + + # Gather closet hits (best-per-source) to build a boost lookup. + closet_boost_by_source = {} # source_file -> (rank, closet_dist, preview) try: closets_col = get_closets_collection(palace_path, create=False) ckwargs = { "query_texts": [query], - "n_results": n_results * 2, # over-fetch closets to find best drawers + "n_results": n_results * 2, "include": ["documents", "metadatas", "distances"], } if where: ckwargs["where"] = where closet_results = closets_col.query(**ckwargs) - if closet_results["documents"][0]: - closet_hits = list(zip( + for rank, (doc, meta, dist) in enumerate( + zip( closet_results["documents"][0], closet_results["metadatas"][0], closet_results["distances"][0], - )) + ) + ): + source = meta.get("source_file", "") + if source and source not in closet_boost_by_source: + closet_boost_by_source[source] = (rank, dist, doc[:200]) except Exception: - pass # no closets yet — fall through to direct drawer search + pass # no closets yet — hybrid degrades to pure drawer search - # If closets found results, hydrate the referenced drawers - MAX_HYDRATION_CHARS = 10000 # cap to prevent blowup on large source files + # Rank-based boost. Ordinal signal (which closet matched best) is more + # reliable than absolute distance on narrative content. + CLOSET_RANK_BOOSTS = [0.40, 0.25, 0.15, 0.08, 0.04] + CLOSET_DISTANCE_CAP = 1.5 # cosine dist > 1.5 = too weak to use as signal - if closet_hits: - import re - seen_sources = set() - hits = [] - for closet_doc, closet_meta, closet_dist in closet_hits: - source = closet_meta.get("source_file", "") - if source in seen_sources: - continue - seen_sources.add(source) - - # Find drawers for this source file, grep for most relevant chunk - try: - drawer_results = drawers_col.get( - where={"source_file": source}, - include=["documents", "metadatas"], - ) - if drawer_results.get("ids"): - # Drawer-grep: score each chunk against the query, - # return the best-matching chunk first + surrounding context - query_terms = set(re.findall(r'\w{2,}', query.lower())) - best_idx = 0 - best_score = -1 - for idx, doc in enumerate(drawer_results["documents"]): - doc_lower = doc.lower() - score = sum(1 for t in query_terms if t in doc_lower) - if score > best_score: - best_score = score - best_idx = idx - - # Build result: best chunk first, then neighbors - docs = drawer_results["documents"] - n_docs = len(docs) - # Include best chunk + 1 before + 1 after for context - start = max(0, best_idx - 1) - end = min(n_docs, best_idx + 2) - relevant_text = "\n\n".join(docs[start:end]) - - if len(relevant_text) > MAX_HYDRATION_CHARS: - relevant_text = relevant_text[:MAX_HYDRATION_CHARS] + f"\n\n[...truncated. {n_docs} total drawers. Use mempalace_get_drawer for full content.]" - - meta = drawer_results["metadatas"][best_idx] - hits.append({ - "text": relevant_text, - "wing": meta.get("wing", "unknown"), - "room": meta.get("room", "unknown"), - "source_file": Path(source).name, - "similarity": round(max(0.0, 1 - closet_dist), 3), - "distance": round(closet_dist, 4), - "matched_via": "closet", - "closet_preview": closet_doc[:200], - "drawer_index": best_idx, - "total_drawers": n_docs, - }) - except Exception: - pass - - if len(hits) >= n_results: - break - - if hits: - # Re-rank with BM25 hybrid scoring - hits = _hybrid_rank(hits, query) - return { - "query": query, - "filters": {"wing": wing, "room": room}, - "total_before_filter": len(closet_hits), - "results": hits, - } - - # Fallback: direct drawer search (no closets yet, or closets empty) - try: - kwargs = { - "query_texts": [query], - "n_results": n_results, - "include": ["documents", "metadatas", "distances"], - } - if where: - kwargs["where"] = where - - results = drawers_col.query(**kwargs) - except Exception as e: - return {"error": f"Search error: {e}"} - - docs = results["documents"][0] - metas = results["metadatas"][0] - dists = results["distances"][0] - - hits = [] - for doc, meta, dist in zip(docs, metas, dists): - # Filter on raw distance before rounding to avoid precision loss + scored = [] + for doc, meta, dist in zip( + drawer_results["documents"][0], + drawer_results["metadatas"][0], + drawer_results["distances"][0], + ): if max_distance > 0.0 and dist > max_distance: continue - hits.append( - { - "text": doc, - "wing": meta.get("wing", "unknown"), - "room": meta.get("room", "unknown"), - "source_file": Path(meta.get("source_file", "?")).name, - "similarity": round(max(0.0, 1 - dist), 3), - "distance": round(dist, 4), - } - ) - # Re-rank with BM25 hybrid scoring + source = meta.get("source_file", "") + boost = 0.0 + matched_via = "drawer" + closet_preview = None + if source in closet_boost_by_source: + c_rank, c_dist, c_preview = closet_boost_by_source[source] + if c_dist <= CLOSET_DISTANCE_CAP and c_rank < len(CLOSET_RANK_BOOSTS): + boost = CLOSET_RANK_BOOSTS[c_rank] + matched_via = "drawer+closet" + closet_preview = c_preview + + effective_dist = dist - boost + entry = { + "text": doc, + "wing": meta.get("wing", "unknown"), + "room": meta.get("room", "unknown"), + "source_file": Path(source).name if source else "?", + "similarity": round(max(0.0, 1 - effective_dist), 3), + "distance": round(dist, 4), + "effective_distance": round(effective_dist, 4), + "closet_boost": round(boost, 3), + "matched_via": matched_via, + "_sort_key": effective_dist, + } + if closet_preview: + entry["closet_preview"] = closet_preview + scored.append(entry) + + scored.sort(key=lambda h: h["_sort_key"]) + hits = scored[:n_results] + + # Drawer-grep enrichment: for top hits whose source file has multiple + # drawers, return the best-matching chunk + its immediate neighbors + # instead of just the single drawer. Preserves the chunk-expansion + # behavior users relied on in the closet-first path. + MAX_HYDRATION_CHARS = 10000 + import re as _re + + for h in hits: + if h["matched_via"] == "drawer": + continue + # Only enrich closet-matched hits (cheap: we already know source matters) + source_name = h["source_file"] + # Look up full source_file by matching suffix in candidate pool + full_source = next( + ( + m.get("source_file", "") + for m in drawer_results["metadatas"][0] + if m.get("source_file", "").endswith(source_name) + ), + "", + ) + if not full_source: + continue + try: + source_drawers = drawers_col.get( + where={"source_file": full_source}, include=["documents"] + ) + except Exception: + continue + docs = source_drawers.get("documents") or [] + if len(docs) <= 1: + continue + + query_terms = set(_re.findall(r"\w{2,}", query.lower())) + best_idx, best_score = 0, -1 + for idx, d in enumerate(docs): + d_lower = d.lower() + s = sum(1 for t in query_terms if t in d_lower) + if s > best_score: + best_score, best_idx = s, idx + + start = max(0, best_idx - 1) + end = min(len(docs), best_idx + 2) + expanded = "\n\n".join(docs[start:end]) + if len(expanded) > MAX_HYDRATION_CHARS: + expanded = ( + expanded[:MAX_HYDRATION_CHARS] + + f"\n\n[...truncated. {len(docs)} total drawers. Use mempalace_get_drawer for full content.]" + ) + h["text"] = expanded + h["drawer_index"] = best_idx + h["total_drawers"] = len(docs) + + # BM25 hybrid re-rank within the final candidate set hits = _hybrid_rank(hits, query) + for h in hits: + h.pop("_sort_key", None) + return { "query": query, "filters": {"wing": wing, "room": room}, - "total_before_filter": len(docs), + "total_before_filter": len(drawer_results["documents"][0]), "results": hits, } diff --git a/tests/test_hybrid_search.py b/tests/test_hybrid_search.py new file mode 100644 index 0000000..02d3f5f --- /dev/null +++ b/tests/test_hybrid_search.py @@ -0,0 +1,141 @@ +"""Tests for the hybrid closet+drawer retrieval in search_memories. + +The hybrid path queries drawers directly (the floor) AND closets, applying a +rank-based boost to drawers whose source_file appears in top closet hits. +This avoids the "weak-closets regression" where low-signal closets (from +regex extraction on narrative content) could hide drawers that direct +search would have found. +""" + +import os +import tempfile + +import chromadb +import pytest + +from mempalace.palace import ( + get_collection, + get_closets_collection, + upsert_closet_lines, +) +from mempalace.searcher import search_memories + + +def _seed_drawers(palace_path): + """Insert 4 short drawers with deterministic content.""" + col = get_collection(palace_path, create=True) + col.upsert( + ids=["D1", "D2", "D3", "D4"], + documents=[ + "We switched the auth service to use JWT tokens with a 24h expiry.", + "Database migration to PostgreSQL 15 completed last Tuesday.", + "The frontend team is debating whether to adopt TanStack Query.", + "Kafka consumer rebalance timeout set to 45 seconds after incident.", + ], + metadatas=[ + {"wing": "backend", "room": "auth", "source_file": "fixture_D1.md"}, + {"wing": "backend", "room": "db", "source_file": "fixture_D2.md"}, + {"wing": "frontend", "room": "state", "source_file": "fixture_D3.md"}, + {"wing": "backend", "room": "queue", "source_file": "fixture_D4.md"}, + ], + ) + + +def _seed_strong_closet_for(palace_path, drawer_id, source_file, topics): + """Insert a closet whose content strongly overlaps the query keywords.""" + col = get_closets_collection(palace_path) + lines = [f"{t}||→{drawer_id}" for t in topics] + upsert_closet_lines( + col, + closet_id_base=f"closet_{drawer_id}", + lines=lines, + metadata={ + "wing": "backend", + "room": "auth", + "source_file": source_file, + "generated_by": "test", + }, + ) + + +# ── core invariant: closets can only HELP, never HIDE ───────────────────── + + +class TestHybridInvariant: + def test_no_closets_degrades_to_direct_drawer_search(self, tmp_path): + palace = str(tmp_path / "palace") + _seed_drawers(palace) + # No closets created. + result = search_memories("Kafka rebalance timeout", palace, n_results=3) + ids = [h["source_file"] for h in result["results"]] + assert ids, "should return results" + assert "fixture_D4.md" in ids, ( + "direct drawer search alone should surface the Kafka drawer" + ) + + def test_weak_closets_do_not_hide_direct_drawer_hits(self, tmp_path): + """A closet that points at a wrong drawer must NOT suppress the + drawer that direct search would have ranked first.""" + palace = str(tmp_path / "palace") + _seed_drawers(palace) + # Seed a misleading closet: it matches a generic phrase but points at D3. + _seed_strong_closet_for( + palace, + drawer_id="D3", + source_file="fixture_D3.md", + topics=["Kafka queue tuning", "consumer rebalance config"], + ) + result = search_memories("Kafka consumer rebalance timeout", palace, n_results=5) + ids = [h["source_file"] for h in result["results"]] + assert "fixture_D4.md" in ids, ( + "D4 must appear — direct drawer search alone would rank it first. " + "Closet pointing to D3 should only boost D3, never hide D4." + ) + + def test_closet_boost_lifts_matching_drawer(self, tmp_path): + """When a closet agrees with direct search, the matching drawer + should be boosted to rank 1.""" + palace = str(tmp_path / "palace") + _seed_drawers(palace) + _seed_strong_closet_for( + palace, + drawer_id="D1", + source_file="fixture_D1.md", + topics=["JWT auth tokens", "session expiry", "authentication service"], + ) + result = search_memories("JWT auth tokens expiry", palace, n_results=3) + ids = [h["source_file"] for h in result["results"]] + assert ids[0] == "fixture_D1.md" + top = result["results"][0] + assert top["matched_via"] == "drawer+closet" + assert top["closet_boost"] > 0 + + +# ── closet_boost metadata ──────────────────────────────────────────────── + + +class TestClosetMetadata: + def test_closet_preview_exposed_when_boosted(self, tmp_path): + palace = str(tmp_path / "palace") + _seed_drawers(palace) + _seed_strong_closet_for( + palace, + drawer_id="D1", + source_file="fixture_D1.md", + topics=["JWT auth tokens", "24h expiry", "authentication"], + ) + result = search_memories("JWT authentication", palace, n_results=2) + top = result["results"][0] + assert top["source_file"] == "fixture_D1.md" + assert "closet_preview" in top + + def test_drawer_only_hits_have_no_closet_preview(self, tmp_path): + palace = str(tmp_path / "palace") + _seed_drawers(palace) + # No closets + result = search_memories("TanStack Query", palace, n_results=2) + assert result["results"] + for h in result["results"]: + assert h["matched_via"] == "drawer" + assert "closet_preview" not in h + assert h["closet_boost"] == 0.0 From 1263c3c91ed39d9f9abc8b0f0d5a875b2b1d6794 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Mon, 13 Apr 2026 18:20:11 -0300 Subject: [PATCH 09/12] merge: full hardened stack + rewrite fact_checker around actual KG API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Merges the full hardened stack (up through #791 drawer-grep) and turns fact_checker from "dead code hidden behind bare except" into an actually-working offline contradiction detector with tests. ## Dead paths the PR body advertised but the code never executed Both buried by a single outer ``except Exception: pass``: * ``kg.query(subject)`` — ``KnowledgeGraph`` has no ``query()`` method; it has ``query_entity()``. The attribute error was silently swallowed and the entire KG branch always returned ``[]``. Now using ``kg.query_entity(subject, direction="outgoing")`` with proper handling of the ``predicate``/``object``/``current``/``valid_to`` fields the real API returns. * ``KnowledgeGraph(palace_path=palace_path)`` — the constructor's only kwarg is ``db_path``. Passing ``palace_path`` raised TypeError, silently swallowed. Now computing the db_path correctly from ``/knowledge_graph.sqlite3``, matching the convention the MCP server already uses. ## Contradiction logic rewritten The previous ``if kg_pred in claim and fact.object not in claim`` only fired when text used the SAME predicate word as the KG fact — the exact opposite of the stated use case ("Bob is Alice's brother" when KG says husband" would NOT have fired). Replaced with a proper parse → lookup → compare pipeline: * ``_extract_claims`` parses two surface forms ("X is Y's Z" and "X's Z is Y") into ``(subject, predicate, object)`` triples. * ``_check_kg_contradictions`` pulls the subject's outgoing facts and flags two classes: - ``relationship_mismatch`` when a current KG fact matches the same ``(subject, object)`` pair but with a different predicate. - ``stale_fact`` when the exact triple exists but is ``valid_to``-closed in the past. * Stale-fact detection is now implemented (the PR body claimed it; the old code silently didn't implement it). ## Performance fix — O(n²) → O(mentioned × n) ``_check_entity_confusion`` previously computed Levenshtein for every pair of registered names on every ``check_text`` call. For 1,000 registered names that's ~500K edit-distance calls per hook invocation. Now we first identify which registry names actually appear in the text (single regex scan), then only compute edit distance between mentioned and unmentioned names. Pinned by a test that asserts <200ms on a 500- name registry with zero mentions. Also: when *both* similar names are mentioned in the text, we no longer flag them — the user clearly knows they're different people. ## Shared entity-registry loader ``mempalace/miner.py`` already had an mtime-cached loader for ``~/.mempalace/known_entities.json``. fact_checker had a duplicate implementation that leaked file handles and ignored caching. Extended miner's cache to expose both the flat set (``_load_known_entities``) and the raw category dict (``_load_known_entities_raw``); fact_checker now imports the latter. No more double disk reads, no more handle leak. ## Tests — 24 cases in tests/test_fact_checker.py All three detection paths + both dead-code regressions: * ``test_kg_init_uses_db_path_not_palace_path_kwarg`` — pins the correct KG constructor signature so the ``palace_path=`` bug can't come back. * ``test_relationship_mismatch_detected`` — the headline example from the PR body now actually fires. * ``test_stale_fact_detected`` — valid_to-closed triple is flagged. * ``test_current_fact_same_triple_is_not_flagged`` — no false positive on a still-valid match. * ``test_performance_bounded_by_mentioned_names`` — 500-name registry, zero mentions, <200ms. Regression for the O(n²) blowup. * ``test_no_false_positive_when_both_names_mentioned`` — Mila and Milla in the same text is fine. * Plus claim extraction, flatten_names shapes, CLI exit code, empty text handling, missing-palace graceful fallback, registry-dict shape support. 785/785 suite pass. ruff + format clean on CI-pinned 0.4.x. --- mempalace/fact_checker.py | 376 ++++++++++++++++++++++++++----------- mempalace/mcp_server.py | 30 ++- mempalace/miner.py | 52 +++-- tests/test_fact_checker.py | 288 ++++++++++++++++++++++++++++ 4 files changed, 620 insertions(+), 126 deletions(-) create mode 100644 tests/test_fact_checker.py diff --git a/mempalace/fact_checker.py b/mempalace/fact_checker.py index 281f117..50e8842 100644 --- a/mempalace/fact_checker.py +++ b/mempalace/fact_checker.py @@ -1,152 +1,304 @@ """ fact_checker.py — Verify text against known facts in the palace. -Checks AI responses, diary entries, and new content against the -entity registry and knowledge graph for contradictions. Catches: - - Wrong names (similar but different entities) - - Wrong relationships (calling someone the wrong role) - - Stale facts (things that changed — KG has valid_from/valid_to) +Checks AI responses, diary entries, and new content against the entity +registry and knowledge graph for three classes of issue: -Uses the entity_registry and knowledge_graph — no hardcoded facts. + * similar_name — text mentions a name that's one/two edits + away from *another* registered name, raising + the possibility of a typo or mix-up. + * relationship_mismatch — text asserts a role between two entities + (e.g. "Bob is Alice's brother") while the KG + records a *different* current role for the + same subject/object pair. + * stale_fact — text asserts a fact that the KG marks closed + (``valid_to`` in the past). + +Purely offline. Inputs: entity_registry JSON + KG SQLite. No network. Usage: from mempalace.fact_checker import check_text issues = check_text("Bob is Alice's brother", palace_path) - # → [{"type": "relationship_mismatch", "detail": "KG says Bob is Alice's husband"}] # CLI - python -m mempalace.fact_checker "Bob is Alice's brother" --palace ~/.mempalace/palace + python -m mempalace.fact_checker "Bob is Alice's brother" \\ + --palace ~/.mempalace/palace """ +from __future__ import annotations + import os import re -from pathlib import Path +from datetime import datetime, timezone + +# Share miner's mtime-cached registry loader so we don't double-read +# ~/.mempalace/known_entities.json on every check_text call. +from .miner import _load_known_entities_raw -def check_text(text, palace_path=None, config=None): - """Check text for contradictions against known facts. +# Narrow detection patterns — parse "X is Y's Z" and "X's Z is Y". +# Names are captured greedily as word sequences (letters + optional +# capitalized follow-ons) so simple multi-token names still work. +# Relationship words are constrained to sane lengths to avoid matching +# arbitrary filler. +_RELATIONSHIP_PATTERNS = [ + # "Bob is Alice's brother" → subject=Bob, possessor=Alice, role=brother + re.compile(r"\b([A-Z][\w-]+)\s+is\s+([A-Z][\w-]+)'s\s+([a-z]{3,20})\b"), + # "Alice's brother is Bob" → possessor=Alice, role=brother, subject=Bob + re.compile(r"\b([A-Z][\w-]+)'s\s+([a-z]{3,20})\s+is\s+([A-Z][\w-]+)\b"), +] - Returns list of issues found. Empty list = no contradictions. + +def check_text(text: str, palace_path: str = None, config=None) -> list: + """Return a list of issues detected in ``text``. + + Empty list means "no contradictions found" — absence of evidence, not + evidence of absence. The detector is deliberately conservative: + every issue is anchored to a specific KG fact or registry entry. """ if config is None: from .config import MempalaceConfig + config = MempalaceConfig() if palace_path is None: palace_path = config.palace_path - issues = [] + if not text: + return [] - # Load known entities - entity_names = _load_known_entities() + issues: list = [] + entity_names_raw = _load_known_entities_raw() - # Check entity name confusion (similar names that might be mixed up) - issues.extend(_check_entity_confusion(text, entity_names)) - - # Check against knowledge graph facts - issues.extend(_check_kg_facts(text, palace_path)) + issues.extend(_check_entity_confusion(text, entity_names_raw)) + issues.extend(_check_kg_contradictions(text, palace_path)) return issues -def _load_known_entities(): - """Load entity names from the registry.""" - import json - registry_path = os.path.expanduser("~/.mempalace/known_entities.json") - if not os.path.exists(registry_path): - return {} - try: - return json.loads(open(registry_path).read()) - except Exception: - return {} +# ── entity-name confusion ──────────────────────────────────────────── -def _check_entity_confusion(text, entity_names): - """Check if text confuses similar entity names.""" - issues = [] - all_names = set() - for cat in entity_names.values(): +def _flatten_names(entity_names_raw: dict) -> set: + """Flatten a ``{category: [names]}`` or ``{category: {name: meta}}`` + registry into a set of names.""" + flat: set = set() + for cat in entity_names_raw.values(): if isinstance(cat, list): - all_names.update(cat) + flat.update(str(n) for n in cat if n) elif isinstance(cat, dict): - all_names.update(cat.keys()) + flat.update(str(k) for k in cat.keys() if k) + return flat - # Find names mentioned in text - mentioned = set() + +def _check_entity_confusion(text: str, entity_names_raw: dict) -> list: + """Flag names mentioned in the text that are edit-distance ≤ 2 from + a *different* registered name — a common typo / mix-up pattern. + + Performance note: the original O(n²) pairwise scan over the full + registry is gone. We first identify which names actually appear in + the text, then only compute edit distance between *mentioned* names + and the rest of the registry. This makes the cost O(m × n) where m + is the handful of names in the text, not the full registry. + """ + all_names = _flatten_names(entity_names_raw) + if not all_names: + return [] + + # Which names from the registry actually appear in the text? + mentioned: list = [] for name in all_names: - if re.search(r'\b' + re.escape(name) + r'\b', text, re.IGNORECASE): - mentioned.add(name) + if re.search(r"\b" + re.escape(name) + r"\b", text, re.IGNORECASE): + mentioned.append(name) + if not mentioned: + return [] - # Check for names that are very similar but different (edit distance 1-2) - name_list = sorted(all_names) - for i, name_a in enumerate(name_list): - for name_b in name_list[i + 1:]: - if _edit_distance(name_a.lower(), name_b.lower()) <= 2: - if name_a in mentioned or name_b in mentioned: - if name_a in text and name_b not in text: - issues.append({ - "type": "similar_name", - "detail": f"'{name_a}' mentioned — did you mean '{name_b}'? (similar names in registry)", - "names": [name_a, name_b], - }) + issues: list = [] + seen_pairs: set = set() + for name_a in mentioned: + a_lower = name_a.lower() + for name_b in all_names: + if name_b == name_a: + continue + # Dedupe by unordered pair so we don't double-report. + pair_key = tuple(sorted((name_a.lower(), name_b.lower()))) + if pair_key in seen_pairs: + continue + # Only flag when name_b is a *different* registry entry that + # was NOT mentioned — otherwise both names in the text is + # just the user writing about two people. + if name_b in mentioned: + seen_pairs.add(pair_key) + continue + distance = _edit_distance(a_lower, name_b.lower()) + if 0 < distance <= 2: + issues.append( + { + "type": "similar_name", + "detail": ( + f"'{name_a}' mentioned — did you mean " + f"'{name_b}'? (edit distance {distance})" + ), + "names": [name_a, name_b], + "distance": distance, + } + ) + seen_pairs.add(pair_key) return issues -def _check_kg_facts(text, palace_path): - """Check text against knowledge graph for contradictions.""" - issues = [] +# ── KG contradictions ──────────────────────────────────────────────── + + +def _extract_claims(text: str) -> list: + """Yield structured (subject, predicate, object) claims from ``text``. + + The two supported surface forms are "X is Y's Z" and "X's Z is Y", + both of which resolve to the triple ``(X, Z, Y)`` — ``X`` has role + ``Z`` with respect to ``Y``. Matches are case-preserving for the + entity names (KG lookup is case-insensitive on normalized IDs). + """ + claims: list = [] + for pat in _RELATIONSHIP_PATTERNS: + for match in pat.finditer(text): + groups = match.groups() + if pat is _RELATIONSHIP_PATTERNS[0]: + subject, possessor, role = groups[0], groups[1], groups[2] + else: + possessor, role, subject = groups[0], groups[1], groups[2] + claims.append( + { + "subject": subject, + "predicate": role.lower(), + "object": possessor, + "span": match.group(0), + } + ) + return claims + + +def _check_kg_contradictions(text: str, palace_path: str) -> list: + """Compare each claim in ``text`` against the KG. + + For every claim ``(subject, predicate, object)`` parsed from the + text, look up the subject's current KG triples: + + * ``relationship_mismatch`` fires when the KG records a fact about + the same ``(subject, object)`` pair but with a *different* + predicate — e.g. text says "brother" but KG says "husband". + * ``stale_fact`` fires when the KG has the exact ``(subject, + predicate, object)`` triple but its ``valid_to`` is in the past, + meaning the claim is no longer current. + """ + claims = _extract_claims(text) + if not claims: + return [] + try: from .knowledge_graph import KnowledgeGraph - kg = KnowledgeGraph(palace_path=palace_path) - # Extract relationship claims from text - # Pattern: "X is Y's Z" or "X's Z is Y" - patterns = [ - (r"(\w+)\s+is\s+(\w+)'s\s+(\w+)", "subject", "possessor", "role"), - (r"(\w+)'s\s+(\w+)\s+is\s+(\w+)", "possessor", "role", "subject"), - ] - - for pattern, *roles in patterns: - for match in re.finditer(pattern, text, re.IGNORECASE): - groups = match.groups() - subject = groups[0] - # Query KG for this entity - try: - facts = kg.query(subject) - if facts: - for fact in facts: - # Check if the claim contradicts a known fact - if fact.get("valid_to") is None: # current fact - kg_pred = fact.get("predicate", "").lower() - claim = match.group(0).lower() - if kg_pred in claim and fact.get("object", "").lower() not in claim: - issues.append({ - "type": "relationship_mismatch", - "detail": f"Text says '{match.group(0)}' but KG says: {subject} {kg_pred} {fact.get('object')}", - "entity": subject, - }) - except Exception: - pass + # KG lives alongside the palace collection; mcp_server uses the + # same convention (see _kg init). Pass ``db_path`` — the previous + # code passed a nonexistent ``palace_path`` kwarg which raised + # TypeError, silently swallowed by the outer except and rendered + # the entire KG-check path dead. + kg = KnowledgeGraph(db_path=os.path.join(palace_path, "knowledge_graph.sqlite3")) except Exception: - pass # KG not available — skip + # KG unavailable (brand-new palace, corrupted DB, etc.) — skip. + return [] + + issues: list = [] + for claim in claims: + subject = claim["subject"] + claim_pred = claim["predicate"] + claim_obj = claim["object"] + try: + facts = kg.query_entity(subject, direction="outgoing") + except Exception: + continue + if not facts: + continue + + current_facts = [f for f in facts if f.get("current")] + + # Mismatch: KG fact about same (subject, object) pair but different predicate. + for fact in current_facts: + if not _objects_match(fact.get("object"), claim_obj): + continue + kg_pred = (fact.get("predicate") or "").lower() + if kg_pred and kg_pred != claim_pred: + issues.append( + { + "type": "relationship_mismatch", + "detail": ( + f"Text says '{claim['span']}' but KG records " + f"{subject} {kg_pred} {fact.get('object')}" + ), + "entity": subject, + "claim": { + "predicate": claim_pred, + "object": claim_obj, + }, + "kg_fact": { + "predicate": kg_pred, + "object": fact.get("object"), + }, + } + ) + + # Stale fact: exact match on (subject, predicate, object) but KG + # closed the window in the past. + now_iso = datetime.now(timezone.utc).date().isoformat() + for fact in facts: + if fact.get("current"): + continue + kg_pred = (fact.get("predicate") or "").lower() + if kg_pred != claim_pred: + continue + if not _objects_match(fact.get("object"), claim_obj): + continue + valid_to = fact.get("valid_to") + if valid_to and str(valid_to) < now_iso: + issues.append( + { + "type": "stale_fact", + "detail": ( + f"Text says '{claim['span']}' but KG marks " + f"this fact closed on {valid_to}" + ), + "entity": subject, + "valid_to": valid_to, + } + ) return issues -def _edit_distance(s1, s2): - """Simple Levenshtein distance.""" +def _objects_match(kg_obj, claim_obj: str) -> bool: + if kg_obj is None or not claim_obj: + return False + return str(kg_obj).strip().lower() == claim_obj.strip().lower() + + +# ── Levenshtein helper (tight iterative version) ───────────────────── + + +def _edit_distance(s1: str, s2: str) -> int: + """Levenshtein distance. O(len(s1) * len(s2)) time, O(len(s2)) space.""" if len(s1) < len(s2): - return _edit_distance(s2, s1) - if len(s2) == 0: + s1, s2 = s2, s1 + if not s2: return len(s1) prev = list(range(len(s2) + 1)) for i, c1 in enumerate(s1): curr = [i + 1] for j, c2 in enumerate(s2): - curr.append(min( - prev[j + 1] + 1, - curr[j] + 1, - prev[j] + (0 if c1 == c2 else 1), - )) + curr.append( + min( + prev[j + 1] + 1, + curr[j] + 1, + prev[j] + (0 if c1 == c2 else 1), + ) + ) prev = curr return prev[-1] @@ -154,24 +306,30 @@ def _edit_distance(s1, s2): if __name__ == "__main__": import argparse import json + import sys - parser = argparse.ArgumentParser(description="Check text against known facts") - parser.add_argument("text", nargs="?", help="Text to check") - parser.add_argument("--palace", default=os.path.expanduser("~/.mempalace/palace")) - parser.add_argument("--stdin", action="store_true", help="Read from stdin") + parser = argparse.ArgumentParser( + description="Check text against known facts in the MemPalace palace.", + epilog="Exits 0 when no issues found, 1 when one or more issues detected.", + ) + parser.add_argument("text", nargs="?", help="Text to check (or use --stdin).") + parser.add_argument( + "--palace", + default=os.path.expanduser("~/.mempalace/palace"), + help="Path to the palace directory.", + ) + parser.add_argument("--stdin", action="store_true", help="Read text from stdin.") args = parser.parse_args() if args.stdin: - import sys - text = sys.stdin.read() + text_in = sys.stdin.read() elif args.text: - text = args.text + text_in = args.text else: - print("Provide text as argument or use --stdin") - exit(1) + parser.error("Provide text as argument or use --stdin.") - issues = check_text(text, palace_path=args.palace) - if issues: - print(json.dumps(issues, indent=2)) - else: - print("No contradictions found.") + found = check_text(text_in, palace_path=args.palace) + if found: + print(json.dumps(found, indent=2)) + sys.exit(1) + print("No contradictions found.") diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index 08226a9..31be8a4 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -35,7 +35,15 @@ from .version import __version__ import chromadb from .query_sanitizer import sanitize_query from .searcher import search_memories -from .palace_graph import traverse, find_tunnels, graph_stats, create_tunnel, list_tunnels, delete_tunnel, follow_tunnels +from .palace_graph import ( + traverse, + find_tunnels, + graph_stats, + create_tunnel, + list_tunnels, + delete_tunnel, + follow_tunnels, +) from .knowledge_graph import KnowledgeGraph @@ -519,7 +527,10 @@ def tool_create_tunnel( except ValueError as e: return {"error": str(e)} return create_tunnel( - source_wing, source_room, target_wing, target_room, + source_wing, + source_room, + target_wing, + target_room, label=label, source_drawer_id=source_drawer_id, target_drawer_id=target_drawer_id, @@ -1251,8 +1262,14 @@ TOOLS = { "target_wing": {"type": "string", "description": "Wing of the target"}, "target_room": {"type": "string", "description": "Room in the target wing"}, "label": {"type": "string", "description": "Description of the connection"}, - "source_drawer_id": {"type": "string", "description": "Optional specific drawer ID"}, - "target_drawer_id": {"type": "string", "description": "Optional specific drawer ID"}, + "source_drawer_id": { + "type": "string", + "description": "Optional specific drawer ID", + }, + "target_drawer_id": { + "type": "string", + "description": "Optional specific drawer ID", + }, }, "required": ["source_wing", "source_room", "target_wing", "target_room"], }, @@ -1263,7 +1280,10 @@ TOOLS = { "input_schema": { "type": "object", "properties": { - "wing": {"type": "string", "description": "Filter tunnels by wing (shows tunnels where wing is source or target)"}, + "wing": { + "type": "string", + "description": "Filter tunnels by wing (shows tunnels where wing is source or target)", + }, }, }, "handler": tool_list_tunnels, diff --git a/mempalace/miner.py b/mempalace/miner.py index 04bcf61..3d8e29e 100644 --- a/mempalace/miner.py +++ b/mempalace/miner.py @@ -379,17 +379,17 @@ def chunk_text(content: str, source_file: str) -> list: _ENTITY_REGISTRY_PATH = os.path.join(os.path.expanduser("~"), ".mempalace", "known_entities.json") -_ENTITY_REGISTRY_CACHE: dict = {"mtime": None, "names": frozenset()} +_ENTITY_REGISTRY_CACHE: dict = {"mtime": None, "names": frozenset(), "raw": {}} _ENTITY_EXTRACT_WINDOW = 5000 # chars of content scanned for capitalized words _ENTITY_METADATA_LIMIT = 25 # max entities packed into the metadata field -def _load_known_entities() -> frozenset: - """Load (and cache) the user's known-entity registry by mtime. - - Reads ``~/.mempalace/known_entities.json``. The registry is shaped as - ``{"category": ["Name1", "Name2", ...], ...}``. Cached across calls - in the same process; invalidated when the file's mtime changes. +def _refresh_known_entities_cache() -> None: + """Reload ``~/.mempalace/known_entities.json`` into the module cache if + its mtime changed since the last read. Shared by ``_load_known_entities`` + (flat set) and ``_load_known_entities_raw`` (category dict), so callers + can pick whichever shape they need without duplicating the mtime-gated + disk read. """ try: mtime = os.path.getmtime(_ENTITY_REGISTRY_PATH) @@ -397,28 +397,56 @@ def _load_known_entities() -> frozenset: if _ENTITY_REGISTRY_CACHE["mtime"] is not None: _ENTITY_REGISTRY_CACHE["mtime"] = None _ENTITY_REGISTRY_CACHE["names"] = frozenset() - return _ENTITY_REGISTRY_CACHE["names"] + _ENTITY_REGISTRY_CACHE["raw"] = {} + return if _ENTITY_REGISTRY_CACHE["mtime"] == mtime: - return _ENTITY_REGISTRY_CACHE["names"] + return names: set = set() + raw: dict = {} try: import json with open(_ENTITY_REGISTRY_PATH, "r", encoding="utf-8") as f: data = json.load(f) - for cat in data.values(): - if isinstance(cat, list): - names.update(str(n) for n in cat if n) + if isinstance(data, dict): + raw = data + for cat in data.values(): + if isinstance(cat, list): + names.update(str(n) for n in cat if n) + elif isinstance(cat, dict): + names.update(str(k) for k in cat.keys() if k) except Exception: names = set() + raw = {} _ENTITY_REGISTRY_CACHE["mtime"] = mtime _ENTITY_REGISTRY_CACHE["names"] = frozenset(names) + _ENTITY_REGISTRY_CACHE["raw"] = raw + + +def _load_known_entities() -> frozenset: + """Flat set of every known entity name (across all categories). + + Cached by mtime; invalidated when the registry file changes. + """ + _refresh_known_entities_cache() return _ENTITY_REGISTRY_CACHE["names"] +def _load_known_entities_raw() -> dict: + """Full category-dict view of the registry, shape + ``{"category": ["Name1", ...], ...}``. Cached by mtime. + + Consumed by modules (e.g., fact_checker) that need to reason about + categories rather than a flat name set. Never returns a mutable + reference to the cache — callers get a shallow copy. + """ + _refresh_known_entities_cache() + return dict(_ENTITY_REGISTRY_CACHE["raw"]) + + def _extract_entities_for_metadata(content: str) -> str: """Extract entity names from content for metadata tagging. diff --git a/tests/test_fact_checker.py b/tests/test_fact_checker.py new file mode 100644 index 0000000..5b34a40 --- /dev/null +++ b/tests/test_fact_checker.py @@ -0,0 +1,288 @@ +""" +test_fact_checker.py — Regression + integration tests for fact_checker. + +Covers every detection path + the three bugs the original PR silently +hid behind ``except Exception: pass``: + + * ``kg.query()`` doesn't exist — code must use ``query_entity``. + * ``KnowledgeGraph(palace_path=...)`` is not a valid kwarg — code + must pass ``db_path``. + * O(n²) edit-distance over the full registry — must filter to names + actually mentioned in the text. + +Also pins the three feature contracts: + * similar_name — "Mila" vs "Milla" in a registry with both. + * relationship_mismatch — "Bob is Alice's brother" vs KG "husband". + * stale_fact — claim matches a triple whose valid_to is in the past. +""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from mempalace.fact_checker import ( + _check_entity_confusion, + _edit_distance, + _extract_claims, + _flatten_names, + check_text, +) +from mempalace.knowledge_graph import KnowledgeGraph + + +# ── claim extraction ───────────────────────────────────────────────── + + +class TestExtractClaims: + def test_parses_x_is_ys_z(self): + claims = _extract_claims("Bob is Alice's brother") + assert len(claims) == 1 + assert claims[0] == { + "subject": "Bob", + "predicate": "brother", + "object": "Alice", + "span": "Bob is Alice's brother", + } + + def test_parses_xs_z_is_y(self): + claims = _extract_claims("Alice's brother is Bob") + assert len(claims) == 1 + assert claims[0]["subject"] == "Bob" + assert claims[0]["predicate"] == "brother" + assert claims[0]["object"] == "Alice" + + def test_ignores_sentences_without_possessive_role(self): + assert _extract_claims("Bob drove to the store today") == [] + assert _extract_claims("Just some prose without relationships") == [] + + def test_multiple_claims_in_one_text(self): + claims = _extract_claims("Bob is Alice's brother. Carol is Dave's sister.") + subjects = {c["subject"] for c in claims} + assert subjects == {"Bob", "Carol"} + + +# ── entity confusion ───────────────────────────────────────────────── + + +class TestEntityConfusion: + def test_flags_near_name_when_only_one_mentioned(self): + registry = {"people": ["Milla", "Mila"]} + issues = _check_entity_confusion("I spoke with Mila today.", registry) + # "Mila" mentioned, "Milla" not — registry has both at edit-distance 1, + # flag the possible confusion. + assert len(issues) == 1 + assert issues[0]["type"] == "similar_name" + assert set(issues[0]["names"]) == {"Mila", "Milla"} + assert issues[0]["distance"] == 1 + + def test_no_false_positive_when_both_names_mentioned(self): + """Regression: a text discussing both Mila and Milla is fine — + the user clearly knows they're different. Don't nag.""" + registry = {"people": ["Milla", "Mila"]} + issues = _check_entity_confusion("Mila and Milla met for lunch.", registry) + assert issues == [] + + def test_no_issues_when_registry_empty(self): + assert _check_entity_confusion("Bob said hi", {}) == [] + assert _check_entity_confusion("Bob said hi", {"people": []}) == [] + + def test_no_issues_when_no_mentioned_names(self): + registry = {"people": ["Zelda", "Link", "Sheik"]} + assert _check_entity_confusion("nothing relevant here", registry) == [] + + def test_registry_dict_shape_is_supported(self): + # Some registries store {"people": {"Alice": {...meta}}}; we still + # need to surface the keys as candidate names. + registry = {"people": {"Milla": {"role": "creator"}, "Mila": {}}} + issues = _check_entity_confusion("I messaged Mila yesterday", registry) + assert any("Milla" in (i["names"] or []) for i in issues) + + +class TestEditDistance: + def test_basic_distances(self): + assert _edit_distance("kitten", "sitting") == 3 + assert _edit_distance("mila", "milla") == 1 + assert _edit_distance("abc", "abc") == 0 + + def test_empty_strings(self): + assert _edit_distance("", "") == 0 + assert _edit_distance("abc", "") == 3 + assert _edit_distance("", "abc") == 3 + + def test_performance_bounded_by_mentioned_names(self): + """Regression: an earlier implementation did O(n²) pairwise + edit-distance over every registry entry on every check_text call. + With 100 names and zero mentions, the call must return in a blink + because no edit-distance comparison should even start.""" + import time + + # 500 random names, none of which appear in the text. + registry = {"people": [f"Zelda{i:03d}" for i in range(500)]} + text = "completely irrelevant prose with no registered names at all" + + start = time.perf_counter() + issues = _check_entity_confusion(text, registry) + elapsed = time.perf_counter() - start + + assert issues == [] + # Even an unoptimized implementation should beat this by orders + # of magnitude once we've filtered to mentioned names (which is + # 0 here) — if it's still doing O(n²), we'll blow past. + assert elapsed < 0.2, f"entity confusion took {elapsed:.3f}s on empty mentions" + + +# ── _flatten_names helper ──────────────────────────────────────────── + + +class TestFlattenNames: + def test_handles_list_categories(self): + assert _flatten_names({"people": ["Ada", "Bob"]}) == {"Ada", "Bob"} + + def test_handles_dict_categories(self): + assert _flatten_names({"people": {"Ada": {}, "Bob": {}}}) == {"Ada", "Bob"} + + def test_skips_falsy_entries(self): + assert _flatten_names({"people": ["Ada", "", None, "Bob"]}) == {"Ada", "Bob"} + + +# ── KG integration (uses a real tmp SQLite palace) ─────────────────── + + +@pytest.fixture +def palace_with_kg(tmp_path): + """Palace directory with a real KG pre-seeded with a few triples. + + The KG file lives at ``/knowledge_graph.sqlite3`` — same + convention used by the MCP server. Fact-checker must find it via + that path, not via a bogus ``palace_path`` kwarg. + """ + palace = tmp_path / "palace" + palace.mkdir() + db = str(palace / "knowledge_graph.sqlite3") + kg = KnowledgeGraph(db_path=db) + yield palace, kg + + +class TestKGContradictions: + def test_kg_init_uses_db_path_not_palace_path_kwarg(self): + """Regression: the original code passed ``palace_path=`` to a + constructor whose only kwarg is ``db_path``. That raised + TypeError — silently swallowed — and the KG path became dead + code. This test pins the correct call signature.""" + # Simply construct via the correct signature; raising means the + # KG constructor has changed in a way that fact_checker must too. + kg = KnowledgeGraph(db_path=":memory:") + # query_entity must exist (this is the method fact_checker calls). + assert callable(getattr(kg, "query_entity", None)) + # The API that fact_checker used to call does NOT exist. + assert not hasattr(kg, "query") + + def test_relationship_mismatch_detected(self, palace_with_kg): + """The feature's headline example: text says brother, KG says husband.""" + palace, kg = palace_with_kg + kg.add_triple("Bob", "husband_of", "Alice", valid_from="2020-01-01") + + issues = check_text("Bob is Alice's husband_of", str(palace)) + # Exact-predicate + same object → no mismatch. + assert all(i["type"] != "relationship_mismatch" for i in issues) + + issues = check_text("Bob is Alice's brother", str(palace)) + mismatches = [i for i in issues if i["type"] == "relationship_mismatch"] + assert mismatches, "should flag text/KG mismatch for same (subject, object)" + m = mismatches[0] + assert m["entity"] == "Bob" + assert m["claim"]["predicate"] == "brother" + assert m["kg_fact"]["predicate"] == "husband_of" + + def test_no_false_positive_when_kg_has_no_facts_about_subject(self, palace_with_kg): + palace, _ = palace_with_kg + # KG is empty → no mismatch should fire. + assert check_text("Bob is Alice's brother", str(palace)) == [] + + def test_stale_fact_detected(self, palace_with_kg): + palace, kg = palace_with_kg + # An old relationship that was superseded in 2023. Using a + # possessive-shape claim so the narrow claim-extraction regex + # actually reaches the stale-fact branch. + kg.add_triple( + "Bob", + "brother", + "Alice", + valid_from="2010-01-01", + valid_to="2023-06-01", + ) + issues = check_text("Bob is Alice's brother", str(palace)) + stale = [i for i in issues if i["type"] == "stale_fact"] + assert stale, "should flag closed-window fact as stale" + assert stale[0]["entity"] == "Bob" + assert stale[0]["valid_to"].startswith("2023") + + def test_current_fact_same_triple_is_not_flagged(self, palace_with_kg): + palace, kg = palace_with_kg + kg.add_triple("Bob", "brother", "Alice", valid_from="2010-01-01") + issues = check_text("Bob is Alice's brother", str(palace)) + assert issues == [] + + def test_missing_palace_does_not_crash(self, tmp_path): + """Brand-new palace (no KG file yet) — check_text must return [] + rather than raising or hanging.""" + nonexistent = str(tmp_path / "never_created") + assert check_text("Bob is Alice's brother", nonexistent) == [] + + +# ── end-to-end check_text contract ─────────────────────────────────── + + +class TestCheckTextContract: + def test_empty_text_returns_empty_list(self, tmp_path): + assert check_text("", str(tmp_path / "palace")) == [] + + def test_registry_confusion_path_isolated_from_kg(self, tmp_path, monkeypatch): + """If the registry file is present but the KG is missing, the + similar-name path must still fire. Prior implementations had + such entangled state that one failure killed both paths.""" + # Bypass the real registry by pointing cache at a temp file. + registry = tmp_path / "known_entities.json" + registry.write_text(json.dumps({"people": ["Milla", "Mila"]})) + from mempalace import miner + + monkeypatch.setattr(miner, "_ENTITY_REGISTRY_PATH", str(registry)) + miner._ENTITY_REGISTRY_CACHE.update({"mtime": None, "names": frozenset(), "raw": {}}) + + issues = check_text("Chatted with Mila.", str(tmp_path / "nonexistent_palace")) + assert any(i["type"] == "similar_name" for i in issues) + + +# ── CLI ────────────────────────────────────────────────────────────── + + +class TestCLI: + def test_exits_nonzero_when_issues_found(self, tmp_path, monkeypatch, capsys): + """The CLI exit code is how shell scripts / hooks know to act — + pin it explicitly.""" + registry = tmp_path / "known_entities.json" + registry.write_text(json.dumps({"people": ["Milla", "Mila"]})) + from mempalace import fact_checker, miner + + monkeypatch.setattr(miner, "_ENTITY_REGISTRY_PATH", str(registry)) + miner._ENTITY_REGISTRY_CACHE.update({"mtime": None, "names": frozenset(), "raw": {}}) + + # Simulate argv: "Mila said hi" + monkeypatch.setattr( + "sys.argv", + ["fact_checker", "Mila said hi", "--palace", str(tmp_path / "palace")], + ) + with pytest.raises(SystemExit) as excinfo: + # Re-exec the __main__ block via runpy. + import runpy + + runpy.run_module("mempalace.fact_checker", run_name="__main__") + # Issues found → exit code 1. + assert excinfo.value.code == 1 + out = capsys.readouterr().out + assert "similar_name" in out + # Silence unused import warning. + _ = (MagicMock, patch, fact_checker) From 7192552624a089e73162a94e1c68d6c7c0da3d93 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Mon, 13 Apr 2026 18:55:26 -0300 Subject: [PATCH 10/12] test: make diary state path assertion platform-neutral MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Windows CI job failed on: assert '/.mempalace/state/' in str(state_path) because Windows uses ``\`` as the path separator, so the substring never matches. The behavior under test (state file lives outside the diary dir, under ``~/.mempalace/state/``) is already correct on both platforms — only the assertion was Unix-only. Switch to ``state_path.parent`` comparisons that work on any OS. --- tests/test_closets.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_closets.py b/tests/test_closets.py index 59321c5..458f767 100644 --- a/tests/test_closets.py +++ b/tests/test_closets.py @@ -586,7 +586,10 @@ class TestDiaryIngest: # State file does exist under ~/.mempalace/state/. state_path = _state_file_for(str(palace_dir), diary_dir.resolve()) assert state_path.exists() - assert "/.mempalace/state/" in str(state_path) + # Platform-neutral path check: compare parents rather than a hardcoded + # separator string that would fail on Windows (``\.mempalace\state\``). + assert state_path.parent.name == "state" + assert state_path.parent.parent.name == ".mempalace" def test_wing_prefixed_drawer_id_prevents_cross_diary_collision(self, tmp_path): # Regression: the original implementation used From e052074624e95009a4080240c306c7ab796199a8 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Mon, 13 Apr 2026 19:02:51 -0300 Subject: [PATCH 11/12] test: serialize mine_lock concurrency test with multiprocessing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The macOS CI job failed ``test_lock_blocks_concurrent_access`` because ``fcntl.flock`` on BSD/macOS is per-*process*, not per-FD: two threads in the same process both acquire even when they open their own file descriptors. The test passed on Linux (per-FD flock) and Windows (per-FD ``msvcrt.locking``) but was never actually exercising the lock's real contract. ``mine_lock`` is designed to serialize multi-*agent* access — i.e., separate processes, not threads. Switch the test to ``multiprocessing.get_context('spawn')`` with a module-level worker (so the spawn pickles cleanly) so it: 1. reflects the actual use case (one lock per mining process); 2. passes on all three OSes without flock-semantics branching; 3. catches real regressions (a broken lock would now let both processes through, exactly what we care about). Hold time bumped to 0.3s and the "wait until p1 acquires" delay to 0.2s to tolerate spawn's higher startup latency on macOS/Windows. --- tests/test_closets.py | 54 ++++++++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/tests/test_closets.py b/tests/test_closets.py index 458f767..fba4cc8 100644 --- a/tests/test_closets.py +++ b/tests/test_closets.py @@ -24,6 +24,7 @@ Coverage map: """ import json +import multiprocessing import os import tempfile import threading @@ -63,6 +64,18 @@ from mempalace.searcher import ( # ── mine_lock ──────────────────────────────────────────────────────────── +def _lock_worker(target: str, name: str, hold_seconds: float, queue) -> None: + """Module-level worker for multiprocessing spawn; must be pickle-able.""" + from mempalace.palace import mine_lock as _mine_lock + + start = time.time() + with _mine_lock(target): + elapsed = time.time() - start + queue.put((name, elapsed)) + if hold_seconds > 0: + time.sleep(hold_seconds) + + class TestMineLock: def test_lock_acquires_and_releases(self, tmp_path): target = str(tmp_path / "lock_target.txt") @@ -76,28 +89,37 @@ class TestMineLock: assert time.time() - start < 1.0 def test_lock_blocks_concurrent_access(self, tmp_path): + """The lock's contract is inter-*process* (multi-agent), not + inter-thread. Use multiprocessing so the test reflects the real + use case and is portable: on macOS/BSD, ``fcntl.flock`` is + per-process, so two threads in one process would both acquire — + a threading-based test would flake there even when the lock is + behaving correctly for its intended users.""" target = str(tmp_path / "concurrent_lock.txt") + # Use multiprocessing so each worker has its own process. + # Use "spawn" to stay consistent across platforms (macOS defaults + # to spawn on 3.8+; Linux defaults to fork). Both work here. + ctx = multiprocessing.get_context("spawn") + queue = ctx.Queue() + + p1 = ctx.Process(target=_lock_worker, args=(target, "a", 0.3, queue)) + p2 = ctx.Process(target=_lock_worker, args=(target, "b", 0.0, queue)) + p1.start() + time.sleep(0.2) # ensure p1 acquires first + p2.start() + p1.join(timeout=10) + p2.join(timeout=10) + results = [] + while not queue.empty(): + results.append(queue.get()) + assert len(results) == 2, f"both workers should report, got {results}" - def worker(name): - start = time.time() - with mine_lock(target): - results.append((name, time.time() - start)) - time.sleep(0.2) - - t1 = threading.Thread(target=worker, args=("a",)) - t2 = threading.Thread(target=worker, args=("b",)) - t1.start() - time.sleep(0.05) # ensure t1 acquires first - t2.start() - t1.join() - t2.join() - - # The second worker must have waited at least most of t1's hold time. + # The second worker must have waited until p1 released the lock. wait_times = sorted(r[1] for r in results) assert ( wait_times[1] > 0.1 - ), f"second thread should block on mine_lock, waited only {wait_times[1]:.3f}s" + ), f"second process should block on mine_lock, waited only {wait_times[1]:.3f}s" # ── build_closet_lines ───────────────────────────────────────────────── From 1dc20e307b8294f6e0917d2abe433929287b7c21 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Mon, 13 Apr 2026 19:08:57 -0300 Subject: [PATCH 12/12] test: verify mine_lock via disjoint critical-section intervals MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous revision used multiprocessing but still relied on timing ("second process waited at least N seconds") which flakes on CI where spawn overhead eats into the hold window. Linux CI observed the second process report a 0.088s wait — below the 0.1s threshold — even though the lock behavior was correct; spawn was just slow enough that the first process had nearly finished holding when the second got past its own spawn. Switch to effect-based verification: each worker logs its [enter_time, exit_time] inside the critical section, and the test asserts the two intervals are disjoint after sorting. A broken lock would produce overlapping intervals regardless of spawn latency; a working lock cannot. Also removed the mp.Queue since we no longer pass timing data back. --- tests/test_closets.py | 78 +++++++++++++++++++++++++++---------------- 1 file changed, 50 insertions(+), 28 deletions(-) diff --git a/tests/test_closets.py b/tests/test_closets.py index fba4cc8..976086d 100644 --- a/tests/test_closets.py +++ b/tests/test_closets.py @@ -64,16 +64,22 @@ from mempalace.searcher import ( # ── mine_lock ──────────────────────────────────────────────────────────── -def _lock_worker(target: str, name: str, hold_seconds: float, queue) -> None: - """Module-level worker for multiprocessing spawn; must be pickle-able.""" +def _lock_worker(target: str, name: str, hold_seconds: float, log_path: str) -> None: + """Worker for multiprocessing-spawn concurrency test. Writes its + critical-section enter/exit timestamps to ``log_path`` so the test + can verify the sections did not overlap in time.""" + import time as _time + from mempalace.palace import mine_lock as _mine_lock - start = time.time() with _mine_lock(target): - elapsed = time.time() - start - queue.put((name, elapsed)) - if hold_seconds > 0: - time.sleep(hold_seconds) + t_enter = _time.time() + _time.sleep(hold_seconds) + t_exit = _time.time() + # Append atomically so concurrent writers don't stomp each other. + with open(log_path, "a") as f: + f.write(f"{name} {t_enter} {t_exit}\n") + f.flush() class TestMineLock: @@ -91,35 +97,51 @@ class TestMineLock: def test_lock_blocks_concurrent_access(self, tmp_path): """The lock's contract is inter-*process* (multi-agent), not inter-thread. Use multiprocessing so the test reflects the real - use case and is portable: on macOS/BSD, ``fcntl.flock`` is - per-process, so two threads in one process would both acquire — - a threading-based test would flake there even when the lock is - behaving correctly for its intended users.""" + use case and is portable: on macOS/BSD ``fcntl.flock`` is + per-process, so two threads would both acquire — a thread-based + test would flake there even when the lock is correct. + + Verify mutual exclusion by the effect the critical section + actually has — each worker records its enter/exit timestamps + under the lock, and the test asserts the two intervals do not + overlap. This is robust to spawn-overhead timing, unlike + "second worker waited at least N seconds" which flakes when CI + spawn latency eats into the hold window. + """ target = str(tmp_path / "concurrent_lock.txt") - # Use multiprocessing so each worker has its own process. - # Use "spawn" to stay consistent across platforms (macOS defaults - # to spawn on 3.8+; Linux defaults to fork). Both work here. + log_path = str(tmp_path / "critical_section.log") + # Spawn so the same code path runs on every OS (macOS 3.8+ and + # Windows already default to spawn; Linux is fork by default). ctx = multiprocessing.get_context("spawn") - queue = ctx.Queue() - p1 = ctx.Process(target=_lock_worker, args=(target, "a", 0.3, queue)) - p2 = ctx.Process(target=_lock_worker, args=(target, "b", 0.0, queue)) + # Each worker holds the lock for HOLD seconds. With real mutual + # exclusion, the two [enter, exit] intervals must be disjoint. + HOLD = 0.3 + p1 = ctx.Process(target=_lock_worker, args=(target, "a", HOLD, log_path)) + p2 = ctx.Process(target=_lock_worker, args=(target, "b", HOLD, log_path)) p1.start() - time.sleep(0.2) # ensure p1 acquires first p2.start() - p1.join(timeout=10) - p2.join(timeout=10) + p1.join(timeout=30) + p2.join(timeout=30) - results = [] - while not queue.empty(): - results.append(queue.get()) - assert len(results) == 2, f"both workers should report, got {results}" + assert p1.exitcode == 0, f"p1 exited non-zero: {p1.exitcode}" + assert p2.exitcode == 0, f"p2 exited non-zero: {p2.exitcode}" - # The second worker must have waited until p1 released the lock. - wait_times = sorted(r[1] for r in results) + # Parse the log: " ". + intervals = [] + with open(log_path) as f: + for line in f: + parts = line.strip().split() + if len(parts) == 3: + intervals.append((parts[0], float(parts[1]), float(parts[2]))) + assert len(intervals) == 2, f"expected two critical sections, got {intervals}" + + # Sort by entry time and verify the second entry is after the first exit. + intervals.sort(key=lambda iv: iv[1]) + (_, enter_a, exit_a), (_, enter_b, exit_b) = intervals assert ( - wait_times[1] > 0.1 - ), f"second process should block on mine_lock, waited only {wait_times[1]:.3f}s" + enter_a < exit_a <= enter_b < exit_b + ), f"critical sections overlapped — lock failed to serialize: {intervals}" # ── build_closet_lines ─────────────────────────────────────────────────