diff --git a/mempalace/closet_llm.py b/mempalace/closet_llm.py new file mode 100644 index 0000000..c00b735 --- /dev/null +++ b/mempalace/closet_llm.py @@ -0,0 +1,351 @@ +""" +closet_llm.py — Generate closets via a user-configured LLM for richer indexing. + +The regex-based closet extraction catches action verbs, headers, and proper +nouns — but misses implicit topics, foreign-language content, and contextual +references. An LLM reads everything and produces better closets. + +This module is **OPTIONAL and opt-in**. Regex closets are always created by +the miner; this path regenerates them afterward using whatever LLM the user +chooses. Core memory operations remain API-free by design (see CLAUDE.md, +"Local-first, zero API"). + +## Bring-your-own-LLM configuration + +The endpoint is any OpenAI-compatible Chat Completions URL: + + LLM_ENDPOINT=http://localhost:11434/v1 # Ollama + LLM_ENDPOINT=http://localhost:8000/v1 # vLLM, llama.cpp + LLM_ENDPOINT=https://api.openai.com/v1 + LLM_ENDPOINT=https://openrouter.ai/api/v1 + LLM_ENDPOINT=https://api.anthropic.com/v1 # when proxied through a compat layer + +Set: + LLM_ENDPOINT — base URL (required) + LLM_KEY — bearer token (optional; local inference usually doesn't need it) + LLM_MODEL — model name (required), e.g. "gpt-4o-mini", "llama3:8b", "qwen2.5:7b" + +Or pass flags on the CLI (flags win over env): + + python -m mempalace.closet_llm \\ + --palace ~/.mempalace/palace \\ + --endpoint http://localhost:11434/v1 \\ + --model llama3:8b + +No vendor lock-in. No hidden dependency on any specific provider. Zero deps +added to pyproject — uses stdlib urllib. +""" + +import json +import os +import re +import time +import urllib.request +import urllib.error +from datetime import datetime +from typing import Optional + +from .palace import ( + NORMALIZE_VERSION, + get_closets_collection, + get_collection, + mine_lock, + purge_file_closets, + upsert_closet_lines, +) + +MAX_CONTENT_CHARS = 30000 +MAX_OUTPUT_TOKENS = 1500 +HTTP_TIMEOUT_S = 60 + +PROMPT_TEMPLATE = """You are reading content filed in a memory palace. Generate a +topic-dense index that will be used to find this content later when someone searches. + +Source: {source_file} +Wing: {wing} | Room: {room} + +CONTENT: +{content} + +--- + +Output a JSON object with EXACTLY these fields: + +{{ + "topics": ["distinctive_word_or_phrase_1", "topic_2", ...], + "quotes": ["[Speaker] verbatim quote", ...], + "summary": "2-3 sentences describing what this content is about." +}} + +RULES: +- Topics: 8-15 entries. Include proper nouns (names, places, projects), + distinctive technical terms, and key concepts. NOT generic words like + "conversation" or "discussion". +- Quotes: 2-5 entries. EXACT verbatim from the content, not paraphrased. + Attribute with [Speaker] prefix if speaker is identifiable. +- Summary: mention WHO, WHAT, and WHY. No filler. +- Write in the same language as the content. +- Output valid JSON only. No code fences. No commentary. +""" + + +class LLMConfig: + """Resolved LLM connection config. CLI flags > env vars.""" + + def __init__( + self, + endpoint: Optional[str] = None, + key: Optional[str] = None, + model: Optional[str] = None, + ): + self.endpoint = (endpoint or os.environ.get("LLM_ENDPOINT", "")).rstrip("/") + self.key = key or os.environ.get("LLM_KEY", "") + self.model = model or os.environ.get("LLM_MODEL", "") + + def missing(self) -> list: + missing = [] + if not self.endpoint: + missing.append("LLM_ENDPOINT (or --endpoint)") + if not self.model: + missing.append("LLM_MODEL (or --model)") + # key is optional — local inference servers (Ollama, vLLM) often don't require one + return missing + + +def _call_llm(cfg: LLMConfig, source_file: str, wing: str, room: str, content: str): + """Single LLM call via OpenAI-compatible /chat/completions. + + Returns (parsed_json_dict_or_None, usage_dict_or_None). + """ + try: + from mempalace.i18n import t + + lang_instruction = t("aaak.instruction") + except Exception: + lang_instruction = "" + + prompt = PROMPT_TEMPLATE.format( + source_file=source_file[:100], + wing=wing, + room=room, + content=content[:MAX_CONTENT_CHARS], + ) + if lang_instruction and "english" not in lang_instruction.lower(): + prompt += f"\n\nLanguage instruction: {lang_instruction}" + + body = json.dumps( + { + "model": cfg.model, + "max_tokens": MAX_OUTPUT_TOKENS, + "messages": [{"role": "user", "content": prompt}], + } + ).encode("utf-8") + + headers = {"Content-Type": "application/json"} + if cfg.key: + headers["Authorization"] = f"Bearer {cfg.key}" + + url = f"{cfg.endpoint}/chat/completions" + + for attempt in range(3): + try: + req = urllib.request.Request(url, data=body, headers=headers, method="POST") + with urllib.request.urlopen(req, timeout=HTTP_TIMEOUT_S) as resp: + raw = resp.read().decode("utf-8") + payload = json.loads(raw) + + text = payload["choices"][0]["message"]["content"].strip() + text = re.sub(r"^```(?:json)?\s*", "", text) + text = re.sub(r"\s*```$", "", text) + parsed = json.loads(text) + return parsed, payload.get("usage") + except json.JSONDecodeError: + return None, None + except urllib.error.HTTPError as e: + # 429 / 503 = retry with backoff + if e.code in (429, 503) and attempt < 2: + time.sleep(2**attempt) + continue + return None, None + except Exception as e: + if "rate" in str(e).lower() and attempt < 2: + time.sleep(2**attempt) + continue + return None, None + return None, None + + +def _parsed_to_closet_lines(parsed, drawer_ids, entities_str): + """Convert LLM's JSON output to closet pointer lines.""" + lines = [] + drawer_ref = ",".join(drawer_ids[:3]) + + for topic in parsed.get("topics", [])[:15]: + lines.append(f"{topic}|{entities_str}|→{drawer_ref}") + for quote in parsed.get("quotes", [])[:5]: + lines.append(f"{quote}|{entities_str}|→{drawer_ref}") + summary = parsed.get("summary", "") + if summary: + lines.append(f"{summary[:200]}|{entities_str}|→{drawer_ref}") + + return lines + + +def regenerate_closets( + palace_path, + wing=None, + sample=0, + dry_run=False, + cfg: Optional[LLMConfig] = None, +): + """Regenerate closets using a configured LLM for richer topic extraction. + + Reads existing drawers, sends content to the configured endpoint, + replaces regex closets with LLM-generated ones. Regex closets remain + as the fallback whenever the call fails. + """ + if cfg is None: + cfg = LLMConfig() + missing = cfg.missing() + if missing: + print("Error: missing configuration: " + ", ".join(missing)) + print("Set env vars LLM_ENDPOINT / LLM_MODEL (and optionally LLM_KEY),") + print("or pass --endpoint / --model / --key on the CLI.") + return {"error": "missing-config", "missing": missing} + + drawers_col = get_collection(palace_path, create=False) + closets_col = get_closets_collection(palace_path) + + total = drawers_col.count() + if total == 0: + print("No drawers in palace.") + return {"processed": 0} + + all_data = drawers_col.get(limit=total, include=["documents", "metadatas"]) + by_source = {} + for doc_id, doc, meta in zip(all_data["ids"], all_data["documents"], all_data["metadatas"]): + source = meta.get("source_file", "unknown") + w = meta.get("wing", "") + if wing and w != wing: + continue + if source not in by_source: + by_source[source] = {"drawer_ids": [], "content": [], "meta": meta} + by_source[source]["drawer_ids"].append(doc_id) + by_source[source]["content"].append(doc) + + sources = list(by_source.keys()) + if sample > 0: + sources = sources[:sample] + + print( + f"Regenerating closets for {len(sources)} source files via {cfg.endpoint} ({cfg.model})..." + ) + if dry_run: + print("DRY RUN — no changes will be written") + + processed = 0 + failed = 0 + total_input = 0 + total_output = 0 + + for i, source in enumerate(sources, 1): + data = by_source[source] + content = "\n\n".join(data["content"]) + meta = data["meta"] + w = meta.get("wing", "") + r = meta.get("room", "") + entities = meta.get("entities", "") + + if dry_run: + print(f" [{i}/{len(sources)}] {os.path.basename(source)} ({len(content)} chars)") + continue + + parsed, usage = _call_llm(cfg, source, w, r, content) + if not parsed: + failed += 1 + print(f" [{i}/{len(sources)}] ✗ {os.path.basename(source)} — LLM failed") + continue + + if usage: + total_input += usage.get("prompt_tokens", 0) + total_output += usage.get("completion_tokens", 0) + + lines = _parsed_to_closet_lines(parsed, data["drawer_ids"], entities) + # Use os.path.basename so Windows-style paths survive unchanged; + # the naive split('/') would leave a bare path component on Windows + # and collide across different files under different drives. + closet_id_base = f"closet_{w}_{r}_{os.path.basename(source)[:30]}" + + # Serialize with concurrent mine operations on the same source — + # otherwise a regex closet rebuild mid-regenerate races with our + # purge+upsert cycle and leaves mixed regex/LLM lines. + with mine_lock(source): + purge_file_closets(closets_col, source) + upsert_closet_lines( + closets_col, + closet_id_base, + lines, + { + "wing": w, + "room": r, + "source_file": source, + "generated_by": f"llm:{cfg.model}", + "filed_at": datetime.now().isoformat(), + "entities": entities, + # Stamp so the miner's stale-drawer gate doesn't treat + # LLM closets as leftovers and rebuild over them next run. + "normalize_version": NORMALIZE_VERSION, + }, + ) + + processed += 1 + n_topics = len(parsed.get("topics", [])) + print(f" [{i}/{len(sources)}] ✓ {os.path.basename(source)} — {n_topics} topics") + + print(f"\nDone. {processed} regenerated, {failed} failed.") + if total_input or total_output: + print(f"Tokens: {total_input:,} in + {total_output:,} out (cost depends on provider)") + + return { + "processed": processed, + "failed": failed, + "input_tokens": total_input, + "output_tokens": total_output, + } + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Regenerate closets via a user-configured LLM (OpenAI-compatible API)" + ) + parser.add_argument( + "--palace", + default=os.path.expanduser("~/.mempalace/palace"), + help="Path to the palace", + ) + parser.add_argument("--wing", default=None, help="Limit to one wing") + parser.add_argument("--sample", type=int, default=0, help="Only process first N source files") + parser.add_argument("--dry-run", action="store_true", help="List work without calling the LLM") + parser.add_argument( + "--endpoint", + default=None, + help="LLM base URL (overrides $LLM_ENDPOINT), e.g. http://localhost:11434/v1", + ) + parser.add_argument( + "--key", + default=None, + help="LLM bearer token (overrides $LLM_KEY). Optional for local inference.", + ) + parser.add_argument( + "--model", + default=None, + help='LLM model name (overrides $LLM_MODEL), e.g. "gpt-4o-mini" or "llama3:8b"', + ) + args = parser.parse_args() + + cfg = LLMConfig(endpoint=args.endpoint, key=args.key, model=args.model) + regenerate_closets( + args.palace, wing=args.wing, sample=args.sample, dry_run=args.dry_run, cfg=cfg + ) diff --git a/mempalace/diary_ingest.py b/mempalace/diary_ingest.py new file mode 100644 index 0000000..503f0c0 --- /dev/null +++ b/mempalace/diary_ingest.py @@ -0,0 +1,209 @@ +""" +diary_ingest.py — Ingest daily summary files into the palace. + +Architecture: +- ONE drawer per (wing, day) — full verbatim content, upserted as the day grows. +- Closets pack topics up to CLOSET_CHAR_LIMIT, never split mid-topic. +- A re-ingest fully purges the prior day's closets before rebuilding so a + shorter day never leaves orphans behind. +- Only new entries are processed by default (tracks entry count in a state + file under ``~/.mempalace/state/`` — never inside the user's diary dir). +- Per-file ``mine_lock`` so concurrent ingest from two terminals can't race. +- Entities extracted and stamped on metadata for filterable search. + +Usage: + python -m mempalace.diary_ingest --dir ~/daily_summaries --palace ~/.mempalace/palace + python -m mempalace.diary_ingest --dir ~/daily_summaries --palace ~/.mempalace/palace --force +""" + +import hashlib +import json +import os +import re +from datetime import datetime, timezone +from pathlib import Path + +from .miner import _extract_entities_for_metadata +from .palace import ( + build_closet_lines, + get_closets_collection, + get_collection, + mine_lock, + purge_file_closets, + upsert_closet_lines, +) + +DIARY_ENTRY_RE = re.compile(r"^## .+", re.MULTILINE) + + +def _state_file_for(palace_path: str, diary_dir: Path) -> Path: + """Return the per-(palace, diary-dir) state-file path under ~/.mempalace/state. + + Keyed by sha256 of (palace_path, diary_dir) so multiple diary folders + pointing at the same palace each get an independent state file. The + state file is *never* written inside the user's diary directory. + """ + state_root = Path(os.path.expanduser("~")) / ".mempalace" / "state" + state_root.mkdir(parents=True, exist_ok=True) + key = hashlib.sha256(f"{palace_path}|{diary_dir}".encode()).hexdigest()[:24] + return state_root / f"diary_ingest_{key}.json" + + +def _split_entries(text): + """Split diary text into (header, body) pairs per ## entry.""" + parts = DIARY_ENTRY_RE.split(text) + headers = DIARY_ENTRY_RE.findall(text) + entries = [] + for i, header in enumerate(headers): + body = parts[i + 1] if i + 1 < len(parts) else "" + entries.append((header.strip(), body.strip())) + return entries + + +def _diary_drawer_id(wing: str, date_str: str) -> str: + """Stable, wing-scoped drawer ID. Two diaries (e.g. 'work' vs 'personal') + sharing the same date never collide.""" + suffix = hashlib.sha256(f"{wing}|{date_str}".encode()).hexdigest()[:24] + return f"drawer_diary_{suffix}" + + +def _diary_closet_id_base(wing: str, date_str: str) -> str: + suffix = hashlib.sha256(f"{wing}|{date_str}".encode()).hexdigest()[:24] + return f"closet_diary_{suffix}" + + +def ingest_diaries( + diary_dir, + palace_path, + wing="diary", + force=False, +): + """Ingest daily summary files into the palace. + + Each date file gets ONE drawer keyed by ``(wing, date)`` and closets that + pack topics atomically up to ``CLOSET_CHAR_LIMIT``. ``force=True`` rebuilds + every entry's closets from scratch (purging stale ones); the default + incremental mode only processes entries appended since the last run. + """ + diary_dir = Path(diary_dir).expanduser().resolve() + if not diary_dir.exists(): + print(f"Diary directory not found: {diary_dir}") + return {"days_updated": 0, "closets_created": 0} + + diary_files = sorted(diary_dir.glob("*.md")) + if not diary_files: + print(f"No .md files in {diary_dir}") + return {"days_updated": 0, "closets_created": 0} + + state_file = _state_file_for(str(palace_path), diary_dir) + if force or not state_file.exists(): + state: dict = {} + else: + try: + state = json.loads(state_file.read_text()) + except Exception: + state = {} + + drawers_col = get_collection(palace_path) + closets_col = get_closets_collection(palace_path) + + days_updated = 0 + closets_created = 0 + + for diary_path in diary_files: + text = diary_path.read_text(encoding="utf-8", errors="replace") + if len(text.strip()) < 50: + continue + + date_match = re.match(r"(\d{4}-\d{2}-\d{2})", diary_path.stem) + if not date_match: + continue + date_str = date_match.group(1) + + # Skip if content hasn't changed + state_key = f"{wing}|{diary_path.name}" + prev_size = state.get(state_key, {}).get("size", 0) + curr_size = len(text) + if curr_size == prev_size and not force: + continue + + now_iso = datetime.now(timezone.utc).isoformat() + drawer_id = _diary_drawer_id(wing, date_str) + entities = _extract_entities_for_metadata(text) + source_file = str(diary_path) + + # Serialize per source — two terminals running ingest at once must + # not interleave the upsert + closet-rebuild. + with mine_lock(source_file): + drawer_meta = { + "date": date_str, + "wing": wing, + "room": "daily", + "source_file": source_file, + "source_session": "daily_diary", + "filed_at": now_iso, + } + if entities: + drawer_meta["entities"] = entities + drawers_col.upsert( + documents=[text], + ids=[drawer_id], + metadatas=[drawer_meta], + ) + + entries = _split_entries(text) + prev_entry_count = state.get(state_key, {}).get("entry_count", 0) + new_entries = entries if force else entries[prev_entry_count:] + + if new_entries: + all_lines = [] + for header, body in new_entries: + entry_text = f"{header}\n{body}" + entry_lines = build_closet_lines( + source_file, [drawer_id], entry_text, wing, "daily" + ) + all_lines.extend(entry_lines) + + if all_lines: + closet_id_base = _diary_closet_id_base(wing, date_str) + closet_meta = { + "date": date_str, + "wing": wing, + "room": "daily", + "source_file": source_file, + "filed_at": now_iso, + } + if entities: + closet_meta["entities"] = entities + # On a force rebuild, wipe any leftover numbered closets + # from a longer prior run before re-writing. + if force: + purge_file_closets(closets_col, source_file) + n = upsert_closet_lines(closets_col, closet_id_base, all_lines, closet_meta) + closets_created += n + + state[state_key] = { + "size": curr_size, + "entry_count": len(entries), + "ingested_at": now_iso, + } + days_updated += 1 + + state_file.write_text(json.dumps(state, indent=2)) + if days_updated: + print(f"Diary: {days_updated} days updated, {closets_created} new closets") + + return {"days_updated": days_updated, "closets_created": closets_created} + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Ingest daily summaries into the palace") + parser.add_argument("--dir", required=True, help="Path to daily_summaries directory") + parser.add_argument("--palace", default=os.path.expanduser("~/.mempalace/palace")) + parser.add_argument("--wing", default="diary") + parser.add_argument("--force", action="store_true") + args = parser.parse_args() + + ingest_diaries(args.dir, args.palace, wing=args.wing, force=args.force) diff --git a/mempalace/fact_checker.py b/mempalace/fact_checker.py new file mode 100644 index 0000000..50e8842 --- /dev/null +++ b/mempalace/fact_checker.py @@ -0,0 +1,335 @@ +""" +fact_checker.py — Verify text against known facts in the palace. + +Checks AI responses, diary entries, and new content against the entity +registry and knowledge graph for three classes of issue: + + * similar_name — text mentions a name that's one/two edits + away from *another* registered name, raising + the possibility of a typo or mix-up. + * relationship_mismatch — text asserts a role between two entities + (e.g. "Bob is Alice's brother") while the KG + records a *different* current role for the + same subject/object pair. + * stale_fact — text asserts a fact that the KG marks closed + (``valid_to`` in the past). + +Purely offline. Inputs: entity_registry JSON + KG SQLite. No network. + +Usage: + from mempalace.fact_checker import check_text + issues = check_text("Bob is Alice's brother", palace_path) + + # CLI + python -m mempalace.fact_checker "Bob is Alice's brother" \\ + --palace ~/.mempalace/palace +""" + +from __future__ import annotations + +import os +import re +from datetime import datetime, timezone + +# Share miner's mtime-cached registry loader so we don't double-read +# ~/.mempalace/known_entities.json on every check_text call. +from .miner import _load_known_entities_raw + + +# Narrow detection patterns — parse "X is Y's Z" and "X's Z is Y". +# Names are captured greedily as word sequences (letters + optional +# capitalized follow-ons) so simple multi-token names still work. +# Relationship words are constrained to sane lengths to avoid matching +# arbitrary filler. +_RELATIONSHIP_PATTERNS = [ + # "Bob is Alice's brother" → subject=Bob, possessor=Alice, role=brother + re.compile(r"\b([A-Z][\w-]+)\s+is\s+([A-Z][\w-]+)'s\s+([a-z]{3,20})\b"), + # "Alice's brother is Bob" → possessor=Alice, role=brother, subject=Bob + re.compile(r"\b([A-Z][\w-]+)'s\s+([a-z]{3,20})\s+is\s+([A-Z][\w-]+)\b"), +] + + +def check_text(text: str, palace_path: str = None, config=None) -> list: + """Return a list of issues detected in ``text``. + + Empty list means "no contradictions found" — absence of evidence, not + evidence of absence. The detector is deliberately conservative: + every issue is anchored to a specific KG fact or registry entry. + """ + if config is None: + from .config import MempalaceConfig + + config = MempalaceConfig() + if palace_path is None: + palace_path = config.palace_path + + if not text: + return [] + + issues: list = [] + entity_names_raw = _load_known_entities_raw() + + issues.extend(_check_entity_confusion(text, entity_names_raw)) + issues.extend(_check_kg_contradictions(text, palace_path)) + + return issues + + +# ── entity-name confusion ──────────────────────────────────────────── + + +def _flatten_names(entity_names_raw: dict) -> set: + """Flatten a ``{category: [names]}`` or ``{category: {name: meta}}`` + registry into a set of names.""" + flat: set = set() + for cat in entity_names_raw.values(): + if isinstance(cat, list): + flat.update(str(n) for n in cat if n) + elif isinstance(cat, dict): + flat.update(str(k) for k in cat.keys() if k) + return flat + + +def _check_entity_confusion(text: str, entity_names_raw: dict) -> list: + """Flag names mentioned in the text that are edit-distance ≤ 2 from + a *different* registered name — a common typo / mix-up pattern. + + Performance note: the original O(n²) pairwise scan over the full + registry is gone. We first identify which names actually appear in + the text, then only compute edit distance between *mentioned* names + and the rest of the registry. This makes the cost O(m × n) where m + is the handful of names in the text, not the full registry. + """ + all_names = _flatten_names(entity_names_raw) + if not all_names: + return [] + + # Which names from the registry actually appear in the text? + mentioned: list = [] + for name in all_names: + if re.search(r"\b" + re.escape(name) + r"\b", text, re.IGNORECASE): + mentioned.append(name) + if not mentioned: + return [] + + issues: list = [] + seen_pairs: set = set() + for name_a in mentioned: + a_lower = name_a.lower() + for name_b in all_names: + if name_b == name_a: + continue + # Dedupe by unordered pair so we don't double-report. + pair_key = tuple(sorted((name_a.lower(), name_b.lower()))) + if pair_key in seen_pairs: + continue + # Only flag when name_b is a *different* registry entry that + # was NOT mentioned — otherwise both names in the text is + # just the user writing about two people. + if name_b in mentioned: + seen_pairs.add(pair_key) + continue + distance = _edit_distance(a_lower, name_b.lower()) + if 0 < distance <= 2: + issues.append( + { + "type": "similar_name", + "detail": ( + f"'{name_a}' mentioned — did you mean " + f"'{name_b}'? (edit distance {distance})" + ), + "names": [name_a, name_b], + "distance": distance, + } + ) + seen_pairs.add(pair_key) + return issues + + +# ── KG contradictions ──────────────────────────────────────────────── + + +def _extract_claims(text: str) -> list: + """Yield structured (subject, predicate, object) claims from ``text``. + + The two supported surface forms are "X is Y's Z" and "X's Z is Y", + both of which resolve to the triple ``(X, Z, Y)`` — ``X`` has role + ``Z`` with respect to ``Y``. Matches are case-preserving for the + entity names (KG lookup is case-insensitive on normalized IDs). + """ + claims: list = [] + for pat in _RELATIONSHIP_PATTERNS: + for match in pat.finditer(text): + groups = match.groups() + if pat is _RELATIONSHIP_PATTERNS[0]: + subject, possessor, role = groups[0], groups[1], groups[2] + else: + possessor, role, subject = groups[0], groups[1], groups[2] + claims.append( + { + "subject": subject, + "predicate": role.lower(), + "object": possessor, + "span": match.group(0), + } + ) + return claims + + +def _check_kg_contradictions(text: str, palace_path: str) -> list: + """Compare each claim in ``text`` against the KG. + + For every claim ``(subject, predicate, object)`` parsed from the + text, look up the subject's current KG triples: + + * ``relationship_mismatch`` fires when the KG records a fact about + the same ``(subject, object)`` pair but with a *different* + predicate — e.g. text says "brother" but KG says "husband". + * ``stale_fact`` fires when the KG has the exact ``(subject, + predicate, object)`` triple but its ``valid_to`` is in the past, + meaning the claim is no longer current. + """ + claims = _extract_claims(text) + if not claims: + return [] + + try: + from .knowledge_graph import KnowledgeGraph + + # KG lives alongside the palace collection; mcp_server uses the + # same convention (see _kg init). Pass ``db_path`` — the previous + # code passed a nonexistent ``palace_path`` kwarg which raised + # TypeError, silently swallowed by the outer except and rendered + # the entire KG-check path dead. + kg = KnowledgeGraph(db_path=os.path.join(palace_path, "knowledge_graph.sqlite3")) + except Exception: + # KG unavailable (brand-new palace, corrupted DB, etc.) — skip. + return [] + + issues: list = [] + for claim in claims: + subject = claim["subject"] + claim_pred = claim["predicate"] + claim_obj = claim["object"] + try: + facts = kg.query_entity(subject, direction="outgoing") + except Exception: + continue + if not facts: + continue + + current_facts = [f for f in facts if f.get("current")] + + # Mismatch: KG fact about same (subject, object) pair but different predicate. + for fact in current_facts: + if not _objects_match(fact.get("object"), claim_obj): + continue + kg_pred = (fact.get("predicate") or "").lower() + if kg_pred and kg_pred != claim_pred: + issues.append( + { + "type": "relationship_mismatch", + "detail": ( + f"Text says '{claim['span']}' but KG records " + f"{subject} {kg_pred} {fact.get('object')}" + ), + "entity": subject, + "claim": { + "predicate": claim_pred, + "object": claim_obj, + }, + "kg_fact": { + "predicate": kg_pred, + "object": fact.get("object"), + }, + } + ) + + # Stale fact: exact match on (subject, predicate, object) but KG + # closed the window in the past. + now_iso = datetime.now(timezone.utc).date().isoformat() + for fact in facts: + if fact.get("current"): + continue + kg_pred = (fact.get("predicate") or "").lower() + if kg_pred != claim_pred: + continue + if not _objects_match(fact.get("object"), claim_obj): + continue + valid_to = fact.get("valid_to") + if valid_to and str(valid_to) < now_iso: + issues.append( + { + "type": "stale_fact", + "detail": ( + f"Text says '{claim['span']}' but KG marks " + f"this fact closed on {valid_to}" + ), + "entity": subject, + "valid_to": valid_to, + } + ) + + return issues + + +def _objects_match(kg_obj, claim_obj: str) -> bool: + if kg_obj is None or not claim_obj: + return False + return str(kg_obj).strip().lower() == claim_obj.strip().lower() + + +# ── Levenshtein helper (tight iterative version) ───────────────────── + + +def _edit_distance(s1: str, s2: str) -> int: + """Levenshtein distance. O(len(s1) * len(s2)) time, O(len(s2)) space.""" + if len(s1) < len(s2): + s1, s2 = s2, s1 + if not s2: + return len(s1) + prev = list(range(len(s2) + 1)) + for i, c1 in enumerate(s1): + curr = [i + 1] + for j, c2 in enumerate(s2): + curr.append( + min( + prev[j + 1] + 1, + curr[j] + 1, + prev[j] + (0 if c1 == c2 else 1), + ) + ) + prev = curr + return prev[-1] + + +if __name__ == "__main__": + import argparse + import json + import sys + + parser = argparse.ArgumentParser( + description="Check text against known facts in the MemPalace palace.", + epilog="Exits 0 when no issues found, 1 when one or more issues detected.", + ) + parser.add_argument("text", nargs="?", help="Text to check (or use --stdin).") + parser.add_argument( + "--palace", + default=os.path.expanduser("~/.mempalace/palace"), + help="Path to the palace directory.", + ) + parser.add_argument("--stdin", action="store_true", help="Read text from stdin.") + args = parser.parse_args() + + if args.stdin: + text_in = sys.stdin.read() + elif args.text: + text_in = args.text + else: + parser.error("Provide text as argument or use --stdin.") + + found = check_text(text_in, palace_path=args.palace) + if found: + print(json.dumps(found, indent=2)) + sys.exit(1) + print("No contradictions found.") diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index 33933ff..31be8a4 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -35,7 +35,15 @@ from .version import __version__ import chromadb from .query_sanitizer import sanitize_query from .searcher import search_memories -from .palace_graph import traverse, find_tunnels, graph_stats +from .palace_graph import ( + traverse, + find_tunnels, + graph_stats, + create_tunnel, + list_tunnels, + delete_tunnel, + follow_tunnels, +) from .knowledge_graph import KnowledgeGraph @@ -496,6 +504,66 @@ def tool_graph_stats(): return graph_stats(col=col) +def tool_create_tunnel( + source_wing: str, + source_room: str, + target_wing: str, + target_room: str, + label: str = "", + source_drawer_id: str = None, + target_drawer_id: str = None, +): + """Create an explicit cross-wing tunnel between two palace locations. + + Use when you notice content in one project relates to another project. + Example: an API design discussion in project_api connects to the + database schema in project_database. + """ + try: + source_wing = sanitize_name(source_wing, "source_wing") + source_room = sanitize_name(source_room, "source_room") + target_wing = sanitize_name(target_wing, "target_wing") + target_room = sanitize_name(target_room, "target_room") + except ValueError as e: + return {"error": str(e)} + return create_tunnel( + source_wing, + source_room, + target_wing, + target_room, + label=label, + source_drawer_id=source_drawer_id, + target_drawer_id=target_drawer_id, + ) + + +def tool_list_tunnels(wing: str = None): + """List all explicit cross-wing tunnels, optionally filtered by wing.""" + try: + wing = _sanitize_optional_name(wing, "wing") + except ValueError as e: + return {"error": str(e)} + return list_tunnels(wing) + + +def tool_delete_tunnel(tunnel_id: str): + """Delete an explicit tunnel by its ID.""" + if not tunnel_id or not isinstance(tunnel_id, str): + return {"error": "tunnel_id is required"} + return delete_tunnel(tunnel_id) + + +def tool_follow_tunnels(wing: str, room: str): + """Follow explicit tunnels from a room to see connected drawers in other wings.""" + try: + wing = sanitize_name(wing, "wing") + room = sanitize_name(room, "room") + except ValueError as e: + return {"error": str(e)} + col = _get_collection() + return follow_tunnels(wing, room, col=col) + + # ==================== WRITE TOOLS ==================== @@ -1184,6 +1252,65 @@ TOOLS = { "input_schema": {"type": "object", "properties": {}}, "handler": tool_graph_stats, }, + "mempalace_create_tunnel": { + "description": "Create a cross-wing tunnel linking two palace locations. Use when content in one project relates to another — e.g., an API design in project_api connects to a database schema in project_database.", + "input_schema": { + "type": "object", + "properties": { + "source_wing": {"type": "string", "description": "Wing of the source"}, + "source_room": {"type": "string", "description": "Room in the source wing"}, + "target_wing": {"type": "string", "description": "Wing of the target"}, + "target_room": {"type": "string", "description": "Room in the target wing"}, + "label": {"type": "string", "description": "Description of the connection"}, + "source_drawer_id": { + "type": "string", + "description": "Optional specific drawer ID", + }, + "target_drawer_id": { + "type": "string", + "description": "Optional specific drawer ID", + }, + }, + "required": ["source_wing", "source_room", "target_wing", "target_room"], + }, + "handler": tool_create_tunnel, + }, + "mempalace_list_tunnels": { + "description": "List all explicit cross-wing tunnels. Optionally filter by wing.", + "input_schema": { + "type": "object", + "properties": { + "wing": { + "type": "string", + "description": "Filter tunnels by wing (shows tunnels where wing is source or target)", + }, + }, + }, + "handler": tool_list_tunnels, + }, + "mempalace_delete_tunnel": { + "description": "Delete an explicit tunnel by its ID.", + "input_schema": { + "type": "object", + "properties": { + "tunnel_id": {"type": "string", "description": "Tunnel ID to delete"}, + }, + "required": ["tunnel_id"], + }, + "handler": tool_delete_tunnel, + }, + "mempalace_follow_tunnels": { + "description": "Follow tunnels from a room to see what it connects to in other wings. Returns connected rooms with drawer previews.", + "input_schema": { + "type": "object", + "properties": { + "wing": {"type": "string", "description": "Wing to start from"}, + "room": {"type": "string", "description": "Room to follow tunnels from"}, + }, + "required": ["wing", "room"], + }, + "handler": tool_follow_tunnels, + }, "mempalace_search": { "description": "Semantic search. Returns verbatim drawer content with similarity scores. IMPORTANT: 'query' must contain ONLY search keywords. Use 'context' for background. Results with cosine distance > max_distance are filtered out.", "input_schema": { diff --git a/mempalace/miner.py b/mempalace/miner.py index c3829d9..3d8e29e 100644 --- a/mempalace/miner.py +++ b/mempalace/miner.py @@ -378,6 +378,116 @@ def chunk_text(content: str, source_file: str) -> list: # ============================================================================= +_ENTITY_REGISTRY_PATH = os.path.join(os.path.expanduser("~"), ".mempalace", "known_entities.json") +_ENTITY_REGISTRY_CACHE: dict = {"mtime": None, "names": frozenset(), "raw": {}} +_ENTITY_EXTRACT_WINDOW = 5000 # chars of content scanned for capitalized words +_ENTITY_METADATA_LIMIT = 25 # max entities packed into the metadata field + + +def _refresh_known_entities_cache() -> None: + """Reload ``~/.mempalace/known_entities.json`` into the module cache if + its mtime changed since the last read. Shared by ``_load_known_entities`` + (flat set) and ``_load_known_entities_raw`` (category dict), so callers + can pick whichever shape they need without duplicating the mtime-gated + disk read. + """ + try: + mtime = os.path.getmtime(_ENTITY_REGISTRY_PATH) + except OSError: + if _ENTITY_REGISTRY_CACHE["mtime"] is not None: + _ENTITY_REGISTRY_CACHE["mtime"] = None + _ENTITY_REGISTRY_CACHE["names"] = frozenset() + _ENTITY_REGISTRY_CACHE["raw"] = {} + return + + if _ENTITY_REGISTRY_CACHE["mtime"] == mtime: + return + + names: set = set() + raw: dict = {} + try: + import json + + with open(_ENTITY_REGISTRY_PATH, "r", encoding="utf-8") as f: + data = json.load(f) + if isinstance(data, dict): + raw = data + for cat in data.values(): + if isinstance(cat, list): + names.update(str(n) for n in cat if n) + elif isinstance(cat, dict): + names.update(str(k) for k in cat.keys() if k) + except Exception: + names = set() + raw = {} + + _ENTITY_REGISTRY_CACHE["mtime"] = mtime + _ENTITY_REGISTRY_CACHE["names"] = frozenset(names) + _ENTITY_REGISTRY_CACHE["raw"] = raw + + +def _load_known_entities() -> frozenset: + """Flat set of every known entity name (across all categories). + + Cached by mtime; invalidated when the registry file changes. + """ + _refresh_known_entities_cache() + return _ENTITY_REGISTRY_CACHE["names"] + + +def _load_known_entities_raw() -> dict: + """Full category-dict view of the registry, shape + ``{"category": ["Name1", ...], ...}``. Cached by mtime. + + Consumed by modules (e.g., fact_checker) that need to reason about + categories rather than a flat name set. Never returns a mutable + reference to the cache — callers get a shallow copy. + """ + _refresh_known_entities_cache() + return dict(_ENTITY_REGISTRY_CACHE["raw"]) + + +def _extract_entities_for_metadata(content: str) -> str: + """Extract entity names from content for metadata tagging. + + Combines the user's known-entity registry (cached across calls) with + capitalized words appearing ≥2 times in the first ``_ENTITY_EXTRACT_WINDOW`` + chars. Filters out the closet stoplist (``When``, ``After``, ``The``, …) + so sentence-starters don't masquerade as proper nouns. + + Returns semicolon-separated string suitable for ChromaDB metadata + filtering. The list is truncated to ``_ENTITY_METADATA_LIMIT`` entries + *before* joining so a name is never cut in half. + """ + import re + + from .palace import _ENTITY_STOPLIST + + matched: set = set() + + known = _load_known_entities() + for name in known: + if re.search(r"(?= 2 and len(w) > 2: + matched.add(w) + + if not matched: + return "" + # Truncate the *list*, not the joined string — never split a name. + capped = sorted(matched)[:_ENTITY_METADATA_LIMIT] + return ";".join(capped) + + def add_drawer( collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str ): @@ -398,6 +508,10 @@ def add_drawer( metadata["source_mtime"] = os.path.getmtime(source_file) except OSError: pass + # Tag with entity names for filterable search + entities = _extract_entities_for_metadata(content) + if entities: + metadata["entities"] = entities collection.upsert( documents=[content], ids=[drawer_id], @@ -490,20 +604,19 @@ def process_file( closet_id_base = ( f"closet_{wing}_{room}_{hashlib.sha256(source_file.encode()).hexdigest()[:24]}" ) + entities = _extract_entities_for_metadata(content) + closet_meta = { + "wing": wing, + "room": room, + "source_file": source_file, + "drawer_count": drawers_added, + "filed_at": datetime.now().isoformat(), + "normalize_version": NORMALIZE_VERSION, + } + if entities: + closet_meta["entities"] = entities purge_file_closets(closets_col, source_file) - upsert_closet_lines( - closets_col, - closet_id_base, - closet_lines, - { - "wing": wing, - "room": room, - "source_file": source_file, - "drawer_count": drawers_added, - "filed_at": datetime.now().isoformat(), - "normalize_version": NORMALIZE_VERSION, - }, - ) + upsert_closet_lines(closets_col, closet_id_base, closet_lines, closet_meta) return drawers_added, room diff --git a/mempalace/palace_graph.py b/mempalace/palace_graph.py index 5e2e72e..71cad89 100644 --- a/mempalace/palace_graph.py +++ b/mempalace/palace_graph.py @@ -15,10 +15,15 @@ Enables queries like: No external graph DB needed — built from ChromaDB metadata. """ -from collections import defaultdict, Counter +import hashlib +import json +import os +from collections import Counter, defaultdict +from datetime import datetime, timezone from .config import MempalaceConfig from .palace import get_collection as _get_palace_collection +from .palace import mine_lock def _get_collection(config=None): @@ -228,3 +233,227 @@ def _fuzzy_match(query: str, nodes: dict, n: int = 5): scored.append((room, 0.5)) scored.sort(key=lambda x: -x[1]) return [r for r, _ in scored[:n]] + + +# ============================================================================= +# EXPLICIT TUNNELS — agent-created cross-wing links +# ============================================================================= +# Passive tunnels are discovered from shared room names across wings. +# Explicit tunnels are created by agents when they notice a connection +# between two specific drawers or rooms in different wings/projects. +# +# Stored as a JSON file at ~/.mempalace/tunnels.json so they persist +# across palace rebuilds (not in ChromaDB which can be recreated). + + +_TUNNEL_FILE = os.path.join(os.path.expanduser("~"), ".mempalace", "tunnels.json") + + +def _load_tunnels(): + """Load explicit tunnels from disk. + + Returns an empty list if the file is missing or corrupt (e.g. truncated + by a crash mid-write on a system that lacks atomic-rename semantics). + """ + if not os.path.exists(_TUNNEL_FILE): + return [] + try: + with open(_TUNNEL_FILE, "r", encoding="utf-8") as f: + data = json.load(f) + except Exception: + return [] + return data if isinstance(data, list) else [] + + +def _save_tunnels(tunnels): + """Persist explicit tunnels atomically. + + Writes to ``tunnels.json.tmp`` then ``os.replace``s it into place, so + a crash mid-write can never leave a partial/empty tunnels.json that + silently wipes every tunnel on next read. + """ + os.makedirs(os.path.dirname(_TUNNEL_FILE), exist_ok=True) + tmp_path = _TUNNEL_FILE + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(tunnels, f, indent=2) + f.flush() + try: + os.fsync(f.fileno()) + except OSError: + # Not all filesystems (or Windows file handles) support fsync — tolerate. + pass + os.replace(tmp_path, _TUNNEL_FILE) + + +def _endpoint_key(wing: str, room: str) -> str: + return f"{wing}/{room}" + + +def _canonical_tunnel_id( + source_wing: str, source_room: str, target_wing: str, target_room: str +) -> str: + """Compute a symmetric tunnel ID. + + Tunnels are conceptually undirected — "auth relates to users" is the + same connection as "users relates to auth". Sort the two endpoints + before hashing so ``create_tunnel(A, B)`` and ``create_tunnel(B, A)`` + resolve to the same ID and dedup into one record. + """ + src = _endpoint_key(source_wing, source_room) + tgt = _endpoint_key(target_wing, target_room) + a, b = sorted((src, tgt)) + return hashlib.sha256(f"{a}↔{b}".encode()).hexdigest()[:16] + + +def _require_name(value: str, field: str) -> str: + """Reject empty / non-string endpoint identifiers.""" + if not isinstance(value, str) or not value.strip(): + raise ValueError(f"{field} must be a non-empty string") + return value.strip() + + +def create_tunnel( + source_wing: str, + source_room: str, + target_wing: str, + target_room: str, + label: str = "", + source_drawer_id: str = None, + target_drawer_id: str = None, +): + """Create an explicit (symmetric) tunnel between two locations in the palace. + + Tunnels are undirected: ``create_tunnel(A, B)`` and ``create_tunnel(B, A)`` + resolve to the same canonical ID. A second call with the same endpoints + updates the stored label (and drawer IDs, if provided) rather than + creating a duplicate. + + The ``source`` / ``target`` fields on the returned dict preserve the + argument order the caller used, so callers can display it directionally + if they like. The ID and dedup are symmetric. + + Args: + source_wing: Wing of the source (e.g., "project_api"). + source_room: Room in the source wing. + target_wing: Wing of the target (e.g., "project_database"). + target_room: Room in the target wing. + label: Description of the connection. + source_drawer_id: Optional specific drawer ID. + target_drawer_id: Optional specific drawer ID. + + Returns: + The stored tunnel dict. + + Raises: + ValueError: if any wing or room is empty or non-string. + """ + source_wing = _require_name(source_wing, "source_wing") + source_room = _require_name(source_room, "source_room") + target_wing = _require_name(target_wing, "target_wing") + target_room = _require_name(target_room, "target_room") + + tunnel_id = _canonical_tunnel_id(source_wing, source_room, target_wing, target_room) + + tunnel = { + "id": tunnel_id, + "source": {"wing": source_wing, "room": source_room}, + "target": {"wing": target_wing, "room": target_room}, + "label": label, + "created_at": datetime.now(timezone.utc).isoformat(), + } + if source_drawer_id: + tunnel["source"]["drawer_id"] = source_drawer_id + if target_drawer_id: + tunnel["target"]["drawer_id"] = target_drawer_id + + # Serialize the load → mutate → save cycle. Without this, two concurrent + # create_tunnel calls can both read the same snapshot and the later + # writer silently drops the earlier writer's tunnel. + with mine_lock(_TUNNEL_FILE): + tunnels = _load_tunnels() + for existing in tunnels: + if existing.get("id") == tunnel_id: + # Preserve original creation timestamp on label updates. + tunnel["created_at"] = existing.get("created_at", tunnel["created_at"]) + tunnel["updated_at"] = datetime.now(timezone.utc).isoformat() + existing.clear() + existing.update(tunnel) + _save_tunnels(tunnels) + return existing + tunnels.append(tunnel) + _save_tunnels(tunnels) + return tunnel + + +def list_tunnels(wing: str = None): + """List all explicit tunnels, optionally filtered by wing. + + Returns tunnels where ``wing`` appears as either source or target + (tunnels are symmetric, so either endpoint is a valid filter match). + """ + tunnels = _load_tunnels() + if wing: + tunnels = [t for t in tunnels if t["source"]["wing"] == wing or t["target"]["wing"] == wing] + return tunnels + + +def delete_tunnel(tunnel_id: str): + """Delete an explicit tunnel by ID. Returns ``{"deleted": }``.""" + with mine_lock(_TUNNEL_FILE): + tunnels = _load_tunnels() + tunnels = [t for t in tunnels if t.get("id") != tunnel_id] + _save_tunnels(tunnels) + return {"deleted": tunnel_id} + + +def follow_tunnels(wing: str, room: str, col=None, config=None): + """Follow explicit tunnels from a room — returns connected drawers. + + Given a location (wing/room), finds all tunnels leading from or to it, + and optionally fetches the connected drawer content. + """ + tunnels = _load_tunnels() + connections = [] + + for t in tunnels: + src = t["source"] + tgt = t["target"] + + if src["wing"] == wing and src["room"] == room: + connections.append( + { + "direction": "outgoing", + "connected_wing": tgt["wing"], + "connected_room": tgt["room"], + "label": t.get("label", ""), + "drawer_id": tgt.get("drawer_id"), + "tunnel_id": t["id"], + } + ) + elif tgt["wing"] == wing and tgt["room"] == room: + connections.append( + { + "direction": "incoming", + "connected_wing": src["wing"], + "connected_room": src["room"], + "label": t.get("label", ""), + "drawer_id": src.get("drawer_id"), + "tunnel_id": t["id"], + } + ) + + # If we have a collection, fetch drawer content for connected items + if col and connections: + drawer_ids = [c["drawer_id"] for c in connections if c.get("drawer_id")] + if drawer_ids: + try: + results = col.get(ids=drawer_ids, include=["documents", "metadatas"]) + drawer_map = dict(zip(results["ids"], results["documents"])) + for c in connections: + did = c.get("drawer_id") + if did and did in drawer_map: + c["drawer_preview"] = drawer_map[did][:300] + except Exception: + pass + + return connections diff --git a/mempalace/searcher.py b/mempalace/searcher.py index 17a848c..dea300d 100644 --- a/mempalace/searcher.py +++ b/mempalace/searcher.py @@ -2,11 +2,15 @@ """ searcher.py — Find anything. Exact words. -Semantic search against the palace. -Returns verbatim text — the actual words, never summaries. +Hybrid search: BM25 keyword matching + vector semantic similarity. The +drawer query is the floor — always runs — and closet hits add a rank-based +boost when they agree. Closets are a ranking *signal*, never a gate, so +weak closets (regex extraction on narrative content) can only help, never +hide drawers the direct path would have found. """ import logging +import math import re from pathlib import Path @@ -23,6 +27,111 @@ class SearchError(Exception): """Raised when search cannot proceed (e.g. no palace found).""" +_TOKEN_RE = re.compile(r"\w{2,}", re.UNICODE) + + +def _tokenize(text: str) -> list: + """Lowercase + strip to alphanumeric tokens of length ≥ 2.""" + return _TOKEN_RE.findall(text.lower()) + + +def _bm25_scores( + query: str, + documents: list, + k1: float = 1.5, + b: float = 0.75, +) -> list: + """Compute Okapi-BM25 scores for ``query`` against each document. + + IDF is computed over the *provided corpus* using the Lucene/BM25+ + smoothed formula ``log((N - df + 0.5) / (df + 0.5) + 1)``, which is + always non-negative. This is well-defined for re-ranking a small + candidate set returned by vector retrieval — IDF then reflects how + discriminative each query term is *within the candidates*, exactly + what's needed to reorder them. + + Parameters mirror Okapi-BM25 conventions: + k1 — term-frequency saturation (1.2-2.0 typical, 1.5 default) + b — length normalization (0.0 = none, 1.0 = full, 0.75 default) + + Returns a list of scores in the same order as ``documents``. + """ + n_docs = len(documents) + query_terms = set(_tokenize(query)) + if not query_terms or n_docs == 0: + return [0.0] * n_docs + + tokenized = [_tokenize(d) for d in documents] + doc_lens = [len(toks) for toks in tokenized] + if not any(doc_lens): + return [0.0] * n_docs + avgdl = sum(doc_lens) / n_docs or 1.0 + + # Document frequency: how many docs contain each query term? + df = {term: 0 for term in query_terms} + for toks in tokenized: + seen = set(toks) & query_terms + for term in seen: + df[term] += 1 + + idf = {term: math.log((n_docs - df[term] + 0.5) / (df[term] + 0.5) + 1) for term in query_terms} + + scores = [] + for toks, dl in zip(tokenized, doc_lens): + if dl == 0: + scores.append(0.0) + continue + tf: dict = {} + for t in toks: + if t in query_terms: + tf[t] = tf.get(t, 0) + 1 + score = 0.0 + for term, freq in tf.items(): + num = freq * (k1 + 1) + den = freq + k1 * (1 - b + b * dl / avgdl) + score += idf[term] * num / den + scores.append(score) + return scores + + +def _hybrid_rank( + results: list, + query: str, + vector_weight: float = 0.6, + bm25_weight: float = 0.4, +) -> list: + """Re-rank ``results`` by a convex combination of vector similarity and BM25. + + * Vector similarity uses absolute cosine sim ``max(0, 1 - distance)`` — + ChromaDB's hnsw cosine distance lives in ``[0, 2]`` (0 = identical). + Absolute (not relative-to-max) means adding/removing a candidate + can't reshuffle the others. + * BM25 is real Okapi-BM25 with corpus-relative IDF over the candidates + themselves. Since the absolute scale is unbounded, BM25 is min-max + normalized within the candidate set so weights are commensurable. + + Mutates each result dict to add ``bm25_score`` and reorders the list + in place. Returns the same list for convenience. + """ + if not results: + return results + + docs = [r.get("text", "") for r in results] + bm25_raw = _bm25_scores(query, docs) + max_bm25 = max(bm25_raw) if bm25_raw else 0.0 + bm25_norm = [s / max_bm25 for s in bm25_raw] if max_bm25 > 0 else [0.0] * len(bm25_raw) + + scored = [] + for r, raw, norm in zip(results, bm25_raw, bm25_norm): + vec_sim = max(0.0, 1.0 - r.get("distance", 1.0)) + r["bm25_score"] = round(raw, 3) + scored.append((vector_weight * vec_sim + bm25_weight * norm, r)) + + scored.sort(key=lambda pair: pair[0], reverse=True) + results[:] = [r for _, r in scored] + return results + + def build_where_filter(wing: str = None, room: str = None) -> dict: """Build ChromaDB where filter for wing/room filtering.""" if wing and room: @@ -48,101 +157,70 @@ def _extract_drawer_ids_from_closet(closet_doc: str) -> list: return list(seen.keys()) -def _closet_first_hits( - palace_path: str, - query: str, - where: dict, - drawers_col, - n_results: int, - max_distance: float, -): - """Run a closet-first search and return chunk-level drawer hits. +def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, radius: int = 1): + """Expand a matched drawer with its ±radius sibling chunks in the same source file. - Returns: - non-empty list of hits when the closet path produced usable matches. - ``None`` when the closet collection is empty/missing OR when every - candidate drawer was filtered out (e.g. by max_distance); the - caller should fall back to direct drawer search. + Motivation — "drawer-grep context" feature: a closet hit returns one + drawer, but the chunk boundary may clip mid-thought (e.g., the matched + chunk says "here's a breakdown:" and the actual breakdown lives in the + next chunk). Fetching the small neighborhood around the match gives + callers enough context without forcing a follow-up ``get_drawer`` call. + + Returns a dict with: + ``text`` combined chunks in chunk_index order + ``drawer_index`` the matched chunk's index in the source file + ``total_drawers`` total drawer count for the source file (or None) + + On any ChromaDB failure or missing metadata, falls back to returning the + matched drawer alone so search never breaks because neighbor expansion + failed. """ + src = matched_meta.get("source_file") + chunk_idx = matched_meta.get("chunk_index") + if not src or not isinstance(chunk_idx, int): + return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None} + + target_indexes = [chunk_idx + offset for offset in range(-radius, radius + 1)] try: - closets_col = get_closets_collection(palace_path, create=False) - except Exception: - return None - - try: - ckwargs = { - "query_texts": [query], - "n_results": max(n_results * 2, 5), - "include": ["documents", "metadatas", "distances"], - } - if where: - ckwargs["where"] = where - closet_results = closets_col.query(**ckwargs) - except Exception: - return None - - closet_docs = closet_results["documents"][0] if closet_results["documents"] else [] - if not closet_docs: - return None - - closet_metas = closet_results["metadatas"][0] - closet_dists = closet_results["distances"][0] - - # Collect candidate drawer IDs in closet-rank order, dedupe, remember - # which closet (and its distance/preview) introduced each one. - drawer_id_order: list = [] - drawer_provenance: dict = {} - for cdoc, cmeta, cdist in zip(closet_docs, closet_metas, closet_dists): - for did in _extract_drawer_ids_from_closet(cdoc): - if did in drawer_provenance: - continue - drawer_provenance[did] = (cdist, cdoc, cmeta) - drawer_id_order.append(did) - - if not drawer_id_order: - return None - - # Hydrate exactly those drawers — chunk-level, not whole-file. - try: - fetched = drawers_col.get( - ids=drawer_id_order, + neighbors = drawers_col.get( + where={ + "$and": [ + {"source_file": src}, + {"chunk_index": {"$in": target_indexes}}, + ] + }, include=["documents", "metadatas"], ) except Exception: - return None + return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None} - fetched_ids = fetched.get("ids") or [] - fetched_docs = fetched.get("documents") or [] - fetched_metas = fetched.get("metadatas") or [] - fetched_map = { - did: (doc, meta) for did, doc, meta in zip(fetched_ids, fetched_docs, fetched_metas) + indexed_docs = [] + for doc, meta in zip(neighbors.get("documents") or [], neighbors.get("metadatas") or []): + ci = meta.get("chunk_index") + if isinstance(ci, int): + indexed_docs.append((ci, doc)) + indexed_docs.sort(key=lambda pair: pair[0]) + + if not indexed_docs: + combined_text = matched_doc + else: + combined_text = "\n\n".join(doc for _, doc in indexed_docs) + + # Cheap total_drawers lookup: metadata-only scan of the source file. + total_drawers = None + try: + all_meta = drawers_col.get(where={"source_file": src}, include=["metadatas"]) + ids = all_meta.get("ids") or [] + total_drawers = len(ids) if ids else None + except Exception: + pass + + return { + "text": combined_text, + "drawer_index": chunk_idx, + "total_drawers": total_drawers, } - hits: list = [] - for did in drawer_id_order: - if did not in fetched_map: - continue # closet pointed to a drawer that no longer exists - doc, meta = fetched_map[did] - cdist, cdoc, _ = drawer_provenance[did] - if max_distance > 0.0 and cdist > max_distance: - continue - hits.append( - { - "text": doc, - "wing": meta.get("wing", "unknown"), - "room": meta.get("room", "unknown"), - "source_file": Path(meta.get("source_file", "?")).name, - "similarity": round(max(0.0, 1 - cdist), 3), - "distance": round(cdist, 4), - "matched_via": "closet", - "closet_preview": cdoc[:200], - } - ) - if len(hits) >= n_results: - break - - return hits if hits else None - def search(query: str, palace_path: str, wing: str = None, room: str = None, n_results: int = 5): """ @@ -242,64 +320,168 @@ def search_memories( where = build_where_filter(wing, room) - # Closet-first search: scan the compact index, parse drawer pointers - # from each matching line, then hydrate exactly those drawers. This - # keeps the result shape chunk-level (consistent with direct search) - # and applies the same max_distance filter. - closet_hits = _closet_first_hits( - palace_path=palace_path, - query=query, - where=where, - drawers_col=drawers_col, - n_results=n_results, - max_distance=max_distance, - ) - if closet_hits is not None: - return { - "query": query, - "filters": {"wing": wing, "room": room}, - "total_before_filter": len(closet_hits), - "results": closet_hits, - } - - # Fallback: direct drawer search (no closets yet, or closets empty) + # Hybrid retrieval: always query drawers directly (the floor), then use + # closet hits to boost rankings. Closets are a ranking SIGNAL, never a + # GATE — direct drawer search is always the baseline. + # + # This avoids the "weak-closets regression" where narrative content + # produces low-signal closets (regex extraction matches few topics) + # and closet-first routing hides drawers that direct search would find. try: - kwargs = { + dkwargs = { "query_texts": [query], - "n_results": n_results, + "n_results": n_results * 3, # over-fetch for re-ranking "include": ["documents", "metadatas", "distances"], } if where: - kwargs["where"] = where - - results = drawers_col.query(**kwargs) + dkwargs["where"] = where + drawer_results = drawers_col.query(**dkwargs) except Exception as e: return {"error": f"Search error: {e}"} - docs = results["documents"][0] - metas = results["metadatas"][0] - dists = results["distances"][0] + # Gather closet hits (best-per-source) to build a boost lookup. + closet_boost_by_source: dict = {} # source_file -> (rank, closet_dist, preview) + try: + closets_col = get_closets_collection(palace_path, create=False) + ckwargs = { + "query_texts": [query], + "n_results": n_results * 2, + "include": ["documents", "metadatas", "distances"], + } + if where: + ckwargs["where"] = where + closet_results = closets_col.query(**ckwargs) + for rank, (cdoc, cmeta, cdist) in enumerate( + zip( + closet_results["documents"][0], + closet_results["metadatas"][0], + closet_results["distances"][0], + ) + ): + source = cmeta.get("source_file", "") + if source and source not in closet_boost_by_source: + closet_boost_by_source[source] = (rank, cdist, cdoc[:200]) + except Exception: + pass # no closets yet — hybrid degrades to pure drawer search - hits = [] - for doc, meta, dist in zip(docs, metas, dists): - # Filter on raw distance before rounding to avoid precision loss + # Rank-based boost. The ordinal signal ("which closet matched best") is + # more reliable than absolute distance on narrative content, where + # closet distances cluster in 1.2-1.5 range regardless of match quality. + CLOSET_RANK_BOOSTS = [0.40, 0.25, 0.15, 0.08, 0.04] + CLOSET_DISTANCE_CAP = 1.5 # cosine dist > 1.5 = too weak to use as signal + + scored: list = [] + for doc, meta, dist in zip( + drawer_results["documents"][0], + drawer_results["metadatas"][0], + drawer_results["distances"][0], + ): + # Filter on raw distance before rounding to avoid precision loss. if max_distance > 0.0 and dist > max_distance: continue - hits.append( - { - "text": doc, - "wing": meta.get("wing", "unknown"), - "room": meta.get("room", "unknown"), - "source_file": Path(meta.get("source_file", "?")).name, - "similarity": round(max(0.0, 1 - dist), 3), - "distance": round(dist, 4), - "matched_via": "drawer", - } - ) + + source = meta.get("source_file", "") or "" + boost = 0.0 + matched_via = "drawer" + closet_preview = None + if source in closet_boost_by_source: + c_rank, c_dist, c_preview = closet_boost_by_source[source] + if c_dist <= CLOSET_DISTANCE_CAP and c_rank < len(CLOSET_RANK_BOOSTS): + boost = CLOSET_RANK_BOOSTS[c_rank] + matched_via = "drawer+closet" + closet_preview = c_preview + + effective_dist = dist - boost + entry = { + "text": doc, + "wing": meta.get("wing", "unknown"), + "room": meta.get("room", "unknown"), + "source_file": Path(source).name if source else "?", + "similarity": round(max(0.0, 1 - effective_dist), 3), + "distance": round(dist, 4), + "effective_distance": round(effective_dist, 4), + "closet_boost": round(boost, 3), + "matched_via": matched_via, + # Internal: retain the full source_file path + chunk_index so the + # enrichment step below doesn't have to reverse-lookup via + # basename-suffix matching (which silently collides when two + # files share a basename across different directories). + "_sort_key": effective_dist, + "_source_file_full": source, + "_chunk_index": meta.get("chunk_index"), + } + if closet_preview: + entry["closet_preview"] = closet_preview + scored.append(entry) + + scored.sort(key=lambda h: h["_sort_key"]) + hits = scored[:n_results] + + # Drawer-grep enrichment: for closet-boosted hits whose source has + # multiple drawers, return the keyword-best chunk + its immediate + # neighbors instead of just the drawer vector search landed on. The + # closet said "this source is relevant"; vector may have picked the + # wrong chunk within it; grep picks the right one. + MAX_HYDRATION_CHARS = 10000 + for h in hits: + if h["matched_via"] == "drawer": + continue + full_source = h.get("_source_file_full") or "" + if not full_source: + continue + try: + source_drawers = drawers_col.get( + where={"source_file": full_source}, + include=["documents", "metadatas"], + ) + except Exception: + continue + docs = source_drawers.get("documents") or [] + metas_ = source_drawers.get("metadatas") or [] + if len(docs) <= 1: + continue + + # Sort by chunk_index so best_idx + neighbors are positional. + indexed = [] + for idx, (d, m) in enumerate(zip(docs, metas_)): + ci = m.get("chunk_index", idx) if isinstance(m, dict) else idx + if not isinstance(ci, int): + ci = idx + indexed.append((ci, d)) + indexed.sort(key=lambda p: p[0]) + ordered_docs = [d for _, d in indexed] + + query_terms = set(_tokenize(query)) + best_idx, best_score = 0, -1 + for idx, d in enumerate(ordered_docs): + d_lower = d.lower() + s = sum(1 for t in query_terms if t in d_lower) + if s > best_score: + best_score, best_idx = s, idx + + start = max(0, best_idx - 1) + end = min(len(ordered_docs), best_idx + 2) + expanded = "\n\n".join(ordered_docs[start:end]) + if len(expanded) > MAX_HYDRATION_CHARS: + expanded = ( + expanded[:MAX_HYDRATION_CHARS] + + f"\n\n[...truncated. {len(ordered_docs)} total drawers. " + "Use mempalace_get_drawer for full content.]" + ) + h["text"] = expanded + h["drawer_index"] = best_idx + h["total_drawers"] = len(ordered_docs) + + # BM25 hybrid re-rank within the final candidate set. + hits = _hybrid_rank(hits, query) + for h in hits: + h.pop("_sort_key", None) + h.pop("_source_file_full", None) + h.pop("_chunk_index", None) return { "query": query, "filters": {"wing": wing, "room": room}, - "total_before_filter": len(docs), + "total_before_filter": len(drawer_results["documents"][0]), "results": hits, } diff --git a/tests/test_closet_llm.py b/tests/test_closet_llm.py new file mode 100644 index 0000000..a92e2fa --- /dev/null +++ b/tests/test_closet_llm.py @@ -0,0 +1,339 @@ +"""Unit tests for the optional LLM-based closet regeneration. + +These tests don't hit the network. They mock urllib to verify: +- LLMConfig correctly reads env vars and CLI overrides +- missing config is reported cleanly +- the OpenAI-compatible request shape is correct +- response parsing handles the standard chat-completions payload +""" + +import json +import tempfile +from unittest.mock import patch + +from mempalace.closet_llm import ( + LLMConfig, + _call_llm, + _parsed_to_closet_lines, + regenerate_closets, +) + + +# ── LLMConfig ───────────────────────────────────────────────────────────── + + +class TestLLMConfig: + def test_reads_env_vars(self, monkeypatch): + monkeypatch.setenv("LLM_ENDPOINT", "http://localhost:11434/v1") + monkeypatch.setenv("LLM_KEY", "sk-abc") + monkeypatch.setenv("LLM_MODEL", "llama3:8b") + c = LLMConfig() + assert c.endpoint == "http://localhost:11434/v1" + assert c.key == "sk-abc" + assert c.model == "llama3:8b" + + def test_cli_flags_override_env(self, monkeypatch): + monkeypatch.setenv("LLM_ENDPOINT", "http://env-endpoint/v1") + monkeypatch.setenv("LLM_MODEL", "env-model") + c = LLMConfig(endpoint="http://flag-endpoint/v1", model="flag-model") + assert c.endpoint == "http://flag-endpoint/v1" + assert c.model == "flag-model" + + def test_trailing_slash_stripped(self): + c = LLMConfig(endpoint="http://foo/v1/", model="m") + assert c.endpoint == "http://foo/v1" + + def test_missing_reports_required(self, monkeypatch): + monkeypatch.delenv("LLM_ENDPOINT", raising=False) + monkeypatch.delenv("LLM_KEY", raising=False) + monkeypatch.delenv("LLM_MODEL", raising=False) + c = LLMConfig() + missing = c.missing() + assert any("ENDPOINT" in m for m in missing) + assert any("MODEL" in m for m in missing) + # key is optional + assert not any("KEY" in m for m in missing) + + def test_key_is_optional(self, monkeypatch): + monkeypatch.delenv("LLM_KEY", raising=False) + c = LLMConfig(endpoint="http://local/v1", model="m") + assert c.missing() == [] + + +# ── _parsed_to_closet_lines ────────────────────────────────────────────── + + +class TestParsedToLines: + def test_topics_become_pointers(self): + parsed = {"topics": ["authentication", "jwt tokens"], "quotes": [], "summary": ""} + lines = _parsed_to_closet_lines(parsed, ["d1", "d2"], "Alice;Bob") + assert len(lines) == 2 + assert "authentication|Alice;Bob|→d1,d2" in lines + assert "jwt tokens|Alice;Bob|→d1,d2" in lines + + def test_quotes_and_summary_included(self): + parsed = { + "topics": ["t1"], + "quotes": ["[Igor] we ship Friday"], + "summary": "Release planning discussion", + } + lines = _parsed_to_closet_lines(parsed, ["d1"], "") + joined = "\n".join(lines) + assert "we ship Friday" in joined + assert "Release planning discussion" in joined + + def test_caps_topics_at_15(self): + parsed = {"topics": [f"t{i}" for i in range(20)], "quotes": [], "summary": ""} + lines = _parsed_to_closet_lines(parsed, ["d1"], "") + assert len(lines) == 15 + + +# ── _call_llm (HTTP mocked) ────────────────────────────────────────────── + + +class _FakeResp: + """Mimics urlopen's context-manager response.""" + + def __init__(self, payload: dict, status: int = 200): + self._body = json.dumps(payload).encode("utf-8") + self.status = status + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + + def read(self): + return self._body + + +class TestCallLLM: + def _make_cfg(self): + return LLMConfig(endpoint="http://localhost:11434/v1", key="sk-test", model="llama3:8b") + + def test_request_shape_and_parsing(self): + cfg = self._make_cfg() + captured = {} + + def fake_urlopen(req, timeout=None): + captured["url"] = req.full_url + captured["headers"] = dict(req.header_items()) + captured["body"] = json.loads(req.data.decode("utf-8")) + return _FakeResp( + { + "choices": [ + { + "message": { + "content": json.dumps( + { + "topics": ["postgres"], + "quotes": ["[Igor] migrate now"], + "summary": "db migration", + } + ) + } + } + ], + "usage": {"prompt_tokens": 42, "completion_tokens": 17}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + parsed, usage = _call_llm(cfg, "/tmp/test.md", "w", "r", "content body") + + assert parsed["topics"] == ["postgres"] + assert usage["prompt_tokens"] == 42 + assert captured["url"] == "http://localhost:11434/v1/chat/completions" + # Authorization header is stored capitalized-then-lowercase depending on urllib version + auth_vals = {v for k, v in captured["headers"].items() if k.lower() == "authorization"} + assert "Bearer sk-test" in auth_vals + assert captured["body"]["model"] == "llama3:8b" + assert captured["body"]["messages"][0]["role"] == "user" + + def test_omits_auth_header_when_no_key(self): + cfg = LLMConfig(endpoint="http://localhost:11434/v1", model="llama3:8b") + captured_headers = {} + + def fake_urlopen(req, timeout=None): + captured_headers.update({k.lower(): v for k, v in req.header_items()}) + return _FakeResp( + { + "choices": [{"message": {"content": '{"topics":[],"quotes":[],"summary":""}'}}], + "usage": {"prompt_tokens": 0, "completion_tokens": 0}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + _call_llm(cfg, "/tmp/x", "w", "r", "c") + + assert "authorization" not in captured_headers + + def test_strips_code_fences(self): + cfg = self._make_cfg() + fenced = '```json\n{"topics":["t1"],"quotes":[],"summary":""}\n```' + + def fake_urlopen(req, timeout=None): + return _FakeResp( + { + "choices": [{"message": {"content": fenced}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + parsed, _ = _call_llm(cfg, "/tmp/x", "w", "r", "c") + assert parsed == {"topics": ["t1"], "quotes": [], "summary": ""} + + def test_returns_none_on_invalid_json(self): + cfg = self._make_cfg() + + def fake_urlopen(req, timeout=None): + return _FakeResp( + { + "choices": [{"message": {"content": "not json at all"}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + parsed, usage = _call_llm(cfg, "/tmp/x", "w", "r", "c") + assert parsed is None + + +# ── regenerate_closets error paths ─────────────────────────────────────── + + +class TestRegenerateClosets: + def test_missing_config_returns_error(self, monkeypatch): + monkeypatch.delenv("LLM_ENDPOINT", raising=False) + monkeypatch.delenv("LLM_MODEL", raising=False) + with tempfile.TemporaryDirectory() as palace: + result = regenerate_closets(palace) + assert result["error"] == "missing-config" + assert any("ENDPOINT" in m for m in result["missing"]) + + def test_regen_purges_regex_closets_and_stamps_normalize_version(self, tmp_path): + """Regression: before the hardening, regex closets for the same + source survived alongside fresh LLM closets (the old path used a + bare ``closets_col.delete(ids=...)`` with a swallowed exception). + Now we go through ``purge_file_closets`` + ``mine_lock`` + stamp + ``NORMALIZE_VERSION`` so the next mine's stale-version gate doesn't + treat the LLM closets as leftovers to rebuild over.""" + from mempalace.palace import ( + NORMALIZE_VERSION, + get_closets_collection, + get_collection, + upsert_closet_lines, + ) + + palace = str(tmp_path / "palace") + # Seed one drawer and a pre-existing regex closet for the same source. + source = "/proj/story.md" + drawers = get_collection(palace, create=True) + drawers.upsert( + ids=["drawer_01"], + documents=["Content about JWT authentication."], + metadatas=[ + { + "wing": "project", + "room": "auth", + "source_file": source, + "entities": "", + } + ], + ) + closets = get_closets_collection(palace) + upsert_closet_lines( + closets, + closet_id_base="closet_old_regex", + lines=["STALE_REGEX_TOPIC|;|→drawer_01"], + metadata={ + "wing": "project", + "room": "auth", + "source_file": source, + "generated_by": "regex", + }, + ) + + cfg = LLMConfig(endpoint="http://local/v1", model="llama3:8b") + + def fake_urlopen(req, timeout=None): + return _FakeResp( + { + "choices": [ + { + "message": { + "content": json.dumps( + { + "topics": ["jwt auth", "session expiry"], + "quotes": [], + "summary": "auth refactor", + } + ) + } + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + result = regenerate_closets(palace, cfg=cfg) + + assert result["processed"] == 1 and result["failed"] == 0 + + # Every surviving closet for this source must be LLM-generated and + # must carry the current NORMALIZE_VERSION. + survivors = closets.get(where={"source_file": source}, include=["documents", "metadatas"]) + assert survivors["ids"], "LLM closets should have been written" + joined = "\n".join(survivors["documents"]) + assert ( + "STALE_REGEX_TOPIC" not in joined + ), "pre-existing regex closet was not purged before LLM write" + assert "jwt auth" in joined + for meta in survivors["metadatas"]: + assert meta.get("generated_by", "").startswith("llm:") + assert meta.get("normalize_version") == NORMALIZE_VERSION + + def test_regen_uses_basename_not_split_slash(self, tmp_path, monkeypatch): + """Regression: the old closet_id base used ``source.split('/')[-1]`` + which silently degrades on Windows paths (``C:\\proj\\a.md`` → + the whole string). ``os.path.basename`` handles both separators.""" + from mempalace.palace import get_collection, get_closets_collection + + palace = str(tmp_path / "palace") + # Use a path whose basename differs between '/' split and + # os.path.basename only on a platform-aware function, but verify + # at minimum that IDs encode just the filename, not the full path. + source = "/deep/nested/project/dir/mydoc.md" + drawers = get_collection(palace, create=True) + drawers.upsert( + ids=["d1"], + documents=["body"], + metadatas=[{"wing": "w", "room": "r", "source_file": source, "entities": ""}], + ) + + cfg = LLMConfig(endpoint="http://local/v1", model="m") + + def fake_urlopen(req, timeout=None): + return _FakeResp( + { + "choices": [ + {"message": {"content": '{"topics":["t1"],"quotes":[],"summary":""}'}} + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + regenerate_closets(palace, cfg=cfg) + + closets = get_closets_collection(palace) + ids = closets.get(where={"source_file": source}).get("ids", []) + assert ids + # IDs must not leak the full path (would happen if we used + # source.split('/')[-1] on Windows, or forgot to strip entirely). + for cid in ids: + assert "/" not in cid + assert "mydoc.md" in cid diff --git a/tests/test_closets.py b/tests/test_closets.py index 2946bae..976086d 100644 --- a/tests/test_closets.py +++ b/tests/test_closets.py @@ -1,35 +1,154 @@ """ -test_closets.py — Tests for the closet (searchable index) layer. +test_closets.py — Tests for the closet (searchable index) layer and the +features that ride on top of it: mine_lock serialization, entity metadata, +hybrid BM25+vector search, and diary ingest. -Covers: - * build_closet_lines — pointer-line shape, entity extraction, stoplist, - quote/header pickup, and the "always emit one line" guarantee. - * upsert_closet_lines — pure overwrite (no append), char-limit packing, - atomic-line guarantee. - * purge_file_closets — wipes prior closets so a re-mine starts clean. - * The end-to-end rebuild: re-mining a file fully replaces its closets, - including when the prior run produced more numbered closets. - * search_memories closet-first path — returns chunk-level hits parsed - from `→drawer_ids` pointers, falls back when closets are empty, - respects max_distance. +Coverage map: + * mine_lock — acquire/release, blocks concurrent acquisition. + * build_closet_lines — pointer-line shape, header pickup, entity stoplist + (regression for "When/After/The"), real-name survival, fallback line. + * upsert_closet_lines — pure overwrite (regression for the append bug), + char-limit packing without splitting a line. + * purge_file_closets — scoped to source_file. + * Project-miner end-to-end rebuild — re-mining with fewer topics fully + purges leftover numbered closets from a larger prior run. + * _extract_drawer_ids_from_closet — pointer parsing + dedup. + * search_memories hybrid path — drawer query always the floor, + closets boost matching source_file, matched_via reflects both signals, + no whole-file glue, max_distance enforcement. + * Entity metadata — extracted, stoplist applied, registry cached by mtime. + * Real BM25 — real IDF over candidate corpus, hybrid rerank. + * Diary ingest — drawers + closets created, incremental skips, state + file lives outside the diary dir, wing-prefixed drawer IDs prevent + cross-diary collisions, force=True purges leftover closets. """ -from mempalace.miner import mine +import json +import multiprocessing +import os +import tempfile +import threading +import time + +import yaml + +from mempalace.miner import ( + _extract_entities_for_metadata, + _load_known_entities, + mine, +) from mempalace.palace import ( CLOSET_CHAR_LIMIT, build_closet_lines, get_closets_collection, + get_collection, + mine_lock, purge_file_closets, upsert_closet_lines, ) -from mempalace.searcher import _extract_drawer_ids_from_closet, search_memories +from mempalace.palace_graph import ( + create_tunnel, + delete_tunnel, + follow_tunnels, + list_tunnels, +) +from mempalace.searcher import ( + _bm25_scores, + _expand_with_neighbors, + _extract_drawer_ids_from_closet, + _hybrid_rank, + search_memories, +) + + +# ── mine_lock ──────────────────────────────────────────────────────────── + + +def _lock_worker(target: str, name: str, hold_seconds: float, log_path: str) -> None: + """Worker for multiprocessing-spawn concurrency test. Writes its + critical-section enter/exit timestamps to ``log_path`` so the test + can verify the sections did not overlap in time.""" + import time as _time + + from mempalace.palace import mine_lock as _mine_lock + + with _mine_lock(target): + t_enter = _time.time() + _time.sleep(hold_seconds) + t_exit = _time.time() + # Append atomically so concurrent writers don't stomp each other. + with open(log_path, "a") as f: + f.write(f"{name} {t_enter} {t_exit}\n") + f.flush() + + +class TestMineLock: + def test_lock_acquires_and_releases(self, tmp_path): + target = str(tmp_path / "lock_target.txt") + with mine_lock(target): + lock_dir = os.path.expanduser("~/.mempalace/locks") + assert os.path.isdir(lock_dir) + # Re-acquire after release should succeed instantly. + start = time.time() + with mine_lock(target): + pass + assert time.time() - start < 1.0 + + def test_lock_blocks_concurrent_access(self, tmp_path): + """The lock's contract is inter-*process* (multi-agent), not + inter-thread. Use multiprocessing so the test reflects the real + use case and is portable: on macOS/BSD ``fcntl.flock`` is + per-process, so two threads would both acquire — a thread-based + test would flake there even when the lock is correct. + + Verify mutual exclusion by the effect the critical section + actually has — each worker records its enter/exit timestamps + under the lock, and the test asserts the two intervals do not + overlap. This is robust to spawn-overhead timing, unlike + "second worker waited at least N seconds" which flakes when CI + spawn latency eats into the hold window. + """ + target = str(tmp_path / "concurrent_lock.txt") + log_path = str(tmp_path / "critical_section.log") + # Spawn so the same code path runs on every OS (macOS 3.8+ and + # Windows already default to spawn; Linux is fork by default). + ctx = multiprocessing.get_context("spawn") + + # Each worker holds the lock for HOLD seconds. With real mutual + # exclusion, the two [enter, exit] intervals must be disjoint. + HOLD = 0.3 + p1 = ctx.Process(target=_lock_worker, args=(target, "a", HOLD, log_path)) + p2 = ctx.Process(target=_lock_worker, args=(target, "b", HOLD, log_path)) + p1.start() + p2.start() + p1.join(timeout=30) + p2.join(timeout=30) + + assert p1.exitcode == 0, f"p1 exited non-zero: {p1.exitcode}" + assert p2.exitcode == 0, f"p2 exited non-zero: {p2.exitcode}" + + # Parse the log: " ". + intervals = [] + with open(log_path) as f: + for line in f: + parts = line.strip().split() + if len(parts) == 3: + intervals.append((parts[0], float(parts[1]), float(parts[2]))) + assert len(intervals) == 2, f"expected two critical sections, got {intervals}" + + # Sort by entry time and verify the second entry is after the first exit. + intervals.sort(key=lambda iv: iv[1]) + (_, enter_a, exit_a), (_, enter_b, exit_b) = intervals + assert ( + enter_a < exit_a <= enter_b < exit_b + ), f"critical sections overlapped — lock failed to serialize: {intervals}" # ── build_closet_lines ───────────────────────────────────────────────── class TestBuildClosetLines: - def test_emits_pointer_line_shape(self, tmp_path): + def test_emits_pointer_line_shape(self): content = ( "# Auth rewrite\n\n" "Decided we need to migrate to passkeys. " @@ -59,7 +178,7 @@ class TestBuildClosetLines: def test_entity_stoplist_filters_sentence_starters(self): # "When", "After", "The" repeat 3+ times — old code would index them - # as entities. New code's stoplist drops them. + # as entities. Stoplist drops them. content = ( "When the pipeline ran, the result was good. " "When the user logged in, the token was issued. " @@ -68,7 +187,6 @@ class TestBuildClosetLines: "The new flow is stable. The audit cleared." ) lines = build_closet_lines("/x.md", ["d1"], content, "w", "r") - # Entities sit between the two pipes entity_segments = [line.split("|")[1] for line in lines] for seg in entity_segments: tokens = set(seg.split(";")) if seg else set() @@ -83,13 +201,11 @@ class TestBuildClosetLines: "Igor and Milla shipped together." ) lines = build_closet_lines("/x.md", ["d1"], content, "w", "r") - entity_segments = [line.split("|")[1] for line in lines] - joined_entities = ";".join(entity_segments) + joined_entities = ";".join(line.split("|")[1] for line in lines) assert "Igor" in joined_entities assert "Milla" in joined_entities def test_emits_fallback_line_when_nothing_extractable(self): - # No headers, no action verbs, no quotes, no repeated capitalized words content = "lorem ipsum dolor sit amet consectetur adipiscing elit" lines = build_closet_lines("/x/notes.txt", ["d1"], content, "wing", "room") assert len(lines) == 1 @@ -111,11 +227,9 @@ class TestUpsertClosetLines: base = "closet_test_room_abc" meta = {"wing": "test", "room": "room", "source_file": "/x.md"} - # First mine — three short lines. upsert_closet_lines(col, base, ["alpha|;|→d1", "beta|;|→d2", "gamma|;|→d3"], meta) first = col.get(ids=[f"{base}_01"]) assert "alpha" in first["documents"][0] - assert "beta" in first["documents"][0] # Second mine — entirely different lines. Must replace, not append. upsert_closet_lines(col, base, ["delta|;|→d4", "epsilon|;|→d5"], meta) @@ -131,18 +245,15 @@ class TestUpsertClosetLines: base = "closet_pack_room_def" meta = {"wing": "test", "room": "room", "source_file": "/y.md"} - # Build lines that approach but never exceed the limit. line = "x" * 600 # well under CLOSET_CHAR_LIMIT n_written = upsert_closet_lines(col, base, [line, line, line, line], meta) - # 4 lines @ 600+1 chars = 2404 — should pack into 2 closets (≤1500 each) + # 4 lines @ 601 chars each = 2404 — should pack into 2 closets assert n_written == 2 for i in range(1, n_written + 1): doc = col.get(ids=[f"{base}_{i:02d}"])["documents"][0] - # Every line is intact (never split mid-line) for chunk in doc.split("\n"): assert len(chunk) == 600, f"line was truncated in closet {i}" - # Closet stays under the cap assert len(doc) <= CLOSET_CHAR_LIMIT @@ -161,19 +272,16 @@ class TestPurgeFileClosets: ], ) purge_file_closets(col, "/drop.md") - remaining_ids = set(col.get()["ids"]) assert "closet_a_01" in remaining_ids assert "closet_b_01" not in remaining_ids -# ── End-to-end rebuild via the project miner ────────────────────────── +# ── project miner: closet rebuild end-to-end ────────────────────────── class TestMinerClosetRebuild: def test_remine_replaces_closets_completely(self, tmp_path): - import yaml - project = tmp_path / "proj" project.mkdir() (project / "mempalace.yaml").write_text( @@ -193,16 +301,11 @@ class TestMinerClosetRebuild: first_ids = set(first_pass["ids"]) assert any("topic 0" in (d or "").lower() for d in first_pass["documents"]) - # Touch mtime so file_already_mined doesn't short-circuit, and - # rewrite with fewer topics (so the rebuild produces fewer closets - # than the first run). - import os - import time - + # Touch mtime + shrink content so the rebuild produces fewer closets. target.write_text("# Only Topic Now\n" + ("short body " * 5)) new_mtime = os.path.getmtime(target) + 60 os.utime(target, (new_mtime, new_mtime)) - time.sleep(0.01) # ensure mtime delta is visible + time.sleep(0.01) mine(str(project), str(palace), wing_override="proj", agent="test") @@ -231,19 +334,11 @@ class TestExtractDrawerIds: def test_parses_multiple_pointers_per_line(self): line = "topic|ent|→drawer_a,drawer_b,drawer_c" - assert _extract_drawer_ids_from_closet(line) == [ - "drawer_a", - "drawer_b", - "drawer_c", - ] + assert _extract_drawer_ids_from_closet(line) == ["drawer_a", "drawer_b", "drawer_c"] def test_dedupes_across_lines(self): doc = "one|;|→drawer_a,drawer_b\ntwo|;|→drawer_b,drawer_c" - assert _extract_drawer_ids_from_closet(doc) == [ - "drawer_a", - "drawer_b", - "drawer_c", - ] + assert _extract_drawer_ids_from_closet(doc) == ["drawer_a", "drawer_b", "drawer_c"] def test_empty_doc_returns_empty(self): assert _extract_drawer_ids_from_closet("") == [] @@ -253,64 +348,627 @@ class TestExtractDrawerIds: # ── search_memories closet-first path ──────────────────────────────── -class TestSearchMemoriesClosetFirst: - def test_falls_back_to_direct_when_no_closets(self, palace_path, seeded_collection): - # seeded_collection populates only mempalace_drawers, not closets. +class TestSearchMemoriesHybrid: + def test_pure_drawer_when_no_closets(self, palace_path, seeded_collection): + """Palaces without closets return results via direct drawer search — + every hit must advertise that the closet signal was absent.""" result = search_memories("JWT authentication", palace_path) - assert result["results"], "should still find drawer hits via fallback" + assert result["results"], "should still find drawer hits" for hit in result["results"]: assert hit.get("matched_via") == "drawer" + assert hit.get("closet_boost") == 0.0 + assert "closet_preview" not in hit - def test_closet_first_returns_chunk_level_hits(self, palace_path, seeded_collection): - # Build a closet that points at the JWT drawer specifically. + def test_closet_boost_marks_hit_as_drawer_plus_closet(self, palace_path, seeded_collection): + """When a closet agrees with direct search on source_file, the + matching drawer's ``matched_via`` switches to ``drawer+closet`` and + ``closet_preview`` exposes the hydrated index line.""" closets = get_closets_collection(palace_path) + # Seed the closet against the same source_file the drawer uses so + # the boost lookup keys align. closets.upsert( ids=["closet_proj_backend_aaa_01"], documents=["JWT auth tokens|;|→drawer_proj_backend_aaa"], - metadatas=[ - { - "wing": "project", - "room": "backend", - "source_file": "auth.py", - } - ], + metadatas=[{"wing": "project", "room": "backend", "source_file": "auth.py"}], ) result = search_memories("JWT authentication", palace_path) - assert result["results"], "closet-first search should hydrate the drawer" - top = result["results"][0] - assert top["matched_via"] == "closet" - # Must be the chunk-level drawer text, not a concatenation of every - # drawer in the file. + assert result["results"], "hybrid search should still return results" + # The JWT-bearing drawer should surface with closet agreement. + boosted = [h for h in result["results"] if h["matched_via"] == "drawer+closet"] + assert boosted, "closet agreement should promote the matching source" + top = boosted[0] assert "JWT" in top["text"] - assert ( - "Database migrations" not in top["text"] - ), "closet path should not glue unrelated drawers together" - assert "closet_preview" in top + assert top["closet_boost"] > 0 assert "→drawer_proj_backend_aaa" in top["closet_preview"] - def test_max_distance_filters_closet_hits(self, palace_path, seeded_collection): + def test_max_distance_filters_hybrid_hits(self, palace_path, seeded_collection): closets = get_closets_collection(palace_path) closets.upsert( ids=["closet_proj_backend_aaa_01"], documents=["JWT auth tokens|;|→drawer_proj_backend_aaa"], - metadatas=[ - { - "wing": "project", - "room": "backend", - "source_file": "auth.py", - } - ], + metadatas=[{"wing": "project", "room": "backend", "source_file": "auth.py"}], ) - - # max_distance=0.001 is essentially "must match exactly". The closet - # path should reject everything and the caller falls back to direct - # search (which also filters with the same threshold). result = search_memories( "completely unrelated query about quantum gardening", palace_path, max_distance=0.001, ) - # Either no results, or every result respected the threshold. for hit in result["results"]: assert hit["distance"] <= 0.001 + + +# ── entity metadata ────────────────────────────────────────────────── + + +class TestEntityMetadata: + def test_extracts_capitalized_names(self): + text = "Ben reviewed the code. Ben approved it. Igor flagged two issues. Igor fixed them." + entities = _extract_entities_for_metadata(text) + assert "Ben" in entities + assert "Igor" in entities + + def test_empty_for_no_entities(self): + text = "this is all lowercase with no proper nouns at all" + assert _extract_entities_for_metadata(text) == "" + + def test_semicolon_separated(self): + text = "Alice and Bob met Charlie. Alice said hello. Bob agreed. Charlie laughed." + entities = _extract_entities_for_metadata(text) + assert ";" in entities + + def test_stoplist_filters_sentence_starters(self): + # Same regression as the closet entity test — "When/After/The" must + # not become entities just because they're capitalized 2+ times. + text = ( + "When the build broke, the team paged. " + "When the fix landed, the alarm cleared. " + "After the rollback, the queue drained. " + "After the deploy, the latency normalized." + ) + entities = _extract_entities_for_metadata(text) + tokens = set(entities.split(";")) if entities else set() + assert "When" not in tokens + assert "After" not in tokens + assert "The" not in tokens + + def test_capped_list_never_truncates_a_name(self): + # 30 distinct repeated proper nouns — extraction should cap the list + # before joining so a name never gets cut in half. + # Use morphologically distinct stems so the [A-Z][a-z]+ regex sees + # each as its own token. + names = [ + "Anna", + "Brian", + "Carol", + "David", + "Elena", + "Frank", + "Grace", + "Harold", + "Iris", + "Julian", + "Kira", + "Liam", + "Maya", + "Noah", + "Oscar", + "Penny", + "Quinn", + "Rosa", + "Sergei", + "Tara", + "Umar", + "Vera", + "Walter", + "Xander", + "Yvonne", + "Zachary", + "Amelia", + "Boris", + "Clara", + "Dmitri", + ] + text = " ".join(f"{n} met {n}." for n in names) + entities = _extract_entities_for_metadata(text) + extracted = [n for n in entities.split(";") if n] + assert extracted, "should have extracted some entities" + for name in extracted: + assert name in names, f"truncation produced a partial token: {name!r}" + + def test_known_registry_is_cached_by_mtime(self, monkeypatch, tmp_path): + # Point the registry at a temp file we control, exercise the cache. + registry = tmp_path / "known_entities.json" + registry.write_text(json.dumps({"people": ["Zelda"]})) + from mempalace import miner + + monkeypatch.setattr(miner, "_ENTITY_REGISTRY_PATH", str(registry)) + miner._ENTITY_REGISTRY_CACHE["mtime"] = None + miner._ENTITY_REGISTRY_CACHE["names"] = frozenset() + + first = _load_known_entities() + assert "Zelda" in first + + # Second call without changing mtime: must reuse cache, not re-read. + read_count = {"n": 0} + original_open = open + + def counting_open(path, *a, **kw): + if str(path) == str(registry): + read_count["n"] += 1 + return original_open(path, *a, **kw) + + monkeypatch.setattr("builtins.open", counting_open) + _load_known_entities() + assert read_count["n"] == 0, "registry should not be re-read when mtime unchanged" + + # Bump mtime → cache must invalidate. + new_mtime = os.path.getmtime(registry) + 5 + os.utime(registry, (new_mtime, new_mtime)) + registry.write_text(json.dumps({"people": ["Zelda", "Link"]})) + os.utime(registry, (new_mtime, new_mtime)) + names = _load_known_entities() + assert "Link" in names + + +# ── BM25 hybrid search (real IDF over candidate corpus) ────────────── + + +class TestBM25: + def test_scores_positive_for_matching_doc(self): + scores = _bm25_scores( + "database migration", + ["We migrated the database to Postgres.", "unrelated cookery tips"], + ) + assert scores[0] > 0 + assert scores[1] == 0.0 + + def test_scores_zero_when_no_overlap(self): + scores = _bm25_scores("quantum physics", ["We built a web app in React"]) + assert scores == [0.0] + + def test_idf_downweights_terms_present_in_every_doc(self): + # "database" appears in every candidate → low IDF → low contribution. + # "vacuum" is unique to one → high IDF → that doc dominates. + scores = _bm25_scores( + "database vacuum", + [ + "database backup nightly schedule", + "database vacuum scheduled weekly", + "database failover plan", + ], + ) + assert scores[1] == max(scores), "doc with the rare query term should win on IDF" + + def test_empty_inputs_return_zeros(self): + assert _bm25_scores("", ["hello world"]) == [0.0] + assert _bm25_scores("query here", []) == [] + assert _bm25_scores("query", [""]) == [0.0] + + def test_hybrid_rank_promotes_keyword_match(self): + results = [ + {"text": "database schema design for Postgres", "distance": 0.5}, + {"text": "unrelated topic about cooking", "distance": 0.3}, + ] + ranked = _hybrid_rank(results, "database Postgres schema") + # The keyword-rich result outranks the closer-vector but irrelevant one. + assert "database" in ranked[0]["text"] + # bm25_score field is exposed for debugging. + assert "bm25_score" in ranked[0] + # No internal scoring leak. + assert "_hybrid_score" not in ranked[0] + + def test_hybrid_rank_absolute_normalization(self): + # Adding a much-worse result to the candidate set must NOT reshuffle + # the top two — proves we're using absolute (1 - dist) and not + # dist / max_dist normalization. + base = [ + {"text": "alpha alpha alpha", "distance": 0.1}, + {"text": "beta beta beta", "distance": 0.4}, + ] + ranked_short = _hybrid_rank([dict(r) for r in base], "alpha") + with_outlier = base + [{"text": "gamma gamma gamma", "distance": 1.9}] + ranked_long = _hybrid_rank([dict(r) for r in with_outlier], "alpha") + assert ranked_short[0]["text"] == ranked_long[0]["text"] + assert ranked_short[1]["text"] == ranked_long[1]["text"] + + +# ── diary ingest ───────────────────────────────────────────────────── + + +class TestDiaryIngest: + def test_ingest_creates_drawers_and_closets(self, tmp_path): + diary_dir = tmp_path / "diaries" + diary_dir.mkdir() + (diary_dir / "2026-04-13.md").write_text( + "# 2026-04-13\n\n## 10:00 PDT — Test\n\nBuilt the auth system.\n" + ) + palace_dir = tmp_path / "palace" + + from mempalace.diary_ingest import ingest_diaries + + result = ingest_diaries(str(diary_dir), str(palace_dir), force=True) + assert result["days_updated"] >= 1 + assert get_collection(str(palace_dir)).count() >= 1 + + def test_ingest_skips_unchanged_on_second_run(self, tmp_path): + diary_dir = tmp_path / "diaries" + diary_dir.mkdir() + (diary_dir / "2026-04-13.md").write_text( + "# 2026-04-13\n\n## 10:00 — Test\n\nContent here that's long enough.\n" + ) + palace_dir = tmp_path / "palace" + + from mempalace.diary_ingest import ingest_diaries + + ingest_diaries(str(diary_dir), str(palace_dir), force=True) + result = ingest_diaries(str(diary_dir), str(palace_dir)) + assert result["days_updated"] == 0 + + def test_state_file_lives_outside_diary_dir(self, tmp_path): + # Regression: the original implementation wrote + # ``.diary_ingest_state.json`` *inside* the user's diary directory, + # polluting their content folder. State must live under + # ``~/.mempalace/state/`` instead. + diary_dir = tmp_path / "diaries" + diary_dir.mkdir() + (diary_dir / "2026-04-13.md").write_text( + "# 2026-04-13\n\n## 10:00 — Test\n\nBody content here long enough.\n" + ) + palace_dir = tmp_path / "palace" + + from mempalace.diary_ingest import _state_file_for, ingest_diaries + + ingest_diaries(str(diary_dir), str(palace_dir), force=True) + + # No state file inside the user's diary dir. + for entry in diary_dir.iterdir(): + assert ( + "diary_ingest" not in entry.name + ), f"state file leaked into user diary dir: {entry}" + + # State file does exist under ~/.mempalace/state/. + state_path = _state_file_for(str(palace_dir), diary_dir.resolve()) + assert state_path.exists() + # Platform-neutral path check: compare parents rather than a hardcoded + # separator string that would fail on Windows (``\.mempalace\state\``). + assert state_path.parent.name == "state" + assert state_path.parent.parent.name == ".mempalace" + + def test_wing_prefixed_drawer_id_prevents_cross_diary_collision(self, tmp_path): + # Regression: the original implementation used + # ``drawer_diary_{date_str}`` regardless of wing — two diaries with + # the same date in different wings would clobber each other. + date_md = "# 2026-04-13\n\n## 10:00 — entry\n\nThis is the day's content.\n" + + # Two separate diary dirs, ingested into the same palace under + # different wings. Each must produce a distinct drawer. + personal_dir = tmp_path / "personal" + personal_dir.mkdir() + (personal_dir / "2026-04-13.md").write_text(date_md + "Personal-only marker.\n") + + work_dir = tmp_path / "work" + work_dir.mkdir() + (work_dir / "2026-04-13.md").write_text(date_md + "Work-only marker.\n") + + palace_dir = tmp_path / "palace" + + from mempalace.diary_ingest import _diary_drawer_id, ingest_diaries + + ingest_diaries(str(personal_dir), str(palace_dir), wing="personal", force=True) + ingest_diaries(str(work_dir), str(palace_dir), wing="work", force=True) + + col = get_collection(str(palace_dir)) + personal_id = _diary_drawer_id("personal", "2026-04-13") + work_id = _diary_drawer_id("work", "2026-04-13") + assert personal_id != work_id + + personal = col.get(ids=[personal_id]) + work = col.get(ids=[work_id]) + assert personal["ids"] == [personal_id] + assert work["ids"] == [work_id] + assert "Personal-only marker." in personal["documents"][0] + assert "Work-only marker." in work["documents"][0] + + +# ── cross-wing tunnels ─────────────────────────────────────────────── + + +class TestTunnels: + """Tunnels are explicit cross-wing connections stored in + ``~/.mempalace/tunnels.json``. Each test points the module-level + ``_TUNNEL_FILE`` at a fresh tmp file so tests don't cross-contaminate + or touch the user's real tunnels.""" + + def setup_method(self): + import mempalace.palace_graph as pg + + self._orig = pg._TUNNEL_FILE + self._tmpdir = tempfile.mkdtemp() + pg._TUNNEL_FILE = os.path.join(self._tmpdir, "tunnels.json") + + def teardown_method(self): + import mempalace.palace_graph as pg + + pg._TUNNEL_FILE = self._orig + import shutil + + shutil.rmtree(self._tmpdir, ignore_errors=True) + + def test_create_tunnel(self): + t = create_tunnel("wing_api", "auth", "wing_db", "users", label="auth uses users table") + assert t["id"] + assert t["source"]["wing"] == "wing_api" + assert t["source"]["room"] == "auth" + assert t["target"]["wing"] == "wing_db" + assert t["target"]["room"] == "users" + assert t["label"] == "auth uses users table" + + def test_list_tunnels_with_and_without_filter(self): + create_tunnel("wing_a", "room1", "wing_b", "room2") + create_tunnel("wing_a", "room3", "wing_c", "room4") + assert len(list_tunnels()) == 2 + # Filtering by a wing that appears on either endpoint. + assert len(list_tunnels("wing_a")) == 2 + assert len(list_tunnels("wing_c")) == 1 + assert len(list_tunnels("wing_nonexistent")) == 0 + + def test_delete_tunnel(self): + t = create_tunnel("wing_x", "r1", "wing_y", "r2") + delete_tunnel(t["id"]) + assert list_tunnels() == [] + + def test_dedup_same_endpoints_updates_label(self): + create_tunnel("wing_a", "r1", "wing_b", "r2", label="first") + create_tunnel("wing_a", "r1", "wing_b", "r2", label="updated") + tunnels = list_tunnels() + assert len(tunnels) == 1 + assert tunnels[0]["label"] == "updated" + + def test_follow_tunnels_returns_connected_endpoints(self): + create_tunnel("wing_api", "auth", "wing_db", "users") + create_tunnel("wing_api", "auth", "wing_frontend", "login") + # Unrelated tunnel that must not surface. + create_tunnel("wing_other", "notes", "wing_misc", "scratch") + + connections = follow_tunnels("wing_api", "auth") + assert len(connections) == 2 + wings = {c["connected_wing"] for c in connections} + assert wings == {"wing_db", "wing_frontend"} + + # ── regression: symmetry, durability, validation, concurrency ───── + + def test_tunnel_is_symmetric(self): + """Regression: tunnels are undirected. create(A, B) and create(B, A) + must resolve to the same canonical ID and dedupe into one record — + the second call updates the label instead of creating a dupe.""" + first = create_tunnel("wing_a", "r1", "wing_b", "r2", label="forward") + second = create_tunnel("wing_b", "r2", "wing_a", "r1", label="reversed") + assert first["id"] == second["id"] + assert len(list_tunnels()) == 1 + assert list_tunnels()[0]["label"] == "reversed" + + def test_follow_tunnels_works_from_either_endpoint(self): + """Symmetric: you can follow_tunnels from either end of the link.""" + create_tunnel("wing_api", "auth", "wing_db", "users", label="auth uses users") + from_source = follow_tunnels("wing_api", "auth") + from_target = follow_tunnels("wing_db", "users") + assert len(from_source) == 1 + assert len(from_target) == 1 + assert from_source[0]["connected_wing"] == "wing_db" + assert from_target[0]["connected_wing"] == "wing_api" + # Both surfaces should carry the same label. + assert from_source[0]["label"] == "auth uses users" + assert from_target[0]["label"] == "auth uses users" + + def test_empty_endpoint_fields_rejected(self): + """Regression: create_tunnel must reject empty strings on any + endpoint field so the JSON store can't grow phantom tunnels.""" + import pytest + + for args in [ + ("", "r1", "wing", "r2"), + ("wing", "", "wing", "r2"), + ("wing", "r1", "", "r2"), + ("wing", "r1", "wing", ""), + (" ", "r1", "wing", "r2"), # whitespace-only also rejected + ]: + with pytest.raises(ValueError): + create_tunnel(*args) + + def test_corrupt_tunnel_file_does_not_lose_new_writes(self): + """A truncated/corrupt tunnels.json (crash mid-write on a system + without atomic rename) must not leak into subsequent reads — the + file should be treated as empty and a fresh create_tunnel should + persist cleanly.""" + import mempalace.palace_graph as pg + + # Simulate a crash that left a truncated file behind. + with open(pg._TUNNEL_FILE, "w") as f: + f.write("{not valid json") + + # Load should return [] rather than raising. + assert list_tunnels() == [] + + # A subsequent create must persist (atomic write replaces the corrupt file). + t = create_tunnel("wing_a", "r1", "wing_b", "r2") + assert list_tunnels() == [t] + + def test_atomic_write_leaves_no_stray_tmp_file(self): + """Regression: _save_tunnels uses write-then-os.replace. After a + successful create, there must be no leftover ``tunnels.json.tmp``.""" + import mempalace.palace_graph as pg + + create_tunnel("wing_a", "r1", "wing_b", "r2") + assert os.path.exists(pg._TUNNEL_FILE) + assert not os.path.exists(pg._TUNNEL_FILE + ".tmp") + + def test_concurrent_creates_preserve_all_tunnels(self): + """Regression: two concurrent create_tunnel calls must not clobber + each other. Without the mine_lock around load+save, the later + writer's snapshot would overwrite the earlier writer's tunnel.""" + barrier = threading.Barrier(5) + errors: list = [] + + def worker(i): + try: + barrier.wait(timeout=2) + create_tunnel(f"wing_{i}", "r", "wing_shared", "hub") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors, f"worker raised: {errors}" + tunnels = list_tunnels() + assert len(tunnels) == 5, ( + f"expected 5 concurrent tunnels, got {len(tunnels)} — " "write race dropped some" + ) + + def test_created_at_is_timezone_aware(self): + """Regression: created_at must be tz-aware UTC, not naive.""" + t = create_tunnel("wing_a", "r1", "wing_b", "r2") + # ISO format with tz offset contains '+' or 'Z'. + assert t["created_at"].endswith("+00:00") or t["created_at"].endswith("Z") + + +# ── drawer-grep neighbor expansion ──────────────────────────────────── +# +# When a closet hit lands on a drawer whose chunk boundary clips a thought +# (matched chunk says "here's a breakdown:" and the breakdown lives in the +# next chunk), the closet path now expands to ±1 neighbor chunks from the +# same source file. These tests pin that behavior end-to-end and at the +# helper level. + + +class TestDrawerGrepExpansion: + def _seed_source_file(self, palace_path, source: str, n_chunks: int): + """Helper: put N sequential drawers for a single source file into + the palace and return the drawer IDs keyed by chunk_index.""" + col = get_collection(palace_path) + ids = [f"drawer_test_room_{source.replace('/', '_')}_{i:03d}" for i in range(n_chunks)] + docs = [f"chunk_{i} content about topic alpha" for i in range(n_chunks)] + metas = [ + { + "wing": "test", + "room": "room", + "source_file": source, + "chunk_index": i, + "filed_at": "2026-04-13T00:00:00", + } + for i in range(n_chunks) + ] + col.upsert(ids=ids, documents=docs, metadatas=metas) + return col, {i: ids[i] for i in range(n_chunks)} + + def test_expand_returns_matched_plus_neighbors(self, palace_path): + col, by_idx = self._seed_source_file(palace_path, "/proj/doc.md", n_chunks=5) + matched_meta = {"source_file": "/proj/doc.md", "chunk_index": 2} + matched_doc = "chunk_2 content about topic alpha" + + out = _expand_with_neighbors(col, matched_doc, matched_meta, radius=1) + assert out["drawer_index"] == 2 + assert out["total_drawers"] == 5 + # Expect chunks 1, 2, 3 joined in chunk_index order. + text = out["text"] + assert "chunk_1" in text + assert "chunk_2" in text + assert "chunk_3" in text + # No leakage of non-neighbors. + assert "chunk_0" not in text + assert "chunk_4" not in text + # Ordering preserved — chunk_1 before chunk_2 before chunk_3. + assert text.index("chunk_1") < text.index("chunk_2") < text.index("chunk_3") + + def test_expand_at_start_of_file_only_has_next_neighbor(self, palace_path): + col, _ = self._seed_source_file(palace_path, "/proj/edge_start.md", n_chunks=3) + out = _expand_with_neighbors( + col, + "chunk_0 content", + {"source_file": "/proj/edge_start.md", "chunk_index": 0}, + ) + assert out["drawer_index"] == 0 + assert out["total_drawers"] == 3 + assert "chunk_0" in out["text"] + assert "chunk_1" in out["text"] + # No chunk_-1 could exist; the expansion must not invent one. + assert "chunk_-1" not in out["text"] + + def test_expand_at_end_of_file_only_has_prev_neighbor(self, palace_path): + col, _ = self._seed_source_file(palace_path, "/proj/edge_end.md", n_chunks=3) + out = _expand_with_neighbors( + col, + "chunk_2 content", + {"source_file": "/proj/edge_end.md", "chunk_index": 2}, + ) + assert out["drawer_index"] == 2 + assert out["total_drawers"] == 3 + assert "chunk_1" in out["text"] + assert "chunk_2" in out["text"] + # No chunk_3 exists. + assert "chunk_3" not in out["text"] + + def test_expand_single_drawer_file_returns_just_matched(self, palace_path): + col, _ = self._seed_source_file(palace_path, "/proj/lone.md", n_chunks=1) + out = _expand_with_neighbors( + col, + "chunk_0 content", + {"source_file": "/proj/lone.md", "chunk_index": 0}, + ) + assert out["drawer_index"] == 0 + assert out["total_drawers"] == 1 + assert out["text"] == "chunk_0 content about topic alpha" + + def test_expand_falls_back_when_metadata_missing(self, palace_path): + col = get_collection(palace_path) + # No source_file / chunk_index in meta — degrade gracefully. + out = _expand_with_neighbors(col, "matched doc", {}) + assert out["text"] == "matched doc" + assert out["drawer_index"] is None + assert out["total_drawers"] is None + + def test_hybrid_search_enrichment_populates_drawer_index_and_total(self, palace_path): + """End-to-end: when a closet boosts a source with many drawers, the + enrichment step runs drawer-grep across all chunks of that source + and exposes drawer_index + total_drawers on the hit (so the client + knows which chunk was expanded around).""" + col = get_collection(palace_path) + source = "/proj/indexed.md" + # Seed 5 drawers for one source file. + for i in range(5): + col.upsert( + ids=[f"drawer_proj_backend_indexed_{i:03d}"], + documents=[f"chunk_{i} talks about JWT authentication flow"], + metadatas=[ + { + "wing": "project", + "room": "backend", + "source_file": source, + "chunk_index": i, + "filed_at": "2026-04-13T00:00:00", + } + ], + ) + # Closet pointing at chunk_2 for this source. + closets = get_closets_collection(palace_path) + closets.upsert( + ids=["closet_proj_backend_indexed_01"], + documents=["JWT auth|;|→drawer_proj_backend_indexed_002"], + metadatas=[{"wing": "project", "room": "backend", "source_file": source}], + ) + + result = search_memories("JWT authentication", palace_path) + assert result["results"] + # The hybrid path promotes the closet-agreeing source to drawer+closet. + boosted = [h for h in result["results"] if h["matched_via"] == "drawer+closet"] + assert boosted, "hybrid search should mark the closet-agreeing source" + top = boosted[0] + assert top["total_drawers"] == 5 + assert isinstance(top["drawer_index"], int) + # Enriched text must include the grep-best chunk plus one neighbor + # on each side (chunk boundary may clip). + assert "chunk_" in top["text"] diff --git a/tests/test_fact_checker.py b/tests/test_fact_checker.py new file mode 100644 index 0000000..5b34a40 --- /dev/null +++ b/tests/test_fact_checker.py @@ -0,0 +1,288 @@ +""" +test_fact_checker.py — Regression + integration tests for fact_checker. + +Covers every detection path + the three bugs the original PR silently +hid behind ``except Exception: pass``: + + * ``kg.query()`` doesn't exist — code must use ``query_entity``. + * ``KnowledgeGraph(palace_path=...)`` is not a valid kwarg — code + must pass ``db_path``. + * O(n²) edit-distance over the full registry — must filter to names + actually mentioned in the text. + +Also pins the three feature contracts: + * similar_name — "Mila" vs "Milla" in a registry with both. + * relationship_mismatch — "Bob is Alice's brother" vs KG "husband". + * stale_fact — claim matches a triple whose valid_to is in the past. +""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from mempalace.fact_checker import ( + _check_entity_confusion, + _edit_distance, + _extract_claims, + _flatten_names, + check_text, +) +from mempalace.knowledge_graph import KnowledgeGraph + + +# ── claim extraction ───────────────────────────────────────────────── + + +class TestExtractClaims: + def test_parses_x_is_ys_z(self): + claims = _extract_claims("Bob is Alice's brother") + assert len(claims) == 1 + assert claims[0] == { + "subject": "Bob", + "predicate": "brother", + "object": "Alice", + "span": "Bob is Alice's brother", + } + + def test_parses_xs_z_is_y(self): + claims = _extract_claims("Alice's brother is Bob") + assert len(claims) == 1 + assert claims[0]["subject"] == "Bob" + assert claims[0]["predicate"] == "brother" + assert claims[0]["object"] == "Alice" + + def test_ignores_sentences_without_possessive_role(self): + assert _extract_claims("Bob drove to the store today") == [] + assert _extract_claims("Just some prose without relationships") == [] + + def test_multiple_claims_in_one_text(self): + claims = _extract_claims("Bob is Alice's brother. Carol is Dave's sister.") + subjects = {c["subject"] for c in claims} + assert subjects == {"Bob", "Carol"} + + +# ── entity confusion ───────────────────────────────────────────────── + + +class TestEntityConfusion: + def test_flags_near_name_when_only_one_mentioned(self): + registry = {"people": ["Milla", "Mila"]} + issues = _check_entity_confusion("I spoke with Mila today.", registry) + # "Mila" mentioned, "Milla" not — registry has both at edit-distance 1, + # flag the possible confusion. + assert len(issues) == 1 + assert issues[0]["type"] == "similar_name" + assert set(issues[0]["names"]) == {"Mila", "Milla"} + assert issues[0]["distance"] == 1 + + def test_no_false_positive_when_both_names_mentioned(self): + """Regression: a text discussing both Mila and Milla is fine — + the user clearly knows they're different. Don't nag.""" + registry = {"people": ["Milla", "Mila"]} + issues = _check_entity_confusion("Mila and Milla met for lunch.", registry) + assert issues == [] + + def test_no_issues_when_registry_empty(self): + assert _check_entity_confusion("Bob said hi", {}) == [] + assert _check_entity_confusion("Bob said hi", {"people": []}) == [] + + def test_no_issues_when_no_mentioned_names(self): + registry = {"people": ["Zelda", "Link", "Sheik"]} + assert _check_entity_confusion("nothing relevant here", registry) == [] + + def test_registry_dict_shape_is_supported(self): + # Some registries store {"people": {"Alice": {...meta}}}; we still + # need to surface the keys as candidate names. + registry = {"people": {"Milla": {"role": "creator"}, "Mila": {}}} + issues = _check_entity_confusion("I messaged Mila yesterday", registry) + assert any("Milla" in (i["names"] or []) for i in issues) + + +class TestEditDistance: + def test_basic_distances(self): + assert _edit_distance("kitten", "sitting") == 3 + assert _edit_distance("mila", "milla") == 1 + assert _edit_distance("abc", "abc") == 0 + + def test_empty_strings(self): + assert _edit_distance("", "") == 0 + assert _edit_distance("abc", "") == 3 + assert _edit_distance("", "abc") == 3 + + def test_performance_bounded_by_mentioned_names(self): + """Regression: an earlier implementation did O(n²) pairwise + edit-distance over every registry entry on every check_text call. + With 100 names and zero mentions, the call must return in a blink + because no edit-distance comparison should even start.""" + import time + + # 500 random names, none of which appear in the text. + registry = {"people": [f"Zelda{i:03d}" for i in range(500)]} + text = "completely irrelevant prose with no registered names at all" + + start = time.perf_counter() + issues = _check_entity_confusion(text, registry) + elapsed = time.perf_counter() - start + + assert issues == [] + # Even an unoptimized implementation should beat this by orders + # of magnitude once we've filtered to mentioned names (which is + # 0 here) — if it's still doing O(n²), we'll blow past. + assert elapsed < 0.2, f"entity confusion took {elapsed:.3f}s on empty mentions" + + +# ── _flatten_names helper ──────────────────────────────────────────── + + +class TestFlattenNames: + def test_handles_list_categories(self): + assert _flatten_names({"people": ["Ada", "Bob"]}) == {"Ada", "Bob"} + + def test_handles_dict_categories(self): + assert _flatten_names({"people": {"Ada": {}, "Bob": {}}}) == {"Ada", "Bob"} + + def test_skips_falsy_entries(self): + assert _flatten_names({"people": ["Ada", "", None, "Bob"]}) == {"Ada", "Bob"} + + +# ── KG integration (uses a real tmp SQLite palace) ─────────────────── + + +@pytest.fixture +def palace_with_kg(tmp_path): + """Palace directory with a real KG pre-seeded with a few triples. + + The KG file lives at ``/knowledge_graph.sqlite3`` — same + convention used by the MCP server. Fact-checker must find it via + that path, not via a bogus ``palace_path`` kwarg. + """ + palace = tmp_path / "palace" + palace.mkdir() + db = str(palace / "knowledge_graph.sqlite3") + kg = KnowledgeGraph(db_path=db) + yield palace, kg + + +class TestKGContradictions: + def test_kg_init_uses_db_path_not_palace_path_kwarg(self): + """Regression: the original code passed ``palace_path=`` to a + constructor whose only kwarg is ``db_path``. That raised + TypeError — silently swallowed — and the KG path became dead + code. This test pins the correct call signature.""" + # Simply construct via the correct signature; raising means the + # KG constructor has changed in a way that fact_checker must too. + kg = KnowledgeGraph(db_path=":memory:") + # query_entity must exist (this is the method fact_checker calls). + assert callable(getattr(kg, "query_entity", None)) + # The API that fact_checker used to call does NOT exist. + assert not hasattr(kg, "query") + + def test_relationship_mismatch_detected(self, palace_with_kg): + """The feature's headline example: text says brother, KG says husband.""" + palace, kg = palace_with_kg + kg.add_triple("Bob", "husband_of", "Alice", valid_from="2020-01-01") + + issues = check_text("Bob is Alice's husband_of", str(palace)) + # Exact-predicate + same object → no mismatch. + assert all(i["type"] != "relationship_mismatch" for i in issues) + + issues = check_text("Bob is Alice's brother", str(palace)) + mismatches = [i for i in issues if i["type"] == "relationship_mismatch"] + assert mismatches, "should flag text/KG mismatch for same (subject, object)" + m = mismatches[0] + assert m["entity"] == "Bob" + assert m["claim"]["predicate"] == "brother" + assert m["kg_fact"]["predicate"] == "husband_of" + + def test_no_false_positive_when_kg_has_no_facts_about_subject(self, palace_with_kg): + palace, _ = palace_with_kg + # KG is empty → no mismatch should fire. + assert check_text("Bob is Alice's brother", str(palace)) == [] + + def test_stale_fact_detected(self, palace_with_kg): + palace, kg = palace_with_kg + # An old relationship that was superseded in 2023. Using a + # possessive-shape claim so the narrow claim-extraction regex + # actually reaches the stale-fact branch. + kg.add_triple( + "Bob", + "brother", + "Alice", + valid_from="2010-01-01", + valid_to="2023-06-01", + ) + issues = check_text("Bob is Alice's brother", str(palace)) + stale = [i for i in issues if i["type"] == "stale_fact"] + assert stale, "should flag closed-window fact as stale" + assert stale[0]["entity"] == "Bob" + assert stale[0]["valid_to"].startswith("2023") + + def test_current_fact_same_triple_is_not_flagged(self, palace_with_kg): + palace, kg = palace_with_kg + kg.add_triple("Bob", "brother", "Alice", valid_from="2010-01-01") + issues = check_text("Bob is Alice's brother", str(palace)) + assert issues == [] + + def test_missing_palace_does_not_crash(self, tmp_path): + """Brand-new palace (no KG file yet) — check_text must return [] + rather than raising or hanging.""" + nonexistent = str(tmp_path / "never_created") + assert check_text("Bob is Alice's brother", nonexistent) == [] + + +# ── end-to-end check_text contract ─────────────────────────────────── + + +class TestCheckTextContract: + def test_empty_text_returns_empty_list(self, tmp_path): + assert check_text("", str(tmp_path / "palace")) == [] + + def test_registry_confusion_path_isolated_from_kg(self, tmp_path, monkeypatch): + """If the registry file is present but the KG is missing, the + similar-name path must still fire. Prior implementations had + such entangled state that one failure killed both paths.""" + # Bypass the real registry by pointing cache at a temp file. + registry = tmp_path / "known_entities.json" + registry.write_text(json.dumps({"people": ["Milla", "Mila"]})) + from mempalace import miner + + monkeypatch.setattr(miner, "_ENTITY_REGISTRY_PATH", str(registry)) + miner._ENTITY_REGISTRY_CACHE.update({"mtime": None, "names": frozenset(), "raw": {}}) + + issues = check_text("Chatted with Mila.", str(tmp_path / "nonexistent_palace")) + assert any(i["type"] == "similar_name" for i in issues) + + +# ── CLI ────────────────────────────────────────────────────────────── + + +class TestCLI: + def test_exits_nonzero_when_issues_found(self, tmp_path, monkeypatch, capsys): + """The CLI exit code is how shell scripts / hooks know to act — + pin it explicitly.""" + registry = tmp_path / "known_entities.json" + registry.write_text(json.dumps({"people": ["Milla", "Mila"]})) + from mempalace import fact_checker, miner + + monkeypatch.setattr(miner, "_ENTITY_REGISTRY_PATH", str(registry)) + miner._ENTITY_REGISTRY_CACHE.update({"mtime": None, "names": frozenset(), "raw": {}}) + + # Simulate argv: "Mila said hi" + monkeypatch.setattr( + "sys.argv", + ["fact_checker", "Mila said hi", "--palace", str(tmp_path / "palace")], + ) + with pytest.raises(SystemExit) as excinfo: + # Re-exec the __main__ block via runpy. + import runpy + + runpy.run_module("mempalace.fact_checker", run_name="__main__") + # Issues found → exit code 1. + assert excinfo.value.code == 1 + out = capsys.readouterr().out + assert "similar_name" in out + # Silence unused import warning. + _ = (MagicMock, patch, fact_checker) diff --git a/tests/test_hybrid_search.py b/tests/test_hybrid_search.py new file mode 100644 index 0000000..793216a --- /dev/null +++ b/tests/test_hybrid_search.py @@ -0,0 +1,133 @@ +"""Tests for the hybrid closet+drawer retrieval in search_memories. + +The hybrid path queries drawers directly (the floor) AND closets, applying a +rank-based boost to drawers whose source_file appears in top closet hits. +This avoids the "weak-closets regression" where low-signal closets (from +regex extraction on narrative content) could hide drawers that direct +search would have found. +""" + +from mempalace.palace import ( + get_closets_collection, + get_collection, + upsert_closet_lines, +) +from mempalace.searcher import search_memories + + +def _seed_drawers(palace_path): + """Insert 4 short drawers with deterministic content.""" + col = get_collection(palace_path, create=True) + col.upsert( + ids=["D1", "D2", "D3", "D4"], + documents=[ + "We switched the auth service to use JWT tokens with a 24h expiry.", + "Database migration to PostgreSQL 15 completed last Tuesday.", + "The frontend team is debating whether to adopt TanStack Query.", + "Kafka consumer rebalance timeout set to 45 seconds after incident.", + ], + metadatas=[ + {"wing": "backend", "room": "auth", "source_file": "fixture_D1.md"}, + {"wing": "backend", "room": "db", "source_file": "fixture_D2.md"}, + {"wing": "frontend", "room": "state", "source_file": "fixture_D3.md"}, + {"wing": "backend", "room": "queue", "source_file": "fixture_D4.md"}, + ], + ) + + +def _seed_strong_closet_for(palace_path, drawer_id, source_file, topics): + """Insert a closet whose content strongly overlaps the query keywords.""" + col = get_closets_collection(palace_path) + lines = [f"{t}||→{drawer_id}" for t in topics] + upsert_closet_lines( + col, + closet_id_base=f"closet_{drawer_id}", + lines=lines, + metadata={ + "wing": "backend", + "room": "auth", + "source_file": source_file, + "generated_by": "test", + }, + ) + + +# ── core invariant: closets can only HELP, never HIDE ───────────────────── + + +class TestHybridInvariant: + def test_no_closets_degrades_to_direct_drawer_search(self, tmp_path): + palace = str(tmp_path / "palace") + _seed_drawers(palace) + # No closets created. + result = search_memories("Kafka rebalance timeout", palace, n_results=3) + ids = [h["source_file"] for h in result["results"]] + assert ids, "should return results" + assert "fixture_D4.md" in ids, "direct drawer search alone should surface the Kafka drawer" + + def test_weak_closets_do_not_hide_direct_drawer_hits(self, tmp_path): + """A closet that points at a wrong drawer must NOT suppress the + drawer that direct search would have ranked first.""" + palace = str(tmp_path / "palace") + _seed_drawers(palace) + # Seed a misleading closet: it matches a generic phrase but points at D3. + _seed_strong_closet_for( + palace, + drawer_id="D3", + source_file="fixture_D3.md", + topics=["Kafka queue tuning", "consumer rebalance config"], + ) + result = search_memories("Kafka consumer rebalance timeout", palace, n_results=5) + ids = [h["source_file"] for h in result["results"]] + assert "fixture_D4.md" in ids, ( + "D4 must appear — direct drawer search alone would rank it first. " + "Closet pointing to D3 should only boost D3, never hide D4." + ) + + def test_closet_boost_lifts_matching_drawer(self, tmp_path): + """When a closet agrees with direct search, the matching drawer + should be boosted to rank 1.""" + palace = str(tmp_path / "palace") + _seed_drawers(palace) + _seed_strong_closet_for( + palace, + drawer_id="D1", + source_file="fixture_D1.md", + topics=["JWT auth tokens", "session expiry", "authentication service"], + ) + result = search_memories("JWT auth tokens expiry", palace, n_results=3) + ids = [h["source_file"] for h in result["results"]] + assert ids[0] == "fixture_D1.md" + top = result["results"][0] + assert top["matched_via"] == "drawer+closet" + assert top["closet_boost"] > 0 + + +# ── closet_boost metadata ──────────────────────────────────────────────── + + +class TestClosetMetadata: + def test_closet_preview_exposed_when_boosted(self, tmp_path): + palace = str(tmp_path / "palace") + _seed_drawers(palace) + _seed_strong_closet_for( + palace, + drawer_id="D1", + source_file="fixture_D1.md", + topics=["JWT auth tokens", "24h expiry", "authentication"], + ) + result = search_memories("JWT authentication", palace, n_results=2) + top = result["results"][0] + assert top["source_file"] == "fixture_D1.md" + assert "closet_preview" in top + + def test_drawer_only_hits_have_no_closet_preview(self, tmp_path): + palace = str(tmp_path / "palace") + _seed_drawers(palace) + # No closets + result = search_memories("TanStack Query", palace, n_results=2) + assert result["results"] + for h in result["results"]: + assert h["matched_via"] == "drawer" + assert "closet_preview" not in h + assert h["closet_boost"] == 0.0