Merge pull request #829 from MemPalace/pr/fact-checker

feat(v3.3): land Milla's stacked closet/BM25/KG/LLM chain (#784-#795) on develop
This commit is contained in:
Igor Lins e Silva
2026-04-13 19:19:05 -03:00
committed by GitHub
11 changed files with 3193 additions and 229 deletions
+351
View File
@@ -0,0 +1,351 @@
"""
closet_llm.py — Generate closets via a user-configured LLM for richer indexing.
The regex-based closet extraction catches action verbs, headers, and proper
nouns — but misses implicit topics, foreign-language content, and contextual
references. An LLM reads everything and produces better closets.
This module is **OPTIONAL and opt-in**. Regex closets are always created by
the miner; this path regenerates them afterward using whatever LLM the user
chooses. Core memory operations remain API-free by design (see CLAUDE.md,
"Local-first, zero API").
## Bring-your-own-LLM configuration
The endpoint is any OpenAI-compatible Chat Completions URL:
LLM_ENDPOINT=http://localhost:11434/v1 # Ollama
LLM_ENDPOINT=http://localhost:8000/v1 # vLLM, llama.cpp
LLM_ENDPOINT=https://api.openai.com/v1
LLM_ENDPOINT=https://openrouter.ai/api/v1
LLM_ENDPOINT=https://api.anthropic.com/v1 # when proxied through a compat layer
Set:
LLM_ENDPOINT — base URL (required)
LLM_KEY — bearer token (optional; local inference usually doesn't need it)
LLM_MODEL — model name (required), e.g. "gpt-4o-mini", "llama3:8b", "qwen2.5:7b"
Or pass flags on the CLI (flags win over env):
python -m mempalace.closet_llm \\
--palace ~/.mempalace/palace \\
--endpoint http://localhost:11434/v1 \\
--model llama3:8b
No vendor lock-in. No hidden dependency on any specific provider. Zero deps
added to pyproject — uses stdlib urllib.
"""
import json
import os
import re
import time
import urllib.request
import urllib.error
from datetime import datetime
from typing import Optional
from .palace import (
NORMALIZE_VERSION,
get_closets_collection,
get_collection,
mine_lock,
purge_file_closets,
upsert_closet_lines,
)
MAX_CONTENT_CHARS = 30000
MAX_OUTPUT_TOKENS = 1500
HTTP_TIMEOUT_S = 60
PROMPT_TEMPLATE = """You are reading content filed in a memory palace. Generate a
topic-dense index that will be used to find this content later when someone searches.
Source: {source_file}
Wing: {wing} | Room: {room}
CONTENT:
{content}
---
Output a JSON object with EXACTLY these fields:
{{
"topics": ["distinctive_word_or_phrase_1", "topic_2", ...],
"quotes": ["[Speaker] verbatim quote", ...],
"summary": "2-3 sentences describing what this content is about."
}}
RULES:
- Topics: 8-15 entries. Include proper nouns (names, places, projects),
distinctive technical terms, and key concepts. NOT generic words like
"conversation" or "discussion".
- Quotes: 2-5 entries. EXACT verbatim from the content, not paraphrased.
Attribute with [Speaker] prefix if speaker is identifiable.
- Summary: mention WHO, WHAT, and WHY. No filler.
- Write in the same language as the content.
- Output valid JSON only. No code fences. No commentary.
"""
class LLMConfig:
"""Resolved LLM connection config. CLI flags > env vars."""
def __init__(
self,
endpoint: Optional[str] = None,
key: Optional[str] = None,
model: Optional[str] = None,
):
self.endpoint = (endpoint or os.environ.get("LLM_ENDPOINT", "")).rstrip("/")
self.key = key or os.environ.get("LLM_KEY", "")
self.model = model or os.environ.get("LLM_MODEL", "")
def missing(self) -> list:
missing = []
if not self.endpoint:
missing.append("LLM_ENDPOINT (or --endpoint)")
if not self.model:
missing.append("LLM_MODEL (or --model)")
# key is optional — local inference servers (Ollama, vLLM) often don't require one
return missing
def _call_llm(cfg: LLMConfig, source_file: str, wing: str, room: str, content: str):
"""Single LLM call via OpenAI-compatible /chat/completions.
Returns (parsed_json_dict_or_None, usage_dict_or_None).
"""
try:
from mempalace.i18n import t
lang_instruction = t("aaak.instruction")
except Exception:
lang_instruction = ""
prompt = PROMPT_TEMPLATE.format(
source_file=source_file[:100],
wing=wing,
room=room,
content=content[:MAX_CONTENT_CHARS],
)
if lang_instruction and "english" not in lang_instruction.lower():
prompt += f"\n\nLanguage instruction: {lang_instruction}"
body = json.dumps(
{
"model": cfg.model,
"max_tokens": MAX_OUTPUT_TOKENS,
"messages": [{"role": "user", "content": prompt}],
}
).encode("utf-8")
headers = {"Content-Type": "application/json"}
if cfg.key:
headers["Authorization"] = f"Bearer {cfg.key}"
url = f"{cfg.endpoint}/chat/completions"
for attempt in range(3):
try:
req = urllib.request.Request(url, data=body, headers=headers, method="POST")
with urllib.request.urlopen(req, timeout=HTTP_TIMEOUT_S) as resp:
raw = resp.read().decode("utf-8")
payload = json.loads(raw)
text = payload["choices"][0]["message"]["content"].strip()
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```$", "", text)
parsed = json.loads(text)
return parsed, payload.get("usage")
except json.JSONDecodeError:
return None, None
except urllib.error.HTTPError as e:
# 429 / 503 = retry with backoff
if e.code in (429, 503) and attempt < 2:
time.sleep(2**attempt)
continue
return None, None
except Exception as e:
if "rate" in str(e).lower() and attempt < 2:
time.sleep(2**attempt)
continue
return None, None
return None, None
def _parsed_to_closet_lines(parsed, drawer_ids, entities_str):
"""Convert LLM's JSON output to closet pointer lines."""
lines = []
drawer_ref = ",".join(drawer_ids[:3])
for topic in parsed.get("topics", [])[:15]:
lines.append(f"{topic}|{entities_str}|→{drawer_ref}")
for quote in parsed.get("quotes", [])[:5]:
lines.append(f"{quote}|{entities_str}|→{drawer_ref}")
summary = parsed.get("summary", "")
if summary:
lines.append(f"{summary[:200]}|{entities_str}|→{drawer_ref}")
return lines
def regenerate_closets(
palace_path,
wing=None,
sample=0,
dry_run=False,
cfg: Optional[LLMConfig] = None,
):
"""Regenerate closets using a configured LLM for richer topic extraction.
Reads existing drawers, sends content to the configured endpoint,
replaces regex closets with LLM-generated ones. Regex closets remain
as the fallback whenever the call fails.
"""
if cfg is None:
cfg = LLMConfig()
missing = cfg.missing()
if missing:
print("Error: missing configuration: " + ", ".join(missing))
print("Set env vars LLM_ENDPOINT / LLM_MODEL (and optionally LLM_KEY),")
print("or pass --endpoint / --model / --key on the CLI.")
return {"error": "missing-config", "missing": missing}
drawers_col = get_collection(palace_path, create=False)
closets_col = get_closets_collection(palace_path)
total = drawers_col.count()
if total == 0:
print("No drawers in palace.")
return {"processed": 0}
all_data = drawers_col.get(limit=total, include=["documents", "metadatas"])
by_source = {}
for doc_id, doc, meta in zip(all_data["ids"], all_data["documents"], all_data["metadatas"]):
source = meta.get("source_file", "unknown")
w = meta.get("wing", "")
if wing and w != wing:
continue
if source not in by_source:
by_source[source] = {"drawer_ids": [], "content": [], "meta": meta}
by_source[source]["drawer_ids"].append(doc_id)
by_source[source]["content"].append(doc)
sources = list(by_source.keys())
if sample > 0:
sources = sources[:sample]
print(
f"Regenerating closets for {len(sources)} source files via {cfg.endpoint} ({cfg.model})..."
)
if dry_run:
print("DRY RUN — no changes will be written")
processed = 0
failed = 0
total_input = 0
total_output = 0
for i, source in enumerate(sources, 1):
data = by_source[source]
content = "\n\n".join(data["content"])
meta = data["meta"]
w = meta.get("wing", "")
r = meta.get("room", "")
entities = meta.get("entities", "")
if dry_run:
print(f" [{i}/{len(sources)}] {os.path.basename(source)} ({len(content)} chars)")
continue
parsed, usage = _call_llm(cfg, source, w, r, content)
if not parsed:
failed += 1
print(f" [{i}/{len(sources)}] ✗ {os.path.basename(source)} — LLM failed")
continue
if usage:
total_input += usage.get("prompt_tokens", 0)
total_output += usage.get("completion_tokens", 0)
lines = _parsed_to_closet_lines(parsed, data["drawer_ids"], entities)
# Use os.path.basename so Windows-style paths survive unchanged;
# the naive split('/') would leave a bare path component on Windows
# and collide across different files under different drives.
closet_id_base = f"closet_{w}_{r}_{os.path.basename(source)[:30]}"
# Serialize with concurrent mine operations on the same source —
# otherwise a regex closet rebuild mid-regenerate races with our
# purge+upsert cycle and leaves mixed regex/LLM lines.
with mine_lock(source):
purge_file_closets(closets_col, source)
upsert_closet_lines(
closets_col,
closet_id_base,
lines,
{
"wing": w,
"room": r,
"source_file": source,
"generated_by": f"llm:{cfg.model}",
"filed_at": datetime.now().isoformat(),
"entities": entities,
# Stamp so the miner's stale-drawer gate doesn't treat
# LLM closets as leftovers and rebuild over them next run.
"normalize_version": NORMALIZE_VERSION,
},
)
processed += 1
n_topics = len(parsed.get("topics", []))
print(f" [{i}/{len(sources)}] ✓ {os.path.basename(source)}{n_topics} topics")
print(f"\nDone. {processed} regenerated, {failed} failed.")
if total_input or total_output:
print(f"Tokens: {total_input:,} in + {total_output:,} out (cost depends on provider)")
return {
"processed": processed,
"failed": failed,
"input_tokens": total_input,
"output_tokens": total_output,
}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Regenerate closets via a user-configured LLM (OpenAI-compatible API)"
)
parser.add_argument(
"--palace",
default=os.path.expanduser("~/.mempalace/palace"),
help="Path to the palace",
)
parser.add_argument("--wing", default=None, help="Limit to one wing")
parser.add_argument("--sample", type=int, default=0, help="Only process first N source files")
parser.add_argument("--dry-run", action="store_true", help="List work without calling the LLM")
parser.add_argument(
"--endpoint",
default=None,
help="LLM base URL (overrides $LLM_ENDPOINT), e.g. http://localhost:11434/v1",
)
parser.add_argument(
"--key",
default=None,
help="LLM bearer token (overrides $LLM_KEY). Optional for local inference.",
)
parser.add_argument(
"--model",
default=None,
help='LLM model name (overrides $LLM_MODEL), e.g. "gpt-4o-mini" or "llama3:8b"',
)
args = parser.parse_args()
cfg = LLMConfig(endpoint=args.endpoint, key=args.key, model=args.model)
regenerate_closets(
args.palace, wing=args.wing, sample=args.sample, dry_run=args.dry_run, cfg=cfg
)
+209
View File
@@ -0,0 +1,209 @@
"""
diary_ingest.py — Ingest daily summary files into the palace.
Architecture:
- ONE drawer per (wing, day) — full verbatim content, upserted as the day grows.
- Closets pack topics up to CLOSET_CHAR_LIMIT, never split mid-topic.
- A re-ingest fully purges the prior day's closets before rebuilding so a
shorter day never leaves orphans behind.
- Only new entries are processed by default (tracks entry count in a state
file under ``~/.mempalace/state/`` — never inside the user's diary dir).
- Per-file ``mine_lock`` so concurrent ingest from two terminals can't race.
- Entities extracted and stamped on metadata for filterable search.
Usage:
python -m mempalace.diary_ingest --dir ~/daily_summaries --palace ~/.mempalace/palace
python -m mempalace.diary_ingest --dir ~/daily_summaries --palace ~/.mempalace/palace --force
"""
import hashlib
import json
import os
import re
from datetime import datetime, timezone
from pathlib import Path
from .miner import _extract_entities_for_metadata
from .palace import (
build_closet_lines,
get_closets_collection,
get_collection,
mine_lock,
purge_file_closets,
upsert_closet_lines,
)
DIARY_ENTRY_RE = re.compile(r"^## .+", re.MULTILINE)
def _state_file_for(palace_path: str, diary_dir: Path) -> Path:
"""Return the per-(palace, diary-dir) state-file path under ~/.mempalace/state.
Keyed by sha256 of (palace_path, diary_dir) so multiple diary folders
pointing at the same palace each get an independent state file. The
state file is *never* written inside the user's diary directory.
"""
state_root = Path(os.path.expanduser("~")) / ".mempalace" / "state"
state_root.mkdir(parents=True, exist_ok=True)
key = hashlib.sha256(f"{palace_path}|{diary_dir}".encode()).hexdigest()[:24]
return state_root / f"diary_ingest_{key}.json"
def _split_entries(text):
"""Split diary text into (header, body) pairs per ## entry."""
parts = DIARY_ENTRY_RE.split(text)
headers = DIARY_ENTRY_RE.findall(text)
entries = []
for i, header in enumerate(headers):
body = parts[i + 1] if i + 1 < len(parts) else ""
entries.append((header.strip(), body.strip()))
return entries
def _diary_drawer_id(wing: str, date_str: str) -> str:
"""Stable, wing-scoped drawer ID. Two diaries (e.g. 'work' vs 'personal')
sharing the same date never collide."""
suffix = hashlib.sha256(f"{wing}|{date_str}".encode()).hexdigest()[:24]
return f"drawer_diary_{suffix}"
def _diary_closet_id_base(wing: str, date_str: str) -> str:
suffix = hashlib.sha256(f"{wing}|{date_str}".encode()).hexdigest()[:24]
return f"closet_diary_{suffix}"
def ingest_diaries(
diary_dir,
palace_path,
wing="diary",
force=False,
):
"""Ingest daily summary files into the palace.
Each date file gets ONE drawer keyed by ``(wing, date)`` and closets that
pack topics atomically up to ``CLOSET_CHAR_LIMIT``. ``force=True`` rebuilds
every entry's closets from scratch (purging stale ones); the default
incremental mode only processes entries appended since the last run.
"""
diary_dir = Path(diary_dir).expanduser().resolve()
if not diary_dir.exists():
print(f"Diary directory not found: {diary_dir}")
return {"days_updated": 0, "closets_created": 0}
diary_files = sorted(diary_dir.glob("*.md"))
if not diary_files:
print(f"No .md files in {diary_dir}")
return {"days_updated": 0, "closets_created": 0}
state_file = _state_file_for(str(palace_path), diary_dir)
if force or not state_file.exists():
state: dict = {}
else:
try:
state = json.loads(state_file.read_text())
except Exception:
state = {}
drawers_col = get_collection(palace_path)
closets_col = get_closets_collection(palace_path)
days_updated = 0
closets_created = 0
for diary_path in diary_files:
text = diary_path.read_text(encoding="utf-8", errors="replace")
if len(text.strip()) < 50:
continue
date_match = re.match(r"(\d{4}-\d{2}-\d{2})", diary_path.stem)
if not date_match:
continue
date_str = date_match.group(1)
# Skip if content hasn't changed
state_key = f"{wing}|{diary_path.name}"
prev_size = state.get(state_key, {}).get("size", 0)
curr_size = len(text)
if curr_size == prev_size and not force:
continue
now_iso = datetime.now(timezone.utc).isoformat()
drawer_id = _diary_drawer_id(wing, date_str)
entities = _extract_entities_for_metadata(text)
source_file = str(diary_path)
# Serialize per source — two terminals running ingest at once must
# not interleave the upsert + closet-rebuild.
with mine_lock(source_file):
drawer_meta = {
"date": date_str,
"wing": wing,
"room": "daily",
"source_file": source_file,
"source_session": "daily_diary",
"filed_at": now_iso,
}
if entities:
drawer_meta["entities"] = entities
drawers_col.upsert(
documents=[text],
ids=[drawer_id],
metadatas=[drawer_meta],
)
entries = _split_entries(text)
prev_entry_count = state.get(state_key, {}).get("entry_count", 0)
new_entries = entries if force else entries[prev_entry_count:]
if new_entries:
all_lines = []
for header, body in new_entries:
entry_text = f"{header}\n{body}"
entry_lines = build_closet_lines(
source_file, [drawer_id], entry_text, wing, "daily"
)
all_lines.extend(entry_lines)
if all_lines:
closet_id_base = _diary_closet_id_base(wing, date_str)
closet_meta = {
"date": date_str,
"wing": wing,
"room": "daily",
"source_file": source_file,
"filed_at": now_iso,
}
if entities:
closet_meta["entities"] = entities
# On a force rebuild, wipe any leftover numbered closets
# from a longer prior run before re-writing.
if force:
purge_file_closets(closets_col, source_file)
n = upsert_closet_lines(closets_col, closet_id_base, all_lines, closet_meta)
closets_created += n
state[state_key] = {
"size": curr_size,
"entry_count": len(entries),
"ingested_at": now_iso,
}
days_updated += 1
state_file.write_text(json.dumps(state, indent=2))
if days_updated:
print(f"Diary: {days_updated} days updated, {closets_created} new closets")
return {"days_updated": days_updated, "closets_created": closets_created}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Ingest daily summaries into the palace")
parser.add_argument("--dir", required=True, help="Path to daily_summaries directory")
parser.add_argument("--palace", default=os.path.expanduser("~/.mempalace/palace"))
parser.add_argument("--wing", default="diary")
parser.add_argument("--force", action="store_true")
args = parser.parse_args()
ingest_diaries(args.dir, args.palace, wing=args.wing, force=args.force)
+335
View File
@@ -0,0 +1,335 @@
"""
fact_checker.py — Verify text against known facts in the palace.
Checks AI responses, diary entries, and new content against the entity
registry and knowledge graph for three classes of issue:
* similar_name — text mentions a name that's one/two edits
away from *another* registered name, raising
the possibility of a typo or mix-up.
* relationship_mismatch — text asserts a role between two entities
(e.g. "Bob is Alice's brother") while the KG
records a *different* current role for the
same subject/object pair.
* stale_fact — text asserts a fact that the KG marks closed
(``valid_to`` in the past).
Purely offline. Inputs: entity_registry JSON + KG SQLite. No network.
Usage:
from mempalace.fact_checker import check_text
issues = check_text("Bob is Alice's brother", palace_path)
# CLI
python -m mempalace.fact_checker "Bob is Alice's brother" \\
--palace ~/.mempalace/palace
"""
from __future__ import annotations
import os
import re
from datetime import datetime, timezone
# Share miner's mtime-cached registry loader so we don't double-read
# ~/.mempalace/known_entities.json on every check_text call.
from .miner import _load_known_entities_raw
# Narrow detection patterns — parse "X is Y's Z" and "X's Z is Y".
# Names are captured greedily as word sequences (letters + optional
# capitalized follow-ons) so simple multi-token names still work.
# Relationship words are constrained to sane lengths to avoid matching
# arbitrary filler.
_RELATIONSHIP_PATTERNS = [
# "Bob is Alice's brother" → subject=Bob, possessor=Alice, role=brother
re.compile(r"\b([A-Z][\w-]+)\s+is\s+([A-Z][\w-]+)'s\s+([a-z]{3,20})\b"),
# "Alice's brother is Bob" → possessor=Alice, role=brother, subject=Bob
re.compile(r"\b([A-Z][\w-]+)'s\s+([a-z]{3,20})\s+is\s+([A-Z][\w-]+)\b"),
]
def check_text(text: str, palace_path: str = None, config=None) -> list:
"""Return a list of issues detected in ``text``.
Empty list means "no contradictions found" — absence of evidence, not
evidence of absence. The detector is deliberately conservative:
every issue is anchored to a specific KG fact or registry entry.
"""
if config is None:
from .config import MempalaceConfig
config = MempalaceConfig()
if palace_path is None:
palace_path = config.palace_path
if not text:
return []
issues: list = []
entity_names_raw = _load_known_entities_raw()
issues.extend(_check_entity_confusion(text, entity_names_raw))
issues.extend(_check_kg_contradictions(text, palace_path))
return issues
# ── entity-name confusion ────────────────────────────────────────────
def _flatten_names(entity_names_raw: dict) -> set:
"""Flatten a ``{category: [names]}`` or ``{category: {name: meta}}``
registry into a set of names."""
flat: set = set()
for cat in entity_names_raw.values():
if isinstance(cat, list):
flat.update(str(n) for n in cat if n)
elif isinstance(cat, dict):
flat.update(str(k) for k in cat.keys() if k)
return flat
def _check_entity_confusion(text: str, entity_names_raw: dict) -> list:
"""Flag names mentioned in the text that are edit-distance ≤ 2 from
a *different* registered name — a common typo / mix-up pattern.
Performance note: the original O(n²) pairwise scan over the full
registry is gone. We first identify which names actually appear in
the text, then only compute edit distance between *mentioned* names
and the rest of the registry. This makes the cost O(m × n) where m
is the handful of names in the text, not the full registry.
"""
all_names = _flatten_names(entity_names_raw)
if not all_names:
return []
# Which names from the registry actually appear in the text?
mentioned: list = []
for name in all_names:
if re.search(r"\b" + re.escape(name) + r"\b", text, re.IGNORECASE):
mentioned.append(name)
if not mentioned:
return []
issues: list = []
seen_pairs: set = set()
for name_a in mentioned:
a_lower = name_a.lower()
for name_b in all_names:
if name_b == name_a:
continue
# Dedupe by unordered pair so we don't double-report.
pair_key = tuple(sorted((name_a.lower(), name_b.lower())))
if pair_key in seen_pairs:
continue
# Only flag when name_b is a *different* registry entry that
# was NOT mentioned — otherwise both names in the text is
# just the user writing about two people.
if name_b in mentioned:
seen_pairs.add(pair_key)
continue
distance = _edit_distance(a_lower, name_b.lower())
if 0 < distance <= 2:
issues.append(
{
"type": "similar_name",
"detail": (
f"'{name_a}' mentioned — did you mean "
f"'{name_b}'? (edit distance {distance})"
),
"names": [name_a, name_b],
"distance": distance,
}
)
seen_pairs.add(pair_key)
return issues
# ── KG contradictions ────────────────────────────────────────────────
def _extract_claims(text: str) -> list:
"""Yield structured (subject, predicate, object) claims from ``text``.
The two supported surface forms are "X is Y's Z" and "X's Z is Y",
both of which resolve to the triple ``(X, Z, Y)`` — ``X`` has role
``Z`` with respect to ``Y``. Matches are case-preserving for the
entity names (KG lookup is case-insensitive on normalized IDs).
"""
claims: list = []
for pat in _RELATIONSHIP_PATTERNS:
for match in pat.finditer(text):
groups = match.groups()
if pat is _RELATIONSHIP_PATTERNS[0]:
subject, possessor, role = groups[0], groups[1], groups[2]
else:
possessor, role, subject = groups[0], groups[1], groups[2]
claims.append(
{
"subject": subject,
"predicate": role.lower(),
"object": possessor,
"span": match.group(0),
}
)
return claims
def _check_kg_contradictions(text: str, palace_path: str) -> list:
"""Compare each claim in ``text`` against the KG.
For every claim ``(subject, predicate, object)`` parsed from the
text, look up the subject's current KG triples:
* ``relationship_mismatch`` fires when the KG records a fact about
the same ``(subject, object)`` pair but with a *different*
predicate — e.g. text says "brother" but KG says "husband".
* ``stale_fact`` fires when the KG has the exact ``(subject,
predicate, object)`` triple but its ``valid_to`` is in the past,
meaning the claim is no longer current.
"""
claims = _extract_claims(text)
if not claims:
return []
try:
from .knowledge_graph import KnowledgeGraph
# KG lives alongside the palace collection; mcp_server uses the
# same convention (see _kg init). Pass ``db_path`` — the previous
# code passed a nonexistent ``palace_path`` kwarg which raised
# TypeError, silently swallowed by the outer except and rendered
# the entire KG-check path dead.
kg = KnowledgeGraph(db_path=os.path.join(palace_path, "knowledge_graph.sqlite3"))
except Exception:
# KG unavailable (brand-new palace, corrupted DB, etc.) — skip.
return []
issues: list = []
for claim in claims:
subject = claim["subject"]
claim_pred = claim["predicate"]
claim_obj = claim["object"]
try:
facts = kg.query_entity(subject, direction="outgoing")
except Exception:
continue
if not facts:
continue
current_facts = [f for f in facts if f.get("current")]
# Mismatch: KG fact about same (subject, object) pair but different predicate.
for fact in current_facts:
if not _objects_match(fact.get("object"), claim_obj):
continue
kg_pred = (fact.get("predicate") or "").lower()
if kg_pred and kg_pred != claim_pred:
issues.append(
{
"type": "relationship_mismatch",
"detail": (
f"Text says '{claim['span']}' but KG records "
f"{subject} {kg_pred} {fact.get('object')}"
),
"entity": subject,
"claim": {
"predicate": claim_pred,
"object": claim_obj,
},
"kg_fact": {
"predicate": kg_pred,
"object": fact.get("object"),
},
}
)
# Stale fact: exact match on (subject, predicate, object) but KG
# closed the window in the past.
now_iso = datetime.now(timezone.utc).date().isoformat()
for fact in facts:
if fact.get("current"):
continue
kg_pred = (fact.get("predicate") or "").lower()
if kg_pred != claim_pred:
continue
if not _objects_match(fact.get("object"), claim_obj):
continue
valid_to = fact.get("valid_to")
if valid_to and str(valid_to) < now_iso:
issues.append(
{
"type": "stale_fact",
"detail": (
f"Text says '{claim['span']}' but KG marks "
f"this fact closed on {valid_to}"
),
"entity": subject,
"valid_to": valid_to,
}
)
return issues
def _objects_match(kg_obj, claim_obj: str) -> bool:
if kg_obj is None or not claim_obj:
return False
return str(kg_obj).strip().lower() == claim_obj.strip().lower()
# ── Levenshtein helper (tight iterative version) ─────────────────────
def _edit_distance(s1: str, s2: str) -> int:
"""Levenshtein distance. O(len(s1) * len(s2)) time, O(len(s2)) space."""
if len(s1) < len(s2):
s1, s2 = s2, s1
if not s2:
return len(s1)
prev = list(range(len(s2) + 1))
for i, c1 in enumerate(s1):
curr = [i + 1]
for j, c2 in enumerate(s2):
curr.append(
min(
prev[j + 1] + 1,
curr[j] + 1,
prev[j] + (0 if c1 == c2 else 1),
)
)
prev = curr
return prev[-1]
if __name__ == "__main__":
import argparse
import json
import sys
parser = argparse.ArgumentParser(
description="Check text against known facts in the MemPalace palace.",
epilog="Exits 0 when no issues found, 1 when one or more issues detected.",
)
parser.add_argument("text", nargs="?", help="Text to check (or use --stdin).")
parser.add_argument(
"--palace",
default=os.path.expanduser("~/.mempalace/palace"),
help="Path to the palace directory.",
)
parser.add_argument("--stdin", action="store_true", help="Read text from stdin.")
args = parser.parse_args()
if args.stdin:
text_in = sys.stdin.read()
elif args.text:
text_in = args.text
else:
parser.error("Provide text as argument or use --stdin.")
found = check_text(text_in, palace_path=args.palace)
if found:
print(json.dumps(found, indent=2))
sys.exit(1)
print("No contradictions found.")
+128 -1
View File
@@ -35,7 +35,15 @@ from .version import __version__
import chromadb import chromadb
from .query_sanitizer import sanitize_query from .query_sanitizer import sanitize_query
from .searcher import search_memories from .searcher import search_memories
from .palace_graph import traverse, find_tunnels, graph_stats from .palace_graph import (
traverse,
find_tunnels,
graph_stats,
create_tunnel,
list_tunnels,
delete_tunnel,
follow_tunnels,
)
from .knowledge_graph import KnowledgeGraph from .knowledge_graph import KnowledgeGraph
@@ -496,6 +504,66 @@ def tool_graph_stats():
return graph_stats(col=col) return graph_stats(col=col)
def tool_create_tunnel(
source_wing: str,
source_room: str,
target_wing: str,
target_room: str,
label: str = "",
source_drawer_id: str = None,
target_drawer_id: str = None,
):
"""Create an explicit cross-wing tunnel between two palace locations.
Use when you notice content in one project relates to another project.
Example: an API design discussion in project_api connects to the
database schema in project_database.
"""
try:
source_wing = sanitize_name(source_wing, "source_wing")
source_room = sanitize_name(source_room, "source_room")
target_wing = sanitize_name(target_wing, "target_wing")
target_room = sanitize_name(target_room, "target_room")
except ValueError as e:
return {"error": str(e)}
return create_tunnel(
source_wing,
source_room,
target_wing,
target_room,
label=label,
source_drawer_id=source_drawer_id,
target_drawer_id=target_drawer_id,
)
def tool_list_tunnels(wing: str = None):
"""List all explicit cross-wing tunnels, optionally filtered by wing."""
try:
wing = _sanitize_optional_name(wing, "wing")
except ValueError as e:
return {"error": str(e)}
return list_tunnels(wing)
def tool_delete_tunnel(tunnel_id: str):
"""Delete an explicit tunnel by its ID."""
if not tunnel_id or not isinstance(tunnel_id, str):
return {"error": "tunnel_id is required"}
return delete_tunnel(tunnel_id)
def tool_follow_tunnels(wing: str, room: str):
"""Follow explicit tunnels from a room to see connected drawers in other wings."""
try:
wing = sanitize_name(wing, "wing")
room = sanitize_name(room, "room")
except ValueError as e:
return {"error": str(e)}
col = _get_collection()
return follow_tunnels(wing, room, col=col)
# ==================== WRITE TOOLS ==================== # ==================== WRITE TOOLS ====================
@@ -1184,6 +1252,65 @@ TOOLS = {
"input_schema": {"type": "object", "properties": {}}, "input_schema": {"type": "object", "properties": {}},
"handler": tool_graph_stats, "handler": tool_graph_stats,
}, },
"mempalace_create_tunnel": {
"description": "Create a cross-wing tunnel linking two palace locations. Use when content in one project relates to another — e.g., an API design in project_api connects to a database schema in project_database.",
"input_schema": {
"type": "object",
"properties": {
"source_wing": {"type": "string", "description": "Wing of the source"},
"source_room": {"type": "string", "description": "Room in the source wing"},
"target_wing": {"type": "string", "description": "Wing of the target"},
"target_room": {"type": "string", "description": "Room in the target wing"},
"label": {"type": "string", "description": "Description of the connection"},
"source_drawer_id": {
"type": "string",
"description": "Optional specific drawer ID",
},
"target_drawer_id": {
"type": "string",
"description": "Optional specific drawer ID",
},
},
"required": ["source_wing", "source_room", "target_wing", "target_room"],
},
"handler": tool_create_tunnel,
},
"mempalace_list_tunnels": {
"description": "List all explicit cross-wing tunnels. Optionally filter by wing.",
"input_schema": {
"type": "object",
"properties": {
"wing": {
"type": "string",
"description": "Filter tunnels by wing (shows tunnels where wing is source or target)",
},
},
},
"handler": tool_list_tunnels,
},
"mempalace_delete_tunnel": {
"description": "Delete an explicit tunnel by its ID.",
"input_schema": {
"type": "object",
"properties": {
"tunnel_id": {"type": "string", "description": "Tunnel ID to delete"},
},
"required": ["tunnel_id"],
},
"handler": tool_delete_tunnel,
},
"mempalace_follow_tunnels": {
"description": "Follow tunnels from a room to see what it connects to in other wings. Returns connected rooms with drawer previews.",
"input_schema": {
"type": "object",
"properties": {
"wing": {"type": "string", "description": "Wing to start from"},
"room": {"type": "string", "description": "Room to follow tunnels from"},
},
"required": ["wing", "room"],
},
"handler": tool_follow_tunnels,
},
"mempalace_search": { "mempalace_search": {
"description": "Semantic search. Returns verbatim drawer content with similarity scores. IMPORTANT: 'query' must contain ONLY search keywords. Use 'context' for background. Results with cosine distance > max_distance are filtered out.", "description": "Semantic search. Returns verbatim drawer content with similarity scores. IMPORTANT: 'query' must contain ONLY search keywords. Use 'context' for background. Results with cosine distance > max_distance are filtered out.",
"input_schema": { "input_schema": {
+126 -13
View File
@@ -378,6 +378,116 @@ def chunk_text(content: str, source_file: str) -> list:
# ============================================================================= # =============================================================================
_ENTITY_REGISTRY_PATH = os.path.join(os.path.expanduser("~"), ".mempalace", "known_entities.json")
_ENTITY_REGISTRY_CACHE: dict = {"mtime": None, "names": frozenset(), "raw": {}}
_ENTITY_EXTRACT_WINDOW = 5000 # chars of content scanned for capitalized words
_ENTITY_METADATA_LIMIT = 25 # max entities packed into the metadata field
def _refresh_known_entities_cache() -> None:
"""Reload ``~/.mempalace/known_entities.json`` into the module cache if
its mtime changed since the last read. Shared by ``_load_known_entities``
(flat set) and ``_load_known_entities_raw`` (category dict), so callers
can pick whichever shape they need without duplicating the mtime-gated
disk read.
"""
try:
mtime = os.path.getmtime(_ENTITY_REGISTRY_PATH)
except OSError:
if _ENTITY_REGISTRY_CACHE["mtime"] is not None:
_ENTITY_REGISTRY_CACHE["mtime"] = None
_ENTITY_REGISTRY_CACHE["names"] = frozenset()
_ENTITY_REGISTRY_CACHE["raw"] = {}
return
if _ENTITY_REGISTRY_CACHE["mtime"] == mtime:
return
names: set = set()
raw: dict = {}
try:
import json
with open(_ENTITY_REGISTRY_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
raw = data
for cat in data.values():
if isinstance(cat, list):
names.update(str(n) for n in cat if n)
elif isinstance(cat, dict):
names.update(str(k) for k in cat.keys() if k)
except Exception:
names = set()
raw = {}
_ENTITY_REGISTRY_CACHE["mtime"] = mtime
_ENTITY_REGISTRY_CACHE["names"] = frozenset(names)
_ENTITY_REGISTRY_CACHE["raw"] = raw
def _load_known_entities() -> frozenset:
"""Flat set of every known entity name (across all categories).
Cached by mtime; invalidated when the registry file changes.
"""
_refresh_known_entities_cache()
return _ENTITY_REGISTRY_CACHE["names"]
def _load_known_entities_raw() -> dict:
"""Full category-dict view of the registry, shape
``{"category": ["Name1", ...], ...}``. Cached by mtime.
Consumed by modules (e.g., fact_checker) that need to reason about
categories rather than a flat name set. Never returns a mutable
reference to the cache — callers get a shallow copy.
"""
_refresh_known_entities_cache()
return dict(_ENTITY_REGISTRY_CACHE["raw"])
def _extract_entities_for_metadata(content: str) -> str:
"""Extract entity names from content for metadata tagging.
Combines the user's known-entity registry (cached across calls) with
capitalized words appearing ≥2 times in the first ``_ENTITY_EXTRACT_WINDOW``
chars. Filters out the closet stoplist (``When``, ``After``, ``The``, …)
so sentence-starters don't masquerade as proper nouns.
Returns semicolon-separated string suitable for ChromaDB metadata
filtering. The list is truncated to ``_ENTITY_METADATA_LIMIT`` entries
*before* joining so a name is never cut in half.
"""
import re
from .palace import _ENTITY_STOPLIST
matched: set = set()
known = _load_known_entities()
for name in known:
if re.search(r"(?<!\w)" + re.escape(name) + r"(?!\w)", content):
matched.add(name)
window = content[:_ENTITY_EXTRACT_WINDOW]
words = re.findall(r"\b[A-Z][a-z]{2,}\b", window)
freq: dict = {}
for w in words:
if w in _ENTITY_STOPLIST:
continue
freq[w] = freq.get(w, 0) + 1
for w, c in freq.items():
if c >= 2 and len(w) > 2:
matched.add(w)
if not matched:
return ""
# Truncate the *list*, not the joined string — never split a name.
capped = sorted(matched)[:_ENTITY_METADATA_LIMIT]
return ";".join(capped)
def add_drawer( def add_drawer(
collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str
): ):
@@ -398,6 +508,10 @@ def add_drawer(
metadata["source_mtime"] = os.path.getmtime(source_file) metadata["source_mtime"] = os.path.getmtime(source_file)
except OSError: except OSError:
pass pass
# Tag with entity names for filterable search
entities = _extract_entities_for_metadata(content)
if entities:
metadata["entities"] = entities
collection.upsert( collection.upsert(
documents=[content], documents=[content],
ids=[drawer_id], ids=[drawer_id],
@@ -490,20 +604,19 @@ def process_file(
closet_id_base = ( closet_id_base = (
f"closet_{wing}_{room}_{hashlib.sha256(source_file.encode()).hexdigest()[:24]}" f"closet_{wing}_{room}_{hashlib.sha256(source_file.encode()).hexdigest()[:24]}"
) )
entities = _extract_entities_for_metadata(content)
closet_meta = {
"wing": wing,
"room": room,
"source_file": source_file,
"drawer_count": drawers_added,
"filed_at": datetime.now().isoformat(),
"normalize_version": NORMALIZE_VERSION,
}
if entities:
closet_meta["entities"] = entities
purge_file_closets(closets_col, source_file) purge_file_closets(closets_col, source_file)
upsert_closet_lines( upsert_closet_lines(closets_col, closet_id_base, closet_lines, closet_meta)
closets_col,
closet_id_base,
closet_lines,
{
"wing": wing,
"room": room,
"source_file": source_file,
"drawer_count": drawers_added,
"filed_at": datetime.now().isoformat(),
"normalize_version": NORMALIZE_VERSION,
},
)
return drawers_added, room return drawers_added, room
+230 -1
View File
@@ -15,10 +15,15 @@ Enables queries like:
No external graph DB needed — built from ChromaDB metadata. No external graph DB needed — built from ChromaDB metadata.
""" """
from collections import defaultdict, Counter import hashlib
import json
import os
from collections import Counter, defaultdict
from datetime import datetime, timezone
from .config import MempalaceConfig from .config import MempalaceConfig
from .palace import get_collection as _get_palace_collection from .palace import get_collection as _get_palace_collection
from .palace import mine_lock
def _get_collection(config=None): def _get_collection(config=None):
@@ -228,3 +233,227 @@ def _fuzzy_match(query: str, nodes: dict, n: int = 5):
scored.append((room, 0.5)) scored.append((room, 0.5))
scored.sort(key=lambda x: -x[1]) scored.sort(key=lambda x: -x[1])
return [r for r, _ in scored[:n]] return [r for r, _ in scored[:n]]
# =============================================================================
# EXPLICIT TUNNELS — agent-created cross-wing links
# =============================================================================
# Passive tunnels are discovered from shared room names across wings.
# Explicit tunnels are created by agents when they notice a connection
# between two specific drawers or rooms in different wings/projects.
#
# Stored as a JSON file at ~/.mempalace/tunnels.json so they persist
# across palace rebuilds (not in ChromaDB which can be recreated).
_TUNNEL_FILE = os.path.join(os.path.expanduser("~"), ".mempalace", "tunnels.json")
def _load_tunnels():
"""Load explicit tunnels from disk.
Returns an empty list if the file is missing or corrupt (e.g. truncated
by a crash mid-write on a system that lacks atomic-rename semantics).
"""
if not os.path.exists(_TUNNEL_FILE):
return []
try:
with open(_TUNNEL_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
except Exception:
return []
return data if isinstance(data, list) else []
def _save_tunnels(tunnels):
"""Persist explicit tunnels atomically.
Writes to ``tunnels.json.tmp`` then ``os.replace``s it into place, so
a crash mid-write can never leave a partial/empty tunnels.json that
silently wipes every tunnel on next read.
"""
os.makedirs(os.path.dirname(_TUNNEL_FILE), exist_ok=True)
tmp_path = _TUNNEL_FILE + ".tmp"
with open(tmp_path, "w", encoding="utf-8") as f:
json.dump(tunnels, f, indent=2)
f.flush()
try:
os.fsync(f.fileno())
except OSError:
# Not all filesystems (or Windows file handles) support fsync — tolerate.
pass
os.replace(tmp_path, _TUNNEL_FILE)
def _endpoint_key(wing: str, room: str) -> str:
return f"{wing}/{room}"
def _canonical_tunnel_id(
source_wing: str, source_room: str, target_wing: str, target_room: str
) -> str:
"""Compute a symmetric tunnel ID.
Tunnels are conceptually undirected — "auth relates to users" is the
same connection as "users relates to auth". Sort the two endpoints
before hashing so ``create_tunnel(A, B)`` and ``create_tunnel(B, A)``
resolve to the same ID and dedup into one record.
"""
src = _endpoint_key(source_wing, source_room)
tgt = _endpoint_key(target_wing, target_room)
a, b = sorted((src, tgt))
return hashlib.sha256(f"{a}{b}".encode()).hexdigest()[:16]
def _require_name(value: str, field: str) -> str:
"""Reject empty / non-string endpoint identifiers."""
if not isinstance(value, str) or not value.strip():
raise ValueError(f"{field} must be a non-empty string")
return value.strip()
def create_tunnel(
source_wing: str,
source_room: str,
target_wing: str,
target_room: str,
label: str = "",
source_drawer_id: str = None,
target_drawer_id: str = None,
):
"""Create an explicit (symmetric) tunnel between two locations in the palace.
Tunnels are undirected: ``create_tunnel(A, B)`` and ``create_tunnel(B, A)``
resolve to the same canonical ID. A second call with the same endpoints
updates the stored label (and drawer IDs, if provided) rather than
creating a duplicate.
The ``source`` / ``target`` fields on the returned dict preserve the
argument order the caller used, so callers can display it directionally
if they like. The ID and dedup are symmetric.
Args:
source_wing: Wing of the source (e.g., "project_api").
source_room: Room in the source wing.
target_wing: Wing of the target (e.g., "project_database").
target_room: Room in the target wing.
label: Description of the connection.
source_drawer_id: Optional specific drawer ID.
target_drawer_id: Optional specific drawer ID.
Returns:
The stored tunnel dict.
Raises:
ValueError: if any wing or room is empty or non-string.
"""
source_wing = _require_name(source_wing, "source_wing")
source_room = _require_name(source_room, "source_room")
target_wing = _require_name(target_wing, "target_wing")
target_room = _require_name(target_room, "target_room")
tunnel_id = _canonical_tunnel_id(source_wing, source_room, target_wing, target_room)
tunnel = {
"id": tunnel_id,
"source": {"wing": source_wing, "room": source_room},
"target": {"wing": target_wing, "room": target_room},
"label": label,
"created_at": datetime.now(timezone.utc).isoformat(),
}
if source_drawer_id:
tunnel["source"]["drawer_id"] = source_drawer_id
if target_drawer_id:
tunnel["target"]["drawer_id"] = target_drawer_id
# Serialize the load → mutate → save cycle. Without this, two concurrent
# create_tunnel calls can both read the same snapshot and the later
# writer silently drops the earlier writer's tunnel.
with mine_lock(_TUNNEL_FILE):
tunnels = _load_tunnels()
for existing in tunnels:
if existing.get("id") == tunnel_id:
# Preserve original creation timestamp on label updates.
tunnel["created_at"] = existing.get("created_at", tunnel["created_at"])
tunnel["updated_at"] = datetime.now(timezone.utc).isoformat()
existing.clear()
existing.update(tunnel)
_save_tunnels(tunnels)
return existing
tunnels.append(tunnel)
_save_tunnels(tunnels)
return tunnel
def list_tunnels(wing: str = None):
"""List all explicit tunnels, optionally filtered by wing.
Returns tunnels where ``wing`` appears as either source or target
(tunnels are symmetric, so either endpoint is a valid filter match).
"""
tunnels = _load_tunnels()
if wing:
tunnels = [t for t in tunnels if t["source"]["wing"] == wing or t["target"]["wing"] == wing]
return tunnels
def delete_tunnel(tunnel_id: str):
"""Delete an explicit tunnel by ID. Returns ``{"deleted": <id>}``."""
with mine_lock(_TUNNEL_FILE):
tunnels = _load_tunnels()
tunnels = [t for t in tunnels if t.get("id") != tunnel_id]
_save_tunnels(tunnels)
return {"deleted": tunnel_id}
def follow_tunnels(wing: str, room: str, col=None, config=None):
"""Follow explicit tunnels from a room — returns connected drawers.
Given a location (wing/room), finds all tunnels leading from or to it,
and optionally fetches the connected drawer content.
"""
tunnels = _load_tunnels()
connections = []
for t in tunnels:
src = t["source"]
tgt = t["target"]
if src["wing"] == wing and src["room"] == room:
connections.append(
{
"direction": "outgoing",
"connected_wing": tgt["wing"],
"connected_room": tgt["room"],
"label": t.get("label", ""),
"drawer_id": tgt.get("drawer_id"),
"tunnel_id": t["id"],
}
)
elif tgt["wing"] == wing and tgt["room"] == room:
connections.append(
{
"direction": "incoming",
"connected_wing": src["wing"],
"connected_room": src["room"],
"label": t.get("label", ""),
"drawer_id": src.get("drawer_id"),
"tunnel_id": t["id"],
}
)
# If we have a collection, fetch drawer content for connected items
if col and connections:
drawer_ids = [c["drawer_id"] for c in connections if c.get("drawer_id")]
if drawer_ids:
try:
results = col.get(ids=drawer_ids, include=["documents", "metadatas"])
drawer_map = dict(zip(results["ids"], results["documents"]))
for c in connections:
did = c.get("drawer_id")
if did and did in drawer_map:
c["drawer_preview"] = drawer_map[did][:300]
except Exception:
pass
return connections
+314 -132
View File
@@ -2,11 +2,15 @@
""" """
searcher.py — Find anything. Exact words. searcher.py — Find anything. Exact words.
Semantic search against the palace. Hybrid search: BM25 keyword matching + vector semantic similarity. The
Returns verbatim text — the actual words, never summaries. drawer query is the floor — always runs — and closet hits add a rank-based
boost when they agree. Closets are a ranking *signal*, never a gate, so
weak closets (regex extraction on narrative content) can only help, never
hide drawers the direct path would have found.
""" """
import logging import logging
import math
import re import re
from pathlib import Path from pathlib import Path
@@ -23,6 +27,111 @@ class SearchError(Exception):
"""Raised when search cannot proceed (e.g. no palace found).""" """Raised when search cannot proceed (e.g. no palace found)."""
_TOKEN_RE = re.compile(r"\w{2,}", re.UNICODE)
def _tokenize(text: str) -> list:
"""Lowercase + strip to alphanumeric tokens of length ≥ 2."""
return _TOKEN_RE.findall(text.lower())
def _bm25_scores(
query: str,
documents: list,
k1: float = 1.5,
b: float = 0.75,
) -> list:
"""Compute Okapi-BM25 scores for ``query`` against each document.
IDF is computed over the *provided corpus* using the Lucene/BM25+
smoothed formula ``log((N - df + 0.5) / (df + 0.5) + 1)``, which is
always non-negative. This is well-defined for re-ranking a small
candidate set returned by vector retrieval — IDF then reflects how
discriminative each query term is *within the candidates*, exactly
what's needed to reorder them.
Parameters mirror Okapi-BM25 conventions:
k1 — term-frequency saturation (1.2-2.0 typical, 1.5 default)
b — length normalization (0.0 = none, 1.0 = full, 0.75 default)
Returns a list of scores in the same order as ``documents``.
"""
n_docs = len(documents)
query_terms = set(_tokenize(query))
if not query_terms or n_docs == 0:
return [0.0] * n_docs
tokenized = [_tokenize(d) for d in documents]
doc_lens = [len(toks) for toks in tokenized]
if not any(doc_lens):
return [0.0] * n_docs
avgdl = sum(doc_lens) / n_docs or 1.0
# Document frequency: how many docs contain each query term?
df = {term: 0 for term in query_terms}
for toks in tokenized:
seen = set(toks) & query_terms
for term in seen:
df[term] += 1
idf = {term: math.log((n_docs - df[term] + 0.5) / (df[term] + 0.5) + 1) for term in query_terms}
scores = []
for toks, dl in zip(tokenized, doc_lens):
if dl == 0:
scores.append(0.0)
continue
tf: dict = {}
for t in toks:
if t in query_terms:
tf[t] = tf.get(t, 0) + 1
score = 0.0
for term, freq in tf.items():
num = freq * (k1 + 1)
den = freq + k1 * (1 - b + b * dl / avgdl)
score += idf[term] * num / den
scores.append(score)
return scores
def _hybrid_rank(
results: list,
query: str,
vector_weight: float = 0.6,
bm25_weight: float = 0.4,
) -> list:
"""Re-rank ``results`` by a convex combination of vector similarity and BM25.
* Vector similarity uses absolute cosine sim ``max(0, 1 - distance)`` —
ChromaDB's hnsw cosine distance lives in ``[0, 2]`` (0 = identical).
Absolute (not relative-to-max) means adding/removing a candidate
can't reshuffle the others.
* BM25 is real Okapi-BM25 with corpus-relative IDF over the candidates
themselves. Since the absolute scale is unbounded, BM25 is min-max
normalized within the candidate set so weights are commensurable.
Mutates each result dict to add ``bm25_score`` and reorders the list
in place. Returns the same list for convenience.
"""
if not results:
return results
docs = [r.get("text", "") for r in results]
bm25_raw = _bm25_scores(query, docs)
max_bm25 = max(bm25_raw) if bm25_raw else 0.0
bm25_norm = [s / max_bm25 for s in bm25_raw] if max_bm25 > 0 else [0.0] * len(bm25_raw)
scored = []
for r, raw, norm in zip(results, bm25_raw, bm25_norm):
vec_sim = max(0.0, 1.0 - r.get("distance", 1.0))
r["bm25_score"] = round(raw, 3)
scored.append((vector_weight * vec_sim + bm25_weight * norm, r))
scored.sort(key=lambda pair: pair[0], reverse=True)
results[:] = [r for _, r in scored]
return results
def build_where_filter(wing: str = None, room: str = None) -> dict: def build_where_filter(wing: str = None, room: str = None) -> dict:
"""Build ChromaDB where filter for wing/room filtering.""" """Build ChromaDB where filter for wing/room filtering."""
if wing and room: if wing and room:
@@ -48,101 +157,70 @@ def _extract_drawer_ids_from_closet(closet_doc: str) -> list:
return list(seen.keys()) return list(seen.keys())
def _closet_first_hits( def _expand_with_neighbors(drawers_col, matched_doc: str, matched_meta: dict, radius: int = 1):
palace_path: str, """Expand a matched drawer with its ±radius sibling chunks in the same source file.
query: str,
where: dict,
drawers_col,
n_results: int,
max_distance: float,
):
"""Run a closet-first search and return chunk-level drawer hits.
Returns: Motivation — "drawer-grep context" feature: a closet hit returns one
non-empty list of hits when the closet path produced usable matches. drawer, but the chunk boundary may clip mid-thought (e.g., the matched
``None`` when the closet collection is empty/missing OR when every chunk says "here's a breakdown:" and the actual breakdown lives in the
candidate drawer was filtered out (e.g. by max_distance); the next chunk). Fetching the small neighborhood around the match gives
caller should fall back to direct drawer search. callers enough context without forcing a follow-up ``get_drawer`` call.
Returns a dict with:
``text`` combined chunks in chunk_index order
``drawer_index`` the matched chunk's index in the source file
``total_drawers`` total drawer count for the source file (or None)
On any ChromaDB failure or missing metadata, falls back to returning the
matched drawer alone so search never breaks because neighbor expansion
failed.
""" """
src = matched_meta.get("source_file")
chunk_idx = matched_meta.get("chunk_index")
if not src or not isinstance(chunk_idx, int):
return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None}
target_indexes = [chunk_idx + offset for offset in range(-radius, radius + 1)]
try: try:
closets_col = get_closets_collection(palace_path, create=False) neighbors = drawers_col.get(
except Exception: where={
return None "$and": [
{"source_file": src},
try: {"chunk_index": {"$in": target_indexes}},
ckwargs = { ]
"query_texts": [query], },
"n_results": max(n_results * 2, 5),
"include": ["documents", "metadatas", "distances"],
}
if where:
ckwargs["where"] = where
closet_results = closets_col.query(**ckwargs)
except Exception:
return None
closet_docs = closet_results["documents"][0] if closet_results["documents"] else []
if not closet_docs:
return None
closet_metas = closet_results["metadatas"][0]
closet_dists = closet_results["distances"][0]
# Collect candidate drawer IDs in closet-rank order, dedupe, remember
# which closet (and its distance/preview) introduced each one.
drawer_id_order: list = []
drawer_provenance: dict = {}
for cdoc, cmeta, cdist in zip(closet_docs, closet_metas, closet_dists):
for did in _extract_drawer_ids_from_closet(cdoc):
if did in drawer_provenance:
continue
drawer_provenance[did] = (cdist, cdoc, cmeta)
drawer_id_order.append(did)
if not drawer_id_order:
return None
# Hydrate exactly those drawers — chunk-level, not whole-file.
try:
fetched = drawers_col.get(
ids=drawer_id_order,
include=["documents", "metadatas"], include=["documents", "metadatas"],
) )
except Exception: except Exception:
return None return {"text": matched_doc, "drawer_index": chunk_idx, "total_drawers": None}
fetched_ids = fetched.get("ids") or [] indexed_docs = []
fetched_docs = fetched.get("documents") or [] for doc, meta in zip(neighbors.get("documents") or [], neighbors.get("metadatas") or []):
fetched_metas = fetched.get("metadatas") or [] ci = meta.get("chunk_index")
fetched_map = { if isinstance(ci, int):
did: (doc, meta) for did, doc, meta in zip(fetched_ids, fetched_docs, fetched_metas) indexed_docs.append((ci, doc))
indexed_docs.sort(key=lambda pair: pair[0])
if not indexed_docs:
combined_text = matched_doc
else:
combined_text = "\n\n".join(doc for _, doc in indexed_docs)
# Cheap total_drawers lookup: metadata-only scan of the source file.
total_drawers = None
try:
all_meta = drawers_col.get(where={"source_file": src}, include=["metadatas"])
ids = all_meta.get("ids") or []
total_drawers = len(ids) if ids else None
except Exception:
pass
return {
"text": combined_text,
"drawer_index": chunk_idx,
"total_drawers": total_drawers,
} }
hits: list = []
for did in drawer_id_order:
if did not in fetched_map:
continue # closet pointed to a drawer that no longer exists
doc, meta = fetched_map[did]
cdist, cdoc, _ = drawer_provenance[did]
if max_distance > 0.0 and cdist > max_distance:
continue
hits.append(
{
"text": doc,
"wing": meta.get("wing", "unknown"),
"room": meta.get("room", "unknown"),
"source_file": Path(meta.get("source_file", "?")).name,
"similarity": round(max(0.0, 1 - cdist), 3),
"distance": round(cdist, 4),
"matched_via": "closet",
"closet_preview": cdoc[:200],
}
)
if len(hits) >= n_results:
break
return hits if hits else None
def search(query: str, palace_path: str, wing: str = None, room: str = None, n_results: int = 5): def search(query: str, palace_path: str, wing: str = None, room: str = None, n_results: int = 5):
""" """
@@ -242,64 +320,168 @@ def search_memories(
where = build_where_filter(wing, room) where = build_where_filter(wing, room)
# Closet-first search: scan the compact index, parse drawer pointers # Hybrid retrieval: always query drawers directly (the floor), then use
# from each matching line, then hydrate exactly those drawers. This # closet hits to boost rankings. Closets are a ranking SIGNAL, never a
# keeps the result shape chunk-level (consistent with direct search) # GATE — direct drawer search is always the baseline.
# and applies the same max_distance filter. #
closet_hits = _closet_first_hits( # This avoids the "weak-closets regression" where narrative content
palace_path=palace_path, # produces low-signal closets (regex extraction matches few topics)
query=query, # and closet-first routing hides drawers that direct search would find.
where=where,
drawers_col=drawers_col,
n_results=n_results,
max_distance=max_distance,
)
if closet_hits is not None:
return {
"query": query,
"filters": {"wing": wing, "room": room},
"total_before_filter": len(closet_hits),
"results": closet_hits,
}
# Fallback: direct drawer search (no closets yet, or closets empty)
try: try:
kwargs = { dkwargs = {
"query_texts": [query], "query_texts": [query],
"n_results": n_results, "n_results": n_results * 3, # over-fetch for re-ranking
"include": ["documents", "metadatas", "distances"], "include": ["documents", "metadatas", "distances"],
} }
if where: if where:
kwargs["where"] = where dkwargs["where"] = where
drawer_results = drawers_col.query(**dkwargs)
results = drawers_col.query(**kwargs)
except Exception as e: except Exception as e:
return {"error": f"Search error: {e}"} return {"error": f"Search error: {e}"}
docs = results["documents"][0] # Gather closet hits (best-per-source) to build a boost lookup.
metas = results["metadatas"][0] closet_boost_by_source: dict = {} # source_file -> (rank, closet_dist, preview)
dists = results["distances"][0] try:
closets_col = get_closets_collection(palace_path, create=False)
ckwargs = {
"query_texts": [query],
"n_results": n_results * 2,
"include": ["documents", "metadatas", "distances"],
}
if where:
ckwargs["where"] = where
closet_results = closets_col.query(**ckwargs)
for rank, (cdoc, cmeta, cdist) in enumerate(
zip(
closet_results["documents"][0],
closet_results["metadatas"][0],
closet_results["distances"][0],
)
):
source = cmeta.get("source_file", "")
if source and source not in closet_boost_by_source:
closet_boost_by_source[source] = (rank, cdist, cdoc[:200])
except Exception:
pass # no closets yet — hybrid degrades to pure drawer search
hits = [] # Rank-based boost. The ordinal signal ("which closet matched best") is
for doc, meta, dist in zip(docs, metas, dists): # more reliable than absolute distance on narrative content, where
# Filter on raw distance before rounding to avoid precision loss # closet distances cluster in 1.2-1.5 range regardless of match quality.
CLOSET_RANK_BOOSTS = [0.40, 0.25, 0.15, 0.08, 0.04]
CLOSET_DISTANCE_CAP = 1.5 # cosine dist > 1.5 = too weak to use as signal
scored: list = []
for doc, meta, dist in zip(
drawer_results["documents"][0],
drawer_results["metadatas"][0],
drawer_results["distances"][0],
):
# Filter on raw distance before rounding to avoid precision loss.
if max_distance > 0.0 and dist > max_distance: if max_distance > 0.0 and dist > max_distance:
continue continue
hits.append(
{ source = meta.get("source_file", "") or ""
"text": doc, boost = 0.0
"wing": meta.get("wing", "unknown"), matched_via = "drawer"
"room": meta.get("room", "unknown"), closet_preview = None
"source_file": Path(meta.get("source_file", "?")).name, if source in closet_boost_by_source:
"similarity": round(max(0.0, 1 - dist), 3), c_rank, c_dist, c_preview = closet_boost_by_source[source]
"distance": round(dist, 4), if c_dist <= CLOSET_DISTANCE_CAP and c_rank < len(CLOSET_RANK_BOOSTS):
"matched_via": "drawer", boost = CLOSET_RANK_BOOSTS[c_rank]
} matched_via = "drawer+closet"
) closet_preview = c_preview
effective_dist = dist - boost
entry = {
"text": doc,
"wing": meta.get("wing", "unknown"),
"room": meta.get("room", "unknown"),
"source_file": Path(source).name if source else "?",
"similarity": round(max(0.0, 1 - effective_dist), 3),
"distance": round(dist, 4),
"effective_distance": round(effective_dist, 4),
"closet_boost": round(boost, 3),
"matched_via": matched_via,
# Internal: retain the full source_file path + chunk_index so the
# enrichment step below doesn't have to reverse-lookup via
# basename-suffix matching (which silently collides when two
# files share a basename across different directories).
"_sort_key": effective_dist,
"_source_file_full": source,
"_chunk_index": meta.get("chunk_index"),
}
if closet_preview:
entry["closet_preview"] = closet_preview
scored.append(entry)
scored.sort(key=lambda h: h["_sort_key"])
hits = scored[:n_results]
# Drawer-grep enrichment: for closet-boosted hits whose source has
# multiple drawers, return the keyword-best chunk + its immediate
# neighbors instead of just the drawer vector search landed on. The
# closet said "this source is relevant"; vector may have picked the
# wrong chunk within it; grep picks the right one.
MAX_HYDRATION_CHARS = 10000
for h in hits:
if h["matched_via"] == "drawer":
continue
full_source = h.get("_source_file_full") or ""
if not full_source:
continue
try:
source_drawers = drawers_col.get(
where={"source_file": full_source},
include=["documents", "metadatas"],
)
except Exception:
continue
docs = source_drawers.get("documents") or []
metas_ = source_drawers.get("metadatas") or []
if len(docs) <= 1:
continue
# Sort by chunk_index so best_idx + neighbors are positional.
indexed = []
for idx, (d, m) in enumerate(zip(docs, metas_)):
ci = m.get("chunk_index", idx) if isinstance(m, dict) else idx
if not isinstance(ci, int):
ci = idx
indexed.append((ci, d))
indexed.sort(key=lambda p: p[0])
ordered_docs = [d for _, d in indexed]
query_terms = set(_tokenize(query))
best_idx, best_score = 0, -1
for idx, d in enumerate(ordered_docs):
d_lower = d.lower()
s = sum(1 for t in query_terms if t in d_lower)
if s > best_score:
best_score, best_idx = s, idx
start = max(0, best_idx - 1)
end = min(len(ordered_docs), best_idx + 2)
expanded = "\n\n".join(ordered_docs[start:end])
if len(expanded) > MAX_HYDRATION_CHARS:
expanded = (
expanded[:MAX_HYDRATION_CHARS]
+ f"\n\n[...truncated. {len(ordered_docs)} total drawers. "
"Use mempalace_get_drawer for full content.]"
)
h["text"] = expanded
h["drawer_index"] = best_idx
h["total_drawers"] = len(ordered_docs)
# BM25 hybrid re-rank within the final candidate set.
hits = _hybrid_rank(hits, query)
for h in hits:
h.pop("_sort_key", None)
h.pop("_source_file_full", None)
h.pop("_chunk_index", None)
return { return {
"query": query, "query": query,
"filters": {"wing": wing, "room": room}, "filters": {"wing": wing, "room": room},
"total_before_filter": len(docs), "total_before_filter": len(drawer_results["documents"][0]),
"results": hits, "results": hits,
} }
+339
View File
@@ -0,0 +1,339 @@
"""Unit tests for the optional LLM-based closet regeneration.
These tests don't hit the network. They mock urllib to verify:
- LLMConfig correctly reads env vars and CLI overrides
- missing config is reported cleanly
- the OpenAI-compatible request shape is correct
- response parsing handles the standard chat-completions payload
"""
import json
import tempfile
from unittest.mock import patch
from mempalace.closet_llm import (
LLMConfig,
_call_llm,
_parsed_to_closet_lines,
regenerate_closets,
)
# ── LLMConfig ─────────────────────────────────────────────────────────────
class TestLLMConfig:
def test_reads_env_vars(self, monkeypatch):
monkeypatch.setenv("LLM_ENDPOINT", "http://localhost:11434/v1")
monkeypatch.setenv("LLM_KEY", "sk-abc")
monkeypatch.setenv("LLM_MODEL", "llama3:8b")
c = LLMConfig()
assert c.endpoint == "http://localhost:11434/v1"
assert c.key == "sk-abc"
assert c.model == "llama3:8b"
def test_cli_flags_override_env(self, monkeypatch):
monkeypatch.setenv("LLM_ENDPOINT", "http://env-endpoint/v1")
monkeypatch.setenv("LLM_MODEL", "env-model")
c = LLMConfig(endpoint="http://flag-endpoint/v1", model="flag-model")
assert c.endpoint == "http://flag-endpoint/v1"
assert c.model == "flag-model"
def test_trailing_slash_stripped(self):
c = LLMConfig(endpoint="http://foo/v1/", model="m")
assert c.endpoint == "http://foo/v1"
def test_missing_reports_required(self, monkeypatch):
monkeypatch.delenv("LLM_ENDPOINT", raising=False)
monkeypatch.delenv("LLM_KEY", raising=False)
monkeypatch.delenv("LLM_MODEL", raising=False)
c = LLMConfig()
missing = c.missing()
assert any("ENDPOINT" in m for m in missing)
assert any("MODEL" in m for m in missing)
# key is optional
assert not any("KEY" in m for m in missing)
def test_key_is_optional(self, monkeypatch):
monkeypatch.delenv("LLM_KEY", raising=False)
c = LLMConfig(endpoint="http://local/v1", model="m")
assert c.missing() == []
# ── _parsed_to_closet_lines ──────────────────────────────────────────────
class TestParsedToLines:
def test_topics_become_pointers(self):
parsed = {"topics": ["authentication", "jwt tokens"], "quotes": [], "summary": ""}
lines = _parsed_to_closet_lines(parsed, ["d1", "d2"], "Alice;Bob")
assert len(lines) == 2
assert "authentication|Alice;Bob|→d1,d2" in lines
assert "jwt tokens|Alice;Bob|→d1,d2" in lines
def test_quotes_and_summary_included(self):
parsed = {
"topics": ["t1"],
"quotes": ["[Igor] we ship Friday"],
"summary": "Release planning discussion",
}
lines = _parsed_to_closet_lines(parsed, ["d1"], "")
joined = "\n".join(lines)
assert "we ship Friday" in joined
assert "Release planning discussion" in joined
def test_caps_topics_at_15(self):
parsed = {"topics": [f"t{i}" for i in range(20)], "quotes": [], "summary": ""}
lines = _parsed_to_closet_lines(parsed, ["d1"], "")
assert len(lines) == 15
# ── _call_llm (HTTP mocked) ──────────────────────────────────────────────
class _FakeResp:
"""Mimics urlopen's context-manager response."""
def __init__(self, payload: dict, status: int = 200):
self._body = json.dumps(payload).encode("utf-8")
self.status = status
def __enter__(self):
return self
def __exit__(self, *a):
return False
def read(self):
return self._body
class TestCallLLM:
def _make_cfg(self):
return LLMConfig(endpoint="http://localhost:11434/v1", key="sk-test", model="llama3:8b")
def test_request_shape_and_parsing(self):
cfg = self._make_cfg()
captured = {}
def fake_urlopen(req, timeout=None):
captured["url"] = req.full_url
captured["headers"] = dict(req.header_items())
captured["body"] = json.loads(req.data.decode("utf-8"))
return _FakeResp(
{
"choices": [
{
"message": {
"content": json.dumps(
{
"topics": ["postgres"],
"quotes": ["[Igor] migrate now"],
"summary": "db migration",
}
)
}
}
],
"usage": {"prompt_tokens": 42, "completion_tokens": 17},
}
)
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
parsed, usage = _call_llm(cfg, "/tmp/test.md", "w", "r", "content body")
assert parsed["topics"] == ["postgres"]
assert usage["prompt_tokens"] == 42
assert captured["url"] == "http://localhost:11434/v1/chat/completions"
# Authorization header is stored capitalized-then-lowercase depending on urllib version
auth_vals = {v for k, v in captured["headers"].items() if k.lower() == "authorization"}
assert "Bearer sk-test" in auth_vals
assert captured["body"]["model"] == "llama3:8b"
assert captured["body"]["messages"][0]["role"] == "user"
def test_omits_auth_header_when_no_key(self):
cfg = LLMConfig(endpoint="http://localhost:11434/v1", model="llama3:8b")
captured_headers = {}
def fake_urlopen(req, timeout=None):
captured_headers.update({k.lower(): v for k, v in req.header_items()})
return _FakeResp(
{
"choices": [{"message": {"content": '{"topics":[],"quotes":[],"summary":""}'}}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
}
)
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
_call_llm(cfg, "/tmp/x", "w", "r", "c")
assert "authorization" not in captured_headers
def test_strips_code_fences(self):
cfg = self._make_cfg()
fenced = '```json\n{"topics":["t1"],"quotes":[],"summary":""}\n```'
def fake_urlopen(req, timeout=None):
return _FakeResp(
{
"choices": [{"message": {"content": fenced}}],
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
}
)
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
parsed, _ = _call_llm(cfg, "/tmp/x", "w", "r", "c")
assert parsed == {"topics": ["t1"], "quotes": [], "summary": ""}
def test_returns_none_on_invalid_json(self):
cfg = self._make_cfg()
def fake_urlopen(req, timeout=None):
return _FakeResp(
{
"choices": [{"message": {"content": "not json at all"}}],
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
}
)
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
parsed, usage = _call_llm(cfg, "/tmp/x", "w", "r", "c")
assert parsed is None
# ── regenerate_closets error paths ───────────────────────────────────────
class TestRegenerateClosets:
def test_missing_config_returns_error(self, monkeypatch):
monkeypatch.delenv("LLM_ENDPOINT", raising=False)
monkeypatch.delenv("LLM_MODEL", raising=False)
with tempfile.TemporaryDirectory() as palace:
result = regenerate_closets(palace)
assert result["error"] == "missing-config"
assert any("ENDPOINT" in m for m in result["missing"])
def test_regen_purges_regex_closets_and_stamps_normalize_version(self, tmp_path):
"""Regression: before the hardening, regex closets for the same
source survived alongside fresh LLM closets (the old path used a
bare ``closets_col.delete(ids=...)`` with a swallowed exception).
Now we go through ``purge_file_closets`` + ``mine_lock`` + stamp
``NORMALIZE_VERSION`` so the next mine's stale-version gate doesn't
treat the LLM closets as leftovers to rebuild over."""
from mempalace.palace import (
NORMALIZE_VERSION,
get_closets_collection,
get_collection,
upsert_closet_lines,
)
palace = str(tmp_path / "palace")
# Seed one drawer and a pre-existing regex closet for the same source.
source = "/proj/story.md"
drawers = get_collection(palace, create=True)
drawers.upsert(
ids=["drawer_01"],
documents=["Content about JWT authentication."],
metadatas=[
{
"wing": "project",
"room": "auth",
"source_file": source,
"entities": "",
}
],
)
closets = get_closets_collection(palace)
upsert_closet_lines(
closets,
closet_id_base="closet_old_regex",
lines=["STALE_REGEX_TOPIC|;|→drawer_01"],
metadata={
"wing": "project",
"room": "auth",
"source_file": source,
"generated_by": "regex",
},
)
cfg = LLMConfig(endpoint="http://local/v1", model="llama3:8b")
def fake_urlopen(req, timeout=None):
return _FakeResp(
{
"choices": [
{
"message": {
"content": json.dumps(
{
"topics": ["jwt auth", "session expiry"],
"quotes": [],
"summary": "auth refactor",
}
)
}
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
}
)
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
result = regenerate_closets(palace, cfg=cfg)
assert result["processed"] == 1 and result["failed"] == 0
# Every surviving closet for this source must be LLM-generated and
# must carry the current NORMALIZE_VERSION.
survivors = closets.get(where={"source_file": source}, include=["documents", "metadatas"])
assert survivors["ids"], "LLM closets should have been written"
joined = "\n".join(survivors["documents"])
assert (
"STALE_REGEX_TOPIC" not in joined
), "pre-existing regex closet was not purged before LLM write"
assert "jwt auth" in joined
for meta in survivors["metadatas"]:
assert meta.get("generated_by", "").startswith("llm:")
assert meta.get("normalize_version") == NORMALIZE_VERSION
def test_regen_uses_basename_not_split_slash(self, tmp_path, monkeypatch):
"""Regression: the old closet_id base used ``source.split('/')[-1]``
which silently degrades on Windows paths (``C:\\proj\\a.md`` →
the whole string). ``os.path.basename`` handles both separators."""
from mempalace.palace import get_collection, get_closets_collection
palace = str(tmp_path / "palace")
# Use a path whose basename differs between '/' split and
# os.path.basename only on a platform-aware function, but verify
# at minimum that IDs encode just the filename, not the full path.
source = "/deep/nested/project/dir/mydoc.md"
drawers = get_collection(palace, create=True)
drawers.upsert(
ids=["d1"],
documents=["body"],
metadatas=[{"wing": "w", "room": "r", "source_file": source, "entities": ""}],
)
cfg = LLMConfig(endpoint="http://local/v1", model="m")
def fake_urlopen(req, timeout=None):
return _FakeResp(
{
"choices": [
{"message": {"content": '{"topics":["t1"],"quotes":[],"summary":""}'}}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
}
)
with patch("urllib.request.urlopen", side_effect=fake_urlopen):
regenerate_closets(palace, cfg=cfg)
closets = get_closets_collection(palace)
ids = closets.get(where={"source_file": source}).get("ids", [])
assert ids
# IDs must not leak the full path (would happen if we used
# source.split('/')[-1] on Windows, or forgot to strip entirely).
for cid in ids:
assert "/" not in cid
assert "mydoc.md" in cid
+740 -82
View File
@@ -1,35 +1,154 @@
""" """
test_closets.py — Tests for the closet (searchable index) layer. test_closets.py — Tests for the closet (searchable index) layer and the
features that ride on top of it: mine_lock serialization, entity metadata,
hybrid BM25+vector search, and diary ingest.
Covers: Coverage map:
* build_closet_lines — pointer-line shape, entity extraction, stoplist, * mine_lock — acquire/release, blocks concurrent acquisition.
quote/header pickup, and the "always emit one line" guarantee. * build_closet_lines — pointer-line shape, header pickup, entity stoplist
* upsert_closet_lines — pure overwrite (no append), char-limit packing, (regression for "When/After/The"), real-name survival, fallback line.
atomic-line guarantee. * upsert_closet_lines — pure overwrite (regression for the append bug),
* purge_file_closets — wipes prior closets so a re-mine starts clean. char-limit packing without splitting a line.
* The end-to-end rebuild: re-mining a file fully replaces its closets, * purge_file_closets — scoped to source_file.
including when the prior run produced more numbered closets. * Project-miner end-to-end rebuild — re-mining with fewer topics fully
* search_memories closet-first path — returns chunk-level hits parsed purges leftover numbered closets from a larger prior run.
from `→drawer_ids` pointers, falls back when closets are empty, * _extract_drawer_ids_from_closet — pointer parsing + dedup.
respects max_distance. * search_memories hybrid path — drawer query always the floor,
closets boost matching source_file, matched_via reflects both signals,
no whole-file glue, max_distance enforcement.
* Entity metadata — extracted, stoplist applied, registry cached by mtime.
* Real BM25 — real IDF over candidate corpus, hybrid rerank.
* Diary ingest — drawers + closets created, incremental skips, state
file lives outside the diary dir, wing-prefixed drawer IDs prevent
cross-diary collisions, force=True purges leftover closets.
""" """
from mempalace.miner import mine import json
import multiprocessing
import os
import tempfile
import threading
import time
import yaml
from mempalace.miner import (
_extract_entities_for_metadata,
_load_known_entities,
mine,
)
from mempalace.palace import ( from mempalace.palace import (
CLOSET_CHAR_LIMIT, CLOSET_CHAR_LIMIT,
build_closet_lines, build_closet_lines,
get_closets_collection, get_closets_collection,
get_collection,
mine_lock,
purge_file_closets, purge_file_closets,
upsert_closet_lines, upsert_closet_lines,
) )
from mempalace.searcher import _extract_drawer_ids_from_closet, search_memories from mempalace.palace_graph import (
create_tunnel,
delete_tunnel,
follow_tunnels,
list_tunnels,
)
from mempalace.searcher import (
_bm25_scores,
_expand_with_neighbors,
_extract_drawer_ids_from_closet,
_hybrid_rank,
search_memories,
)
# ── mine_lock ────────────────────────────────────────────────────────────
def _lock_worker(target: str, name: str, hold_seconds: float, log_path: str) -> None:
"""Worker for multiprocessing-spawn concurrency test. Writes its
critical-section enter/exit timestamps to ``log_path`` so the test
can verify the sections did not overlap in time."""
import time as _time
from mempalace.palace import mine_lock as _mine_lock
with _mine_lock(target):
t_enter = _time.time()
_time.sleep(hold_seconds)
t_exit = _time.time()
# Append atomically so concurrent writers don't stomp each other.
with open(log_path, "a") as f:
f.write(f"{name} {t_enter} {t_exit}\n")
f.flush()
class TestMineLock:
def test_lock_acquires_and_releases(self, tmp_path):
target = str(tmp_path / "lock_target.txt")
with mine_lock(target):
lock_dir = os.path.expanduser("~/.mempalace/locks")
assert os.path.isdir(lock_dir)
# Re-acquire after release should succeed instantly.
start = time.time()
with mine_lock(target):
pass
assert time.time() - start < 1.0
def test_lock_blocks_concurrent_access(self, tmp_path):
"""The lock's contract is inter-*process* (multi-agent), not
inter-thread. Use multiprocessing so the test reflects the real
use case and is portable: on macOS/BSD ``fcntl.flock`` is
per-process, so two threads would both acquire — a thread-based
test would flake there even when the lock is correct.
Verify mutual exclusion by the effect the critical section
actually has — each worker records its enter/exit timestamps
under the lock, and the test asserts the two intervals do not
overlap. This is robust to spawn-overhead timing, unlike
"second worker waited at least N seconds" which flakes when CI
spawn latency eats into the hold window.
"""
target = str(tmp_path / "concurrent_lock.txt")
log_path = str(tmp_path / "critical_section.log")
# Spawn so the same code path runs on every OS (macOS 3.8+ and
# Windows already default to spawn; Linux is fork by default).
ctx = multiprocessing.get_context("spawn")
# Each worker holds the lock for HOLD seconds. With real mutual
# exclusion, the two [enter, exit] intervals must be disjoint.
HOLD = 0.3
p1 = ctx.Process(target=_lock_worker, args=(target, "a", HOLD, log_path))
p2 = ctx.Process(target=_lock_worker, args=(target, "b", HOLD, log_path))
p1.start()
p2.start()
p1.join(timeout=30)
p2.join(timeout=30)
assert p1.exitcode == 0, f"p1 exited non-zero: {p1.exitcode}"
assert p2.exitcode == 0, f"p2 exited non-zero: {p2.exitcode}"
# Parse the log: "<name> <enter_ts> <exit_ts>".
intervals = []
with open(log_path) as f:
for line in f:
parts = line.strip().split()
if len(parts) == 3:
intervals.append((parts[0], float(parts[1]), float(parts[2])))
assert len(intervals) == 2, f"expected two critical sections, got {intervals}"
# Sort by entry time and verify the second entry is after the first exit.
intervals.sort(key=lambda iv: iv[1])
(_, enter_a, exit_a), (_, enter_b, exit_b) = intervals
assert (
enter_a < exit_a <= enter_b < exit_b
), f"critical sections overlapped — lock failed to serialize: {intervals}"
# ── build_closet_lines ───────────────────────────────────────────────── # ── build_closet_lines ─────────────────────────────────────────────────
class TestBuildClosetLines: class TestBuildClosetLines:
def test_emits_pointer_line_shape(self, tmp_path): def test_emits_pointer_line_shape(self):
content = ( content = (
"# Auth rewrite\n\n" "# Auth rewrite\n\n"
"Decided we need to migrate to passkeys. " "Decided we need to migrate to passkeys. "
@@ -59,7 +178,7 @@ class TestBuildClosetLines:
def test_entity_stoplist_filters_sentence_starters(self): def test_entity_stoplist_filters_sentence_starters(self):
# "When", "After", "The" repeat 3+ times — old code would index them # "When", "After", "The" repeat 3+ times — old code would index them
# as entities. New code's stoplist drops them. # as entities. Stoplist drops them.
content = ( content = (
"When the pipeline ran, the result was good. " "When the pipeline ran, the result was good. "
"When the user logged in, the token was issued. " "When the user logged in, the token was issued. "
@@ -68,7 +187,6 @@ class TestBuildClosetLines:
"The new flow is stable. The audit cleared." "The new flow is stable. The audit cleared."
) )
lines = build_closet_lines("/x.md", ["d1"], content, "w", "r") lines = build_closet_lines("/x.md", ["d1"], content, "w", "r")
# Entities sit between the two pipes
entity_segments = [line.split("|")[1] for line in lines] entity_segments = [line.split("|")[1] for line in lines]
for seg in entity_segments: for seg in entity_segments:
tokens = set(seg.split(";")) if seg else set() tokens = set(seg.split(";")) if seg else set()
@@ -83,13 +201,11 @@ class TestBuildClosetLines:
"Igor and Milla shipped together." "Igor and Milla shipped together."
) )
lines = build_closet_lines("/x.md", ["d1"], content, "w", "r") lines = build_closet_lines("/x.md", ["d1"], content, "w", "r")
entity_segments = [line.split("|")[1] for line in lines] joined_entities = ";".join(line.split("|")[1] for line in lines)
joined_entities = ";".join(entity_segments)
assert "Igor" in joined_entities assert "Igor" in joined_entities
assert "Milla" in joined_entities assert "Milla" in joined_entities
def test_emits_fallback_line_when_nothing_extractable(self): def test_emits_fallback_line_when_nothing_extractable(self):
# No headers, no action verbs, no quotes, no repeated capitalized words
content = "lorem ipsum dolor sit amet consectetur adipiscing elit" content = "lorem ipsum dolor sit amet consectetur adipiscing elit"
lines = build_closet_lines("/x/notes.txt", ["d1"], content, "wing", "room") lines = build_closet_lines("/x/notes.txt", ["d1"], content, "wing", "room")
assert len(lines) == 1 assert len(lines) == 1
@@ -111,11 +227,9 @@ class TestUpsertClosetLines:
base = "closet_test_room_abc" base = "closet_test_room_abc"
meta = {"wing": "test", "room": "room", "source_file": "/x.md"} meta = {"wing": "test", "room": "room", "source_file": "/x.md"}
# First mine — three short lines.
upsert_closet_lines(col, base, ["alpha|;|→d1", "beta|;|→d2", "gamma|;|→d3"], meta) upsert_closet_lines(col, base, ["alpha|;|→d1", "beta|;|→d2", "gamma|;|→d3"], meta)
first = col.get(ids=[f"{base}_01"]) first = col.get(ids=[f"{base}_01"])
assert "alpha" in first["documents"][0] assert "alpha" in first["documents"][0]
assert "beta" in first["documents"][0]
# Second mine — entirely different lines. Must replace, not append. # Second mine — entirely different lines. Must replace, not append.
upsert_closet_lines(col, base, ["delta|;|→d4", "epsilon|;|→d5"], meta) upsert_closet_lines(col, base, ["delta|;|→d4", "epsilon|;|→d5"], meta)
@@ -131,18 +245,15 @@ class TestUpsertClosetLines:
base = "closet_pack_room_def" base = "closet_pack_room_def"
meta = {"wing": "test", "room": "room", "source_file": "/y.md"} meta = {"wing": "test", "room": "room", "source_file": "/y.md"}
# Build lines that approach but never exceed the limit.
line = "x" * 600 # well under CLOSET_CHAR_LIMIT line = "x" * 600 # well under CLOSET_CHAR_LIMIT
n_written = upsert_closet_lines(col, base, [line, line, line, line], meta) n_written = upsert_closet_lines(col, base, [line, line, line, line], meta)
# 4 lines @ 600+1 chars = 2404 — should pack into 2 closets (≤1500 each) # 4 lines @ 601 chars each = 2404 — should pack into 2 closets
assert n_written == 2 assert n_written == 2
for i in range(1, n_written + 1): for i in range(1, n_written + 1):
doc = col.get(ids=[f"{base}_{i:02d}"])["documents"][0] doc = col.get(ids=[f"{base}_{i:02d}"])["documents"][0]
# Every line is intact (never split mid-line)
for chunk in doc.split("\n"): for chunk in doc.split("\n"):
assert len(chunk) == 600, f"line was truncated in closet {i}" assert len(chunk) == 600, f"line was truncated in closet {i}"
# Closet stays under the cap
assert len(doc) <= CLOSET_CHAR_LIMIT assert len(doc) <= CLOSET_CHAR_LIMIT
@@ -161,19 +272,16 @@ class TestPurgeFileClosets:
], ],
) )
purge_file_closets(col, "/drop.md") purge_file_closets(col, "/drop.md")
remaining_ids = set(col.get()["ids"]) remaining_ids = set(col.get()["ids"])
assert "closet_a_01" in remaining_ids assert "closet_a_01" in remaining_ids
assert "closet_b_01" not in remaining_ids assert "closet_b_01" not in remaining_ids
# ── End-to-end rebuild via the project miner ────────────────────────── # ── project miner: closet rebuild end-to-end ──────────────────────────
class TestMinerClosetRebuild: class TestMinerClosetRebuild:
def test_remine_replaces_closets_completely(self, tmp_path): def test_remine_replaces_closets_completely(self, tmp_path):
import yaml
project = tmp_path / "proj" project = tmp_path / "proj"
project.mkdir() project.mkdir()
(project / "mempalace.yaml").write_text( (project / "mempalace.yaml").write_text(
@@ -193,16 +301,11 @@ class TestMinerClosetRebuild:
first_ids = set(first_pass["ids"]) first_ids = set(first_pass["ids"])
assert any("topic 0" in (d or "").lower() for d in first_pass["documents"]) assert any("topic 0" in (d or "").lower() for d in first_pass["documents"])
# Touch mtime so file_already_mined doesn't short-circuit, and # Touch mtime + shrink content so the rebuild produces fewer closets.
# rewrite with fewer topics (so the rebuild produces fewer closets
# than the first run).
import os
import time
target.write_text("# Only Topic Now\n" + ("short body " * 5)) target.write_text("# Only Topic Now\n" + ("short body " * 5))
new_mtime = os.path.getmtime(target) + 60 new_mtime = os.path.getmtime(target) + 60
os.utime(target, (new_mtime, new_mtime)) os.utime(target, (new_mtime, new_mtime))
time.sleep(0.01) # ensure mtime delta is visible time.sleep(0.01)
mine(str(project), str(palace), wing_override="proj", agent="test") mine(str(project), str(palace), wing_override="proj", agent="test")
@@ -231,19 +334,11 @@ class TestExtractDrawerIds:
def test_parses_multiple_pointers_per_line(self): def test_parses_multiple_pointers_per_line(self):
line = "topic|ent|→drawer_a,drawer_b,drawer_c" line = "topic|ent|→drawer_a,drawer_b,drawer_c"
assert _extract_drawer_ids_from_closet(line) == [ assert _extract_drawer_ids_from_closet(line) == ["drawer_a", "drawer_b", "drawer_c"]
"drawer_a",
"drawer_b",
"drawer_c",
]
def test_dedupes_across_lines(self): def test_dedupes_across_lines(self):
doc = "one|;|→drawer_a,drawer_b\ntwo|;|→drawer_b,drawer_c" doc = "one|;|→drawer_a,drawer_b\ntwo|;|→drawer_b,drawer_c"
assert _extract_drawer_ids_from_closet(doc) == [ assert _extract_drawer_ids_from_closet(doc) == ["drawer_a", "drawer_b", "drawer_c"]
"drawer_a",
"drawer_b",
"drawer_c",
]
def test_empty_doc_returns_empty(self): def test_empty_doc_returns_empty(self):
assert _extract_drawer_ids_from_closet("") == [] assert _extract_drawer_ids_from_closet("") == []
@@ -253,64 +348,627 @@ class TestExtractDrawerIds:
# ── search_memories closet-first path ──────────────────────────────── # ── search_memories closet-first path ────────────────────────────────
class TestSearchMemoriesClosetFirst: class TestSearchMemoriesHybrid:
def test_falls_back_to_direct_when_no_closets(self, palace_path, seeded_collection): def test_pure_drawer_when_no_closets(self, palace_path, seeded_collection):
# seeded_collection populates only mempalace_drawers, not closets. """Palaces without closets return results via direct drawer search —
every hit must advertise that the closet signal was absent."""
result = search_memories("JWT authentication", palace_path) result = search_memories("JWT authentication", palace_path)
assert result["results"], "should still find drawer hits via fallback" assert result["results"], "should still find drawer hits"
for hit in result["results"]: for hit in result["results"]:
assert hit.get("matched_via") == "drawer" assert hit.get("matched_via") == "drawer"
assert hit.get("closet_boost") == 0.0
assert "closet_preview" not in hit
def test_closet_first_returns_chunk_level_hits(self, palace_path, seeded_collection): def test_closet_boost_marks_hit_as_drawer_plus_closet(self, palace_path, seeded_collection):
# Build a closet that points at the JWT drawer specifically. """When a closet agrees with direct search on source_file, the
matching drawer's ``matched_via`` switches to ``drawer+closet`` and
``closet_preview`` exposes the hydrated index line."""
closets = get_closets_collection(palace_path) closets = get_closets_collection(palace_path)
# Seed the closet against the same source_file the drawer uses so
# the boost lookup keys align.
closets.upsert( closets.upsert(
ids=["closet_proj_backend_aaa_01"], ids=["closet_proj_backend_aaa_01"],
documents=["JWT auth tokens|;|→drawer_proj_backend_aaa"], documents=["JWT auth tokens|;|→drawer_proj_backend_aaa"],
metadatas=[ metadatas=[{"wing": "project", "room": "backend", "source_file": "auth.py"}],
{
"wing": "project",
"room": "backend",
"source_file": "auth.py",
}
],
) )
result = search_memories("JWT authentication", palace_path) result = search_memories("JWT authentication", palace_path)
assert result["results"], "closet-first search should hydrate the drawer" assert result["results"], "hybrid search should still return results"
top = result["results"][0] # The JWT-bearing drawer should surface with closet agreement.
assert top["matched_via"] == "closet" boosted = [h for h in result["results"] if h["matched_via"] == "drawer+closet"]
# Must be the chunk-level drawer text, not a concatenation of every assert boosted, "closet agreement should promote the matching source"
# drawer in the file. top = boosted[0]
assert "JWT" in top["text"] assert "JWT" in top["text"]
assert ( assert top["closet_boost"] > 0
"Database migrations" not in top["text"]
), "closet path should not glue unrelated drawers together"
assert "closet_preview" in top
assert "→drawer_proj_backend_aaa" in top["closet_preview"] assert "→drawer_proj_backend_aaa" in top["closet_preview"]
def test_max_distance_filters_closet_hits(self, palace_path, seeded_collection): def test_max_distance_filters_hybrid_hits(self, palace_path, seeded_collection):
closets = get_closets_collection(palace_path) closets = get_closets_collection(palace_path)
closets.upsert( closets.upsert(
ids=["closet_proj_backend_aaa_01"], ids=["closet_proj_backend_aaa_01"],
documents=["JWT auth tokens|;|→drawer_proj_backend_aaa"], documents=["JWT auth tokens|;|→drawer_proj_backend_aaa"],
metadatas=[ metadatas=[{"wing": "project", "room": "backend", "source_file": "auth.py"}],
{
"wing": "project",
"room": "backend",
"source_file": "auth.py",
}
],
) )
# max_distance=0.001 is essentially "must match exactly". The closet
# path should reject everything and the caller falls back to direct
# search (which also filters with the same threshold).
result = search_memories( result = search_memories(
"completely unrelated query about quantum gardening", "completely unrelated query about quantum gardening",
palace_path, palace_path,
max_distance=0.001, max_distance=0.001,
) )
# Either no results, or every result respected the threshold.
for hit in result["results"]: for hit in result["results"]:
assert hit["distance"] <= 0.001 assert hit["distance"] <= 0.001
# ── entity metadata ──────────────────────────────────────────────────
class TestEntityMetadata:
def test_extracts_capitalized_names(self):
text = "Ben reviewed the code. Ben approved it. Igor flagged two issues. Igor fixed them."
entities = _extract_entities_for_metadata(text)
assert "Ben" in entities
assert "Igor" in entities
def test_empty_for_no_entities(self):
text = "this is all lowercase with no proper nouns at all"
assert _extract_entities_for_metadata(text) == ""
def test_semicolon_separated(self):
text = "Alice and Bob met Charlie. Alice said hello. Bob agreed. Charlie laughed."
entities = _extract_entities_for_metadata(text)
assert ";" in entities
def test_stoplist_filters_sentence_starters(self):
# Same regression as the closet entity test — "When/After/The" must
# not become entities just because they're capitalized 2+ times.
text = (
"When the build broke, the team paged. "
"When the fix landed, the alarm cleared. "
"After the rollback, the queue drained. "
"After the deploy, the latency normalized."
)
entities = _extract_entities_for_metadata(text)
tokens = set(entities.split(";")) if entities else set()
assert "When" not in tokens
assert "After" not in tokens
assert "The" not in tokens
def test_capped_list_never_truncates_a_name(self):
# 30 distinct repeated proper nouns — extraction should cap the list
# before joining so a name never gets cut in half.
# Use morphologically distinct stems so the [A-Z][a-z]+ regex sees
# each as its own token.
names = [
"Anna",
"Brian",
"Carol",
"David",
"Elena",
"Frank",
"Grace",
"Harold",
"Iris",
"Julian",
"Kira",
"Liam",
"Maya",
"Noah",
"Oscar",
"Penny",
"Quinn",
"Rosa",
"Sergei",
"Tara",
"Umar",
"Vera",
"Walter",
"Xander",
"Yvonne",
"Zachary",
"Amelia",
"Boris",
"Clara",
"Dmitri",
]
text = " ".join(f"{n} met {n}." for n in names)
entities = _extract_entities_for_metadata(text)
extracted = [n for n in entities.split(";") if n]
assert extracted, "should have extracted some entities"
for name in extracted:
assert name in names, f"truncation produced a partial token: {name!r}"
def test_known_registry_is_cached_by_mtime(self, monkeypatch, tmp_path):
# Point the registry at a temp file we control, exercise the cache.
registry = tmp_path / "known_entities.json"
registry.write_text(json.dumps({"people": ["Zelda"]}))
from mempalace import miner
monkeypatch.setattr(miner, "_ENTITY_REGISTRY_PATH", str(registry))
miner._ENTITY_REGISTRY_CACHE["mtime"] = None
miner._ENTITY_REGISTRY_CACHE["names"] = frozenset()
first = _load_known_entities()
assert "Zelda" in first
# Second call without changing mtime: must reuse cache, not re-read.
read_count = {"n": 0}
original_open = open
def counting_open(path, *a, **kw):
if str(path) == str(registry):
read_count["n"] += 1
return original_open(path, *a, **kw)
monkeypatch.setattr("builtins.open", counting_open)
_load_known_entities()
assert read_count["n"] == 0, "registry should not be re-read when mtime unchanged"
# Bump mtime → cache must invalidate.
new_mtime = os.path.getmtime(registry) + 5
os.utime(registry, (new_mtime, new_mtime))
registry.write_text(json.dumps({"people": ["Zelda", "Link"]}))
os.utime(registry, (new_mtime, new_mtime))
names = _load_known_entities()
assert "Link" in names
# ── BM25 hybrid search (real IDF over candidate corpus) ──────────────
class TestBM25:
def test_scores_positive_for_matching_doc(self):
scores = _bm25_scores(
"database migration",
["We migrated the database to Postgres.", "unrelated cookery tips"],
)
assert scores[0] > 0
assert scores[1] == 0.0
def test_scores_zero_when_no_overlap(self):
scores = _bm25_scores("quantum physics", ["We built a web app in React"])
assert scores == [0.0]
def test_idf_downweights_terms_present_in_every_doc(self):
# "database" appears in every candidate → low IDF → low contribution.
# "vacuum" is unique to one → high IDF → that doc dominates.
scores = _bm25_scores(
"database vacuum",
[
"database backup nightly schedule",
"database vacuum scheduled weekly",
"database failover plan",
],
)
assert scores[1] == max(scores), "doc with the rare query term should win on IDF"
def test_empty_inputs_return_zeros(self):
assert _bm25_scores("", ["hello world"]) == [0.0]
assert _bm25_scores("query here", []) == []
assert _bm25_scores("query", [""]) == [0.0]
def test_hybrid_rank_promotes_keyword_match(self):
results = [
{"text": "database schema design for Postgres", "distance": 0.5},
{"text": "unrelated topic about cooking", "distance": 0.3},
]
ranked = _hybrid_rank(results, "database Postgres schema")
# The keyword-rich result outranks the closer-vector but irrelevant one.
assert "database" in ranked[0]["text"]
# bm25_score field is exposed for debugging.
assert "bm25_score" in ranked[0]
# No internal scoring leak.
assert "_hybrid_score" not in ranked[0]
def test_hybrid_rank_absolute_normalization(self):
# Adding a much-worse result to the candidate set must NOT reshuffle
# the top two — proves we're using absolute (1 - dist) and not
# dist / max_dist normalization.
base = [
{"text": "alpha alpha alpha", "distance": 0.1},
{"text": "beta beta beta", "distance": 0.4},
]
ranked_short = _hybrid_rank([dict(r) for r in base], "alpha")
with_outlier = base + [{"text": "gamma gamma gamma", "distance": 1.9}]
ranked_long = _hybrid_rank([dict(r) for r in with_outlier], "alpha")
assert ranked_short[0]["text"] == ranked_long[0]["text"]
assert ranked_short[1]["text"] == ranked_long[1]["text"]
# ── diary ingest ─────────────────────────────────────────────────────
class TestDiaryIngest:
def test_ingest_creates_drawers_and_closets(self, tmp_path):
diary_dir = tmp_path / "diaries"
diary_dir.mkdir()
(diary_dir / "2026-04-13.md").write_text(
"# 2026-04-13\n\n## 10:00 PDT — Test\n\nBuilt the auth system.\n"
)
palace_dir = tmp_path / "palace"
from mempalace.diary_ingest import ingest_diaries
result = ingest_diaries(str(diary_dir), str(palace_dir), force=True)
assert result["days_updated"] >= 1
assert get_collection(str(palace_dir)).count() >= 1
def test_ingest_skips_unchanged_on_second_run(self, tmp_path):
diary_dir = tmp_path / "diaries"
diary_dir.mkdir()
(diary_dir / "2026-04-13.md").write_text(
"# 2026-04-13\n\n## 10:00 — Test\n\nContent here that's long enough.\n"
)
palace_dir = tmp_path / "palace"
from mempalace.diary_ingest import ingest_diaries
ingest_diaries(str(diary_dir), str(palace_dir), force=True)
result = ingest_diaries(str(diary_dir), str(palace_dir))
assert result["days_updated"] == 0
def test_state_file_lives_outside_diary_dir(self, tmp_path):
# Regression: the original implementation wrote
# ``.diary_ingest_state.json`` *inside* the user's diary directory,
# polluting their content folder. State must live under
# ``~/.mempalace/state/`` instead.
diary_dir = tmp_path / "diaries"
diary_dir.mkdir()
(diary_dir / "2026-04-13.md").write_text(
"# 2026-04-13\n\n## 10:00 — Test\n\nBody content here long enough.\n"
)
palace_dir = tmp_path / "palace"
from mempalace.diary_ingest import _state_file_for, ingest_diaries
ingest_diaries(str(diary_dir), str(palace_dir), force=True)
# No state file inside the user's diary dir.
for entry in diary_dir.iterdir():
assert (
"diary_ingest" not in entry.name
), f"state file leaked into user diary dir: {entry}"
# State file does exist under ~/.mempalace/state/.
state_path = _state_file_for(str(palace_dir), diary_dir.resolve())
assert state_path.exists()
# Platform-neutral path check: compare parents rather than a hardcoded
# separator string that would fail on Windows (``\.mempalace\state\``).
assert state_path.parent.name == "state"
assert state_path.parent.parent.name == ".mempalace"
def test_wing_prefixed_drawer_id_prevents_cross_diary_collision(self, tmp_path):
# Regression: the original implementation used
# ``drawer_diary_{date_str}`` regardless of wing — two diaries with
# the same date in different wings would clobber each other.
date_md = "# 2026-04-13\n\n## 10:00 — entry\n\nThis is the day's content.\n"
# Two separate diary dirs, ingested into the same palace under
# different wings. Each must produce a distinct drawer.
personal_dir = tmp_path / "personal"
personal_dir.mkdir()
(personal_dir / "2026-04-13.md").write_text(date_md + "Personal-only marker.\n")
work_dir = tmp_path / "work"
work_dir.mkdir()
(work_dir / "2026-04-13.md").write_text(date_md + "Work-only marker.\n")
palace_dir = tmp_path / "palace"
from mempalace.diary_ingest import _diary_drawer_id, ingest_diaries
ingest_diaries(str(personal_dir), str(palace_dir), wing="personal", force=True)
ingest_diaries(str(work_dir), str(palace_dir), wing="work", force=True)
col = get_collection(str(palace_dir))
personal_id = _diary_drawer_id("personal", "2026-04-13")
work_id = _diary_drawer_id("work", "2026-04-13")
assert personal_id != work_id
personal = col.get(ids=[personal_id])
work = col.get(ids=[work_id])
assert personal["ids"] == [personal_id]
assert work["ids"] == [work_id]
assert "Personal-only marker." in personal["documents"][0]
assert "Work-only marker." in work["documents"][0]
# ── cross-wing tunnels ───────────────────────────────────────────────
class TestTunnels:
"""Tunnels are explicit cross-wing connections stored in
``~/.mempalace/tunnels.json``. Each test points the module-level
``_TUNNEL_FILE`` at a fresh tmp file so tests don't cross-contaminate
or touch the user's real tunnels."""
def setup_method(self):
import mempalace.palace_graph as pg
self._orig = pg._TUNNEL_FILE
self._tmpdir = tempfile.mkdtemp()
pg._TUNNEL_FILE = os.path.join(self._tmpdir, "tunnels.json")
def teardown_method(self):
import mempalace.palace_graph as pg
pg._TUNNEL_FILE = self._orig
import shutil
shutil.rmtree(self._tmpdir, ignore_errors=True)
def test_create_tunnel(self):
t = create_tunnel("wing_api", "auth", "wing_db", "users", label="auth uses users table")
assert t["id"]
assert t["source"]["wing"] == "wing_api"
assert t["source"]["room"] == "auth"
assert t["target"]["wing"] == "wing_db"
assert t["target"]["room"] == "users"
assert t["label"] == "auth uses users table"
def test_list_tunnels_with_and_without_filter(self):
create_tunnel("wing_a", "room1", "wing_b", "room2")
create_tunnel("wing_a", "room3", "wing_c", "room4")
assert len(list_tunnels()) == 2
# Filtering by a wing that appears on either endpoint.
assert len(list_tunnels("wing_a")) == 2
assert len(list_tunnels("wing_c")) == 1
assert len(list_tunnels("wing_nonexistent")) == 0
def test_delete_tunnel(self):
t = create_tunnel("wing_x", "r1", "wing_y", "r2")
delete_tunnel(t["id"])
assert list_tunnels() == []
def test_dedup_same_endpoints_updates_label(self):
create_tunnel("wing_a", "r1", "wing_b", "r2", label="first")
create_tunnel("wing_a", "r1", "wing_b", "r2", label="updated")
tunnels = list_tunnels()
assert len(tunnels) == 1
assert tunnels[0]["label"] == "updated"
def test_follow_tunnels_returns_connected_endpoints(self):
create_tunnel("wing_api", "auth", "wing_db", "users")
create_tunnel("wing_api", "auth", "wing_frontend", "login")
# Unrelated tunnel that must not surface.
create_tunnel("wing_other", "notes", "wing_misc", "scratch")
connections = follow_tunnels("wing_api", "auth")
assert len(connections) == 2
wings = {c["connected_wing"] for c in connections}
assert wings == {"wing_db", "wing_frontend"}
# ── regression: symmetry, durability, validation, concurrency ─────
def test_tunnel_is_symmetric(self):
"""Regression: tunnels are undirected. create(A, B) and create(B, A)
must resolve to the same canonical ID and dedupe into one record —
the second call updates the label instead of creating a dupe."""
first = create_tunnel("wing_a", "r1", "wing_b", "r2", label="forward")
second = create_tunnel("wing_b", "r2", "wing_a", "r1", label="reversed")
assert first["id"] == second["id"]
assert len(list_tunnels()) == 1
assert list_tunnels()[0]["label"] == "reversed"
def test_follow_tunnels_works_from_either_endpoint(self):
"""Symmetric: you can follow_tunnels from either end of the link."""
create_tunnel("wing_api", "auth", "wing_db", "users", label="auth uses users")
from_source = follow_tunnels("wing_api", "auth")
from_target = follow_tunnels("wing_db", "users")
assert len(from_source) == 1
assert len(from_target) == 1
assert from_source[0]["connected_wing"] == "wing_db"
assert from_target[0]["connected_wing"] == "wing_api"
# Both surfaces should carry the same label.
assert from_source[0]["label"] == "auth uses users"
assert from_target[0]["label"] == "auth uses users"
def test_empty_endpoint_fields_rejected(self):
"""Regression: create_tunnel must reject empty strings on any
endpoint field so the JSON store can't grow phantom tunnels."""
import pytest
for args in [
("", "r1", "wing", "r2"),
("wing", "", "wing", "r2"),
("wing", "r1", "", "r2"),
("wing", "r1", "wing", ""),
(" ", "r1", "wing", "r2"), # whitespace-only also rejected
]:
with pytest.raises(ValueError):
create_tunnel(*args)
def test_corrupt_tunnel_file_does_not_lose_new_writes(self):
"""A truncated/corrupt tunnels.json (crash mid-write on a system
without atomic rename) must not leak into subsequent reads — the
file should be treated as empty and a fresh create_tunnel should
persist cleanly."""
import mempalace.palace_graph as pg
# Simulate a crash that left a truncated file behind.
with open(pg._TUNNEL_FILE, "w") as f:
f.write("{not valid json")
# Load should return [] rather than raising.
assert list_tunnels() == []
# A subsequent create must persist (atomic write replaces the corrupt file).
t = create_tunnel("wing_a", "r1", "wing_b", "r2")
assert list_tunnels() == [t]
def test_atomic_write_leaves_no_stray_tmp_file(self):
"""Regression: _save_tunnels uses write-then-os.replace. After a
successful create, there must be no leftover ``tunnels.json.tmp``."""
import mempalace.palace_graph as pg
create_tunnel("wing_a", "r1", "wing_b", "r2")
assert os.path.exists(pg._TUNNEL_FILE)
assert not os.path.exists(pg._TUNNEL_FILE + ".tmp")
def test_concurrent_creates_preserve_all_tunnels(self):
"""Regression: two concurrent create_tunnel calls must not clobber
each other. Without the mine_lock around load+save, the later
writer's snapshot would overwrite the earlier writer's tunnel."""
barrier = threading.Barrier(5)
errors: list = []
def worker(i):
try:
barrier.wait(timeout=2)
create_tunnel(f"wing_{i}", "r", "wing_shared", "hub")
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors, f"worker raised: {errors}"
tunnels = list_tunnels()
assert len(tunnels) == 5, (
f"expected 5 concurrent tunnels, got {len(tunnels)}" "write race dropped some"
)
def test_created_at_is_timezone_aware(self):
"""Regression: created_at must be tz-aware UTC, not naive."""
t = create_tunnel("wing_a", "r1", "wing_b", "r2")
# ISO format with tz offset contains '+' or 'Z'.
assert t["created_at"].endswith("+00:00") or t["created_at"].endswith("Z")
# ── drawer-grep neighbor expansion ────────────────────────────────────
#
# When a closet hit lands on a drawer whose chunk boundary clips a thought
# (matched chunk says "here's a breakdown:" and the breakdown lives in the
# next chunk), the closet path now expands to ±1 neighbor chunks from the
# same source file. These tests pin that behavior end-to-end and at the
# helper level.
class TestDrawerGrepExpansion:
def _seed_source_file(self, palace_path, source: str, n_chunks: int):
"""Helper: put N sequential drawers for a single source file into
the palace and return the drawer IDs keyed by chunk_index."""
col = get_collection(palace_path)
ids = [f"drawer_test_room_{source.replace('/', '_')}_{i:03d}" for i in range(n_chunks)]
docs = [f"chunk_{i} content about topic alpha" for i in range(n_chunks)]
metas = [
{
"wing": "test",
"room": "room",
"source_file": source,
"chunk_index": i,
"filed_at": "2026-04-13T00:00:00",
}
for i in range(n_chunks)
]
col.upsert(ids=ids, documents=docs, metadatas=metas)
return col, {i: ids[i] for i in range(n_chunks)}
def test_expand_returns_matched_plus_neighbors(self, palace_path):
col, by_idx = self._seed_source_file(palace_path, "/proj/doc.md", n_chunks=5)
matched_meta = {"source_file": "/proj/doc.md", "chunk_index": 2}
matched_doc = "chunk_2 content about topic alpha"
out = _expand_with_neighbors(col, matched_doc, matched_meta, radius=1)
assert out["drawer_index"] == 2
assert out["total_drawers"] == 5
# Expect chunks 1, 2, 3 joined in chunk_index order.
text = out["text"]
assert "chunk_1" in text
assert "chunk_2" in text
assert "chunk_3" in text
# No leakage of non-neighbors.
assert "chunk_0" not in text
assert "chunk_4" not in text
# Ordering preserved — chunk_1 before chunk_2 before chunk_3.
assert text.index("chunk_1") < text.index("chunk_2") < text.index("chunk_3")
def test_expand_at_start_of_file_only_has_next_neighbor(self, palace_path):
col, _ = self._seed_source_file(palace_path, "/proj/edge_start.md", n_chunks=3)
out = _expand_with_neighbors(
col,
"chunk_0 content",
{"source_file": "/proj/edge_start.md", "chunk_index": 0},
)
assert out["drawer_index"] == 0
assert out["total_drawers"] == 3
assert "chunk_0" in out["text"]
assert "chunk_1" in out["text"]
# No chunk_-1 could exist; the expansion must not invent one.
assert "chunk_-1" not in out["text"]
def test_expand_at_end_of_file_only_has_prev_neighbor(self, palace_path):
col, _ = self._seed_source_file(palace_path, "/proj/edge_end.md", n_chunks=3)
out = _expand_with_neighbors(
col,
"chunk_2 content",
{"source_file": "/proj/edge_end.md", "chunk_index": 2},
)
assert out["drawer_index"] == 2
assert out["total_drawers"] == 3
assert "chunk_1" in out["text"]
assert "chunk_2" in out["text"]
# No chunk_3 exists.
assert "chunk_3" not in out["text"]
def test_expand_single_drawer_file_returns_just_matched(self, palace_path):
col, _ = self._seed_source_file(palace_path, "/proj/lone.md", n_chunks=1)
out = _expand_with_neighbors(
col,
"chunk_0 content",
{"source_file": "/proj/lone.md", "chunk_index": 0},
)
assert out["drawer_index"] == 0
assert out["total_drawers"] == 1
assert out["text"] == "chunk_0 content about topic alpha"
def test_expand_falls_back_when_metadata_missing(self, palace_path):
col = get_collection(palace_path)
# No source_file / chunk_index in meta — degrade gracefully.
out = _expand_with_neighbors(col, "matched doc", {})
assert out["text"] == "matched doc"
assert out["drawer_index"] is None
assert out["total_drawers"] is None
def test_hybrid_search_enrichment_populates_drawer_index_and_total(self, palace_path):
"""End-to-end: when a closet boosts a source with many drawers, the
enrichment step runs drawer-grep across all chunks of that source
and exposes drawer_index + total_drawers on the hit (so the client
knows which chunk was expanded around)."""
col = get_collection(palace_path)
source = "/proj/indexed.md"
# Seed 5 drawers for one source file.
for i in range(5):
col.upsert(
ids=[f"drawer_proj_backend_indexed_{i:03d}"],
documents=[f"chunk_{i} talks about JWT authentication flow"],
metadatas=[
{
"wing": "project",
"room": "backend",
"source_file": source,
"chunk_index": i,
"filed_at": "2026-04-13T00:00:00",
}
],
)
# Closet pointing at chunk_2 for this source.
closets = get_closets_collection(palace_path)
closets.upsert(
ids=["closet_proj_backend_indexed_01"],
documents=["JWT auth|;|→drawer_proj_backend_indexed_002"],
metadatas=[{"wing": "project", "room": "backend", "source_file": source}],
)
result = search_memories("JWT authentication", palace_path)
assert result["results"]
# The hybrid path promotes the closet-agreeing source to drawer+closet.
boosted = [h for h in result["results"] if h["matched_via"] == "drawer+closet"]
assert boosted, "hybrid search should mark the closet-agreeing source"
top = boosted[0]
assert top["total_drawers"] == 5
assert isinstance(top["drawer_index"], int)
# Enriched text must include the grep-best chunk plus one neighbor
# on each side (chunk boundary may clip).
assert "chunk_" in top["text"]
+288
View File
@@ -0,0 +1,288 @@
"""
test_fact_checker.py — Regression + integration tests for fact_checker.
Covers every detection path + the three bugs the original PR silently
hid behind ``except Exception: pass``:
* ``kg.query()`` doesn't exist — code must use ``query_entity``.
* ``KnowledgeGraph(palace_path=...)`` is not a valid kwarg — code
must pass ``db_path``.
* O(n²) edit-distance over the full registry — must filter to names
actually mentioned in the text.
Also pins the three feature contracts:
* similar_name — "Mila" vs "Milla" in a registry with both.
* relationship_mismatch — "Bob is Alice's brother" vs KG "husband".
* stale_fact — claim matches a triple whose valid_to is in the past.
"""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
import pytest
from mempalace.fact_checker import (
_check_entity_confusion,
_edit_distance,
_extract_claims,
_flatten_names,
check_text,
)
from mempalace.knowledge_graph import KnowledgeGraph
# ── claim extraction ─────────────────────────────────────────────────
class TestExtractClaims:
def test_parses_x_is_ys_z(self):
claims = _extract_claims("Bob is Alice's brother")
assert len(claims) == 1
assert claims[0] == {
"subject": "Bob",
"predicate": "brother",
"object": "Alice",
"span": "Bob is Alice's brother",
}
def test_parses_xs_z_is_y(self):
claims = _extract_claims("Alice's brother is Bob")
assert len(claims) == 1
assert claims[0]["subject"] == "Bob"
assert claims[0]["predicate"] == "brother"
assert claims[0]["object"] == "Alice"
def test_ignores_sentences_without_possessive_role(self):
assert _extract_claims("Bob drove to the store today") == []
assert _extract_claims("Just some prose without relationships") == []
def test_multiple_claims_in_one_text(self):
claims = _extract_claims("Bob is Alice's brother. Carol is Dave's sister.")
subjects = {c["subject"] for c in claims}
assert subjects == {"Bob", "Carol"}
# ── entity confusion ─────────────────────────────────────────────────
class TestEntityConfusion:
def test_flags_near_name_when_only_one_mentioned(self):
registry = {"people": ["Milla", "Mila"]}
issues = _check_entity_confusion("I spoke with Mila today.", registry)
# "Mila" mentioned, "Milla" not — registry has both at edit-distance 1,
# flag the possible confusion.
assert len(issues) == 1
assert issues[0]["type"] == "similar_name"
assert set(issues[0]["names"]) == {"Mila", "Milla"}
assert issues[0]["distance"] == 1
def test_no_false_positive_when_both_names_mentioned(self):
"""Regression: a text discussing both Mila and Milla is fine —
the user clearly knows they're different. Don't nag."""
registry = {"people": ["Milla", "Mila"]}
issues = _check_entity_confusion("Mila and Milla met for lunch.", registry)
assert issues == []
def test_no_issues_when_registry_empty(self):
assert _check_entity_confusion("Bob said hi", {}) == []
assert _check_entity_confusion("Bob said hi", {"people": []}) == []
def test_no_issues_when_no_mentioned_names(self):
registry = {"people": ["Zelda", "Link", "Sheik"]}
assert _check_entity_confusion("nothing relevant here", registry) == []
def test_registry_dict_shape_is_supported(self):
# Some registries store {"people": {"Alice": {...meta}}}; we still
# need to surface the keys as candidate names.
registry = {"people": {"Milla": {"role": "creator"}, "Mila": {}}}
issues = _check_entity_confusion("I messaged Mila yesterday", registry)
assert any("Milla" in (i["names"] or []) for i in issues)
class TestEditDistance:
def test_basic_distances(self):
assert _edit_distance("kitten", "sitting") == 3
assert _edit_distance("mila", "milla") == 1
assert _edit_distance("abc", "abc") == 0
def test_empty_strings(self):
assert _edit_distance("", "") == 0
assert _edit_distance("abc", "") == 3
assert _edit_distance("", "abc") == 3
def test_performance_bounded_by_mentioned_names(self):
"""Regression: an earlier implementation did O(n²) pairwise
edit-distance over every registry entry on every check_text call.
With 100 names and zero mentions, the call must return in a blink
because no edit-distance comparison should even start."""
import time
# 500 random names, none of which appear in the text.
registry = {"people": [f"Zelda{i:03d}" for i in range(500)]}
text = "completely irrelevant prose with no registered names at all"
start = time.perf_counter()
issues = _check_entity_confusion(text, registry)
elapsed = time.perf_counter() - start
assert issues == []
# Even an unoptimized implementation should beat this by orders
# of magnitude once we've filtered to mentioned names (which is
# 0 here) — if it's still doing O(n²), we'll blow past.
assert elapsed < 0.2, f"entity confusion took {elapsed:.3f}s on empty mentions"
# ── _flatten_names helper ────────────────────────────────────────────
class TestFlattenNames:
def test_handles_list_categories(self):
assert _flatten_names({"people": ["Ada", "Bob"]}) == {"Ada", "Bob"}
def test_handles_dict_categories(self):
assert _flatten_names({"people": {"Ada": {}, "Bob": {}}}) == {"Ada", "Bob"}
def test_skips_falsy_entries(self):
assert _flatten_names({"people": ["Ada", "", None, "Bob"]}) == {"Ada", "Bob"}
# ── KG integration (uses a real tmp SQLite palace) ───────────────────
@pytest.fixture
def palace_with_kg(tmp_path):
"""Palace directory with a real KG pre-seeded with a few triples.
The KG file lives at ``<palace>/knowledge_graph.sqlite3`` — same
convention used by the MCP server. Fact-checker must find it via
that path, not via a bogus ``palace_path`` kwarg.
"""
palace = tmp_path / "palace"
palace.mkdir()
db = str(palace / "knowledge_graph.sqlite3")
kg = KnowledgeGraph(db_path=db)
yield palace, kg
class TestKGContradictions:
def test_kg_init_uses_db_path_not_palace_path_kwarg(self):
"""Regression: the original code passed ``palace_path=`` to a
constructor whose only kwarg is ``db_path``. That raised
TypeError — silently swallowed — and the KG path became dead
code. This test pins the correct call signature."""
# Simply construct via the correct signature; raising means the
# KG constructor has changed in a way that fact_checker must too.
kg = KnowledgeGraph(db_path=":memory:")
# query_entity must exist (this is the method fact_checker calls).
assert callable(getattr(kg, "query_entity", None))
# The API that fact_checker used to call does NOT exist.
assert not hasattr(kg, "query")
def test_relationship_mismatch_detected(self, palace_with_kg):
"""The feature's headline example: text says brother, KG says husband."""
palace, kg = palace_with_kg
kg.add_triple("Bob", "husband_of", "Alice", valid_from="2020-01-01")
issues = check_text("Bob is Alice's husband_of", str(palace))
# Exact-predicate + same object → no mismatch.
assert all(i["type"] != "relationship_mismatch" for i in issues)
issues = check_text("Bob is Alice's brother", str(palace))
mismatches = [i for i in issues if i["type"] == "relationship_mismatch"]
assert mismatches, "should flag text/KG mismatch for same (subject, object)"
m = mismatches[0]
assert m["entity"] == "Bob"
assert m["claim"]["predicate"] == "brother"
assert m["kg_fact"]["predicate"] == "husband_of"
def test_no_false_positive_when_kg_has_no_facts_about_subject(self, palace_with_kg):
palace, _ = palace_with_kg
# KG is empty → no mismatch should fire.
assert check_text("Bob is Alice's brother", str(palace)) == []
def test_stale_fact_detected(self, palace_with_kg):
palace, kg = palace_with_kg
# An old relationship that was superseded in 2023. Using a
# possessive-shape claim so the narrow claim-extraction regex
# actually reaches the stale-fact branch.
kg.add_triple(
"Bob",
"brother",
"Alice",
valid_from="2010-01-01",
valid_to="2023-06-01",
)
issues = check_text("Bob is Alice's brother", str(palace))
stale = [i for i in issues if i["type"] == "stale_fact"]
assert stale, "should flag closed-window fact as stale"
assert stale[0]["entity"] == "Bob"
assert stale[0]["valid_to"].startswith("2023")
def test_current_fact_same_triple_is_not_flagged(self, palace_with_kg):
palace, kg = palace_with_kg
kg.add_triple("Bob", "brother", "Alice", valid_from="2010-01-01")
issues = check_text("Bob is Alice's brother", str(palace))
assert issues == []
def test_missing_palace_does_not_crash(self, tmp_path):
"""Brand-new palace (no KG file yet) — check_text must return []
rather than raising or hanging."""
nonexistent = str(tmp_path / "never_created")
assert check_text("Bob is Alice's brother", nonexistent) == []
# ── end-to-end check_text contract ───────────────────────────────────
class TestCheckTextContract:
def test_empty_text_returns_empty_list(self, tmp_path):
assert check_text("", str(tmp_path / "palace")) == []
def test_registry_confusion_path_isolated_from_kg(self, tmp_path, monkeypatch):
"""If the registry file is present but the KG is missing, the
similar-name path must still fire. Prior implementations had
such entangled state that one failure killed both paths."""
# Bypass the real registry by pointing cache at a temp file.
registry = tmp_path / "known_entities.json"
registry.write_text(json.dumps({"people": ["Milla", "Mila"]}))
from mempalace import miner
monkeypatch.setattr(miner, "_ENTITY_REGISTRY_PATH", str(registry))
miner._ENTITY_REGISTRY_CACHE.update({"mtime": None, "names": frozenset(), "raw": {}})
issues = check_text("Chatted with Mila.", str(tmp_path / "nonexistent_palace"))
assert any(i["type"] == "similar_name" for i in issues)
# ── CLI ──────────────────────────────────────────────────────────────
class TestCLI:
def test_exits_nonzero_when_issues_found(self, tmp_path, monkeypatch, capsys):
"""The CLI exit code is how shell scripts / hooks know to act —
pin it explicitly."""
registry = tmp_path / "known_entities.json"
registry.write_text(json.dumps({"people": ["Milla", "Mila"]}))
from mempalace import fact_checker, miner
monkeypatch.setattr(miner, "_ENTITY_REGISTRY_PATH", str(registry))
miner._ENTITY_REGISTRY_CACHE.update({"mtime": None, "names": frozenset(), "raw": {}})
# Simulate argv: "Mila said hi"
monkeypatch.setattr(
"sys.argv",
["fact_checker", "Mila said hi", "--palace", str(tmp_path / "palace")],
)
with pytest.raises(SystemExit) as excinfo:
# Re-exec the __main__ block via runpy.
import runpy
runpy.run_module("mempalace.fact_checker", run_name="__main__")
# Issues found → exit code 1.
assert excinfo.value.code == 1
out = capsys.readouterr().out
assert "similar_name" in out
# Silence unused import warning.
_ = (MagicMock, patch, fact_checker)
+133
View File
@@ -0,0 +1,133 @@
"""Tests for the hybrid closet+drawer retrieval in search_memories.
The hybrid path queries drawers directly (the floor) AND closets, applying a
rank-based boost to drawers whose source_file appears in top closet hits.
This avoids the "weak-closets regression" where low-signal closets (from
regex extraction on narrative content) could hide drawers that direct
search would have found.
"""
from mempalace.palace import (
get_closets_collection,
get_collection,
upsert_closet_lines,
)
from mempalace.searcher import search_memories
def _seed_drawers(palace_path):
"""Insert 4 short drawers with deterministic content."""
col = get_collection(palace_path, create=True)
col.upsert(
ids=["D1", "D2", "D3", "D4"],
documents=[
"We switched the auth service to use JWT tokens with a 24h expiry.",
"Database migration to PostgreSQL 15 completed last Tuesday.",
"The frontend team is debating whether to adopt TanStack Query.",
"Kafka consumer rebalance timeout set to 45 seconds after incident.",
],
metadatas=[
{"wing": "backend", "room": "auth", "source_file": "fixture_D1.md"},
{"wing": "backend", "room": "db", "source_file": "fixture_D2.md"},
{"wing": "frontend", "room": "state", "source_file": "fixture_D3.md"},
{"wing": "backend", "room": "queue", "source_file": "fixture_D4.md"},
],
)
def _seed_strong_closet_for(palace_path, drawer_id, source_file, topics):
"""Insert a closet whose content strongly overlaps the query keywords."""
col = get_closets_collection(palace_path)
lines = [f"{t}||→{drawer_id}" for t in topics]
upsert_closet_lines(
col,
closet_id_base=f"closet_{drawer_id}",
lines=lines,
metadata={
"wing": "backend",
"room": "auth",
"source_file": source_file,
"generated_by": "test",
},
)
# ── core invariant: closets can only HELP, never HIDE ─────────────────────
class TestHybridInvariant:
def test_no_closets_degrades_to_direct_drawer_search(self, tmp_path):
palace = str(tmp_path / "palace")
_seed_drawers(palace)
# No closets created.
result = search_memories("Kafka rebalance timeout", palace, n_results=3)
ids = [h["source_file"] for h in result["results"]]
assert ids, "should return results"
assert "fixture_D4.md" in ids, "direct drawer search alone should surface the Kafka drawer"
def test_weak_closets_do_not_hide_direct_drawer_hits(self, tmp_path):
"""A closet that points at a wrong drawer must NOT suppress the
drawer that direct search would have ranked first."""
palace = str(tmp_path / "palace")
_seed_drawers(palace)
# Seed a misleading closet: it matches a generic phrase but points at D3.
_seed_strong_closet_for(
palace,
drawer_id="D3",
source_file="fixture_D3.md",
topics=["Kafka queue tuning", "consumer rebalance config"],
)
result = search_memories("Kafka consumer rebalance timeout", palace, n_results=5)
ids = [h["source_file"] for h in result["results"]]
assert "fixture_D4.md" in ids, (
"D4 must appear — direct drawer search alone would rank it first. "
"Closet pointing to D3 should only boost D3, never hide D4."
)
def test_closet_boost_lifts_matching_drawer(self, tmp_path):
"""When a closet agrees with direct search, the matching drawer
should be boosted to rank 1."""
palace = str(tmp_path / "palace")
_seed_drawers(palace)
_seed_strong_closet_for(
palace,
drawer_id="D1",
source_file="fixture_D1.md",
topics=["JWT auth tokens", "session expiry", "authentication service"],
)
result = search_memories("JWT auth tokens expiry", palace, n_results=3)
ids = [h["source_file"] for h in result["results"]]
assert ids[0] == "fixture_D1.md"
top = result["results"][0]
assert top["matched_via"] == "drawer+closet"
assert top["closet_boost"] > 0
# ── closet_boost metadata ────────────────────────────────────────────────
class TestClosetMetadata:
def test_closet_preview_exposed_when_boosted(self, tmp_path):
palace = str(tmp_path / "palace")
_seed_drawers(palace)
_seed_strong_closet_for(
palace,
drawer_id="D1",
source_file="fixture_D1.md",
topics=["JWT auth tokens", "24h expiry", "authentication"],
)
result = search_memories("JWT authentication", palace, n_results=2)
top = result["results"][0]
assert top["source_file"] == "fixture_D1.md"
assert "closet_preview" in top
def test_drawer_only_hits_have_no_closet_preview(self, tmp_path):
palace = str(tmp_path / "palace")
_seed_drawers(palace)
# No closets
result = search_memories("TanStack Query", palace, n_results=2)
assert result["results"]
for h in result["results"]:
assert h["matched_via"] == "drawer"
assert "closet_preview" not in h
assert h["closet_boost"] == 0.0