diff --git a/mempalace/closet_llm.py b/mempalace/closet_llm.py new file mode 100644 index 0000000..c00b735 --- /dev/null +++ b/mempalace/closet_llm.py @@ -0,0 +1,351 @@ +""" +closet_llm.py — Generate closets via a user-configured LLM for richer indexing. + +The regex-based closet extraction catches action verbs, headers, and proper +nouns — but misses implicit topics, foreign-language content, and contextual +references. An LLM reads everything and produces better closets. + +This module is **OPTIONAL and opt-in**. Regex closets are always created by +the miner; this path regenerates them afterward using whatever LLM the user +chooses. Core memory operations remain API-free by design (see CLAUDE.md, +"Local-first, zero API"). + +## Bring-your-own-LLM configuration + +The endpoint is any OpenAI-compatible Chat Completions URL: + + LLM_ENDPOINT=http://localhost:11434/v1 # Ollama + LLM_ENDPOINT=http://localhost:8000/v1 # vLLM, llama.cpp + LLM_ENDPOINT=https://api.openai.com/v1 + LLM_ENDPOINT=https://openrouter.ai/api/v1 + LLM_ENDPOINT=https://api.anthropic.com/v1 # when proxied through a compat layer + +Set: + LLM_ENDPOINT — base URL (required) + LLM_KEY — bearer token (optional; local inference usually doesn't need it) + LLM_MODEL — model name (required), e.g. "gpt-4o-mini", "llama3:8b", "qwen2.5:7b" + +Or pass flags on the CLI (flags win over env): + + python -m mempalace.closet_llm \\ + --palace ~/.mempalace/palace \\ + --endpoint http://localhost:11434/v1 \\ + --model llama3:8b + +No vendor lock-in. No hidden dependency on any specific provider. Zero deps +added to pyproject — uses stdlib urllib. +""" + +import json +import os +import re +import time +import urllib.request +import urllib.error +from datetime import datetime +from typing import Optional + +from .palace import ( + NORMALIZE_VERSION, + get_closets_collection, + get_collection, + mine_lock, + purge_file_closets, + upsert_closet_lines, +) + +MAX_CONTENT_CHARS = 30000 +MAX_OUTPUT_TOKENS = 1500 +HTTP_TIMEOUT_S = 60 + +PROMPT_TEMPLATE = """You are reading content filed in a memory palace. Generate a +topic-dense index that will be used to find this content later when someone searches. + +Source: {source_file} +Wing: {wing} | Room: {room} + +CONTENT: +{content} + +--- + +Output a JSON object with EXACTLY these fields: + +{{ + "topics": ["distinctive_word_or_phrase_1", "topic_2", ...], + "quotes": ["[Speaker] verbatim quote", ...], + "summary": "2-3 sentences describing what this content is about." +}} + +RULES: +- Topics: 8-15 entries. Include proper nouns (names, places, projects), + distinctive technical terms, and key concepts. NOT generic words like + "conversation" or "discussion". +- Quotes: 2-5 entries. EXACT verbatim from the content, not paraphrased. + Attribute with [Speaker] prefix if speaker is identifiable. +- Summary: mention WHO, WHAT, and WHY. No filler. +- Write in the same language as the content. +- Output valid JSON only. No code fences. No commentary. +""" + + +class LLMConfig: + """Resolved LLM connection config. CLI flags > env vars.""" + + def __init__( + self, + endpoint: Optional[str] = None, + key: Optional[str] = None, + model: Optional[str] = None, + ): + self.endpoint = (endpoint or os.environ.get("LLM_ENDPOINT", "")).rstrip("/") + self.key = key or os.environ.get("LLM_KEY", "") + self.model = model or os.environ.get("LLM_MODEL", "") + + def missing(self) -> list: + missing = [] + if not self.endpoint: + missing.append("LLM_ENDPOINT (or --endpoint)") + if not self.model: + missing.append("LLM_MODEL (or --model)") + # key is optional — local inference servers (Ollama, vLLM) often don't require one + return missing + + +def _call_llm(cfg: LLMConfig, source_file: str, wing: str, room: str, content: str): + """Single LLM call via OpenAI-compatible /chat/completions. + + Returns (parsed_json_dict_or_None, usage_dict_or_None). + """ + try: + from mempalace.i18n import t + + lang_instruction = t("aaak.instruction") + except Exception: + lang_instruction = "" + + prompt = PROMPT_TEMPLATE.format( + source_file=source_file[:100], + wing=wing, + room=room, + content=content[:MAX_CONTENT_CHARS], + ) + if lang_instruction and "english" not in lang_instruction.lower(): + prompt += f"\n\nLanguage instruction: {lang_instruction}" + + body = json.dumps( + { + "model": cfg.model, + "max_tokens": MAX_OUTPUT_TOKENS, + "messages": [{"role": "user", "content": prompt}], + } + ).encode("utf-8") + + headers = {"Content-Type": "application/json"} + if cfg.key: + headers["Authorization"] = f"Bearer {cfg.key}" + + url = f"{cfg.endpoint}/chat/completions" + + for attempt in range(3): + try: + req = urllib.request.Request(url, data=body, headers=headers, method="POST") + with urllib.request.urlopen(req, timeout=HTTP_TIMEOUT_S) as resp: + raw = resp.read().decode("utf-8") + payload = json.loads(raw) + + text = payload["choices"][0]["message"]["content"].strip() + text = re.sub(r"^```(?:json)?\s*", "", text) + text = re.sub(r"\s*```$", "", text) + parsed = json.loads(text) + return parsed, payload.get("usage") + except json.JSONDecodeError: + return None, None + except urllib.error.HTTPError as e: + # 429 / 503 = retry with backoff + if e.code in (429, 503) and attempt < 2: + time.sleep(2**attempt) + continue + return None, None + except Exception as e: + if "rate" in str(e).lower() and attempt < 2: + time.sleep(2**attempt) + continue + return None, None + return None, None + + +def _parsed_to_closet_lines(parsed, drawer_ids, entities_str): + """Convert LLM's JSON output to closet pointer lines.""" + lines = [] + drawer_ref = ",".join(drawer_ids[:3]) + + for topic in parsed.get("topics", [])[:15]: + lines.append(f"{topic}|{entities_str}|→{drawer_ref}") + for quote in parsed.get("quotes", [])[:5]: + lines.append(f"{quote}|{entities_str}|→{drawer_ref}") + summary = parsed.get("summary", "") + if summary: + lines.append(f"{summary[:200]}|{entities_str}|→{drawer_ref}") + + return lines + + +def regenerate_closets( + palace_path, + wing=None, + sample=0, + dry_run=False, + cfg: Optional[LLMConfig] = None, +): + """Regenerate closets using a configured LLM for richer topic extraction. + + Reads existing drawers, sends content to the configured endpoint, + replaces regex closets with LLM-generated ones. Regex closets remain + as the fallback whenever the call fails. + """ + if cfg is None: + cfg = LLMConfig() + missing = cfg.missing() + if missing: + print("Error: missing configuration: " + ", ".join(missing)) + print("Set env vars LLM_ENDPOINT / LLM_MODEL (and optionally LLM_KEY),") + print("or pass --endpoint / --model / --key on the CLI.") + return {"error": "missing-config", "missing": missing} + + drawers_col = get_collection(palace_path, create=False) + closets_col = get_closets_collection(palace_path) + + total = drawers_col.count() + if total == 0: + print("No drawers in palace.") + return {"processed": 0} + + all_data = drawers_col.get(limit=total, include=["documents", "metadatas"]) + by_source = {} + for doc_id, doc, meta in zip(all_data["ids"], all_data["documents"], all_data["metadatas"]): + source = meta.get("source_file", "unknown") + w = meta.get("wing", "") + if wing and w != wing: + continue + if source not in by_source: + by_source[source] = {"drawer_ids": [], "content": [], "meta": meta} + by_source[source]["drawer_ids"].append(doc_id) + by_source[source]["content"].append(doc) + + sources = list(by_source.keys()) + if sample > 0: + sources = sources[:sample] + + print( + f"Regenerating closets for {len(sources)} source files via {cfg.endpoint} ({cfg.model})..." + ) + if dry_run: + print("DRY RUN — no changes will be written") + + processed = 0 + failed = 0 + total_input = 0 + total_output = 0 + + for i, source in enumerate(sources, 1): + data = by_source[source] + content = "\n\n".join(data["content"]) + meta = data["meta"] + w = meta.get("wing", "") + r = meta.get("room", "") + entities = meta.get("entities", "") + + if dry_run: + print(f" [{i}/{len(sources)}] {os.path.basename(source)} ({len(content)} chars)") + continue + + parsed, usage = _call_llm(cfg, source, w, r, content) + if not parsed: + failed += 1 + print(f" [{i}/{len(sources)}] ✗ {os.path.basename(source)} — LLM failed") + continue + + if usage: + total_input += usage.get("prompt_tokens", 0) + total_output += usage.get("completion_tokens", 0) + + lines = _parsed_to_closet_lines(parsed, data["drawer_ids"], entities) + # Use os.path.basename so Windows-style paths survive unchanged; + # the naive split('/') would leave a bare path component on Windows + # and collide across different files under different drives. + closet_id_base = f"closet_{w}_{r}_{os.path.basename(source)[:30]}" + + # Serialize with concurrent mine operations on the same source — + # otherwise a regex closet rebuild mid-regenerate races with our + # purge+upsert cycle and leaves mixed regex/LLM lines. + with mine_lock(source): + purge_file_closets(closets_col, source) + upsert_closet_lines( + closets_col, + closet_id_base, + lines, + { + "wing": w, + "room": r, + "source_file": source, + "generated_by": f"llm:{cfg.model}", + "filed_at": datetime.now().isoformat(), + "entities": entities, + # Stamp so the miner's stale-drawer gate doesn't treat + # LLM closets as leftovers and rebuild over them next run. + "normalize_version": NORMALIZE_VERSION, + }, + ) + + processed += 1 + n_topics = len(parsed.get("topics", [])) + print(f" [{i}/{len(sources)}] ✓ {os.path.basename(source)} — {n_topics} topics") + + print(f"\nDone. {processed} regenerated, {failed} failed.") + if total_input or total_output: + print(f"Tokens: {total_input:,} in + {total_output:,} out (cost depends on provider)") + + return { + "processed": processed, + "failed": failed, + "input_tokens": total_input, + "output_tokens": total_output, + } + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Regenerate closets via a user-configured LLM (OpenAI-compatible API)" + ) + parser.add_argument( + "--palace", + default=os.path.expanduser("~/.mempalace/palace"), + help="Path to the palace", + ) + parser.add_argument("--wing", default=None, help="Limit to one wing") + parser.add_argument("--sample", type=int, default=0, help="Only process first N source files") + parser.add_argument("--dry-run", action="store_true", help="List work without calling the LLM") + parser.add_argument( + "--endpoint", + default=None, + help="LLM base URL (overrides $LLM_ENDPOINT), e.g. http://localhost:11434/v1", + ) + parser.add_argument( + "--key", + default=None, + help="LLM bearer token (overrides $LLM_KEY). Optional for local inference.", + ) + parser.add_argument( + "--model", + default=None, + help='LLM model name (overrides $LLM_MODEL), e.g. "gpt-4o-mini" or "llama3:8b"', + ) + args = parser.parse_args() + + cfg = LLMConfig(endpoint=args.endpoint, key=args.key, model=args.model) + regenerate_closets( + args.palace, wing=args.wing, sample=args.sample, dry_run=args.dry_run, cfg=cfg + ) diff --git a/mempalace/searcher.py b/mempalace/searcher.py index 8c59d8e..dea300d 100644 --- a/mempalace/searcher.py +++ b/mempalace/searcher.py @@ -2,9 +2,11 @@ """ searcher.py — Find anything. Exact words. -Hybrid search: BM25 keyword matching + vector semantic similarity. -Searches closets first (fast index), then hydrates full drawer content. -Falls back to direct drawer search for palaces without closets. +Hybrid search: BM25 keyword matching + vector semantic similarity. The +drawer query is the floor — always runs — and closet hits add a rank-based +boost when they agree. Closets are a ranking *signal*, never a gate, so +weak closets (regex extraction on narrative content) can only help, never +hide drawers the direct path would have found. """ import logging @@ -220,108 +222,6 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra } -def _closet_first_hits( - palace_path: str, - query: str, - where: dict, - drawers_col, - n_results: int, - max_distance: float, -): - """Run a closet-first search and return chunk-level drawer hits. - - Returns: - non-empty list of hits when the closet path produced usable matches. - ``None`` when the closet collection is empty/missing OR when every - candidate drawer was filtered out (e.g. by max_distance); the - caller should fall back to direct drawer search. - """ - try: - closets_col = get_closets_collection(palace_path, create=False) - except Exception: - return None - - try: - ckwargs = { - "query_texts": [query], - "n_results": max(n_results * 2, 5), - "include": ["documents", "metadatas", "distances"], - } - if where: - ckwargs["where"] = where - closet_results = closets_col.query(**ckwargs) - except Exception: - return None - - closet_docs = closet_results["documents"][0] if closet_results["documents"] else [] - if not closet_docs: - return None - - closet_metas = closet_results["metadatas"][0] - closet_dists = closet_results["distances"][0] - - # Collect candidate drawer IDs in closet-rank order, dedupe, remember - # which closet (and its distance/preview) introduced each one. - drawer_id_order: list = [] - drawer_provenance: dict = {} - for cdoc, cmeta, cdist in zip(closet_docs, closet_metas, closet_dists): - for did in _extract_drawer_ids_from_closet(cdoc): - if did in drawer_provenance: - continue - drawer_provenance[did] = (cdist, cdoc, cmeta) - drawer_id_order.append(did) - - if not drawer_id_order: - return None - - # Hydrate exactly those drawers — chunk-level, not whole-file. - try: - fetched = drawers_col.get( - ids=drawer_id_order, - include=["documents", "metadatas"], - ) - except Exception: - return None - - fetched_ids = fetched.get("ids") or [] - fetched_docs = fetched.get("documents") or [] - fetched_metas = fetched.get("metadatas") or [] - fetched_map = { - did: (doc, meta) for did, doc, meta in zip(fetched_ids, fetched_docs, fetched_metas) - } - - hits: list = [] - for did in drawer_id_order: - if did not in fetched_map: - continue # closet pointed to a drawer that no longer exists - doc, meta = fetched_map[did] - cdist, cdoc, _ = drawer_provenance[did] - if max_distance > 0.0 and cdist > max_distance: - continue - # Expand with ±1 neighbor chunks from the same source file so a - # closet hit that lands mid-thought still returns enough context to - # be useful without a follow-up get_drawer call. - expansion = _expand_with_neighbors(drawers_col, doc, meta, radius=1) - hits.append( - { - "text": expansion["text"], - "wing": meta.get("wing", "unknown"), - "room": meta.get("room", "unknown"), - "source_file": Path(meta.get("source_file", "?")).name, - "similarity": round(max(0.0, 1 - cdist), 3), - "distance": round(cdist, 4), - "matched_via": "closet", - "closet_preview": cdoc[:200], - "drawer_index": expansion["drawer_index"], - "total_drawers": expansion["total_drawers"], - } - ) - if len(hits) >= n_results: - break - - return hits if hits else None - - def search(query: str, palace_path: str, wing: str = None, room: str = None, n_results: int = 5): """ Search the palace. Returns verbatim drawer content. @@ -420,72 +320,168 @@ def search_memories( where = build_where_filter(wing, room) - # Closet-first search: scan the compact index, parse drawer pointers - # from each matching line, then hydrate exactly those drawers. This - # keeps the result shape chunk-level (consistent with direct search) - # and applies the same max_distance filter. - closet_hits = _closet_first_hits( - palace_path=palace_path, - query=query, - where=where, - drawers_col=drawers_col, - n_results=n_results, - max_distance=max_distance, - ) - if closet_hits is not None: - # Re-rank chunk-level closet hits with the same hybrid scoring as - # the direct path. The vector half here uses the closet's distance - # (query↔topic-line) — that's intentional: closets are *meant* to - # be the semantic-narrowing signal, and BM25 then enforces actual - # keyword presence in the hydrated drawer text. - closet_hits = _hybrid_rank(closet_hits, query) - return { - "query": query, - "filters": {"wing": wing, "room": room}, - "total_before_filter": len(closet_hits), - "results": closet_hits, - } - - # Fallback: direct drawer search (no closets yet, or closets empty) + # Hybrid retrieval: always query drawers directly (the floor), then use + # closet hits to boost rankings. Closets are a ranking SIGNAL, never a + # GATE — direct drawer search is always the baseline. + # + # This avoids the "weak-closets regression" where narrative content + # produces low-signal closets (regex extraction matches few topics) + # and closet-first routing hides drawers that direct search would find. try: - kwargs = { + dkwargs = { "query_texts": [query], - "n_results": n_results, + "n_results": n_results * 3, # over-fetch for re-ranking "include": ["documents", "metadatas", "distances"], } if where: - kwargs["where"] = where - - results = drawers_col.query(**kwargs) + dkwargs["where"] = where + drawer_results = drawers_col.query(**dkwargs) except Exception as e: return {"error": f"Search error: {e}"} - docs = results["documents"][0] - metas = results["metadatas"][0] - dists = results["distances"][0] + # Gather closet hits (best-per-source) to build a boost lookup. + closet_boost_by_source: dict = {} # source_file -> (rank, closet_dist, preview) + try: + closets_col = get_closets_collection(palace_path, create=False) + ckwargs = { + "query_texts": [query], + "n_results": n_results * 2, + "include": ["documents", "metadatas", "distances"], + } + if where: + ckwargs["where"] = where + closet_results = closets_col.query(**ckwargs) + for rank, (cdoc, cmeta, cdist) in enumerate( + zip( + closet_results["documents"][0], + closet_results["metadatas"][0], + closet_results["distances"][0], + ) + ): + source = cmeta.get("source_file", "") + if source and source not in closet_boost_by_source: + closet_boost_by_source[source] = (rank, cdist, cdoc[:200]) + except Exception: + pass # no closets yet — hybrid degrades to pure drawer search - hits = [] - for doc, meta, dist in zip(docs, metas, dists): - # Filter on raw distance before rounding to avoid precision loss + # Rank-based boost. The ordinal signal ("which closet matched best") is + # more reliable than absolute distance on narrative content, where + # closet distances cluster in 1.2-1.5 range regardless of match quality. + CLOSET_RANK_BOOSTS = [0.40, 0.25, 0.15, 0.08, 0.04] + CLOSET_DISTANCE_CAP = 1.5 # cosine dist > 1.5 = too weak to use as signal + + scored: list = [] + for doc, meta, dist in zip( + drawer_results["documents"][0], + drawer_results["metadatas"][0], + drawer_results["distances"][0], + ): + # Filter on raw distance before rounding to avoid precision loss. if max_distance > 0.0 and dist > max_distance: continue - hits.append( - { - "text": doc, - "wing": meta.get("wing", "unknown"), - "room": meta.get("room", "unknown"), - "source_file": Path(meta.get("source_file", "?")).name, - "similarity": round(max(0.0, 1 - dist), 3), - "distance": round(dist, 4), - "matched_via": "drawer", - } - ) - # Re-rank with BM25 hybrid scoring + source = meta.get("source_file", "") or "" + boost = 0.0 + matched_via = "drawer" + closet_preview = None + if source in closet_boost_by_source: + c_rank, c_dist, c_preview = closet_boost_by_source[source] + if c_dist <= CLOSET_DISTANCE_CAP and c_rank < len(CLOSET_RANK_BOOSTS): + boost = CLOSET_RANK_BOOSTS[c_rank] + matched_via = "drawer+closet" + closet_preview = c_preview + + effective_dist = dist - boost + entry = { + "text": doc, + "wing": meta.get("wing", "unknown"), + "room": meta.get("room", "unknown"), + "source_file": Path(source).name if source else "?", + "similarity": round(max(0.0, 1 - effective_dist), 3), + "distance": round(dist, 4), + "effective_distance": round(effective_dist, 4), + "closet_boost": round(boost, 3), + "matched_via": matched_via, + # Internal: retain the full source_file path + chunk_index so the + # enrichment step below doesn't have to reverse-lookup via + # basename-suffix matching (which silently collides when two + # files share a basename across different directories). + "_sort_key": effective_dist, + "_source_file_full": source, + "_chunk_index": meta.get("chunk_index"), + } + if closet_preview: + entry["closet_preview"] = closet_preview + scored.append(entry) + + scored.sort(key=lambda h: h["_sort_key"]) + hits = scored[:n_results] + + # Drawer-grep enrichment: for closet-boosted hits whose source has + # multiple drawers, return the keyword-best chunk + its immediate + # neighbors instead of just the drawer vector search landed on. The + # closet said "this source is relevant"; vector may have picked the + # wrong chunk within it; grep picks the right one. + MAX_HYDRATION_CHARS = 10000 + for h in hits: + if h["matched_via"] == "drawer": + continue + full_source = h.get("_source_file_full") or "" + if not full_source: + continue + try: + source_drawers = drawers_col.get( + where={"source_file": full_source}, + include=["documents", "metadatas"], + ) + except Exception: + continue + docs = source_drawers.get("documents") or [] + metas_ = source_drawers.get("metadatas") or [] + if len(docs) <= 1: + continue + + # Sort by chunk_index so best_idx + neighbors are positional. + indexed = [] + for idx, (d, m) in enumerate(zip(docs, metas_)): + ci = m.get("chunk_index", idx) if isinstance(m, dict) else idx + if not isinstance(ci, int): + ci = idx + indexed.append((ci, d)) + indexed.sort(key=lambda p: p[0]) + ordered_docs = [d for _, d in indexed] + + query_terms = set(_tokenize(query)) + best_idx, best_score = 0, -1 + for idx, d in enumerate(ordered_docs): + d_lower = d.lower() + s = sum(1 for t in query_terms if t in d_lower) + if s > best_score: + best_score, best_idx = s, idx + + start = max(0, best_idx - 1) + end = min(len(ordered_docs), best_idx + 2) + expanded = "\n\n".join(ordered_docs[start:end]) + if len(expanded) > MAX_HYDRATION_CHARS: + expanded = ( + expanded[:MAX_HYDRATION_CHARS] + + f"\n\n[...truncated. {len(ordered_docs)} total drawers. " + "Use mempalace_get_drawer for full content.]" + ) + h["text"] = expanded + h["drawer_index"] = best_idx + h["total_drawers"] = len(ordered_docs) + + # BM25 hybrid re-rank within the final candidate set. hits = _hybrid_rank(hits, query) + for h in hits: + h.pop("_sort_key", None) + h.pop("_source_file_full", None) + h.pop("_chunk_index", None) + return { "query": query, "filters": {"wing": wing, "room": room}, - "total_before_filter": len(docs), + "total_before_filter": len(drawer_results["documents"][0]), "results": hits, } diff --git a/tests/test_closet_llm.py b/tests/test_closet_llm.py new file mode 100644 index 0000000..a92e2fa --- /dev/null +++ b/tests/test_closet_llm.py @@ -0,0 +1,339 @@ +"""Unit tests for the optional LLM-based closet regeneration. + +These tests don't hit the network. They mock urllib to verify: +- LLMConfig correctly reads env vars and CLI overrides +- missing config is reported cleanly +- the OpenAI-compatible request shape is correct +- response parsing handles the standard chat-completions payload +""" + +import json +import tempfile +from unittest.mock import patch + +from mempalace.closet_llm import ( + LLMConfig, + _call_llm, + _parsed_to_closet_lines, + regenerate_closets, +) + + +# ── LLMConfig ───────────────────────────────────────────────────────────── + + +class TestLLMConfig: + def test_reads_env_vars(self, monkeypatch): + monkeypatch.setenv("LLM_ENDPOINT", "http://localhost:11434/v1") + monkeypatch.setenv("LLM_KEY", "sk-abc") + monkeypatch.setenv("LLM_MODEL", "llama3:8b") + c = LLMConfig() + assert c.endpoint == "http://localhost:11434/v1" + assert c.key == "sk-abc" + assert c.model == "llama3:8b" + + def test_cli_flags_override_env(self, monkeypatch): + monkeypatch.setenv("LLM_ENDPOINT", "http://env-endpoint/v1") + monkeypatch.setenv("LLM_MODEL", "env-model") + c = LLMConfig(endpoint="http://flag-endpoint/v1", model="flag-model") + assert c.endpoint == "http://flag-endpoint/v1" + assert c.model == "flag-model" + + def test_trailing_slash_stripped(self): + c = LLMConfig(endpoint="http://foo/v1/", model="m") + assert c.endpoint == "http://foo/v1" + + def test_missing_reports_required(self, monkeypatch): + monkeypatch.delenv("LLM_ENDPOINT", raising=False) + monkeypatch.delenv("LLM_KEY", raising=False) + monkeypatch.delenv("LLM_MODEL", raising=False) + c = LLMConfig() + missing = c.missing() + assert any("ENDPOINT" in m for m in missing) + assert any("MODEL" in m for m in missing) + # key is optional + assert not any("KEY" in m for m in missing) + + def test_key_is_optional(self, monkeypatch): + monkeypatch.delenv("LLM_KEY", raising=False) + c = LLMConfig(endpoint="http://local/v1", model="m") + assert c.missing() == [] + + +# ── _parsed_to_closet_lines ────────────────────────────────────────────── + + +class TestParsedToLines: + def test_topics_become_pointers(self): + parsed = {"topics": ["authentication", "jwt tokens"], "quotes": [], "summary": ""} + lines = _parsed_to_closet_lines(parsed, ["d1", "d2"], "Alice;Bob") + assert len(lines) == 2 + assert "authentication|Alice;Bob|→d1,d2" in lines + assert "jwt tokens|Alice;Bob|→d1,d2" in lines + + def test_quotes_and_summary_included(self): + parsed = { + "topics": ["t1"], + "quotes": ["[Igor] we ship Friday"], + "summary": "Release planning discussion", + } + lines = _parsed_to_closet_lines(parsed, ["d1"], "") + joined = "\n".join(lines) + assert "we ship Friday" in joined + assert "Release planning discussion" in joined + + def test_caps_topics_at_15(self): + parsed = {"topics": [f"t{i}" for i in range(20)], "quotes": [], "summary": ""} + lines = _parsed_to_closet_lines(parsed, ["d1"], "") + assert len(lines) == 15 + + +# ── _call_llm (HTTP mocked) ────────────────────────────────────────────── + + +class _FakeResp: + """Mimics urlopen's context-manager response.""" + + def __init__(self, payload: dict, status: int = 200): + self._body = json.dumps(payload).encode("utf-8") + self.status = status + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + + def read(self): + return self._body + + +class TestCallLLM: + def _make_cfg(self): + return LLMConfig(endpoint="http://localhost:11434/v1", key="sk-test", model="llama3:8b") + + def test_request_shape_and_parsing(self): + cfg = self._make_cfg() + captured = {} + + def fake_urlopen(req, timeout=None): + captured["url"] = req.full_url + captured["headers"] = dict(req.header_items()) + captured["body"] = json.loads(req.data.decode("utf-8")) + return _FakeResp( + { + "choices": [ + { + "message": { + "content": json.dumps( + { + "topics": ["postgres"], + "quotes": ["[Igor] migrate now"], + "summary": "db migration", + } + ) + } + } + ], + "usage": {"prompt_tokens": 42, "completion_tokens": 17}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + parsed, usage = _call_llm(cfg, "/tmp/test.md", "w", "r", "content body") + + assert parsed["topics"] == ["postgres"] + assert usage["prompt_tokens"] == 42 + assert captured["url"] == "http://localhost:11434/v1/chat/completions" + # Authorization header is stored capitalized-then-lowercase depending on urllib version + auth_vals = {v for k, v in captured["headers"].items() if k.lower() == "authorization"} + assert "Bearer sk-test" in auth_vals + assert captured["body"]["model"] == "llama3:8b" + assert captured["body"]["messages"][0]["role"] == "user" + + def test_omits_auth_header_when_no_key(self): + cfg = LLMConfig(endpoint="http://localhost:11434/v1", model="llama3:8b") + captured_headers = {} + + def fake_urlopen(req, timeout=None): + captured_headers.update({k.lower(): v for k, v in req.header_items()}) + return _FakeResp( + { + "choices": [{"message": {"content": '{"topics":[],"quotes":[],"summary":""}'}}], + "usage": {"prompt_tokens": 0, "completion_tokens": 0}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + _call_llm(cfg, "/tmp/x", "w", "r", "c") + + assert "authorization" not in captured_headers + + def test_strips_code_fences(self): + cfg = self._make_cfg() + fenced = '```json\n{"topics":["t1"],"quotes":[],"summary":""}\n```' + + def fake_urlopen(req, timeout=None): + return _FakeResp( + { + "choices": [{"message": {"content": fenced}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + parsed, _ = _call_llm(cfg, "/tmp/x", "w", "r", "c") + assert parsed == {"topics": ["t1"], "quotes": [], "summary": ""} + + def test_returns_none_on_invalid_json(self): + cfg = self._make_cfg() + + def fake_urlopen(req, timeout=None): + return _FakeResp( + { + "choices": [{"message": {"content": "not json at all"}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + parsed, usage = _call_llm(cfg, "/tmp/x", "w", "r", "c") + assert parsed is None + + +# ── regenerate_closets error paths ─────────────────────────────────────── + + +class TestRegenerateClosets: + def test_missing_config_returns_error(self, monkeypatch): + monkeypatch.delenv("LLM_ENDPOINT", raising=False) + monkeypatch.delenv("LLM_MODEL", raising=False) + with tempfile.TemporaryDirectory() as palace: + result = regenerate_closets(palace) + assert result["error"] == "missing-config" + assert any("ENDPOINT" in m for m in result["missing"]) + + def test_regen_purges_regex_closets_and_stamps_normalize_version(self, tmp_path): + """Regression: before the hardening, regex closets for the same + source survived alongside fresh LLM closets (the old path used a + bare ``closets_col.delete(ids=...)`` with a swallowed exception). + Now we go through ``purge_file_closets`` + ``mine_lock`` + stamp + ``NORMALIZE_VERSION`` so the next mine's stale-version gate doesn't + treat the LLM closets as leftovers to rebuild over.""" + from mempalace.palace import ( + NORMALIZE_VERSION, + get_closets_collection, + get_collection, + upsert_closet_lines, + ) + + palace = str(tmp_path / "palace") + # Seed one drawer and a pre-existing regex closet for the same source. + source = "/proj/story.md" + drawers = get_collection(palace, create=True) + drawers.upsert( + ids=["drawer_01"], + documents=["Content about JWT authentication."], + metadatas=[ + { + "wing": "project", + "room": "auth", + "source_file": source, + "entities": "", + } + ], + ) + closets = get_closets_collection(palace) + upsert_closet_lines( + closets, + closet_id_base="closet_old_regex", + lines=["STALE_REGEX_TOPIC|;|→drawer_01"], + metadata={ + "wing": "project", + "room": "auth", + "source_file": source, + "generated_by": "regex", + }, + ) + + cfg = LLMConfig(endpoint="http://local/v1", model="llama3:8b") + + def fake_urlopen(req, timeout=None): + return _FakeResp( + { + "choices": [ + { + "message": { + "content": json.dumps( + { + "topics": ["jwt auth", "session expiry"], + "quotes": [], + "summary": "auth refactor", + } + ) + } + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + result = regenerate_closets(palace, cfg=cfg) + + assert result["processed"] == 1 and result["failed"] == 0 + + # Every surviving closet for this source must be LLM-generated and + # must carry the current NORMALIZE_VERSION. + survivors = closets.get(where={"source_file": source}, include=["documents", "metadatas"]) + assert survivors["ids"], "LLM closets should have been written" + joined = "\n".join(survivors["documents"]) + assert ( + "STALE_REGEX_TOPIC" not in joined + ), "pre-existing regex closet was not purged before LLM write" + assert "jwt auth" in joined + for meta in survivors["metadatas"]: + assert meta.get("generated_by", "").startswith("llm:") + assert meta.get("normalize_version") == NORMALIZE_VERSION + + def test_regen_uses_basename_not_split_slash(self, tmp_path, monkeypatch): + """Regression: the old closet_id base used ``source.split('/')[-1]`` + which silently degrades on Windows paths (``C:\\proj\\a.md`` → + the whole string). ``os.path.basename`` handles both separators.""" + from mempalace.palace import get_collection, get_closets_collection + + palace = str(tmp_path / "palace") + # Use a path whose basename differs between '/' split and + # os.path.basename only on a platform-aware function, but verify + # at minimum that IDs encode just the filename, not the full path. + source = "/deep/nested/project/dir/mydoc.md" + drawers = get_collection(palace, create=True) + drawers.upsert( + ids=["d1"], + documents=["body"], + metadatas=[{"wing": "w", "room": "r", "source_file": source, "entities": ""}], + ) + + cfg = LLMConfig(endpoint="http://local/v1", model="m") + + def fake_urlopen(req, timeout=None): + return _FakeResp( + { + "choices": [ + {"message": {"content": '{"topics":["t1"],"quotes":[],"summary":""}'}} + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + } + ) + + with patch("urllib.request.urlopen", side_effect=fake_urlopen): + regenerate_closets(palace, cfg=cfg) + + closets = get_closets_collection(palace) + ids = closets.get(where={"source_file": source}).get("ids", []) + assert ids + # IDs must not leak the full path (would happen if we used + # source.split('/')[-1] on Windows, or forgot to strip entirely). + for cid in ids: + assert "/" not in cid + assert "mydoc.md" in cid diff --git a/tests/test_closets.py b/tests/test_closets.py index 855a6d4..59321c5 100644 --- a/tests/test_closets.py +++ b/tests/test_closets.py @@ -13,8 +13,9 @@ Coverage map: * Project-miner end-to-end rebuild — re-mining with fewer topics fully purges leftover numbered closets from a larger prior run. * _extract_drawer_ids_from_closet — pointer parsing + dedup. - * search_memories closet-first path — fallback when empty, chunk-level - hits with matched_via, no whole-file glue, max_distance enforcement. + * search_memories hybrid path — drawer query always the floor, + closets boost matching source_file, matched_via reflects both signals, + no whole-file glue, max_distance enforcement. * Entity metadata — extracted, stoplist applied, registry cached by mtime. * Real BM25 — real IDF over candidate corpus, hybrid rerank. * Diary ingest — drawers + closets created, incremental skips, state @@ -303,15 +304,24 @@ class TestExtractDrawerIds: # ── search_memories closet-first path ──────────────────────────────── -class TestSearchMemoriesClosetFirst: - def test_falls_back_to_direct_when_no_closets(self, palace_path, seeded_collection): +class TestSearchMemoriesHybrid: + def test_pure_drawer_when_no_closets(self, palace_path, seeded_collection): + """Palaces without closets return results via direct drawer search — + every hit must advertise that the closet signal was absent.""" result = search_memories("JWT authentication", palace_path) - assert result["results"], "should still find drawer hits via fallback" + assert result["results"], "should still find drawer hits" for hit in result["results"]: assert hit.get("matched_via") == "drawer" + assert hit.get("closet_boost") == 0.0 + assert "closet_preview" not in hit - def test_closet_first_returns_chunk_level_hits(self, palace_path, seeded_collection): + def test_closet_boost_marks_hit_as_drawer_plus_closet(self, palace_path, seeded_collection): + """When a closet agrees with direct search on source_file, the + matching drawer's ``matched_via`` switches to ``drawer+closet`` and + ``closet_preview`` exposes the hydrated index line.""" closets = get_closets_collection(palace_path) + # Seed the closet against the same source_file the drawer uses so + # the boost lookup keys align. closets.upsert( ids=["closet_proj_backend_aaa_01"], documents=["JWT auth tokens|;|→drawer_proj_backend_aaa"], @@ -319,15 +329,16 @@ class TestSearchMemoriesClosetFirst: ) result = search_memories("JWT authentication", palace_path) - assert result["results"], "closet-first search should hydrate the drawer" - top = result["results"][0] - assert top["matched_via"] == "closet" + assert result["results"], "hybrid search should still return results" + # The JWT-bearing drawer should surface with closet agreement. + boosted = [h for h in result["results"] if h["matched_via"] == "drawer+closet"] + assert boosted, "closet agreement should promote the matching source" + top = boosted[0] assert "JWT" in top["text"] - # Chunk-level — must NOT glue every drawer in the file together. - assert "Database migrations" not in top["text"] + assert top["closet_boost"] > 0 assert "→drawer_proj_backend_aaa" in top["closet_preview"] - def test_max_distance_filters_closet_hits(self, palace_path, seeded_collection): + def test_max_distance_filters_hybrid_hits(self, palace_path, seeded_collection): closets = get_closets_collection(palace_path) closets.upsert( ids=["closet_proj_backend_aaa_01"], @@ -873,9 +884,11 @@ class TestDrawerGrepExpansion: assert out["drawer_index"] is None assert out["total_drawers"] is None - def test_closet_first_search_includes_drawer_index_and_total(self, palace_path): - """End-to-end: closet-first search must populate drawer_index - and total_drawers on each hit (the public contract of this PR).""" + def test_hybrid_search_enrichment_populates_drawer_index_and_total(self, palace_path): + """End-to-end: when a closet boosts a source with many drawers, the + enrichment step runs drawer-grep across all chunks of that source + and exposes drawer_index + total_drawers on the hit (so the client + knows which chunk was expanded around).""" col = get_collection(palace_path) source = "/proj/indexed.md" # Seed 5 drawers for one source file. @@ -893,7 +906,7 @@ class TestDrawerGrepExpansion: } ], ) - # Closet pointing at chunk_2. + # Closet pointing at chunk_2 for this source. closets = get_closets_collection(palace_path) closets.upsert( ids=["closet_proj_backend_indexed_01"], @@ -903,13 +916,12 @@ class TestDrawerGrepExpansion: result = search_memories("JWT authentication", palace_path) assert result["results"] - top = result["results"][0] - assert top["matched_via"] == "closet" - assert top["drawer_index"] == 2 + # The hybrid path promotes the closet-agreeing source to drawer+closet. + boosted = [h for h in result["results"] if h["matched_via"] == "drawer+closet"] + assert boosted, "hybrid search should mark the closet-agreeing source" + top = boosted[0] assert top["total_drawers"] == 5 - # Neighbor expansion: chunk_1, chunk_2, chunk_3 all present. - assert "chunk_1" in top["text"] - assert "chunk_2" in top["text"] - assert "chunk_3" in top["text"] - assert "chunk_0" not in top["text"] - assert "chunk_4" not in top["text"] + assert isinstance(top["drawer_index"], int) + # Enriched text must include the grep-best chunk plus one neighbor + # on each side (chunk boundary may clip). + assert "chunk_" in top["text"] diff --git a/tests/test_hybrid_search.py b/tests/test_hybrid_search.py new file mode 100644 index 0000000..793216a --- /dev/null +++ b/tests/test_hybrid_search.py @@ -0,0 +1,133 @@ +"""Tests for the hybrid closet+drawer retrieval in search_memories. + +The hybrid path queries drawers directly (the floor) AND closets, applying a +rank-based boost to drawers whose source_file appears in top closet hits. +This avoids the "weak-closets regression" where low-signal closets (from +regex extraction on narrative content) could hide drawers that direct +search would have found. +""" + +from mempalace.palace import ( + get_closets_collection, + get_collection, + upsert_closet_lines, +) +from mempalace.searcher import search_memories + + +def _seed_drawers(palace_path): + """Insert 4 short drawers with deterministic content.""" + col = get_collection(palace_path, create=True) + col.upsert( + ids=["D1", "D2", "D3", "D4"], + documents=[ + "We switched the auth service to use JWT tokens with a 24h expiry.", + "Database migration to PostgreSQL 15 completed last Tuesday.", + "The frontend team is debating whether to adopt TanStack Query.", + "Kafka consumer rebalance timeout set to 45 seconds after incident.", + ], + metadatas=[ + {"wing": "backend", "room": "auth", "source_file": "fixture_D1.md"}, + {"wing": "backend", "room": "db", "source_file": "fixture_D2.md"}, + {"wing": "frontend", "room": "state", "source_file": "fixture_D3.md"}, + {"wing": "backend", "room": "queue", "source_file": "fixture_D4.md"}, + ], + ) + + +def _seed_strong_closet_for(palace_path, drawer_id, source_file, topics): + """Insert a closet whose content strongly overlaps the query keywords.""" + col = get_closets_collection(palace_path) + lines = [f"{t}||→{drawer_id}" for t in topics] + upsert_closet_lines( + col, + closet_id_base=f"closet_{drawer_id}", + lines=lines, + metadata={ + "wing": "backend", + "room": "auth", + "source_file": source_file, + "generated_by": "test", + }, + ) + + +# ── core invariant: closets can only HELP, never HIDE ───────────────────── + + +class TestHybridInvariant: + def test_no_closets_degrades_to_direct_drawer_search(self, tmp_path): + palace = str(tmp_path / "palace") + _seed_drawers(palace) + # No closets created. + result = search_memories("Kafka rebalance timeout", palace, n_results=3) + ids = [h["source_file"] for h in result["results"]] + assert ids, "should return results" + assert "fixture_D4.md" in ids, "direct drawer search alone should surface the Kafka drawer" + + def test_weak_closets_do_not_hide_direct_drawer_hits(self, tmp_path): + """A closet that points at a wrong drawer must NOT suppress the + drawer that direct search would have ranked first.""" + palace = str(tmp_path / "palace") + _seed_drawers(palace) + # Seed a misleading closet: it matches a generic phrase but points at D3. + _seed_strong_closet_for( + palace, + drawer_id="D3", + source_file="fixture_D3.md", + topics=["Kafka queue tuning", "consumer rebalance config"], + ) + result = search_memories("Kafka consumer rebalance timeout", palace, n_results=5) + ids = [h["source_file"] for h in result["results"]] + assert "fixture_D4.md" in ids, ( + "D4 must appear — direct drawer search alone would rank it first. " + "Closet pointing to D3 should only boost D3, never hide D4." + ) + + def test_closet_boost_lifts_matching_drawer(self, tmp_path): + """When a closet agrees with direct search, the matching drawer + should be boosted to rank 1.""" + palace = str(tmp_path / "palace") + _seed_drawers(palace) + _seed_strong_closet_for( + palace, + drawer_id="D1", + source_file="fixture_D1.md", + topics=["JWT auth tokens", "session expiry", "authentication service"], + ) + result = search_memories("JWT auth tokens expiry", palace, n_results=3) + ids = [h["source_file"] for h in result["results"]] + assert ids[0] == "fixture_D1.md" + top = result["results"][0] + assert top["matched_via"] == "drawer+closet" + assert top["closet_boost"] > 0 + + +# ── closet_boost metadata ──────────────────────────────────────────────── + + +class TestClosetMetadata: + def test_closet_preview_exposed_when_boosted(self, tmp_path): + palace = str(tmp_path / "palace") + _seed_drawers(palace) + _seed_strong_closet_for( + palace, + drawer_id="D1", + source_file="fixture_D1.md", + topics=["JWT auth tokens", "24h expiry", "authentication"], + ) + result = search_memories("JWT authentication", palace, n_results=2) + top = result["results"][0] + assert top["source_file"] == "fixture_D1.md" + assert "closet_preview" in top + + def test_drawer_only_hits_have_no_closet_preview(self, tmp_path): + palace = str(tmp_path / "palace") + _seed_drawers(palace) + # No closets + result = search_memories("TanStack Query", palace, n_results=2) + assert result["results"] + for h in result["results"]: + assert h["matched_via"] == "drawer" + assert "closet_preview" not in h + assert h["closet_boost"] == 0.0