merge: pr/closet-llm-generic + harden LLM regen path for production
Brings in PR #793 (optional LLM-based closet regeneration via user-configured OpenAI-compatible endpoint) and PR #795 (hybrid closet+drawer search — closets boost, never gate). Stack: #784 → #788 → #789 → #790 → #791 → #792 → #793 (+ #795). Findings hardened on our side ───────────────────────────── 1) closet_llm.regenerate_closets didn't use the blessed palace helpers. Before: * manual closets_col.get(where=...) + .delete(ids=...) with a silent ``except Exception: pass`` around both — if the purge failed, pre-existing regex closets survived alongside fresh LLM closets, giving the searcher double hits for the same source. * ``source.split('/')[-1][:30]`` to build the closet_id — quietly wrong on Windows paths (``C:\\proj\\a.md`` has no ``/``, so the whole string ends up in the ID). * no mine_lock around purge+upsert — a concurrent regex rebuild of the same source could interleave with our purge and leave a mix of regex and LLM pointers. * no ``normalize_version`` stamp on the LLM closets — the miner's stale-version gate would treat them as leftovers from an older schema and rebuild over them on the next mine. After: routes through ``purge_file_closets`` + ``mine_lock`` + ``os.path.basename`` + ``NORMALIZE_VERSION`` stamp. Regression tests cover each. 2) searcher.search_memories was still closet-first. PR #795 merged into #793's head to fix the recall regression documented in that PR (R@1 0.25 on narrative content vs. 0.42 baseline). The hybrid design makes closets a ranking boost rather than a gate: drawers are always queried at the floor, and matching closet hits (rank 0-4 within CLOSET_DISTANCE_CAP=1.5) add a boost of 0.40/0.25/0.15/0.08/0.04 to the effective distance. Merged to take the incoming hybrid design, with two cleanups: * kept the ``_expand_with_neighbors`` / ``_extract_drawer_ids_from_closet`` helpers as separately-tested utilities (still imported by tests and future callers); * replaced the fragile ``source_file.endswith(basename)`` reverse- lookup in the enrichment step with internal ``_source_file_full`` / ``_chunk_index`` fields stripped before return, so enrichment doesn't silently pick the wrong path when two sources share a basename across directories; * drawer-grep enrichment now sorts by ``chunk_index`` before neighbor expansion, so ``best_idx ± 1`` corresponds to actual document order rather than whatever order Chroma returned. 3) Closet-first tests in test_closets.py (``TestSearchMemoriesClosetFirst``, end-to-end ``test_closet_first_search_includes_drawer_index_and_total``) pinned contracts that the hybrid path now violates (``matched_via`` went from ``"closet"`` to ``"drawer+closet"``). Rewrote them around the new invariant: direct drawers are always the floor, closet agreement flips the hit's matched_via and exposes closet_preview. Verification ──────────── * 805/805 pass under ``uv run pytest tests/ -v --ignore=tests/benchmarks`` (13 new tests from PR #793 + 5 from PR #795 + 2 new regressions for the closet_llm hardening + the rewritten hybrid assertions in test_closets.py). * CI-pinned ruff 0.4.x clean on ``mempalace/`` + ``tests/`` (check + format both pass). * No new deps — closet_llm.py still uses stdlib ``urllib.request`` per the PR's "zero new dependencies" promise. Co-Authored-By: MSL <232237854+milla-jovovich@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
closet_llm.py — Generate closets via a user-configured LLM for richer indexing.
|
||||
|
||||
The regex-based closet extraction catches action verbs, headers, and proper
|
||||
nouns — but misses implicit topics, foreign-language content, and contextual
|
||||
references. An LLM reads everything and produces better closets.
|
||||
|
||||
This module is **OPTIONAL and opt-in**. Regex closets are always created by
|
||||
the miner; this path regenerates them afterward using whatever LLM the user
|
||||
chooses. Core memory operations remain API-free by design (see CLAUDE.md,
|
||||
"Local-first, zero API").
|
||||
|
||||
## Bring-your-own-LLM configuration
|
||||
|
||||
The endpoint is any OpenAI-compatible Chat Completions URL:
|
||||
|
||||
LLM_ENDPOINT=http://localhost:11434/v1 # Ollama
|
||||
LLM_ENDPOINT=http://localhost:8000/v1 # vLLM, llama.cpp
|
||||
LLM_ENDPOINT=https://api.openai.com/v1
|
||||
LLM_ENDPOINT=https://openrouter.ai/api/v1
|
||||
LLM_ENDPOINT=https://api.anthropic.com/v1 # when proxied through a compat layer
|
||||
|
||||
Set:
|
||||
LLM_ENDPOINT — base URL (required)
|
||||
LLM_KEY — bearer token (optional; local inference usually doesn't need it)
|
||||
LLM_MODEL — model name (required), e.g. "gpt-4o-mini", "llama3:8b", "qwen2.5:7b"
|
||||
|
||||
Or pass flags on the CLI (flags win over env):
|
||||
|
||||
python -m mempalace.closet_llm \\
|
||||
--palace ~/.mempalace/palace \\
|
||||
--endpoint http://localhost:11434/v1 \\
|
||||
--model llama3:8b
|
||||
|
||||
No vendor lock-in. No hidden dependency on any specific provider. Zero deps
|
||||
added to pyproject — uses stdlib urllib.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .palace import (
|
||||
NORMALIZE_VERSION,
|
||||
get_closets_collection,
|
||||
get_collection,
|
||||
mine_lock,
|
||||
purge_file_closets,
|
||||
upsert_closet_lines,
|
||||
)
|
||||
|
||||
MAX_CONTENT_CHARS = 30000
|
||||
MAX_OUTPUT_TOKENS = 1500
|
||||
HTTP_TIMEOUT_S = 60
|
||||
|
||||
PROMPT_TEMPLATE = """You are reading content filed in a memory palace. Generate a
|
||||
topic-dense index that will be used to find this content later when someone searches.
|
||||
|
||||
Source: {source_file}
|
||||
Wing: {wing} | Room: {room}
|
||||
|
||||
CONTENT:
|
||||
{content}
|
||||
|
||||
---
|
||||
|
||||
Output a JSON object with EXACTLY these fields:
|
||||
|
||||
{{
|
||||
"topics": ["distinctive_word_or_phrase_1", "topic_2", ...],
|
||||
"quotes": ["[Speaker] verbatim quote", ...],
|
||||
"summary": "2-3 sentences describing what this content is about."
|
||||
}}
|
||||
|
||||
RULES:
|
||||
- Topics: 8-15 entries. Include proper nouns (names, places, projects),
|
||||
distinctive technical terms, and key concepts. NOT generic words like
|
||||
"conversation" or "discussion".
|
||||
- Quotes: 2-5 entries. EXACT verbatim from the content, not paraphrased.
|
||||
Attribute with [Speaker] prefix if speaker is identifiable.
|
||||
- Summary: mention WHO, WHAT, and WHY. No filler.
|
||||
- Write in the same language as the content.
|
||||
- Output valid JSON only. No code fences. No commentary.
|
||||
"""
|
||||
|
||||
|
||||
class LLMConfig:
|
||||
"""Resolved LLM connection config. CLI flags > env vars."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: Optional[str] = None,
|
||||
key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
self.endpoint = (endpoint or os.environ.get("LLM_ENDPOINT", "")).rstrip("/")
|
||||
self.key = key or os.environ.get("LLM_KEY", "")
|
||||
self.model = model or os.environ.get("LLM_MODEL", "")
|
||||
|
||||
def missing(self) -> list:
|
||||
missing = []
|
||||
if not self.endpoint:
|
||||
missing.append("LLM_ENDPOINT (or --endpoint)")
|
||||
if not self.model:
|
||||
missing.append("LLM_MODEL (or --model)")
|
||||
# key is optional — local inference servers (Ollama, vLLM) often don't require one
|
||||
return missing
|
||||
|
||||
|
||||
def _call_llm(cfg: LLMConfig, source_file: str, wing: str, room: str, content: str):
|
||||
"""Single LLM call via OpenAI-compatible /chat/completions.
|
||||
|
||||
Returns (parsed_json_dict_or_None, usage_dict_or_None).
|
||||
"""
|
||||
try:
|
||||
from mempalace.i18n import t
|
||||
|
||||
lang_instruction = t("aaak.instruction")
|
||||
except Exception:
|
||||
lang_instruction = ""
|
||||
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
source_file=source_file[:100],
|
||||
wing=wing,
|
||||
room=room,
|
||||
content=content[:MAX_CONTENT_CHARS],
|
||||
)
|
||||
if lang_instruction and "english" not in lang_instruction.lower():
|
||||
prompt += f"\n\nLanguage instruction: {lang_instruction}"
|
||||
|
||||
body = json.dumps(
|
||||
{
|
||||
"model": cfg.model,
|
||||
"max_tokens": MAX_OUTPUT_TOKENS,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
}
|
||||
).encode("utf-8")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if cfg.key:
|
||||
headers["Authorization"] = f"Bearer {cfg.key}"
|
||||
|
||||
url = f"{cfg.endpoint}/chat/completions"
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
req = urllib.request.Request(url, data=body, headers=headers, method="POST")
|
||||
with urllib.request.urlopen(req, timeout=HTTP_TIMEOUT_S) as resp:
|
||||
raw = resp.read().decode("utf-8")
|
||||
payload = json.loads(raw)
|
||||
|
||||
text = payload["choices"][0]["message"]["content"].strip()
|
||||
text = re.sub(r"^```(?:json)?\s*", "", text)
|
||||
text = re.sub(r"\s*```$", "", text)
|
||||
parsed = json.loads(text)
|
||||
return parsed, payload.get("usage")
|
||||
except json.JSONDecodeError:
|
||||
return None, None
|
||||
except urllib.error.HTTPError as e:
|
||||
# 429 / 503 = retry with backoff
|
||||
if e.code in (429, 503) and attempt < 2:
|
||||
time.sleep(2**attempt)
|
||||
continue
|
||||
return None, None
|
||||
except Exception as e:
|
||||
if "rate" in str(e).lower() and attempt < 2:
|
||||
time.sleep(2**attempt)
|
||||
continue
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
|
||||
def _parsed_to_closet_lines(parsed, drawer_ids, entities_str):
|
||||
"""Convert LLM's JSON output to closet pointer lines."""
|
||||
lines = []
|
||||
drawer_ref = ",".join(drawer_ids[:3])
|
||||
|
||||
for topic in parsed.get("topics", [])[:15]:
|
||||
lines.append(f"{topic}|{entities_str}|→{drawer_ref}")
|
||||
for quote in parsed.get("quotes", [])[:5]:
|
||||
lines.append(f"{quote}|{entities_str}|→{drawer_ref}")
|
||||
summary = parsed.get("summary", "")
|
||||
if summary:
|
||||
lines.append(f"{summary[:200]}|{entities_str}|→{drawer_ref}")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def regenerate_closets(
|
||||
palace_path,
|
||||
wing=None,
|
||||
sample=0,
|
||||
dry_run=False,
|
||||
cfg: Optional[LLMConfig] = None,
|
||||
):
|
||||
"""Regenerate closets using a configured LLM for richer topic extraction.
|
||||
|
||||
Reads existing drawers, sends content to the configured endpoint,
|
||||
replaces regex closets with LLM-generated ones. Regex closets remain
|
||||
as the fallback whenever the call fails.
|
||||
"""
|
||||
if cfg is None:
|
||||
cfg = LLMConfig()
|
||||
missing = cfg.missing()
|
||||
if missing:
|
||||
print("Error: missing configuration: " + ", ".join(missing))
|
||||
print("Set env vars LLM_ENDPOINT / LLM_MODEL (and optionally LLM_KEY),")
|
||||
print("or pass --endpoint / --model / --key on the CLI.")
|
||||
return {"error": "missing-config", "missing": missing}
|
||||
|
||||
drawers_col = get_collection(palace_path, create=False)
|
||||
closets_col = get_closets_collection(palace_path)
|
||||
|
||||
total = drawers_col.count()
|
||||
if total == 0:
|
||||
print("No drawers in palace.")
|
||||
return {"processed": 0}
|
||||
|
||||
all_data = drawers_col.get(limit=total, include=["documents", "metadatas"])
|
||||
by_source = {}
|
||||
for doc_id, doc, meta in zip(all_data["ids"], all_data["documents"], all_data["metadatas"]):
|
||||
source = meta.get("source_file", "unknown")
|
||||
w = meta.get("wing", "")
|
||||
if wing and w != wing:
|
||||
continue
|
||||
if source not in by_source:
|
||||
by_source[source] = {"drawer_ids": [], "content": [], "meta": meta}
|
||||
by_source[source]["drawer_ids"].append(doc_id)
|
||||
by_source[source]["content"].append(doc)
|
||||
|
||||
sources = list(by_source.keys())
|
||||
if sample > 0:
|
||||
sources = sources[:sample]
|
||||
|
||||
print(
|
||||
f"Regenerating closets for {len(sources)} source files via {cfg.endpoint} ({cfg.model})..."
|
||||
)
|
||||
if dry_run:
|
||||
print("DRY RUN — no changes will be written")
|
||||
|
||||
processed = 0
|
||||
failed = 0
|
||||
total_input = 0
|
||||
total_output = 0
|
||||
|
||||
for i, source in enumerate(sources, 1):
|
||||
data = by_source[source]
|
||||
content = "\n\n".join(data["content"])
|
||||
meta = data["meta"]
|
||||
w = meta.get("wing", "")
|
||||
r = meta.get("room", "")
|
||||
entities = meta.get("entities", "")
|
||||
|
||||
if dry_run:
|
||||
print(f" [{i}/{len(sources)}] {os.path.basename(source)} ({len(content)} chars)")
|
||||
continue
|
||||
|
||||
parsed, usage = _call_llm(cfg, source, w, r, content)
|
||||
if not parsed:
|
||||
failed += 1
|
||||
print(f" [{i}/{len(sources)}] ✗ {os.path.basename(source)} — LLM failed")
|
||||
continue
|
||||
|
||||
if usage:
|
||||
total_input += usage.get("prompt_tokens", 0)
|
||||
total_output += usage.get("completion_tokens", 0)
|
||||
|
||||
lines = _parsed_to_closet_lines(parsed, data["drawer_ids"], entities)
|
||||
# Use os.path.basename so Windows-style paths survive unchanged;
|
||||
# the naive split('/') would leave a bare path component on Windows
|
||||
# and collide across different files under different drives.
|
||||
closet_id_base = f"closet_{w}_{r}_{os.path.basename(source)[:30]}"
|
||||
|
||||
# Serialize with concurrent mine operations on the same source —
|
||||
# otherwise a regex closet rebuild mid-regenerate races with our
|
||||
# purge+upsert cycle and leaves mixed regex/LLM lines.
|
||||
with mine_lock(source):
|
||||
purge_file_closets(closets_col, source)
|
||||
upsert_closet_lines(
|
||||
closets_col,
|
||||
closet_id_base,
|
||||
lines,
|
||||
{
|
||||
"wing": w,
|
||||
"room": r,
|
||||
"source_file": source,
|
||||
"generated_by": f"llm:{cfg.model}",
|
||||
"filed_at": datetime.now().isoformat(),
|
||||
"entities": entities,
|
||||
# Stamp so the miner's stale-drawer gate doesn't treat
|
||||
# LLM closets as leftovers and rebuild over them next run.
|
||||
"normalize_version": NORMALIZE_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
processed += 1
|
||||
n_topics = len(parsed.get("topics", []))
|
||||
print(f" [{i}/{len(sources)}] ✓ {os.path.basename(source)} — {n_topics} topics")
|
||||
|
||||
print(f"\nDone. {processed} regenerated, {failed} failed.")
|
||||
if total_input or total_output:
|
||||
print(f"Tokens: {total_input:,} in + {total_output:,} out (cost depends on provider)")
|
||||
|
||||
return {
|
||||
"processed": processed,
|
||||
"failed": failed,
|
||||
"input_tokens": total_input,
|
||||
"output_tokens": total_output,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Regenerate closets via a user-configured LLM (OpenAI-compatible API)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--palace",
|
||||
default=os.path.expanduser("~/.mempalace/palace"),
|
||||
help="Path to the palace",
|
||||
)
|
||||
parser.add_argument("--wing", default=None, help="Limit to one wing")
|
||||
parser.add_argument("--sample", type=int, default=0, help="Only process first N source files")
|
||||
parser.add_argument("--dry-run", action="store_true", help="List work without calling the LLM")
|
||||
parser.add_argument(
|
||||
"--endpoint",
|
||||
default=None,
|
||||
help="LLM base URL (overrides $LLM_ENDPOINT), e.g. http://localhost:11434/v1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--key",
|
||||
default=None,
|
||||
help="LLM bearer token (overrides $LLM_KEY). Optional for local inference.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default=None,
|
||||
help='LLM model name (overrides $LLM_MODEL), e.g. "gpt-4o-mini" or "llama3:8b"',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
cfg = LLMConfig(endpoint=args.endpoint, key=args.key, model=args.model)
|
||||
regenerate_closets(
|
||||
args.palace, wing=args.wing, sample=args.sample, dry_run=args.dry_run, cfg=cfg
|
||||
)
|
||||
+147
-151
@@ -2,9 +2,11 @@
|
||||
"""
|
||||
searcher.py — Find anything. Exact words.
|
||||
|
||||
Hybrid search: BM25 keyword matching + vector semantic similarity.
|
||||
Searches closets first (fast index), then hydrates full drawer content.
|
||||
Falls back to direct drawer search for palaces without closets.
|
||||
Hybrid search: BM25 keyword matching + vector semantic similarity. The
|
||||
drawer query is the floor — always runs — and closet hits add a rank-based
|
||||
boost when they agree. Closets are a ranking *signal*, never a gate, so
|
||||
weak closets (regex extraction on narrative content) can only help, never
|
||||
hide drawers the direct path would have found.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -220,108 +222,6 @@ def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, ra
|
||||
}
|
||||
|
||||
|
||||
def _closet_first_hits(
|
||||
palace_path: str,
|
||||
query: str,
|
||||
where: dict,
|
||||
drawers_col,
|
||||
n_results: int,
|
||||
max_distance: float,
|
||||
):
|
||||
"""Run a closet-first search and return chunk-level drawer hits.
|
||||
|
||||
Returns:
|
||||
non-empty list of hits when the closet path produced usable matches.
|
||||
``None`` when the closet collection is empty/missing OR when every
|
||||
candidate drawer was filtered out (e.g. by max_distance); the
|
||||
caller should fall back to direct drawer search.
|
||||
"""
|
||||
try:
|
||||
closets_col = get_closets_collection(palace_path, create=False)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
ckwargs = {
|
||||
"query_texts": [query],
|
||||
"n_results": max(n_results * 2, 5),
|
||||
"include": ["documents", "metadatas", "distances"],
|
||||
}
|
||||
if where:
|
||||
ckwargs["where"] = where
|
||||
closet_results = closets_col.query(**ckwargs)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
closet_docs = closet_results["documents"][0] if closet_results["documents"] else []
|
||||
if not closet_docs:
|
||||
return None
|
||||
|
||||
closet_metas = closet_results["metadatas"][0]
|
||||
closet_dists = closet_results["distances"][0]
|
||||
|
||||
# Collect candidate drawer IDs in closet-rank order, dedupe, remember
|
||||
# which closet (and its distance/preview) introduced each one.
|
||||
drawer_id_order: list = []
|
||||
drawer_provenance: dict = {}
|
||||
for cdoc, cmeta, cdist in zip(closet_docs, closet_metas, closet_dists):
|
||||
for did in _extract_drawer_ids_from_closet(cdoc):
|
||||
if did in drawer_provenance:
|
||||
continue
|
||||
drawer_provenance[did] = (cdist, cdoc, cmeta)
|
||||
drawer_id_order.append(did)
|
||||
|
||||
if not drawer_id_order:
|
||||
return None
|
||||
|
||||
# Hydrate exactly those drawers — chunk-level, not whole-file.
|
||||
try:
|
||||
fetched = drawers_col.get(
|
||||
ids=drawer_id_order,
|
||||
include=["documents", "metadatas"],
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
fetched_ids = fetched.get("ids") or []
|
||||
fetched_docs = fetched.get("documents") or []
|
||||
fetched_metas = fetched.get("metadatas") or []
|
||||
fetched_map = {
|
||||
did: (doc, meta) for did, doc, meta in zip(fetched_ids, fetched_docs, fetched_metas)
|
||||
}
|
||||
|
||||
hits: list = []
|
||||
for did in drawer_id_order:
|
||||
if did not in fetched_map:
|
||||
continue # closet pointed to a drawer that no longer exists
|
||||
doc, meta = fetched_map[did]
|
||||
cdist, cdoc, _ = drawer_provenance[did]
|
||||
if max_distance > 0.0 and cdist > max_distance:
|
||||
continue
|
||||
# Expand with ±1 neighbor chunks from the same source file so a
|
||||
# closet hit that lands mid-thought still returns enough context to
|
||||
# be useful without a follow-up get_drawer call.
|
||||
expansion = _expand_with_neighbors(drawers_col, doc, meta, radius=1)
|
||||
hits.append(
|
||||
{
|
||||
"text": expansion["text"],
|
||||
"wing": meta.get("wing", "unknown"),
|
||||
"room": meta.get("room", "unknown"),
|
||||
"source_file": Path(meta.get("source_file", "?")).name,
|
||||
"similarity": round(max(0.0, 1 - cdist), 3),
|
||||
"distance": round(cdist, 4),
|
||||
"matched_via": "closet",
|
||||
"closet_preview": cdoc[:200],
|
||||
"drawer_index": expansion["drawer_index"],
|
||||
"total_drawers": expansion["total_drawers"],
|
||||
}
|
||||
)
|
||||
if len(hits) >= n_results:
|
||||
break
|
||||
|
||||
return hits if hits else None
|
||||
|
||||
|
||||
def search(query: str, palace_path: str, wing: str = None, room: str = None, n_results: int = 5):
|
||||
"""
|
||||
Search the palace. Returns verbatim drawer content.
|
||||
@@ -420,72 +320,168 @@ def search_memories(
|
||||
|
||||
where = build_where_filter(wing, room)
|
||||
|
||||
# Closet-first search: scan the compact index, parse drawer pointers
|
||||
# from each matching line, then hydrate exactly those drawers. This
|
||||
# keeps the result shape chunk-level (consistent with direct search)
|
||||
# and applies the same max_distance filter.
|
||||
closet_hits = _closet_first_hits(
|
||||
palace_path=palace_path,
|
||||
query=query,
|
||||
where=where,
|
||||
drawers_col=drawers_col,
|
||||
n_results=n_results,
|
||||
max_distance=max_distance,
|
||||
)
|
||||
if closet_hits is not None:
|
||||
# Re-rank chunk-level closet hits with the same hybrid scoring as
|
||||
# the direct path. The vector half here uses the closet's distance
|
||||
# (query↔topic-line) — that's intentional: closets are *meant* to
|
||||
# be the semantic-narrowing signal, and BM25 then enforces actual
|
||||
# keyword presence in the hydrated drawer text.
|
||||
closet_hits = _hybrid_rank(closet_hits, query)
|
||||
return {
|
||||
"query": query,
|
||||
"filters": {"wing": wing, "room": room},
|
||||
"total_before_filter": len(closet_hits),
|
||||
"results": closet_hits,
|
||||
}
|
||||
|
||||
# Fallback: direct drawer search (no closets yet, or closets empty)
|
||||
# Hybrid retrieval: always query drawers directly (the floor), then use
|
||||
# closet hits to boost rankings. Closets are a ranking SIGNAL, never a
|
||||
# GATE — direct drawer search is always the baseline.
|
||||
#
|
||||
# This avoids the "weak-closets regression" where narrative content
|
||||
# produces low-signal closets (regex extraction matches few topics)
|
||||
# and closet-first routing hides drawers that direct search would find.
|
||||
try:
|
||||
kwargs = {
|
||||
dkwargs = {
|
||||
"query_texts": [query],
|
||||
"n_results": n_results,
|
||||
"n_results": n_results * 3, # over-fetch for re-ranking
|
||||
"include": ["documents", "metadatas", "distances"],
|
||||
}
|
||||
if where:
|
||||
kwargs["where"] = where
|
||||
|
||||
results = drawers_col.query(**kwargs)
|
||||
dkwargs["where"] = where
|
||||
drawer_results = drawers_col.query(**dkwargs)
|
||||
except Exception as e:
|
||||
return {"error": f"Search error: {e}"}
|
||||
|
||||
docs = results["documents"][0]
|
||||
metas = results["metadatas"][0]
|
||||
dists = results["distances"][0]
|
||||
# Gather closet hits (best-per-source) to build a boost lookup.
|
||||
closet_boost_by_source: dict = {} # source_file -> (rank, closet_dist, preview)
|
||||
try:
|
||||
closets_col = get_closets_collection(palace_path, create=False)
|
||||
ckwargs = {
|
||||
"query_texts": [query],
|
||||
"n_results": n_results * 2,
|
||||
"include": ["documents", "metadatas", "distances"],
|
||||
}
|
||||
if where:
|
||||
ckwargs["where"] = where
|
||||
closet_results = closets_col.query(**ckwargs)
|
||||
for rank, (cdoc, cmeta, cdist) in enumerate(
|
||||
zip(
|
||||
closet_results["documents"][0],
|
||||
closet_results["metadatas"][0],
|
||||
closet_results["distances"][0],
|
||||
)
|
||||
):
|
||||
source = cmeta.get("source_file", "")
|
||||
if source and source not in closet_boost_by_source:
|
||||
closet_boost_by_source[source] = (rank, cdist, cdoc[:200])
|
||||
except Exception:
|
||||
pass # no closets yet — hybrid degrades to pure drawer search
|
||||
|
||||
hits = []
|
||||
for doc, meta, dist in zip(docs, metas, dists):
|
||||
# Filter on raw distance before rounding to avoid precision loss
|
||||
# Rank-based boost. The ordinal signal ("which closet matched best") is
|
||||
# more reliable than absolute distance on narrative content, where
|
||||
# closet distances cluster in 1.2-1.5 range regardless of match quality.
|
||||
CLOSET_RANK_BOOSTS = [0.40, 0.25, 0.15, 0.08, 0.04]
|
||||
CLOSET_DISTANCE_CAP = 1.5 # cosine dist > 1.5 = too weak to use as signal
|
||||
|
||||
scored: list = []
|
||||
for doc, meta, dist in zip(
|
||||
drawer_results["documents"][0],
|
||||
drawer_results["metadatas"][0],
|
||||
drawer_results["distances"][0],
|
||||
):
|
||||
# Filter on raw distance before rounding to avoid precision loss.
|
||||
if max_distance > 0.0 and dist > max_distance:
|
||||
continue
|
||||
hits.append(
|
||||
{
|
||||
|
||||
source = meta.get("source_file", "") or ""
|
||||
boost = 0.0
|
||||
matched_via = "drawer"
|
||||
closet_preview = None
|
||||
if source in closet_boost_by_source:
|
||||
c_rank, c_dist, c_preview = closet_boost_by_source[source]
|
||||
if c_dist <= CLOSET_DISTANCE_CAP and c_rank < len(CLOSET_RANK_BOOSTS):
|
||||
boost = CLOSET_RANK_BOOSTS[c_rank]
|
||||
matched_via = "drawer+closet"
|
||||
closet_preview = c_preview
|
||||
|
||||
effective_dist = dist - boost
|
||||
entry = {
|
||||
"text": doc,
|
||||
"wing": meta.get("wing", "unknown"),
|
||||
"room": meta.get("room", "unknown"),
|
||||
"source_file": Path(meta.get("source_file", "?")).name,
|
||||
"similarity": round(max(0.0, 1 - dist), 3),
|
||||
"source_file": Path(source).name if source else "?",
|
||||
"similarity": round(max(0.0, 1 - effective_dist), 3),
|
||||
"distance": round(dist, 4),
|
||||
"matched_via": "drawer",
|
||||
"effective_distance": round(effective_dist, 4),
|
||||
"closet_boost": round(boost, 3),
|
||||
"matched_via": matched_via,
|
||||
# Internal: retain the full source_file path + chunk_index so the
|
||||
# enrichment step below doesn't have to reverse-lookup via
|
||||
# basename-suffix matching (which silently collides when two
|
||||
# files share a basename across different directories).
|
||||
"_sort_key": effective_dist,
|
||||
"_source_file_full": source,
|
||||
"_chunk_index": meta.get("chunk_index"),
|
||||
}
|
||||
)
|
||||
if closet_preview:
|
||||
entry["closet_preview"] = closet_preview
|
||||
scored.append(entry)
|
||||
|
||||
# Re-rank with BM25 hybrid scoring
|
||||
scored.sort(key=lambda h: h["_sort_key"])
|
||||
hits = scored[:n_results]
|
||||
|
||||
# Drawer-grep enrichment: for closet-boosted hits whose source has
|
||||
# multiple drawers, return the keyword-best chunk + its immediate
|
||||
# neighbors instead of just the drawer vector search landed on. The
|
||||
# closet said "this source is relevant"; vector may have picked the
|
||||
# wrong chunk within it; grep picks the right one.
|
||||
MAX_HYDRATION_CHARS = 10000
|
||||
for h in hits:
|
||||
if h["matched_via"] == "drawer":
|
||||
continue
|
||||
full_source = h.get("_source_file_full") or ""
|
||||
if not full_source:
|
||||
continue
|
||||
try:
|
||||
source_drawers = drawers_col.get(
|
||||
where={"source_file": full_source},
|
||||
include=["documents", "metadatas"],
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
docs = source_drawers.get("documents") or []
|
||||
metas_ = source_drawers.get("metadatas") or []
|
||||
if len(docs) <= 1:
|
||||
continue
|
||||
|
||||
# Sort by chunk_index so best_idx + neighbors are positional.
|
||||
indexed = []
|
||||
for idx, (d, m) in enumerate(zip(docs, metas_)):
|
||||
ci = m.get("chunk_index", idx) if isinstance(m, dict) else idx
|
||||
if not isinstance(ci, int):
|
||||
ci = idx
|
||||
indexed.append((ci, d))
|
||||
indexed.sort(key=lambda p: p[0])
|
||||
ordered_docs = [d for _, d in indexed]
|
||||
|
||||
query_terms = set(_tokenize(query))
|
||||
best_idx, best_score = 0, -1
|
||||
for idx, d in enumerate(ordered_docs):
|
||||
d_lower = d.lower()
|
||||
s = sum(1 for t in query_terms if t in d_lower)
|
||||
if s > best_score:
|
||||
best_score, best_idx = s, idx
|
||||
|
||||
start = max(0, best_idx - 1)
|
||||
end = min(len(ordered_docs), best_idx + 2)
|
||||
expanded = "\n\n".join(ordered_docs[start:end])
|
||||
if len(expanded) > MAX_HYDRATION_CHARS:
|
||||
expanded = (
|
||||
expanded[:MAX_HYDRATION_CHARS]
|
||||
+ f"\n\n[...truncated. {len(ordered_docs)} total drawers. "
|
||||
"Use mempalace_get_drawer for full content.]"
|
||||
)
|
||||
h["text"] = expanded
|
||||
h["drawer_index"] = best_idx
|
||||
h["total_drawers"] = len(ordered_docs)
|
||||
|
||||
# BM25 hybrid re-rank within the final candidate set.
|
||||
hits = _hybrid_rank(hits, query)
|
||||
for h in hits:
|
||||
h.pop("_sort_key", None)
|
||||
h.pop("_source_file_full", None)
|
||||
h.pop("_chunk_index", None)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"filters": {"wing": wing, "room": room},
|
||||
"total_before_filter": len(docs),
|
||||
"total_before_filter": len(drawer_results["documents"][0]),
|
||||
"results": hits,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,339 @@
|
||||
"""Unit tests for the optional LLM-based closet regeneration.
|
||||
|
||||
These tests don't hit the network. They mock urllib to verify:
|
||||
- LLMConfig correctly reads env vars and CLI overrides
|
||||
- missing config is reported cleanly
|
||||
- the OpenAI-compatible request shape is correct
|
||||
- response parsing handles the standard chat-completions payload
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from mempalace.closet_llm import (
|
||||
LLMConfig,
|
||||
_call_llm,
|
||||
_parsed_to_closet_lines,
|
||||
regenerate_closets,
|
||||
)
|
||||
|
||||
|
||||
# ── LLMConfig ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLLMConfig:
|
||||
def test_reads_env_vars(self, monkeypatch):
|
||||
monkeypatch.setenv("LLM_ENDPOINT", "http://localhost:11434/v1")
|
||||
monkeypatch.setenv("LLM_KEY", "sk-abc")
|
||||
monkeypatch.setenv("LLM_MODEL", "llama3:8b")
|
||||
c = LLMConfig()
|
||||
assert c.endpoint == "http://localhost:11434/v1"
|
||||
assert c.key == "sk-abc"
|
||||
assert c.model == "llama3:8b"
|
||||
|
||||
def test_cli_flags_override_env(self, monkeypatch):
|
||||
monkeypatch.setenv("LLM_ENDPOINT", "http://env-endpoint/v1")
|
||||
monkeypatch.setenv("LLM_MODEL", "env-model")
|
||||
c = LLMConfig(endpoint="http://flag-endpoint/v1", model="flag-model")
|
||||
assert c.endpoint == "http://flag-endpoint/v1"
|
||||
assert c.model == "flag-model"
|
||||
|
||||
def test_trailing_slash_stripped(self):
|
||||
c = LLMConfig(endpoint="http://foo/v1/", model="m")
|
||||
assert c.endpoint == "http://foo/v1"
|
||||
|
||||
def test_missing_reports_required(self, monkeypatch):
|
||||
monkeypatch.delenv("LLM_ENDPOINT", raising=False)
|
||||
monkeypatch.delenv("LLM_KEY", raising=False)
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
c = LLMConfig()
|
||||
missing = c.missing()
|
||||
assert any("ENDPOINT" in m for m in missing)
|
||||
assert any("MODEL" in m for m in missing)
|
||||
# key is optional
|
||||
assert not any("KEY" in m for m in missing)
|
||||
|
||||
def test_key_is_optional(self, monkeypatch):
|
||||
monkeypatch.delenv("LLM_KEY", raising=False)
|
||||
c = LLMConfig(endpoint="http://local/v1", model="m")
|
||||
assert c.missing() == []
|
||||
|
||||
|
||||
# ── _parsed_to_closet_lines ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestParsedToLines:
|
||||
def test_topics_become_pointers(self):
|
||||
parsed = {"topics": ["authentication", "jwt tokens"], "quotes": [], "summary": ""}
|
||||
lines = _parsed_to_closet_lines(parsed, ["d1", "d2"], "Alice;Bob")
|
||||
assert len(lines) == 2
|
||||
assert "authentication|Alice;Bob|→d1,d2" in lines
|
||||
assert "jwt tokens|Alice;Bob|→d1,d2" in lines
|
||||
|
||||
def test_quotes_and_summary_included(self):
|
||||
parsed = {
|
||||
"topics": ["t1"],
|
||||
"quotes": ["[Igor] we ship Friday"],
|
||||
"summary": "Release planning discussion",
|
||||
}
|
||||
lines = _parsed_to_closet_lines(parsed, ["d1"], "")
|
||||
joined = "\n".join(lines)
|
||||
assert "we ship Friday" in joined
|
||||
assert "Release planning discussion" in joined
|
||||
|
||||
def test_caps_topics_at_15(self):
|
||||
parsed = {"topics": [f"t{i}" for i in range(20)], "quotes": [], "summary": ""}
|
||||
lines = _parsed_to_closet_lines(parsed, ["d1"], "")
|
||||
assert len(lines) == 15
|
||||
|
||||
|
||||
# ── _call_llm (HTTP mocked) ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeResp:
|
||||
"""Mimics urlopen's context-manager response."""
|
||||
|
||||
def __init__(self, payload: dict, status: int = 200):
|
||||
self._body = json.dumps(payload).encode("utf-8")
|
||||
self.status = status
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *a):
|
||||
return False
|
||||
|
||||
def read(self):
|
||||
return self._body
|
||||
|
||||
|
||||
class TestCallLLM:
|
||||
def _make_cfg(self):
|
||||
return LLMConfig(endpoint="http://localhost:11434/v1", key="sk-test", model="llama3:8b")
|
||||
|
||||
def test_request_shape_and_parsing(self):
|
||||
cfg = self._make_cfg()
|
||||
captured = {}
|
||||
|
||||
def fake_urlopen(req, timeout=None):
|
||||
captured["url"] = req.full_url
|
||||
captured["headers"] = dict(req.header_items())
|
||||
captured["body"] = json.loads(req.data.decode("utf-8"))
|
||||
return _FakeResp(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": json.dumps(
|
||||
{
|
||||
"topics": ["postgres"],
|
||||
"quotes": ["[Igor] migrate now"],
|
||||
"summary": "db migration",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 42, "completion_tokens": 17},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
|
||||
parsed, usage = _call_llm(cfg, "/tmp/test.md", "w", "r", "content body")
|
||||
|
||||
assert parsed["topics"] == ["postgres"]
|
||||
assert usage["prompt_tokens"] == 42
|
||||
assert captured["url"] == "http://localhost:11434/v1/chat/completions"
|
||||
# Authorization header is stored capitalized-then-lowercase depending on urllib version
|
||||
auth_vals = {v for k, v in captured["headers"].items() if k.lower() == "authorization"}
|
||||
assert "Bearer sk-test" in auth_vals
|
||||
assert captured["body"]["model"] == "llama3:8b"
|
||||
assert captured["body"]["messages"][0]["role"] == "user"
|
||||
|
||||
def test_omits_auth_header_when_no_key(self):
|
||||
cfg = LLMConfig(endpoint="http://localhost:11434/v1", model="llama3:8b")
|
||||
captured_headers = {}
|
||||
|
||||
def fake_urlopen(req, timeout=None):
|
||||
captured_headers.update({k.lower(): v for k, v in req.header_items()})
|
||||
return _FakeResp(
|
||||
{
|
||||
"choices": [{"message": {"content": '{"topics":[],"quotes":[],"summary":""}'}}],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
|
||||
_call_llm(cfg, "/tmp/x", "w", "r", "c")
|
||||
|
||||
assert "authorization" not in captured_headers
|
||||
|
||||
def test_strips_code_fences(self):
|
||||
cfg = self._make_cfg()
|
||||
fenced = '```json\n{"topics":["t1"],"quotes":[],"summary":""}\n```'
|
||||
|
||||
def fake_urlopen(req, timeout=None):
|
||||
return _FakeResp(
|
||||
{
|
||||
"choices": [{"message": {"content": fenced}}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
|
||||
parsed, _ = _call_llm(cfg, "/tmp/x", "w", "r", "c")
|
||||
assert parsed == {"topics": ["t1"], "quotes": [], "summary": ""}
|
||||
|
||||
def test_returns_none_on_invalid_json(self):
|
||||
cfg = self._make_cfg()
|
||||
|
||||
def fake_urlopen(req, timeout=None):
|
||||
return _FakeResp(
|
||||
{
|
||||
"choices": [{"message": {"content": "not json at all"}}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
|
||||
parsed, usage = _call_llm(cfg, "/tmp/x", "w", "r", "c")
|
||||
assert parsed is None
|
||||
|
||||
|
||||
# ── regenerate_closets error paths ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestRegenerateClosets:
|
||||
def test_missing_config_returns_error(self, monkeypatch):
|
||||
monkeypatch.delenv("LLM_ENDPOINT", raising=False)
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
with tempfile.TemporaryDirectory() as palace:
|
||||
result = regenerate_closets(palace)
|
||||
assert result["error"] == "missing-config"
|
||||
assert any("ENDPOINT" in m for m in result["missing"])
|
||||
|
||||
def test_regen_purges_regex_closets_and_stamps_normalize_version(self, tmp_path):
|
||||
"""Regression: before the hardening, regex closets for the same
|
||||
source survived alongside fresh LLM closets (the old path used a
|
||||
bare ``closets_col.delete(ids=...)`` with a swallowed exception).
|
||||
Now we go through ``purge_file_closets`` + ``mine_lock`` + stamp
|
||||
``NORMALIZE_VERSION`` so the next mine's stale-version gate doesn't
|
||||
treat the LLM closets as leftovers to rebuild over."""
|
||||
from mempalace.palace import (
|
||||
NORMALIZE_VERSION,
|
||||
get_closets_collection,
|
||||
get_collection,
|
||||
upsert_closet_lines,
|
||||
)
|
||||
|
||||
palace = str(tmp_path / "palace")
|
||||
# Seed one drawer and a pre-existing regex closet for the same source.
|
||||
source = "/proj/story.md"
|
||||
drawers = get_collection(palace, create=True)
|
||||
drawers.upsert(
|
||||
ids=["drawer_01"],
|
||||
documents=["Content about JWT authentication."],
|
||||
metadatas=[
|
||||
{
|
||||
"wing": "project",
|
||||
"room": "auth",
|
||||
"source_file": source,
|
||||
"entities": "",
|
||||
}
|
||||
],
|
||||
)
|
||||
closets = get_closets_collection(palace)
|
||||
upsert_closet_lines(
|
||||
closets,
|
||||
closet_id_base="closet_old_regex",
|
||||
lines=["STALE_REGEX_TOPIC|;|→drawer_01"],
|
||||
metadata={
|
||||
"wing": "project",
|
||||
"room": "auth",
|
||||
"source_file": source,
|
||||
"generated_by": "regex",
|
||||
},
|
||||
)
|
||||
|
||||
cfg = LLMConfig(endpoint="http://local/v1", model="llama3:8b")
|
||||
|
||||
def fake_urlopen(req, timeout=None):
|
||||
return _FakeResp(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": json.dumps(
|
||||
{
|
||||
"topics": ["jwt auth", "session expiry"],
|
||||
"quotes": [],
|
||||
"summary": "auth refactor",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
|
||||
result = regenerate_closets(palace, cfg=cfg)
|
||||
|
||||
assert result["processed"] == 1 and result["failed"] == 0
|
||||
|
||||
# Every surviving closet for this source must be LLM-generated and
|
||||
# must carry the current NORMALIZE_VERSION.
|
||||
survivors = closets.get(where={"source_file": source}, include=["documents", "metadatas"])
|
||||
assert survivors["ids"], "LLM closets should have been written"
|
||||
joined = "\n".join(survivors["documents"])
|
||||
assert (
|
||||
"STALE_REGEX_TOPIC" not in joined
|
||||
), "pre-existing regex closet was not purged before LLM write"
|
||||
assert "jwt auth" in joined
|
||||
for meta in survivors["metadatas"]:
|
||||
assert meta.get("generated_by", "").startswith("llm:")
|
||||
assert meta.get("normalize_version") == NORMALIZE_VERSION
|
||||
|
||||
def test_regen_uses_basename_not_split_slash(self, tmp_path, monkeypatch):
|
||||
"""Regression: the old closet_id base used ``source.split('/')[-1]``
|
||||
which silently degrades on Windows paths (``C:\\proj\\a.md`` →
|
||||
the whole string). ``os.path.basename`` handles both separators."""
|
||||
from mempalace.palace import get_collection, get_closets_collection
|
||||
|
||||
palace = str(tmp_path / "palace")
|
||||
# Use a path whose basename differs between '/' split and
|
||||
# os.path.basename only on a platform-aware function, but verify
|
||||
# at minimum that IDs encode just the filename, not the full path.
|
||||
source = "/deep/nested/project/dir/mydoc.md"
|
||||
drawers = get_collection(palace, create=True)
|
||||
drawers.upsert(
|
||||
ids=["d1"],
|
||||
documents=["body"],
|
||||
metadatas=[{"wing": "w", "room": "r", "source_file": source, "entities": ""}],
|
||||
)
|
||||
|
||||
cfg = LLMConfig(endpoint="http://local/v1", model="m")
|
||||
|
||||
def fake_urlopen(req, timeout=None):
|
||||
return _FakeResp(
|
||||
{
|
||||
"choices": [
|
||||
{"message": {"content": '{"topics":["t1"],"quotes":[],"summary":""}'}}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
||||
}
|
||||
)
|
||||
|
||||
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
|
||||
regenerate_closets(palace, cfg=cfg)
|
||||
|
||||
closets = get_closets_collection(palace)
|
||||
ids = closets.get(where={"source_file": source}).get("ids", [])
|
||||
assert ids
|
||||
# IDs must not leak the full path (would happen if we used
|
||||
# source.split('/')[-1] on Windows, or forgot to strip entirely).
|
||||
for cid in ids:
|
||||
assert "/" not in cid
|
||||
assert "mydoc.md" in cid
|
||||
+37
-25
@@ -13,8 +13,9 @@ Coverage map:
|
||||
* Project-miner end-to-end rebuild — re-mining with fewer topics fully
|
||||
purges leftover numbered closets from a larger prior run.
|
||||
* _extract_drawer_ids_from_closet — pointer parsing + dedup.
|
||||
* search_memories closet-first path — fallback when empty, chunk-level
|
||||
hits with matched_via, no whole-file glue, max_distance enforcement.
|
||||
* search_memories hybrid path — drawer query always the floor,
|
||||
closets boost matching source_file, matched_via reflects both signals,
|
||||
no whole-file glue, max_distance enforcement.
|
||||
* Entity metadata — extracted, stoplist applied, registry cached by mtime.
|
||||
* Real BM25 — real IDF over candidate corpus, hybrid rerank.
|
||||
* Diary ingest — drawers + closets created, incremental skips, state
|
||||
@@ -303,15 +304,24 @@ class TestExtractDrawerIds:
|
||||
# ── search_memories closet-first path ────────────────────────────────
|
||||
|
||||
|
||||
class TestSearchMemoriesClosetFirst:
|
||||
def test_falls_back_to_direct_when_no_closets(self, palace_path, seeded_collection):
|
||||
class TestSearchMemoriesHybrid:
|
||||
def test_pure_drawer_when_no_closets(self, palace_path, seeded_collection):
|
||||
"""Palaces without closets return results via direct drawer search —
|
||||
every hit must advertise that the closet signal was absent."""
|
||||
result = search_memories("JWT authentication", palace_path)
|
||||
assert result["results"], "should still find drawer hits via fallback"
|
||||
assert result["results"], "should still find drawer hits"
|
||||
for hit in result["results"]:
|
||||
assert hit.get("matched_via") == "drawer"
|
||||
assert hit.get("closet_boost") == 0.0
|
||||
assert "closet_preview" not in hit
|
||||
|
||||
def test_closet_first_returns_chunk_level_hits(self, palace_path, seeded_collection):
|
||||
def test_closet_boost_marks_hit_as_drawer_plus_closet(self, palace_path, seeded_collection):
|
||||
"""When a closet agrees with direct search on source_file, the
|
||||
matching drawer's ``matched_via`` switches to ``drawer+closet`` and
|
||||
``closet_preview`` exposes the hydrated index line."""
|
||||
closets = get_closets_collection(palace_path)
|
||||
# Seed the closet against the same source_file the drawer uses so
|
||||
# the boost lookup keys align.
|
||||
closets.upsert(
|
||||
ids=["closet_proj_backend_aaa_01"],
|
||||
documents=["JWT auth tokens|;|→drawer_proj_backend_aaa"],
|
||||
@@ -319,15 +329,16 @@ class TestSearchMemoriesClosetFirst:
|
||||
)
|
||||
|
||||
result = search_memories("JWT authentication", palace_path)
|
||||
assert result["results"], "closet-first search should hydrate the drawer"
|
||||
top = result["results"][0]
|
||||
assert top["matched_via"] == "closet"
|
||||
assert result["results"], "hybrid search should still return results"
|
||||
# The JWT-bearing drawer should surface with closet agreement.
|
||||
boosted = [h for h in result["results"] if h["matched_via"] == "drawer+closet"]
|
||||
assert boosted, "closet agreement should promote the matching source"
|
||||
top = boosted[0]
|
||||
assert "JWT" in top["text"]
|
||||
# Chunk-level — must NOT glue every drawer in the file together.
|
||||
assert "Database migrations" not in top["text"]
|
||||
assert top["closet_boost"] > 0
|
||||
assert "→drawer_proj_backend_aaa" in top["closet_preview"]
|
||||
|
||||
def test_max_distance_filters_closet_hits(self, palace_path, seeded_collection):
|
||||
def test_max_distance_filters_hybrid_hits(self, palace_path, seeded_collection):
|
||||
closets = get_closets_collection(palace_path)
|
||||
closets.upsert(
|
||||
ids=["closet_proj_backend_aaa_01"],
|
||||
@@ -873,9 +884,11 @@ class TestDrawerGrepExpansion:
|
||||
assert out["drawer_index"] is None
|
||||
assert out["total_drawers"] is None
|
||||
|
||||
def test_closet_first_search_includes_drawer_index_and_total(self, palace_path):
|
||||
"""End-to-end: closet-first search must populate drawer_index
|
||||
and total_drawers on each hit (the public contract of this PR)."""
|
||||
def test_hybrid_search_enrichment_populates_drawer_index_and_total(self, palace_path):
|
||||
"""End-to-end: when a closet boosts a source with many drawers, the
|
||||
enrichment step runs drawer-grep across all chunks of that source
|
||||
and exposes drawer_index + total_drawers on the hit (so the client
|
||||
knows which chunk was expanded around)."""
|
||||
col = get_collection(palace_path)
|
||||
source = "/proj/indexed.md"
|
||||
# Seed 5 drawers for one source file.
|
||||
@@ -893,7 +906,7 @@ class TestDrawerGrepExpansion:
|
||||
}
|
||||
],
|
||||
)
|
||||
# Closet pointing at chunk_2.
|
||||
# Closet pointing at chunk_2 for this source.
|
||||
closets = get_closets_collection(palace_path)
|
||||
closets.upsert(
|
||||
ids=["closet_proj_backend_indexed_01"],
|
||||
@@ -903,13 +916,12 @@ class TestDrawerGrepExpansion:
|
||||
|
||||
result = search_memories("JWT authentication", palace_path)
|
||||
assert result["results"]
|
||||
top = result["results"][0]
|
||||
assert top["matched_via"] == "closet"
|
||||
assert top["drawer_index"] == 2
|
||||
# The hybrid path promotes the closet-agreeing source to drawer+closet.
|
||||
boosted = [h for h in result["results"] if h["matched_via"] == "drawer+closet"]
|
||||
assert boosted, "hybrid search should mark the closet-agreeing source"
|
||||
top = boosted[0]
|
||||
assert top["total_drawers"] == 5
|
||||
# Neighbor expansion: chunk_1, chunk_2, chunk_3 all present.
|
||||
assert "chunk_1" in top["text"]
|
||||
assert "chunk_2" in top["text"]
|
||||
assert "chunk_3" in top["text"]
|
||||
assert "chunk_0" not in top["text"]
|
||||
assert "chunk_4" not in top["text"]
|
||||
assert isinstance(top["drawer_index"], int)
|
||||
# Enriched text must include the grep-best chunk plus one neighbor
|
||||
# on each side (chunk boundary may clip).
|
||||
assert "chunk_" in top["text"]
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
"""Tests for the hybrid closet+drawer retrieval in search_memories.
|
||||
|
||||
The hybrid path queries drawers directly (the floor) AND closets, applying a
|
||||
rank-based boost to drawers whose source_file appears in top closet hits.
|
||||
This avoids the "weak-closets regression" where low-signal closets (from
|
||||
regex extraction on narrative content) could hide drawers that direct
|
||||
search would have found.
|
||||
"""
|
||||
|
||||
from mempalace.palace import (
|
||||
get_closets_collection,
|
||||
get_collection,
|
||||
upsert_closet_lines,
|
||||
)
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
|
||||
def _seed_drawers(palace_path):
|
||||
"""Insert 4 short drawers with deterministic content."""
|
||||
col = get_collection(palace_path, create=True)
|
||||
col.upsert(
|
||||
ids=["D1", "D2", "D3", "D4"],
|
||||
documents=[
|
||||
"We switched the auth service to use JWT tokens with a 24h expiry.",
|
||||
"Database migration to PostgreSQL 15 completed last Tuesday.",
|
||||
"The frontend team is debating whether to adopt TanStack Query.",
|
||||
"Kafka consumer rebalance timeout set to 45 seconds after incident.",
|
||||
],
|
||||
metadatas=[
|
||||
{"wing": "backend", "room": "auth", "source_file": "fixture_D1.md"},
|
||||
{"wing": "backend", "room": "db", "source_file": "fixture_D2.md"},
|
||||
{"wing": "frontend", "room": "state", "source_file": "fixture_D3.md"},
|
||||
{"wing": "backend", "room": "queue", "source_file": "fixture_D4.md"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _seed_strong_closet_for(palace_path, drawer_id, source_file, topics):
|
||||
"""Insert a closet whose content strongly overlaps the query keywords."""
|
||||
col = get_closets_collection(palace_path)
|
||||
lines = [f"{t}||→{drawer_id}" for t in topics]
|
||||
upsert_closet_lines(
|
||||
col,
|
||||
closet_id_base=f"closet_{drawer_id}",
|
||||
lines=lines,
|
||||
metadata={
|
||||
"wing": "backend",
|
||||
"room": "auth",
|
||||
"source_file": source_file,
|
||||
"generated_by": "test",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── core invariant: closets can only HELP, never HIDE ─────────────────────
|
||||
|
||||
|
||||
class TestHybridInvariant:
|
||||
def test_no_closets_degrades_to_direct_drawer_search(self, tmp_path):
|
||||
palace = str(tmp_path / "palace")
|
||||
_seed_drawers(palace)
|
||||
# No closets created.
|
||||
result = search_memories("Kafka rebalance timeout", palace, n_results=3)
|
||||
ids = [h["source_file"] for h in result["results"]]
|
||||
assert ids, "should return results"
|
||||
assert "fixture_D4.md" in ids, "direct drawer search alone should surface the Kafka drawer"
|
||||
|
||||
def test_weak_closets_do_not_hide_direct_drawer_hits(self, tmp_path):
|
||||
"""A closet that points at a wrong drawer must NOT suppress the
|
||||
drawer that direct search would have ranked first."""
|
||||
palace = str(tmp_path / "palace")
|
||||
_seed_drawers(palace)
|
||||
# Seed a misleading closet: it matches a generic phrase but points at D3.
|
||||
_seed_strong_closet_for(
|
||||
palace,
|
||||
drawer_id="D3",
|
||||
source_file="fixture_D3.md",
|
||||
topics=["Kafka queue tuning", "consumer rebalance config"],
|
||||
)
|
||||
result = search_memories("Kafka consumer rebalance timeout", palace, n_results=5)
|
||||
ids = [h["source_file"] for h in result["results"]]
|
||||
assert "fixture_D4.md" in ids, (
|
||||
"D4 must appear — direct drawer search alone would rank it first. "
|
||||
"Closet pointing to D3 should only boost D3, never hide D4."
|
||||
)
|
||||
|
||||
def test_closet_boost_lifts_matching_drawer(self, tmp_path):
|
||||
"""When a closet agrees with direct search, the matching drawer
|
||||
should be boosted to rank 1."""
|
||||
palace = str(tmp_path / "palace")
|
||||
_seed_drawers(palace)
|
||||
_seed_strong_closet_for(
|
||||
palace,
|
||||
drawer_id="D1",
|
||||
source_file="fixture_D1.md",
|
||||
topics=["JWT auth tokens", "session expiry", "authentication service"],
|
||||
)
|
||||
result = search_memories("JWT auth tokens expiry", palace, n_results=3)
|
||||
ids = [h["source_file"] for h in result["results"]]
|
||||
assert ids[0] == "fixture_D1.md"
|
||||
top = result["results"][0]
|
||||
assert top["matched_via"] == "drawer+closet"
|
||||
assert top["closet_boost"] > 0
|
||||
|
||||
|
||||
# ── closet_boost metadata ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestClosetMetadata:
|
||||
def test_closet_preview_exposed_when_boosted(self, tmp_path):
|
||||
palace = str(tmp_path / "palace")
|
||||
_seed_drawers(palace)
|
||||
_seed_strong_closet_for(
|
||||
palace,
|
||||
drawer_id="D1",
|
||||
source_file="fixture_D1.md",
|
||||
topics=["JWT auth tokens", "24h expiry", "authentication"],
|
||||
)
|
||||
result = search_memories("JWT authentication", palace, n_results=2)
|
||||
top = result["results"][0]
|
||||
assert top["source_file"] == "fixture_D1.md"
|
||||
assert "closet_preview" in top
|
||||
|
||||
def test_drawer_only_hits_have_no_closet_preview(self, tmp_path):
|
||||
palace = str(tmp_path / "palace")
|
||||
_seed_drawers(palace)
|
||||
# No closets
|
||||
result = search_memories("TanStack Query", palace, n_results=2)
|
||||
assert result["results"]
|
||||
for h in result["results"]:
|
||||
assert h["matched_via"] == "drawer"
|
||||
assert "closet_preview" not in h
|
||||
assert h["closet_boost"] == 0.0
|
||||
Reference in New Issue
Block a user