chore: rescue merged stacked PRs #1150 and #1157 into develop

#1148, #1150, and #1157 were reviewed and merged on GitHub, but the two
stacked children landed on their parent feature branches (now stale)
rather than on develop. Only #1148's commits reached develop via the
direct merge. Release PR #1159 (develop → main for v3.3.3) is therefore
missing the LLM refinement, Claude-conversation scanner, and miner-
registry wire-up that were ostensibly part of the release.

This merge brings the stale `feat/llm-entity-refine` branch (which
contains the rolled-up merge commit for #1157#1150 → everything
below) into develop so the release tag includes it.

No code changes here — only history recovery.
This commit is contained in:
Igor Lins e Silva
2026-04-24 13:49:12 -03:00
14 changed files with 2588 additions and 12 deletions
+1 -2
View File
@@ -120,8 +120,7 @@ def quarantine_stale_hnsw(palace_path: str, stale_seconds: float = 3600.0) -> li
os.rename(seg_dir, target) os.rename(seg_dir, target)
moved.append(target) moved.append(target)
logger.warning( logger.warning(
"Quarantined stale HNSW segment %s " "Quarantined stale HNSW segment %s (sqlite %.0fs newer than HNSW); renamed to %s",
"(sqlite %.0fs newer than HNSW); renamed to %s",
seg_dir, seg_dir,
sqlite_mtime - hnsw_mtime, sqlite_mtime - hnsw_mtime,
target, target,
+74 -5
View File
@@ -86,21 +86,53 @@ def cmd_init(args):
languages = cfg.entity_languages languages = cfg.entity_languages
languages_tuple = tuple(languages) languages_tuple = tuple(languages)
# Optional phase-2 LLM provider (opt-in via --llm).
llm_provider = None
if getattr(args, "llm", False):
from .llm_client import LLMError, get_provider
try:
llm_provider = get_provider(
name=args.llm_provider,
model=args.llm_model,
endpoint=args.llm_endpoint,
api_key=args.llm_api_key,
)
except LLMError as e:
print(f" ERROR: {e}", file=sys.stderr)
sys.exit(2)
ok, msg = llm_provider.check_available()
if not ok:
print(
f" ERROR: LLM provider '{args.llm_provider}' unavailable: {msg}",
file=sys.stderr,
)
sys.exit(2)
print(f" LLM refinement enabled: {args.llm_provider}/{args.llm_model}")
# Pass 1: discover entities — manifests + git authors first, prose detection # Pass 1: discover entities — manifests + git authors first, prose detection
# as supplement for names mentioned only in docs/notes. # as supplement for names mentioned only in docs/notes. Optional phase-2
# LLM refinement runs inside discover_entities when llm_provider is given.
print(f"\n Scanning for entities in: {args.dir}") print(f"\n Scanning for entities in: {args.dir}")
if languages_tuple != ("en",): if languages_tuple != ("en",):
print(f" Languages: {', '.join(languages_tuple)}") print(f" Languages: {', '.join(languages_tuple)}")
detected = discover_entities(args.dir, languages=languages_tuple) detected = discover_entities(args.dir, languages=languages_tuple, llm_provider=llm_provider)
total = len(detected["people"]) + len(detected["projects"]) + len(detected["uncertain"]) total = len(detected["people"]) + len(detected["projects"]) + len(detected["uncertain"])
if total > 0: if total > 0:
confirmed = confirm_entities(detected, yes=getattr(args, "yes", False)) confirmed = confirm_entities(detected, yes=getattr(args, "yes", False))
# Save confirmed entities to <project>/entities.json for the miner # Save confirmed entities to <project>/entities.json (per-project
# audit trail — user can inspect or hand-edit) AND merge into the
# global registry the miner reads at mine time.
if confirmed["people"] or confirmed["projects"]: if confirmed["people"] or confirmed["projects"]:
entities_path = Path(args.dir).expanduser().resolve() / "entities.json" entities_path = Path(args.dir).expanduser().resolve() / "entities.json"
with open(entities_path, "w") as f: with open(entities_path, "w", encoding="utf-8") as f:
json.dump(confirmed, f, indent=2) json.dump(confirmed, f, indent=2, ensure_ascii=False)
print(f" Entities saved: {entities_path}") print(f" Entities saved: {entities_path}")
from .miner import add_to_known_entities
registry_path = add_to_known_entities(confirmed)
print(f" Registry updated: {registry_path}")
else: else:
print(" No entities detected — proceeding with directory-based rooms.") print(" No entities detected — proceeding with directory-based rooms.")
@@ -550,6 +582,43 @@ def main():
"When given, the value is also persisted to config.json." "When given, the value is also persisted to config.json."
), ),
) )
p_init.add_argument(
"--llm",
action="store_true",
help=(
"Enable LLM-assisted entity refinement (opt-in, local-first). "
"Runs after manifest/git/regex detection, asking the configured "
"provider to reclassify ambiguous candidates. "
"Ctrl-C during refinement returns partial results."
),
)
p_init.add_argument(
"--llm-provider",
default="ollama",
choices=["ollama", "openai-compat", "anthropic"],
help="LLM provider (default: ollama). Use --llm to enable.",
)
p_init.add_argument(
"--llm-model",
default="gemma4:e4b",
help="Model name for the chosen provider (default: gemma4:e4b for Ollama).",
)
p_init.add_argument(
"--llm-endpoint",
default=None,
help=(
"Provider endpoint URL. Default for Ollama: http://localhost:11434. "
"Required for openai-compat."
),
)
p_init.add_argument(
"--llm-api-key",
default=None,
help=(
"API key for the provider. For anthropic, defaults to $ANTHROPIC_API_KEY; "
"for openai-compat, defaults to $OPENAI_API_KEY."
),
)
# mine # mine
p_mine = sub.add_parser("mine", help="Mine files into the palace") p_mine = sub.add_parser("mine", help="Mine files into the palace")
+160
View File
@@ -0,0 +1,160 @@
"""
convo_scanner.py — Parse Claude Code conversation directories into ProjectInfo.
Claude Code stores sessions under ``~/.claude/projects/<slug>/<id>.jsonl``,
where the ``<slug>`` is the original CWD with ``/`` replaced by ``-``. That
encoding is lossy: we can't tell whether ``foo-bar`` in a slug is the
literal project name ``foo-bar`` or two path segments ``foo/bar``.
Fortunately, every message record in the JSONL carries a ``cwd`` field with
the true path. This scanner reads one record per session to recover the
accurate project name, falling back to slug-decoding only if the JSONL
is malformed or empty.
Output is the same ``ProjectInfo`` shape used by ``project_scanner``, so the
``discover_entities`` orchestrator can mix-and-match sources.
Public:
is_claude_projects_root(path) -> bool
scan_claude_projects(path) -> list[ProjectInfo]
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Optional
from mempalace.project_scanner import ProjectInfo
MAX_HEADER_LINES = 20 # lines to read per session looking for `cwd`
def is_claude_projects_root(path: Path) -> bool:
"""Return True if path looks like `.claude/projects/`.
Heuristic: at least one child dir whose name starts with ``-`` and which
contains at least one ``.jsonl`` file.
"""
if not path.is_dir():
return False
try:
children = list(path.iterdir())
except OSError:
return False
for child in children:
if not (child.is_dir() and child.name.startswith("-")):
continue
try:
if any(p.suffix == ".jsonl" for p in child.iterdir() if p.is_file()):
return True
except OSError:
continue
return False
def _extract_cwd_from_session(session_file: Path) -> Optional[str]:
"""Return the ``cwd`` from the first message record that carries one.
Returns None if the file can't be read, has no JSON, or no record has cwd.
"""
try:
with open(session_file, encoding="utf-8", errors="replace") as f:
for i, line in enumerate(f):
if i >= MAX_HEADER_LINES:
break
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError:
continue
cwd = obj.get("cwd")
if isinstance(cwd, str) and cwd:
return cwd
except OSError:
return None
return None
def _decode_slug_fallback(slug: str) -> str:
"""Best-effort project name from slug when cwd is unavailable.
The slug is lossy (`/` and `-` both become `-`). Last non-empty segment
is the closest guess at the project name, preserving kebab-case is
impossible without cwd.
"""
stripped = slug.lstrip("-")
parts = [p for p in stripped.split("-") if p]
return parts[-1] if parts else slug
def _safe_mtime(path: Path) -> float:
"""Return file mtime, defaulting old on permission or filesystem errors."""
try:
return path.stat().st_mtime
except OSError:
return 0.0
def _resolve_project_name(project_dir: Path) -> str:
"""Read one session's cwd to recover the original project name.
Falls back to slug-decoding if no session has a readable cwd.
"""
sessions = sorted(
(p for p in project_dir.iterdir() if p.is_file() and p.suffix == ".jsonl"),
key=_safe_mtime,
reverse=True, # newest first — most likely to be well-formed
)
for session in sessions:
cwd = _extract_cwd_from_session(session)
if cwd:
return Path(cwd).name or cwd
return _decode_slug_fallback(project_dir.name)
def scan_claude_projects(path: str | Path) -> list[ProjectInfo]:
"""Scan a ``.claude/projects/`` directory for Claude Code conversations.
One ProjectInfo per subdir. ``has_git`` is False (the directory isn't a
repo itself) but ``total_commits`` is repurposed here as session count so
the UX surfaces a density signal for ranking.
"""
root = Path(path).expanduser().resolve()
if not is_claude_projects_root(root):
return []
projects: dict[str, ProjectInfo] = {}
for sub in sorted(root.iterdir()):
if not (sub.is_dir() and sub.name.startswith("-")):
continue
try:
sessions = [p for p in sub.iterdir() if p.is_file() and p.suffix == ".jsonl"]
except OSError:
continue
if not sessions:
continue
name = _resolve_project_name(sub)
session_count = len(sessions)
proj = ProjectInfo(
name=name,
repo_root=sub,
manifest=None,
has_git=False,
total_commits=session_count,
user_commits=session_count,
is_mine=True, # Claude Code sessions are authored by the user
)
existing = projects.get(name)
if existing is None or session_count > existing.user_commits:
projects[name] = proj
return sorted(
projects.values(),
key=lambda p: (-p.user_commits, p.name),
)
+305
View File
@@ -0,0 +1,305 @@
"""
llm_client.py — Minimal provider abstraction for LLM-assisted entity refinement.
Three providers cover the useful space:
- ``ollama`` (default): local models via http://localhost:11434. Works fully
offline. Honors MemPalace's "zero-API required" principle.
- ``openai-compat``: any OpenAI-compatible ``/v1/chat/completions`` endpoint.
Covers OpenRouter, LM Studio, llama.cpp server, vLLM, Groq, Fireworks,
Together, and most self-hosted setups.
- ``anthropic``: the official Messages API. Opt-in for users who want Haiku
quality without setting up a local model.
All providers expose the same ``classify(system, user, json_mode)`` method and
the same ``check_available()`` probe. No external SDK dependencies — stdlib
``urllib`` only.
JSON mode matters here: we always ask for structured output. Providers
differ on how to request it (Ollama: ``format: json``; OpenAI-compat:
``response_format``; Anthropic: prompt-level instruction) and this module
normalizes that away from the caller.
"""
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from typing import Optional
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
class LLMError(RuntimeError):
"""Raised for any provider failure — transport, parse, auth, missing model."""
@dataclass
class LLMResponse:
text: str
model: str
provider: str
raw: dict
# ==================== BASE ====================
class LLMProvider:
name: str = "base"
def __init__(
self,
model: str,
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
timeout: int = 120,
):
self.model = model
self.endpoint = endpoint
self.api_key = api_key
self.timeout = timeout
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
raise NotImplementedError
def check_available(self) -> tuple[bool, str]:
"""Return ``(ok, message)``. Fast probe that the provider is reachable."""
raise NotImplementedError
def _http_post_json(url: str, body: dict, headers: dict, timeout: int) -> dict:
"""POST JSON and return the parsed response. Raises LLMError on any failure."""
req = Request(
url,
data=json.dumps(body).encode("utf-8"),
headers={"Content-Type": "application/json", **headers},
)
try:
with urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read())
except HTTPError as e:
detail = ""
try:
detail = e.read().decode("utf-8", errors="replace")[:500]
except Exception:
pass
raise LLMError(f"HTTP {e.code} from {url}: {detail or e.reason}") from e
except (URLError, OSError) as e:
raise LLMError(f"Cannot reach {url}: {e}") from e
except json.JSONDecodeError as e:
raise LLMError(f"Malformed response from {url}: {e}") from e
# ==================== OLLAMA ====================
class OllamaProvider(LLMProvider):
name = "ollama"
DEFAULT_ENDPOINT = "http://localhost:11434"
def __init__(
self,
model: str,
endpoint: Optional[str] = None,
timeout: int = 180,
**_: object,
):
super().__init__(
model=model,
endpoint=endpoint or self.DEFAULT_ENDPOINT,
timeout=timeout,
)
def check_available(self) -> tuple[bool, str]:
try:
with urlopen(f"{self.endpoint}/api/tags", timeout=5) as resp:
data = json.loads(resp.read())
except (URLError, HTTPError, OSError, json.JSONDecodeError) as e:
return False, f"Cannot reach Ollama at {self.endpoint}: {e}"
names = {m.get("name", "") for m in data.get("models", []) or []}
# Ollama tags may or may not include ':latest' — accept either form
wanted = {self.model, f"{self.model}:latest"}
if not names & wanted:
return (
False,
f"Model '{self.model}' not loaded in Ollama. Run: ollama pull {self.model}",
)
return True, "ok"
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
body: dict = {
"model": self.model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
"stream": False,
"options": {"temperature": 0.1},
}
if json_mode:
body["format"] = "json"
data = _http_post_json(f"{self.endpoint}/api/chat", body, headers={}, timeout=self.timeout)
text = (data.get("message") or {}).get("content", "")
if not text:
raise LLMError(f"Empty response from Ollama (model={self.model})")
return LLMResponse(text=text, model=self.model, provider=self.name, raw=data)
# ==================== OPENAI-COMPAT ====================
class OpenAICompatProvider(LLMProvider):
"""Any OpenAI-compatible ``/v1/chat/completions`` endpoint.
Supply ``--llm-endpoint http://host:port`` (with or without ``/v1``).
API key via ``--llm-api-key`` or the ``OPENAI_API_KEY`` env var.
"""
name = "openai-compat"
def __init__(
self,
model: str,
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
timeout: int = 120,
**_: object,
):
resolved_key = api_key or os.environ.get("OPENAI_API_KEY")
super().__init__(model=model, endpoint=endpoint, api_key=resolved_key, timeout=timeout)
def _resolve_url(self) -> str:
if not self.endpoint:
raise LLMError("openai-compat provider requires --llm-endpoint")
url = self.endpoint.rstrip("/")
if url.endswith("/chat/completions"):
return url
if not url.endswith("/v1"):
url = f"{url}/v1"
return f"{url}/chat/completions"
def check_available(self) -> tuple[bool, str]:
if not self.endpoint:
return False, "no --llm-endpoint configured"
base = self.endpoint.rstrip("/")
base = base.removesuffix("/chat/completions").removesuffix("/v1")
try:
req = Request(f"{base}/v1/models")
if self.api_key:
req.add_header("Authorization", f"Bearer {self.api_key}")
with urlopen(req, timeout=5):
pass
except (URLError, HTTPError, OSError) as e:
return False, f"Cannot reach {self.endpoint}: {e}"
return True, "ok"
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
body: dict = {
"model": self.model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
"temperature": 0.1,
}
if json_mode:
body["response_format"] = {"type": "json_object"}
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
data = _http_post_json(self._resolve_url(), body, headers=headers, timeout=self.timeout)
try:
text = data["choices"][0]["message"]["content"]
except (KeyError, IndexError, TypeError) as e:
raise LLMError(f"Unexpected response shape: {e}") from e
if not text:
raise LLMError(f"Empty response from {self.name} (model={self.model})")
return LLMResponse(text=text, model=self.model, provider=self.name, raw=data)
# ==================== ANTHROPIC ====================
class AnthropicProvider(LLMProvider):
name = "anthropic"
DEFAULT_ENDPOINT = "https://api.anthropic.com"
API_VERSION = "2023-06-01"
def __init__(
self,
model: str,
api_key: Optional[str] = None,
endpoint: Optional[str] = None,
timeout: int = 120,
**_: object,
):
key = api_key or os.environ.get("ANTHROPIC_API_KEY")
super().__init__(
model=model,
endpoint=endpoint or self.DEFAULT_ENDPOINT,
api_key=key,
timeout=timeout,
)
def check_available(self) -> tuple[bool, str]:
if not self.api_key:
return False, "ANTHROPIC_API_KEY not set (use --llm-api-key or env)"
# Don't probe — a live request would cost money. First real call will
# surface auth errors if the key is invalid.
return True, "ok"
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
if not self.api_key:
raise LLMError("Anthropic provider requires ANTHROPIC_API_KEY env or --llm-api-key")
sys_prompt = system
if json_mode:
sys_prompt += "\n\nRespond with valid JSON only, no prose."
body = {
"model": self.model,
"max_tokens": 2048,
"temperature": 0.1,
"system": sys_prompt,
"messages": [{"role": "user", "content": user}],
}
headers = {
"X-API-Key": self.api_key,
"anthropic-version": self.API_VERSION,
}
data = _http_post_json(
f"{self.endpoint}/v1/messages", body, headers=headers, timeout=self.timeout
)
try:
text = "".join(
b.get("text", "") for b in data.get("content", []) or [] if b.get("type") == "text"
)
except (AttributeError, TypeError) as e:
raise LLMError(f"Unexpected response shape: {e}") from e
if not text:
raise LLMError(f"Empty response from Anthropic (model={self.model})")
return LLMResponse(text=text, model=self.model, provider=self.name, raw=data)
# ==================== FACTORY ====================
PROVIDERS: dict[str, type[LLMProvider]] = {
"ollama": OllamaProvider,
"openai-compat": OpenAICompatProvider,
"anthropic": AnthropicProvider,
}
def get_provider(
name: str,
model: str,
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
timeout: int = 120,
) -> LLMProvider:
"""Build a provider by name. Raises LLMError on unknown provider."""
cls = PROVIDERS.get(name)
if cls is None:
raise LLMError(f"Unknown provider '{name}'. Choices: {sorted(PROVIDERS.keys())}")
return cls(model=model, endpoint=endpoint, api_key=api_key, timeout=timeout)
+446
View File
@@ -0,0 +1,446 @@
"""
llm_refine.py — Optional LLM refinement of regex-detected entities.
Takes the candidate set produced by phase-1 detection (manifests, git
authors, regex on prose) and asks an LLM to reclassify each candidate as
PERSON / PROJECT / TOPIC / COMMON_WORD / AMBIGUOUS.
Design constraints:
- Opt-in. Default init path never imports this module.
- Local-first by default (Ollama).
- Interactive UX: visible progress, clean cancellation (Ctrl-C returns
whatever was classified before the interrupt).
- Don't feed the raw corpus to the LLM — feed candidates + a few sampled
context lines each. Keeps total input to ~50-100K tokens even for huge
prose corpora.
Public:
refine_entities(detected, corpus_text, provider, ...) -> dict
"""
from __future__ import annotations
import json
import re
import sys
from dataclasses import dataclass
from mempalace.llm_client import LLMError, LLMProvider
BATCH_SIZE = 25 # candidates per LLM call; tuned for 4B local models
CONTEXT_LINES_PER_CANDIDATE = 3
CONTEXT_WINDOW_CHARS = 240 # max chars per context line to keep tokens bounded
# Valid labels the LLM is allowed to return. Anything else is treated as
# AMBIGUOUS so the user reviews it.
VALID_LABELS = {"PERSON", "PROJECT", "TOPIC", "COMMON_WORD", "AMBIGUOUS"}
SYSTEM_PROMPT = """You are helping organize a user's memory palace by classifying capitalized tokens found in their files.
For each candidate, pick exactly ONE label:
- PERSON: a specific real person the user knows (colleague, family, character they write about)
- PROJECT: a named product, codebase, or effort the user works on
- TOPIC: a recurring theme or subject (not a person, not a project) — cities, technologies, concepts
- COMMON_WORD: an English word, verb, or fragment that isn't a named entity at all (e.g. "Created", "Before", "Never")
- AMBIGUOUS: context is insufficient to decide between two of the above
Frameworks, runtimes, APIs, cloud services, vendors, and third-party products
(e.g. Angular, OpenAPI, Terraform, Bun, Google) are TOPIC unless the context
clearly says this is the user's own named codebase, product, or active effort.
Use the provided context lines to disambiguate. A capitalized word that only appears in metadata ("Created: 2026-04-24") is COMMON_WORD. A name that appears with pronouns and dialogue is PERSON.
Respond with JSON only. Schema:
{"classifications": [{"name": "<exact candidate name>", "label": "<LABEL>", "reason": "<one short sentence>"}]}
One entry per candidate, same order as the input."""
@dataclass
class RefineResult:
merged: dict # updated detected dict
reclassified: int # entries whose type changed
dropped: int # entries removed from the merged result (COMMON_WORD only)
errors: list[str] # per-batch error messages (transport/parse failures)
batches_completed: int
batches_total: int
cancelled: bool
def _collect_contexts(
corpus_lines: list[str], name: str, max_lines: int = CONTEXT_LINES_PER_CANDIDATE
) -> list[str]:
"""Return up to `max_lines` distinct lines from the corpus that mention `name`.
Case-insensitive token-boundary match. Lines are truncated to
CONTEXT_WINDOW_CHARS chars to keep token usage bounded.
"""
needle = re.compile(rf"(?<!\w){re.escape(name)}(?!\w)", re.IGNORECASE)
seen: set[str] = set()
out: list[str] = []
for line in corpus_lines:
if not needle.search(line):
continue
trimmed = line.strip()[:CONTEXT_WINDOW_CHARS]
if not trimmed or trimmed in seen:
continue
seen.add(trimmed)
out.append(trimmed)
if len(out) >= max_lines:
break
return out
def _build_user_prompt(candidates_with_contexts: list[tuple[str, str, list[str]]]) -> str:
"""Shape: for each candidate, list its current type guess + sampled contexts."""
parts: list[str] = ["CANDIDATES:"]
for i, (name, current_type, contexts) in enumerate(candidates_with_contexts, 1):
parts.append(f"\n{i}. {name} (currently: {current_type})")
if contexts:
for c in contexts:
parts.append(f" > {c}")
else:
parts.append(" > (no context available)")
return "\n".join(parts)
def _extract_json_candidates(text: str) -> list[str]:
"""Return plausible JSON payloads extracted from an LLM response."""
text = text.strip()
if not text:
return []
candidates: list[str] = [text]
for match in re.finditer(r"```(?:json)?\s*([\s\S]*?)\s*```", text, re.IGNORECASE):
candidate = match.group(1).strip()
if candidate and candidate not in candidates:
candidates.append(candidate)
for start, opener in ((i, ch) for i, ch in enumerate(text) if ch in "{["):
closer = "}" if opener == "{" else "]"
depth = 0
in_string = False
escaped = False
for i in range(start, len(text)):
ch = text[i]
if in_string:
if escaped:
escaped = False
elif ch == "\\":
escaped = True
elif ch == '"':
in_string = False
continue
if ch == '"':
in_string = True
elif ch == opener:
depth += 1
elif ch == closer:
depth -= 1
if depth == 0:
candidate = text[start : i + 1].strip()
if candidate and candidate not in candidates:
candidates.append(candidate)
break
return candidates
def _parse_response(text: str, expected_names: list[str]) -> dict[str, tuple[str, str]]:
"""Parse the LLM's JSON response into {name: (label, reason)}.
Robust to the model occasionally wrapping JSON in text or returning
slight schema variations. Falls back to matching by candidate name.
"""
data = None
for candidate in _extract_json_candidates(text):
try:
data = json.loads(candidate)
break
except json.JSONDecodeError:
continue
if data is None:
return {}
entries = data.get("classifications") if isinstance(data, dict) else data
if not isinstance(entries, list):
return {}
name_to_label: dict[str, tuple[str, str]] = {}
expected_set = {n.lower(): n for n in expected_names}
for entry in entries:
if not isinstance(entry, dict):
continue
name = entry.get("name") or entry.get("candidate")
label = entry.get("label") or entry.get("type") or entry.get("classification")
reason = entry.get("reason") or ""
if not isinstance(name, str) or not isinstance(label, str):
continue
# Restore canonical casing from expected_names
canonical = expected_set.get(name.lower(), name)
lbl = label.strip().upper()
if lbl not in VALID_LABELS:
lbl = "AMBIGUOUS"
name_to_label[canonical] = (lbl, reason.strip()[:120])
return name_to_label
def _apply_classifications(
detected: dict,
decisions: dict[str, tuple[str, str]],
allow_project_promotions: bool = True,
) -> tuple[dict, int, int]:
"""Merge LLM decisions back into the detected dict.
Returns (new_detected, reclassified_count, dropped_count).
"""
label_to_bucket = {
"PERSON": "people",
"PROJECT": "projects",
"TOPIC": "uncertain",
"AMBIGUOUS": "uncertain",
}
# Index every entity by name for in-place update
all_entries: list[tuple[str, dict]] = []
for bucket, items in detected.items():
for e in items:
all_entries.append((bucket, e))
reclassified = 0
dropped = 0
new_detected: dict[str, list[dict]] = {
"people": [],
"projects": [],
"uncertain": [],
}
for old_bucket, entry in all_entries:
decision = decisions.get(entry["name"])
if decision is None:
# No LLM opinion — keep as-is
new_detected[old_bucket].append(entry)
continue
label, reason = decision
if label == "COMMON_WORD":
dropped += 1
continue
target_bucket = label_to_bucket[label]
if (
label == "PROJECT"
and not allow_project_promotions
and not _is_authoritative_project(entry)
):
target_bucket = "uncertain"
updated = dict(entry)
# Append the LLM's reason as a new signal so the user sees why it moved
signals = list(updated.get("signals", []))
signals.append(f"LLM: {label.lower()}{reason}" if reason else f"LLM: {label.lower()}")
updated["signals"] = signals
if target_bucket != old_bucket:
reclassified += 1
updated["type"] = (
"person"
if target_bucket == "people"
else "project"
if target_bucket == "projects"
else "uncertain"
)
new_detected[target_bucket].append(updated)
return new_detected, reclassified, dropped
def _is_authoritative_person(entry: dict) -> bool:
"""Return True for git-author people that should not be second-guessed."""
signals = " ".join(entry.get("signals", [])).lower()
return "commit" in signals and "repo" in signals
def _is_authoritative_project(entry: dict) -> bool:
"""Return True for manifest/git-backed projects that are already source-backed."""
signals = " ".join(entry.get("signals", [])).lower()
manifest_markers = ("package.json", "pyproject.toml", "cargo.toml", "go.mod")
return any(marker in signals for marker in manifest_markers) or "commit" in signals
def _print_progress(batch_idx: int, total: int, current_name: str) -> None:
"""Overwrite-line progress indicator."""
width = 40
filled = int(width * batch_idx / total) if total else 0
bar = "" * filled + "" * (width - filled)
msg = f"\r LLM refine: [{bar}] batch {batch_idx}/{total} current: {current_name[:30]:<30}"
sys.stderr.write(msg)
sys.stderr.flush()
def refine_entities(
detected: dict,
corpus_text: str,
provider: LLMProvider,
batch_size: int = BATCH_SIZE,
show_progress: bool = True,
allow_project_promotions: bool = True,
) -> RefineResult:
"""Reclassify detected entities using the LLM provider.
Only regex-derived candidates are sent for refinement. Git authors and
manifest/git-backed projects are already source-backed and don't benefit
from LLM second-guessing.
Ctrl-C during refinement: cancels the remaining batches, returns a
RefineResult with ``cancelled=True`` and whatever was classified before
the interrupt. The partial result is safe to pass straight to
``confirm_entities``.
Transport or parse failures in individual batches are recorded in
``errors`` and do not abort the run.
``allow_project_promotions=False`` keeps LLM-only project guesses in the
uncertain bucket. This is useful when manifest/git signal already supplied
canonical projects and regex/LLM hits are likely tools, vendors, or topics.
"""
candidates: list[tuple[str, str]] = []
current_type = {"people": "person", "projects": "project", "uncertain": "uncertain"}
for bucket in ("people", "projects", "uncertain"):
for e in detected.get(bucket, []):
if bucket == "people" and _is_authoritative_person(e):
continue
if bucket == "projects" and _is_authoritative_project(e):
continue
candidates.append((e["name"], current_type[bucket]))
corpus_lines = corpus_text.splitlines() if corpus_text else []
# Deduplicate candidate names while preserving order
seen: set[str] = set()
unique: list[tuple[str, str]] = []
for name, kind in candidates:
if name not in seen:
seen.add(name)
unique.append((name, kind))
if not unique:
return RefineResult(
merged=detected,
reclassified=0,
dropped=0,
errors=[],
batches_completed=0,
batches_total=0,
cancelled=False,
)
# Build batches
batches: list[list[tuple[str, str, list[str]]]] = []
for i in range(0, len(unique), batch_size):
chunk = unique[i : i + batch_size]
enriched = [(name, kind, _collect_contexts(corpus_lines, name)) for name, kind in chunk]
batches.append(enriched)
all_decisions: dict[str, tuple[str, str]] = {}
errors: list[str] = []
completed = 0
cancelled = False
for idx, batch in enumerate(batches, 1):
if show_progress and batch:
_print_progress(idx - 1, len(batches), batch[0][0])
user_prompt = _build_user_prompt(batch)
try:
resp = provider.classify(SYSTEM_PROMPT, user_prompt, json_mode=True)
except KeyboardInterrupt:
cancelled = True
break
except LLMError as e:
errors.append(f"batch {idx}: {e}")
continue
names_in_batch = [name for name, _, _ in batch]
decisions = _parse_response(resp.text, names_in_batch)
if not decisions:
errors.append(f"batch {idx}: could not parse response")
all_decisions.update(decisions)
completed += 1
if show_progress:
_print_progress(idx, len(batches), batch[-1][0])
if show_progress:
sys.stderr.write("\n")
sys.stderr.flush()
merged, reclassified, dropped = _apply_classifications(
detected,
all_decisions,
allow_project_promotions=allow_project_promotions,
)
return RefineResult(
merged=merged,
reclassified=reclassified,
dropped=dropped,
errors=errors,
batches_completed=completed,
batches_total=len(batches),
cancelled=cancelled,
)
def collect_corpus_text(
project_dir: str,
max_files: int = 30,
max_bytes_per_file: int = 20_000,
) -> str:
"""Gather prose text from ``project_dir`` for use as LLM context source.
Stratified: reads up to ``max_files`` prose files (``.md``, ``.txt``,
``.rst``), preferring recently-modified. Each file capped at
``max_bytes_per_file`` to bound total input.
"""
from pathlib import Path
from mempalace.entity_detector import PROSE_EXTENSIONS, SKIP_DIRS
root = Path(project_dir).expanduser().resolve()
if not root.is_dir():
return ""
candidates: list[tuple[float, Path]] = []
for dirpath, dirs, files in _walk_prose(root, SKIP_DIRS):
for fname in files:
p = dirpath / fname
if p.suffix.lower() not in PROSE_EXTENSIONS:
continue
try:
mtime = p.stat().st_mtime
except OSError:
continue
candidates.append((mtime, p))
candidates.sort(reverse=True)
selected = [p for _, p in candidates[:max_files]]
chunks: list[str] = []
for p in selected:
try:
with open(p, encoding="utf-8", errors="replace") as f:
chunks.append(f.read(max_bytes_per_file))
except OSError:
continue
return "\n".join(chunks)
def _walk_prose(root, skip_dirs):
"""Walk a directory yielding (Path, dirs, files), pruning skip_dirs.
Inlined from ``project_scanner._walk`` to avoid a private-name import
coupling. Functionality is intentionally narrow: prose collection only.
"""
import os
from pathlib import Path
for dirpath, dirs, files in os.walk(root):
dirs[:] = [d for d in dirs if d not in skip_dirs and not d.startswith(".")]
yield Path(dirpath), dirs, files
+92
View File
@@ -52,6 +52,7 @@ READABLE_EXTENSIONS = {
} }
SKIP_FILENAMES = { SKIP_FILENAMES = {
"entities.json",
"mempalace.yaml", "mempalace.yaml",
"mempalace.yml", "mempalace.yml",
"mempal.yaml", "mempal.yaml",
@@ -471,6 +472,97 @@ def _load_known_entities_raw() -> dict:
return dict(_ENTITY_REGISTRY_CACHE["raw"]) return dict(_ENTITY_REGISTRY_CACHE["raw"])
def add_to_known_entities(entities_by_category: dict) -> str:
"""Union ``entities_by_category`` into ``~/.mempalace/known_entities.json``.
Accepts ``{category: [names]}`` shape as produced by ``mempalace init``
and merges into the registry the miner reads at mine time. Existing
categories are preserved untouched unless also present in the input;
for categories present in both, entries are unioned case-insensitively
without changing the on-disk ordering of pre-existing names.
If a category is stored on-disk as ``{name: code}`` (the alternate
miner-supported shape, used by dialect-style configs), new names are
added as keys with ``None`` values so existing code mappings aren't
overwritten. A later compress pass can assign codes.
The in-process cache is invalidated on write so same-process callers
(notably ``cmd_init`` → ``cmd_mine`` in sequence) see the update
immediately instead of waiting for a mtime re-check.
Returns the registry path as a string for logging.
"""
import json as _json
from pathlib import Path as _Path
registry_path = _Path(_ENTITY_REGISTRY_PATH)
registry_path.parent.mkdir(parents=True, exist_ok=True)
existing: dict = {}
if registry_path.exists():
try:
loaded = _json.loads(registry_path.read_text(encoding="utf-8"))
if isinstance(loaded, dict):
existing = loaded
except (_json.JSONDecodeError, OSError):
existing = {}
def _coerce_name(value):
if not value:
return None
name = str(value)
return name if name else None
for category, names in entities_by_category.items():
if not isinstance(names, list) or not names:
continue
current = existing.get(category)
if isinstance(current, list):
seen_lower = {str(n).lower() for n in current}
for n in names:
name = _coerce_name(n)
if not name:
continue
if name.lower() not in seen_lower:
current.append(name)
seen_lower.add(name.lower())
elif isinstance(current, dict):
seen_lower = {str(name).lower() for name in current}
for n in names:
name = _coerce_name(n)
if not name or name.lower() in seen_lower:
continue
current[name] = None
seen_lower.add(name.lower())
else:
# Missing or unrecognized shape — seed as a fresh list, deduped
seen: set = set()
ordered: list = []
for n in names:
name = _coerce_name(n)
if not name:
continue
key = name.lower()
if key in seen:
continue
seen.add(key)
ordered.append(name)
existing[category] = ordered
registry_path.write_text(_json.dumps(existing, indent=2, ensure_ascii=False), encoding="utf-8")
try:
registry_path.chmod(0o600)
except (OSError, NotImplementedError):
pass
# Invalidate in-process cache so later calls in the same run see the write.
_ENTITY_REGISTRY_CACHE["mtime"] = None
_ENTITY_REGISTRY_CACHE["names"] = frozenset()
_ENTITY_REGISTRY_CACHE["raw"] = {}
return str(registry_path)
_HALL_KEYWORDS_CACHE = None _HALL_KEYWORDS_CACHE = None
+70 -5
View File
@@ -594,6 +594,8 @@ def discover_entities(
prose_file_cap: int = 10, prose_file_cap: int = 10,
project_cap: int = 15, project_cap: int = 15,
people_cap: int = 15, people_cap: int = 15,
llm_provider: object = None,
show_progress: bool = True,
) -> dict: ) -> dict:
"""Top-level entity discovery: real signals first, prose detection second. """Top-level entity discovery: real signals first, prose detection second.
@@ -604,10 +606,39 @@ def discover_entities(
1. Package manifests (package.json, pyproject.toml, Cargo.toml, go.mod) 1. Package manifests (package.json, pyproject.toml, Cargo.toml, go.mod)
→ canonical project names → canonical project names
2. Git commit authors → real people with real commit counts 2. Git commit authors → real people with real commit counts
3. Regex entity detection on prose files → supplementary names only 3. Claude Code conversation dirs (~/.claude/projects/) → per-session
project names (pulled from each session's ``cwd`` metadata)
4. Regex entity detection on prose files → supplementary names only
mentioned in docs/notes (not code) mentioned in docs/notes (not code)
5. Optional LLM refinement pass — reclassifies ambiguous candidates
using the caller-supplied provider
Passing ``llm_provider`` enables phase-2 refinement. The caller is
responsible for constructing the provider (``llm_client.get_provider``)
and confirming availability. Refinement is blocking-interactive:
progress prints to stderr; Ctrl-C returns partial results.
""" """
projects, people = scan(project_dir) projects, people = scan(project_dir)
# If the target is a Claude Code conversations root, extract per-project
# entries from there too. Same ProjectInfo shape, so dedup logic works.
from mempalace.convo_scanner import is_claude_projects_root, scan_claude_projects
root_path = Path(project_dir).expanduser().resolve()
if is_claude_projects_root(root_path):
convo_projects = scan_claude_projects(root_path)
# Dedup by name against the git-manifest list, preferring entries with
# more user_commits as signal strength.
by_name: dict[str, ProjectInfo] = {p.name: p for p in projects}
for cp in convo_projects:
existing = by_name.get(cp.name)
if existing is None or cp.user_commits > existing.user_commits:
by_name[cp.name] = cp
projects = sorted(
by_name.values(),
key=lambda p: (not p.is_mine, -p.user_commits, -p.total_commits, p.name),
)
real_signal = to_detected_dict(projects, people, project_cap=project_cap, people_cap=people_cap) real_signal = to_detected_dict(projects, people, project_cap=project_cap, people_cap=people_cap)
# Secondary pass: prose-only extraction catches names mentioned in docs # Secondary pass: prose-only extraction catches names mentioned in docs
@@ -621,11 +652,45 @@ def discover_entities(
else {"people": [], "projects": [], "uncertain": []} else {"people": [], "projects": [], "uncertain": []}
) )
# If git/manifests gave us real projects, suppress the regex "uncertain" bucket. # Without LLM refinement, suppress regex "uncertain" noise when real
# That bucket is mostly noise (common words, CamelCase tech terms, etc.) and # manifest/git signal exists. With LLM refinement enabled, keep those
# adding it to the review flow just makes the user do triage we can skip. # candidates so the model can promote real entities or drop common words.
has_real_signal = bool(projects) or bool(people) has_real_signal = bool(projects) or bool(people)
return _merge_detected(real_signal, prose_detected, drop_secondary_uncertain=has_real_signal) merged = _merge_detected(
real_signal,
prose_detected,
drop_secondary_uncertain=has_real_signal and llm_provider is None,
)
# Optional phase 2: LLM refinement.
if llm_provider is not None:
from mempalace.llm_refine import collect_corpus_text, refine_entities
corpus = collect_corpus_text(str(project_dir))
result = refine_entities(
merged,
corpus,
llm_provider,
show_progress=show_progress,
allow_project_promotions=not has_real_signal,
)
if show_progress:
status_bits = []
if result.cancelled:
status_bits.append("cancelled")
if result.reclassified:
status_bits.append(f"reclassified {result.reclassified}")
if result.dropped:
status_bits.append(f"dropped {result.dropped}")
if result.errors:
status_bits.append(f"{len(result.errors)} batch error(s)")
if status_bits:
import sys as _sys
print(f" LLM refine: {', '.join(status_bits)}", file=_sys.stderr)
merged = result.merged
return merged
# ==================== CLI ==================== # ==================== CLI ====================
+218
View File
@@ -0,0 +1,218 @@
"""Tests for mempalace.convo_scanner."""
import json
from pathlib import Path
from mempalace.convo_scanner import (
_decode_slug_fallback,
_extract_cwd_from_session,
_resolve_project_name,
_safe_mtime,
is_claude_projects_root,
scan_claude_projects,
)
# ── is_claude_projects_root ─────────────────────────────────────────────
def test_is_claude_projects_root_true(tmp_path):
project_dir = tmp_path / "-home-user-dev-foo"
project_dir.mkdir()
(project_dir / "abc.jsonl").write_text("{}\n")
assert is_claude_projects_root(tmp_path)
def test_is_claude_projects_root_false_no_dash_prefix(tmp_path):
project_dir = tmp_path / "normal-folder"
project_dir.mkdir()
(project_dir / "abc.jsonl").write_text("{}\n")
assert not is_claude_projects_root(tmp_path)
def test_is_claude_projects_root_false_no_jsonl(tmp_path):
project_dir = tmp_path / "-home-user-foo"
project_dir.mkdir()
(project_dir / "other.txt").write_text("hello")
assert not is_claude_projects_root(tmp_path)
def test_is_claude_projects_root_false_empty(tmp_path):
assert not is_claude_projects_root(tmp_path)
def test_is_claude_projects_root_false_nonexistent(tmp_path):
assert not is_claude_projects_root(tmp_path / "does-not-exist")
# ── cwd extraction ──────────────────────────────────────────────────────
def test_extract_cwd_from_session(tmp_path):
f = tmp_path / "session.jsonl"
lines = [
json.dumps({"type": "file-history-snapshot", "messageId": "x"}),
json.dumps({"type": "user", "cwd": "/home/user/dev/myproj", "content": "hi"}),
]
f.write_text("\n".join(lines) + "\n")
assert _extract_cwd_from_session(f) == "/home/user/dev/myproj"
def test_extract_cwd_from_session_skips_malformed(tmp_path):
f = tmp_path / "session.jsonl"
f.write_text(
"{not valid json\n" + json.dumps({"type": "user", "cwd": "/home/user/dev/good"}) + "\n"
)
assert _extract_cwd_from_session(f) == "/home/user/dev/good"
def test_extract_cwd_from_session_none_if_absent(tmp_path):
f = tmp_path / "session.jsonl"
f.write_text(json.dumps({"type": "x", "messageId": "y"}) + "\n")
assert _extract_cwd_from_session(f) is None
def test_extract_cwd_from_session_none_if_file_missing(tmp_path):
assert _extract_cwd_from_session(tmp_path / "missing.jsonl") is None
# ── slug fallback ───────────────────────────────────────────────────────
def test_decode_slug_fallback_last_segment():
assert _decode_slug_fallback("-home-user-dev-foo") == "foo"
def test_decode_slug_fallback_double_dash():
assert _decode_slug_fallback("-home-user--bentokit") == "bentokit"
def test_decode_slug_fallback_empty():
assert _decode_slug_fallback("") == ""
def test_decode_slug_fallback_only_dashes():
assert _decode_slug_fallback("---") == "---"
# ── safe metadata helpers ───────────────────────────────────────────────
def test_safe_mtime_returns_zero_on_stat_error(tmp_path, monkeypatch):
f = tmp_path / "session.jsonl"
f.write_text("{}\n")
original_stat = Path.stat
def fail_stat(self):
if self == f:
raise OSError("permission denied")
return original_stat(self)
monkeypatch.setattr(Path, "stat", fail_stat)
assert _safe_mtime(f) == 0.0
# ── _resolve_project_name ───────────────────────────────────────────────
def test_resolve_project_name_uses_cwd(tmp_path):
pdir = tmp_path / "-home-user-dev-coolproj"
pdir.mkdir()
session = pdir / "a.jsonl"
session.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/cool-proj-real"}) + "\n")
assert _resolve_project_name(pdir) == "cool-proj-real"
def test_resolve_project_name_falls_back_when_no_cwd(tmp_path):
pdir = tmp_path / "-home-user-dev-foo"
pdir.mkdir()
(pdir / "a.jsonl").write_text(json.dumps({"type": "x"}) + "\n")
assert _resolve_project_name(pdir) == "foo"
def test_resolve_project_name_prefers_newer_session(tmp_path):
"""Newest session's cwd wins — covers the case where user renamed the
project directory between sessions."""
pdir = tmp_path / "-home-user-dev-old"
pdir.mkdir()
old = pdir / "old.jsonl"
old.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/old"}) + "\n")
# Ensure distinguishable mtimes
old_mtime = old.stat().st_mtime - 100
import os
os.utime(old, (old_mtime, old_mtime))
new = pdir / "new.jsonl"
new.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/new-name"}) + "\n")
assert _resolve_project_name(pdir) == "new-name"
# ── scan_claude_projects ────────────────────────────────────────────────
def test_scan_claude_projects_empty_dir(tmp_path):
assert scan_claude_projects(tmp_path) == []
def test_scan_claude_projects_not_a_projects_root(tmp_path):
"""Returns empty list if the dir doesn't look like .claude/projects/."""
(tmp_path / "some-folder").mkdir()
(tmp_path / "some-folder" / "readme.md").write_text("hi")
assert scan_claude_projects(tmp_path) == []
def test_scan_claude_projects_finds_projects(tmp_path):
p1 = tmp_path / "-home-user-dev-alpha"
p1.mkdir()
(p1 / "a.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/alpha"}) + "\n")
(p1 / "b.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/alpha"}) + "\n")
p2 = tmp_path / "-home-user-dev-beta"
p2.mkdir()
(p2 / "x.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/beta"}) + "\n")
result = scan_claude_projects(tmp_path)
names = [p.name for p in result]
assert "alpha" in names
assert "beta" in names
# alpha has 2 sessions, beta has 1 — alpha ranks higher
alpha = next(p for p in result if p.name == "alpha")
beta = next(p for p in result if p.name == "beta")
assert alpha.user_commits == 2
assert beta.user_commits == 1
def test_scan_claude_projects_ignores_dirs_without_jsonl(tmp_path):
empty_proj = tmp_path / "-home-user-dev-empty"
empty_proj.mkdir()
(empty_proj / "notes.md").write_text("hi")
assert scan_claude_projects(tmp_path) == []
def test_scan_claude_projects_marks_as_mine(tmp_path):
p = tmp_path / "-home-user-dev-owned"
p.mkdir()
(p / "s.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/owned"}) + "\n")
result = scan_claude_projects(tmp_path)
assert len(result) == 1
assert result[0].is_mine is True
def test_scan_claude_projects_dedup_by_name(tmp_path):
"""Two encoded dirs resolving to the same project name collapse to one."""
p1 = tmp_path / "-home-user-a-proj"
p1.mkdir()
(p1 / "s.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/a/proj"}) + "\n")
(p1 / "t.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/a/proj"}) + "\n")
p2 = tmp_path / "-home-user-b-proj"
p2.mkdir()
(p2 / "u.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/b/proj"}) + "\n")
result = scan_claude_projects(tmp_path)
# Both decode to "proj"; only one remains — the one with more sessions wins
assert len(result) == 1
assert result[0].name == "proj"
assert result[0].user_commits == 2
+208
View File
@@ -0,0 +1,208 @@
"""Tests for mempalace.miner.add_to_known_entities.
Covers the init → miner wire-up: init's confirmed entities merged into
``~/.mempalace/known_entities.json`` so the miner's drawer-tagging path
recognizes them at mine time.
Every test redirects the registry path to a tmp_path to avoid touching
the real ~/.mempalace/ on the developer's machine.
"""
import json
import pytest
from mempalace import miner
@pytest.fixture
def temp_registry(tmp_path, monkeypatch):
"""Redirect the module-level registry path to a tmp file and reset cache."""
registry = tmp_path / "known_entities.json"
monkeypatch.setattr(miner, "_ENTITY_REGISTRY_PATH", str(registry))
miner._ENTITY_REGISTRY_CACHE.update({"mtime": None, "names": frozenset(), "raw": {}})
return registry
# ── fresh-file cases ────────────────────────────────────────────────────
def test_creates_registry_when_absent(temp_registry):
assert not temp_registry.exists()
miner.add_to_known_entities({"people": ["Alice", "Bob"], "projects": ["foo"]})
assert temp_registry.exists()
data = json.loads(temp_registry.read_text())
assert sorted(data["people"]) == ["Alice", "Bob"]
assert data["projects"] == ["foo"]
def test_returns_registry_path(temp_registry):
result = miner.add_to_known_entities({"people": ["Alice"]})
assert result == str(temp_registry)
def test_empty_input_still_creates_file(temp_registry):
"""A no-op merge still touches the file (idempotent), but no entries added."""
miner.add_to_known_entities({})
# File may or may not be written for a truly empty call — tolerate either.
if temp_registry.exists():
data = json.loads(temp_registry.read_text())
assert data == {} or all(not v for v in data.values())
def test_skips_empty_name_strings(temp_registry):
miner.add_to_known_entities({"people": ["Alice", "", None]})
data = json.loads(temp_registry.read_text())
assert data["people"] == ["Alice"]
# ── union / dedup cases ────────────────────────────────────────────────
def test_unions_with_existing_list_category(temp_registry):
temp_registry.write_text(json.dumps({"people": ["Alice", "Bob"]}))
miner.add_to_known_entities({"people": ["Bob", "Carol"]})
data = json.loads(temp_registry.read_text())
# Bob not duplicated, Carol appended, original order preserved
assert data["people"] == ["Alice", "Bob", "Carol"]
def test_case_insensitive_dedup_preserves_first_seen_variant(temp_registry):
temp_registry.write_text(json.dumps({"people": ["Alice"]}))
miner.add_to_known_entities({"people": ["alice", "ALICE", "Bob"]})
data = json.loads(temp_registry.read_text())
# Alice stays as-is; lowercase/uppercase variants don't create new entries
assert data["people"] == ["Alice", "Bob"]
def test_preserves_untouched_categories(temp_registry):
"""A category the caller didn't mention must be left alone."""
temp_registry.write_text(json.dumps({"people": ["Alice"], "places": ["Paris", "Tokyo"]}))
miner.add_to_known_entities({"people": ["Bob"]})
data = json.loads(temp_registry.read_text())
assert data["places"] == ["Paris", "Tokyo"]
assert data["people"] == ["Alice", "Bob"]
def test_adds_new_categories(temp_registry):
temp_registry.write_text(json.dumps({"people": ["Alice"]}))
miner.add_to_known_entities({"projects": ["foo", "bar"]})
data = json.loads(temp_registry.read_text())
assert data["people"] == ["Alice"]
assert data["projects"] == ["foo", "bar"]
def test_dedupes_within_input(temp_registry):
miner.add_to_known_entities({"people": ["Alice", "alice", "Alice"]})
data = json.loads(temp_registry.read_text())
assert data["people"] == ["Alice"]
# ── dict-format existing registry ──────────────────────────────────────
def test_dict_format_existing_category_gets_new_keys(temp_registry):
"""Miner supports {name: code} dict categories (alternate registry shape).
New names are added as keys without overwriting existing codes."""
temp_registry.write_text(json.dumps({"people": {"Alice": "ALC", "Bob": "BOB"}}))
miner.add_to_known_entities({"people": ["Alice", "Carol"]})
data = json.loads(temp_registry.read_text())
# Alice's code survives; Carol added with None; Bob untouched
assert data["people"]["Alice"] == "ALC"
assert data["people"]["Bob"] == "BOB"
assert "Carol" in data["people"]
assert data["people"]["Carol"] is None
def test_dict_format_dedupes_case_insensitively_and_stringifies_new_names(temp_registry):
temp_registry.write_text(json.dumps({"people": {"Alice": "ALC"}}))
miner.add_to_known_entities({"people": ["alice", 123]})
data = json.loads(temp_registry.read_text())
assert data["people"] == {"Alice": "ALC", "123": None}
# ── error tolerance ───────────────────────────────────────────────────
def test_malformed_existing_registry_starts_fresh(temp_registry):
temp_registry.write_text("{ not valid json")
miner.add_to_known_entities({"people": ["Alice"]})
data = json.loads(temp_registry.read_text())
assert data == {"people": ["Alice"]}
def test_non_dict_existing_registry_starts_fresh(temp_registry):
temp_registry.write_text(json.dumps(["unexpected", "array"]))
miner.add_to_known_entities({"people": ["Alice"]})
data = json.loads(temp_registry.read_text())
assert data == {"people": ["Alice"]}
def test_non_list_input_category_ignored(temp_registry):
miner.add_to_known_entities({"people": ["Alice"], "weird": "not a list"})
data = json.loads(temp_registry.read_text())
assert "weird" not in data or data.get("weird") == "not a list"
assert data["people"] == ["Alice"]
# ── cache invalidation ───────────────────────────────────────────────
def test_cache_invalidated_so_subsequent_load_sees_write(temp_registry):
"""cmd_init → cmd_mine runs in the same process; the load path must
see what init just wrote without a process restart."""
# Prime the cache with an empty state
miner._load_known_entities()
assert miner._load_known_entities() == frozenset()
miner.add_to_known_entities({"people": ["Alice", "Bob"], "projects": ["foo"]})
loaded = miner._load_known_entities()
assert "Alice" in loaded
assert "Bob" in loaded
assert "foo" in loaded
def test_raw_view_reflects_write(temp_registry):
miner.add_to_known_entities({"people": ["Alice"]})
raw = miner._load_known_entities_raw()
assert raw.get("people") == ["Alice"]
# ── Unicode round-trip ────────────────────────────────────────────────
def test_unicode_names_written_literally_not_escaped(temp_registry):
"""`ensure_ascii=False` so non-ASCII names stay readable on disk."""
miner.add_to_known_entities({"people": ["Gergő Móricz", "Arturo Domínguez"]})
raw_text = temp_registry.read_text(encoding="utf-8")
assert "Gergő" in raw_text
assert "Móricz" in raw_text
# Round-trips through JSON
data = json.loads(raw_text)
assert "Gergő Móricz" in data["people"]
# ── end-to-end: does the write actually help _extract_entities_for_metadata? ──
def test_populated_registry_improves_miner_recall(temp_registry):
"""The whole point of the wire-up: names written via add_to_known_entities
must be recognized by the miner's entity-extraction metadata pass."""
miner.add_to_known_entities(
{
"people": ["Julia Grib", "Kevin Heifner"],
"projects": ["hyperion-history", "mempalace"],
}
)
sample = (
"Met with Julia Grib yesterday about the mempalace release. "
"Kevin Heifner pushed the hyperion-history fix."
)
result = miner._extract_entities_for_metadata(sample)
tagged = set(result.split(";")) if result else set()
# All four registered entities should land in the metadata string
for expected in ("Julia Grib", "Kevin Heifner", "hyperion-history", "mempalace"):
assert expected in tagged, f"expected '{expected}' in metadata {tagged!r}"
+327
View File
@@ -0,0 +1,327 @@
"""Tests for mempalace.llm_client.
HTTP is mocked throughout — these tests do not require a running Ollama
or network access. Live-provider smoke tests live outside the unit-test
suite.
"""
import json
from unittest.mock import patch, MagicMock
import pytest
from mempalace.llm_client import (
AnthropicProvider,
LLMError,
OllamaProvider,
OpenAICompatProvider,
_http_post_json,
get_provider,
)
# ── factory ─────────────────────────────────────────────────────────────
def test_get_provider_ollama():
p = get_provider("ollama", "gemma4:e4b")
assert isinstance(p, OllamaProvider)
assert p.model == "gemma4:e4b"
assert p.endpoint == OllamaProvider.DEFAULT_ENDPOINT
def test_get_provider_openai_compat():
p = get_provider("openai-compat", "foo", endpoint="http://localhost:1234")
assert isinstance(p, OpenAICompatProvider)
def test_get_provider_anthropic():
p = get_provider("anthropic", "claude-haiku", api_key="sk-xxx")
assert isinstance(p, AnthropicProvider)
assert p.api_key == "sk-xxx"
def test_get_provider_unknown_raises():
with pytest.raises(LLMError, match="Unknown provider"):
get_provider("nonsense", "x")
# ── _http_post_json ─────────────────────────────────────────────────────
def test_http_post_json_success():
mock_resp = MagicMock()
mock_resp.read.return_value = b'{"ok": true}'
mock_resp.__enter__.return_value = mock_resp
mock_resp.__exit__.return_value = False
with patch("mempalace.llm_client.urlopen", return_value=mock_resp):
result = _http_post_json("http://x/y", {"a": 1}, {}, timeout=5)
assert result == {"ok": True}
def test_http_post_json_http_error_wraps_as_llm_error():
from urllib.error import HTTPError
import io
err = HTTPError("http://x", 404, "Not Found", {}, io.BytesIO(b"model missing"))
with patch("mempalace.llm_client.urlopen", side_effect=err):
with pytest.raises(LLMError, match="HTTP 404"):
_http_post_json("http://x", {}, {}, timeout=5)
def test_http_post_json_url_error_wraps_as_llm_error():
from urllib.error import URLError
with patch("mempalace.llm_client.urlopen", side_effect=URLError("conn refused")):
with pytest.raises(LLMError, match="Cannot reach"):
_http_post_json("http://x", {}, {}, timeout=5)
def test_http_post_json_malformed_response():
mock_resp = MagicMock()
mock_resp.read.return_value = b"not json"
mock_resp.__enter__.return_value = mock_resp
mock_resp.__exit__.return_value = False
with patch("mempalace.llm_client.urlopen", return_value=mock_resp):
with pytest.raises(LLMError, match="Malformed"):
_http_post_json("http://x", {}, {}, timeout=5)
# ── OllamaProvider ──────────────────────────────────────────────────────
def _mock_ollama_chat_response(content: str):
mock = MagicMock()
mock.read.return_value = json.dumps({"message": {"content": content}}).encode()
mock.__enter__.return_value = mock
mock.__exit__.return_value = False
return mock
def test_ollama_check_available_finds_model():
tags = {"models": [{"name": "gemma4:e4b"}, {"name": "other:latest"}]}
mock = MagicMock()
mock.read.return_value = json.dumps(tags).encode()
mock.__enter__.return_value = mock
mock.__exit__.return_value = False
with patch("mempalace.llm_client.urlopen", return_value=mock):
p = OllamaProvider(model="gemma4:e4b")
ok, msg = p.check_available()
assert ok
assert msg == "ok"
def test_ollama_check_available_accepts_latest_suffix():
tags = {"models": [{"name": "mymodel:latest"}]}
mock = MagicMock()
mock.read.return_value = json.dumps(tags).encode()
mock.__enter__.return_value = mock
mock.__exit__.return_value = False
with patch("mempalace.llm_client.urlopen", return_value=mock):
p = OllamaProvider(model="mymodel")
ok, _ = p.check_available()
assert ok
def test_ollama_check_available_missing_model():
tags = {"models": [{"name": "other:latest"}]}
mock = MagicMock()
mock.read.return_value = json.dumps(tags).encode()
mock.__enter__.return_value = mock
mock.__exit__.return_value = False
with patch("mempalace.llm_client.urlopen", return_value=mock):
p = OllamaProvider(model="absent")
ok, msg = p.check_available()
assert not ok
assert "ollama pull absent" in msg
def test_ollama_check_available_unreachable():
from urllib.error import URLError
with patch("mempalace.llm_client.urlopen", side_effect=URLError("refused")):
p = OllamaProvider(model="gemma4:e4b")
ok, msg = p.check_available()
assert not ok
assert "Cannot reach Ollama" in msg
def test_ollama_classify_sends_json_format():
captured = {}
def fake_urlopen(req, *, timeout):
captured["url"] = req.full_url
captured["body"] = json.loads(req.data.decode())
return _mock_ollama_chat_response('{"classifications": []}')
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
p = OllamaProvider(model="gemma4:e4b")
resp = p.classify("sys", "user", json_mode=True)
assert captured["body"]["format"] == "json"
assert captured["body"]["model"] == "gemma4:e4b"
assert captured["url"].endswith("/api/chat")
assert resp.provider == "ollama"
assert resp.text == '{"classifications": []}'
def test_ollama_classify_empty_content_raises():
with patch("mempalace.llm_client.urlopen", return_value=_mock_ollama_chat_response("")):
p = OllamaProvider(model="x")
with pytest.raises(LLMError, match="Empty response"):
p.classify("s", "u")
# ── OpenAICompatProvider ────────────────────────────────────────────────
def _mock_openai_response(content: str):
mock = MagicMock()
payload = {"choices": [{"message": {"content": content}}]}
mock.read.return_value = json.dumps(payload).encode()
mock.__enter__.return_value = mock
mock.__exit__.return_value = False
return mock
def test_openai_compat_resolves_url_with_v1_suffix():
captured = {}
def fake_urlopen(req, *, timeout):
captured["url"] = req.full_url
return _mock_openai_response('{"ok": true}')
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
p = OpenAICompatProvider(model="x", endpoint="http://h:1234")
p.classify("s", "u")
assert captured["url"] == "http://h:1234/v1/chat/completions"
def test_openai_compat_resolves_url_with_existing_v1():
captured = {}
def fake_urlopen(req, *, timeout):
captured["url"] = req.full_url
return _mock_openai_response('{"ok": true}')
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
p = OpenAICompatProvider(model="x", endpoint="http://h:1234/v1")
p.classify("s", "u")
assert captured["url"] == "http://h:1234/v1/chat/completions"
def test_openai_compat_requires_endpoint():
p = OpenAICompatProvider(model="x")
with pytest.raises(LLMError, match="requires --llm-endpoint"):
p.classify("s", "u")
def test_openai_compat_sends_authorization_when_key_present():
captured = {}
def fake_urlopen(req, *, timeout):
captured["auth"] = req.get_header("Authorization")
return _mock_openai_response('{"ok": true}')
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
p = OpenAICompatProvider(model="x", endpoint="http://h", api_key="sk-aaa")
p.classify("s", "u")
assert captured["auth"] == "Bearer sk-aaa"
def test_openai_compat_uses_env_var_fallback(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "sk-from-env")
p = OpenAICompatProvider(model="x", endpoint="http://h")
assert p.api_key == "sk-from-env"
def test_openai_compat_sends_response_format_json():
captured = {}
def fake_urlopen(req, *, timeout):
captured["body"] = json.loads(req.data.decode())
return _mock_openai_response('{"ok": true}')
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
p = OpenAICompatProvider(model="x", endpoint="http://h")
p.classify("s", "u", json_mode=True)
assert captured["body"]["response_format"] == {"type": "json_object"}
def test_openai_compat_unexpected_shape_raises():
mock = MagicMock()
mock.read.return_value = b'{"nothing": "here"}'
mock.__enter__.return_value = mock
mock.__exit__.return_value = False
with patch("mempalace.llm_client.urlopen", return_value=mock):
p = OpenAICompatProvider(model="x", endpoint="http://h")
with pytest.raises(LLMError, match="Unexpected response shape"):
p.classify("s", "u")
# ── AnthropicProvider ───────────────────────────────────────────────────
def _mock_anthropic_response(text: str):
mock = MagicMock()
payload = {"content": [{"type": "text", "text": text}]}
mock.read.return_value = json.dumps(payload).encode()
mock.__enter__.return_value = mock
mock.__exit__.return_value = False
return mock
def test_anthropic_requires_api_key(monkeypatch):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
p = AnthropicProvider(model="claude-haiku")
ok, msg = p.check_available()
assert not ok
assert "ANTHROPIC_API_KEY" in msg
def test_anthropic_reads_env_key(monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-env")
p = AnthropicProvider(model="claude-haiku")
assert p.api_key == "sk-ant-env"
ok, _ = p.check_available()
assert ok
def test_anthropic_classify_sends_version_and_key():
captured = {}
def fake_urlopen(req, *, timeout):
captured["api_key"] = req.get_header("X-api-key")
captured["version"] = req.get_header("Anthropic-version")
return _mock_anthropic_response('{"ok": true}')
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
p = AnthropicProvider(model="claude-haiku", api_key="sk-ant-abc")
resp = p.classify("s", "u")
assert captured["api_key"] == "sk-ant-abc"
assert captured["version"] == AnthropicProvider.API_VERSION
assert resp.text == '{"ok": true}'
def test_anthropic_joins_multiple_text_blocks():
mock = MagicMock()
payload = {
"content": [
{"type": "text", "text": "part one. "},
{"type": "text", "text": "part two."},
]
}
mock.read.return_value = json.dumps(payload).encode()
mock.__enter__.return_value = mock
mock.__exit__.return_value = False
with patch("mempalace.llm_client.urlopen", return_value=mock):
p = AnthropicProvider(model="claude-haiku", api_key="sk-ant")
resp = p.classify("s", "u")
assert resp.text == "part one. part two."
def test_anthropic_no_key_raises_on_classify(monkeypatch):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
p = AnthropicProvider(model="claude-haiku")
with pytest.raises(LLMError, match="requires ANTHROPIC_API_KEY"):
p.classify("s", "u")
+631
View File
@@ -0,0 +1,631 @@
"""Tests for mempalace.llm_refine.
Uses a fake provider for deterministic, offline tests. No network.
"""
from dataclasses import dataclass
from mempalace.llm_client import LLMError, LLMResponse
from mempalace.llm_refine import (
_apply_classifications,
_build_user_prompt,
_collect_contexts,
_extract_json_candidates,
_is_authoritative_person,
_is_authoritative_project,
_parse_response,
collect_corpus_text,
refine_entities,
)
# ── fake provider ───────────────────────────────────────────────────────
@dataclass
class FakeProvider:
"""Returns a caller-supplied JSON string on every classify call."""
response_text: str = ""
should_raise: Exception = None
call_count: int = 0
interrupt_on_call: int = -1
def classify(self, system, user, json_mode=True):
self.call_count += 1
if self.call_count == self.interrupt_on_call:
raise KeyboardInterrupt()
if self.should_raise is not None:
raise self.should_raise
return LLMResponse(text=self.response_text, model="fake", provider="fake", raw={})
def check_available(self):
return True, "ok"
# ── _collect_contexts ───────────────────────────────────────────────────
def test_collect_contexts_finds_matches():
lines = [
"Something about Alice",
"Bob said hello",
"Alice was here",
"Alice walked by",
]
out = _collect_contexts(lines, "Alice", max_lines=2)
assert len(out) == 2
assert all("alice" in line.lower() for line in out)
def test_collect_contexts_case_insensitive():
lines = ["lowercase alice mention"]
out = _collect_contexts(lines, "Alice")
assert out == ["lowercase alice mention"]
def test_collect_contexts_uses_token_boundaries():
lines = [
"forgot should not match",
"Go is a language.",
"go-v1 shipped.",
]
out = _collect_contexts(lines, "Go", max_lines=5)
assert out == ["Go is a language.", "go-v1 shipped."]
def test_collect_contexts_dedupes_identical_lines():
lines = ["Alice", "Alice", "Alice was here"]
out = _collect_contexts(lines, "Alice", max_lines=5)
# two unique lines, not three
assert len(out) == 2
def test_collect_contexts_truncates_long_lines():
lines = ["Alice " + ("x" * 1000)]
out = _collect_contexts(lines, "Alice")
assert len(out[0]) <= 240
def test_collect_contexts_no_matches():
assert _collect_contexts(["nothing here"], "Alice") == []
# ── _build_user_prompt ──────────────────────────────────────────────────
def test_build_user_prompt_numbers_and_includes_contexts():
prompt = _build_user_prompt(
[
("Alice", "uncertain", ["Alice said hi"]),
("Bob", "project", []),
]
)
assert "1. Alice" in prompt
assert "2. Bob" in prompt
assert "Alice said hi" in prompt
assert "(no context available)" in prompt
# ── _parse_response ─────────────────────────────────────────────────────
def test_parse_response_canonicalizes_label():
text = '{"classifications": [{"name": "Alice", "label": "person", "reason": "x"}]}'
out = _parse_response(text, ["Alice"])
assert out["Alice"] == ("PERSON", "x")
def test_parse_response_accepts_type_alias():
"""LLMs may return 'type' instead of 'label'."""
text = '{"classifications": [{"name": "Bob", "type": "PROJECT"}]}'
out = _parse_response(text, ["Bob"])
assert out["Bob"][0] == "PROJECT"
def test_parse_response_maps_unknown_label_to_ambiguous():
text = '{"classifications": [{"name": "X", "label": "WEIRD"}]}'
out = _parse_response(text, ["X"])
assert out["X"][0] == "AMBIGUOUS"
def test_parse_response_restores_canonical_casing():
"""Model may lowercase the name; we restore against the expected set."""
text = '{"classifications": [{"name": "mempalace", "label": "PROJECT"}]}'
out = _parse_response(text, ["MemPalace"])
assert "MemPalace" in out
assert out["MemPalace"][0] == "PROJECT"
def test_parse_response_strips_code_fences():
text = '```json\n{"classifications": [{"name": "X", "label": "TOPIC"}]}\n```'
out = _parse_response(text, ["X"])
assert out["X"][0] == "TOPIC"
def test_parse_response_extracts_json_after_prose():
text = 'Sure, here is the JSON: {"classifications": [{"name": "X", "label": "TOPIC"}]}'
out = _parse_response(text, ["X"])
assert out["X"][0] == "TOPIC"
def test_parse_response_extracts_fenced_json_after_prose():
text = 'Sure:\n```json\n{"classifications": [{"name": "X", "label": "PROJECT"}]}\n```'
out = _parse_response(text, ["X"])
assert out["X"][0] == "PROJECT"
def test_extract_json_candidates_handles_embedded_array():
text = 'prefix [{"name": "Y", "label": "PERSON"}] suffix'
candidates = _extract_json_candidates(text)
assert '[{"name": "Y", "label": "PERSON"}]' in candidates
def test_parse_response_ignores_non_json_brackets_before_payload():
text = 'See [note] first. JSON: {"classifications": [{"name": "X", "label": "TOPIC"}]}'
out = _parse_response(text, ["X"])
assert out["X"][0] == "TOPIC"
def test_parse_response_malformed_returns_empty():
out = _parse_response("not json at all", ["X"])
assert out == {}
def test_parse_response_accepts_top_level_list():
"""Some models skip the wrapping object and return the list directly."""
text = '[{"name": "Y", "label": "PERSON"}]'
out = _parse_response(text, ["Y"])
assert out["Y"][0] == "PERSON"
# ── _apply_classifications ──────────────────────────────────────────────
def test_apply_classifications_moves_to_correct_bucket():
detected = {
"people": [],
"projects": [
{
"name": "Foo",
"type": "project",
"confidence": 0.8,
"frequency": 3,
"signals": ["old"],
}
],
"uncertain": [
{"name": "Alice", "type": "uncertain", "confidence": 0.4, "frequency": 5, "signals": []}
],
}
decisions = {
"Foo": ("PROJECT", "real project name"),
"Alice": ("PERSON", "clearly a person"),
}
new, reclass, dropped = _apply_classifications(detected, decisions)
assert len(new["people"]) == 1
assert new["people"][0]["name"] == "Alice"
assert new["people"][0]["type"] == "person"
assert reclass == 1 # Alice moved uncertain -> people
assert dropped == 0
def test_apply_classifications_drops_common_word():
detected = {
"people": [],
"projects": [],
"uncertain": [
{
"name": "Never",
"type": "uncertain",
"confidence": 0.4,
"frequency": 20,
"signals": [],
}
],
}
decisions = {"Never": ("COMMON_WORD", "adverb")}
new, _, dropped = _apply_classifications(detected, decisions)
assert dropped == 1
assert new["uncertain"] == []
def test_apply_classifications_keeps_unvisited_entries():
detected = {
"people": [
{
"name": "Igor",
"type": "person",
"confidence": 0.99,
"frequency": 100,
"signals": ["git"],
}
],
"projects": [],
"uncertain": [],
}
# No decision for Igor — should stay untouched
new, reclass, dropped = _apply_classifications(detected, {})
assert new["people"][0]["name"] == "Igor"
assert reclass == 0
assert dropped == 0
def test_apply_classifications_appends_reason_signal():
detected = {
"people": [],
"projects": [],
"uncertain": [
{
"name": "Foo",
"type": "uncertain",
"confidence": 0.4,
"frequency": 5,
"signals": ["regex"],
}
],
}
decisions = {"Foo": ("PERSON", "spoken of by name")}
new, _, _ = _apply_classifications(detected, decisions)
assert any("LLM: person" in s for s in new["people"][0]["signals"])
assert any("spoken of by name" in s for s in new["people"][0]["signals"])
def test_apply_classifications_topic_goes_to_uncertain():
detected = {
"people": [],
"projects": [
{
"name": "Paris",
"type": "project",
"confidence": 0.7,
"frequency": 5,
"signals": ["regex"],
}
],
"uncertain": [],
}
decisions = {"Paris": ("TOPIC", "city, not a project")}
new, reclass, _ = _apply_classifications(detected, decisions)
assert len(new["projects"]) == 0
assert len(new["uncertain"]) == 1
assert new["uncertain"][0]["name"] == "Paris"
assert reclass == 1
def test_apply_classifications_can_block_llm_only_project_promotion():
detected = {
"people": [],
"projects": [],
"uncertain": [
{
"name": "Terraform",
"type": "uncertain",
"confidence": 0.4,
"frequency": 5,
"signals": ["regex"],
}
],
}
decisions = {"Terraform": ("PROJECT", "tool")}
new, reclass, _ = _apply_classifications(
detected,
decisions,
allow_project_promotions=False,
)
assert new["projects"] == []
assert new["uncertain"][0]["name"] == "Terraform"
assert new["uncertain"][0]["type"] == "uncertain"
assert reclass == 0
def test_apply_classifications_allows_project_promotion_for_prose_only_mode():
detected = {
"people": [],
"projects": [],
"uncertain": [
{
"name": "Project Aurora",
"type": "uncertain",
"confidence": 0.4,
"frequency": 5,
"signals": ["regex"],
}
],
}
decisions = {"Project Aurora": ("PROJECT", "user effort")}
new, reclass, _ = _apply_classifications(detected, decisions)
assert new["projects"][0]["name"] == "Project Aurora"
assert new["projects"][0]["type"] == "project"
assert reclass == 1
# ── authoritative source filters ────────────────────────────────────────
def test_is_authoritative_person_requires_git_signal():
assert _is_authoritative_person({"signals": ["5 commits across 2 repos"]})
assert not _is_authoritative_person({"signals": ["pronoun nearby (5x)"]})
def test_is_authoritative_project_requires_manifest_or_git_signal():
assert _is_authoritative_project({"signals": ["package.json, 12 of your commits"]})
assert _is_authoritative_project({"signals": ["57 commits (none by you)"]})
assert not _is_authoritative_project({"signals": ["code file reference (5x)"]})
# ── refine_entities ─────────────────────────────────────────────────────
def _sample_detected():
return {
"people": [
{
"name": "Igor",
"type": "person",
"confidence": 0.99,
"frequency": 100,
"signals": ["git"],
}
],
"projects": [
{
"name": "Foo",
"type": "project",
"confidence": 0.7,
"frequency": 5,
"signals": ["regex"],
}
],
"uncertain": [
{
"name": "Never",
"type": "uncertain",
"confidence": 0.4,
"frequency": 10,
"signals": [],
},
{
"name": "Alice",
"type": "uncertain",
"confidence": 0.4,
"frequency": 5,
"signals": [],
},
],
}
def test_refine_entities_end_to_end_with_fake_provider():
provider = FakeProvider(
response_text=(
'{"classifications": ['
'{"name": "Foo", "label": "PROJECT", "reason": "real"},'
'{"name": "Never", "label": "COMMON_WORD"},'
'{"name": "Alice", "label": "PERSON", "reason": "name"}'
"]}"
)
)
result = refine_entities(
_sample_detected(),
corpus_text="Alice said hi. Foo was shipped. Never gonna.",
provider=provider,
show_progress=False,
)
assert result.batches_total == 1
assert result.batches_completed == 1
assert not result.cancelled
# Alice → people, Never → dropped, Foo stays in projects
names_in_people = [e["name"] for e in result.merged["people"]]
assert "Alice" in names_in_people
assert "Igor" in names_in_people # untouched
assert "Never" not in [e["name"] for e in result.merged["uncertain"]]
assert result.dropped == 1
def test_refine_entities_skips_high_confidence_projects():
"""Manifest-backed projects (conf >= 0.95) aren't sent to the LLM."""
detected = {
"people": [],
"projects": [
{
"name": "manifest-backed",
"type": "project",
"confidence": 0.99,
"frequency": 50,
"signals": ["pyproject.toml"],
}
],
"uncertain": [],
}
provider = FakeProvider(response_text='{"classifications": []}')
refine_entities(detected, "", provider, show_progress=False)
# Should not have called the LLM at all
assert provider.call_count == 0
def test_refine_entities_refines_high_confidence_regex_projects():
"""High-confidence regex projects still need LLM review without source signal."""
detected = {
"people": [],
"projects": [
{
"name": "OpenAPI",
"type": "project",
"confidence": 0.99,
"frequency": 5,
"signals": ["code file reference (5x)"],
}
],
"uncertain": [],
}
provider = FakeProvider(
response_text=(
'{"classifications": [{"name": "OpenAPI", "label": "TOPIC", "reason": "technology"}]}'
)
)
result = refine_entities(detected, "OpenAPI schemas", provider, show_progress=False)
assert provider.call_count == 1
assert result.reclassified == 1
assert result.merged["projects"] == []
assert result.merged["uncertain"][0]["name"] == "OpenAPI"
def test_refine_entities_refines_regex_people_but_skips_git_people():
detected = {
"people": [
{
"name": "Igor Lins e Silva",
"type": "person",
"confidence": 0.99,
"frequency": 100,
"signals": ["100 commits across 3 repos"],
},
{
"name": "Tool",
"type": "person",
"confidence": 0.99,
"frequency": 5,
"signals": ["pronoun nearby (5x)"],
},
],
"projects": [],
"uncertain": [],
}
provider = FakeProvider(
response_text='{"classifications": [{"name": "Tool", "label": "COMMON_WORD"}]}'
)
result = refine_entities(detected, "Tool is a common noun.", provider, show_progress=False)
assert provider.call_count == 1
names = [e["name"] for e in result.merged["people"]]
assert names == ["Igor Lins e Silva"]
assert result.dropped == 1
def test_refine_entities_can_keep_llm_only_project_in_uncertain():
detected = {
"people": [],
"projects": [],
"uncertain": [
{
"name": "Terraform",
"type": "uncertain",
"confidence": 0.4,
"frequency": 9,
"signals": ["regex"],
}
],
}
provider = FakeProvider(
response_text='{"classifications": [{"name": "Terraform", "label": "PROJECT"}]}'
)
result = refine_entities(
detected,
"Terraform config",
provider,
show_progress=False,
allow_project_promotions=False,
)
assert result.merged["projects"] == []
assert result.merged["uncertain"][0]["name"] == "Terraform"
assert any("LLM: project" in s for s in result.merged["uncertain"][0]["signals"])
def test_refine_entities_empty_candidates_returns_noop():
detected = {"people": [], "projects": [], "uncertain": []}
provider = FakeProvider()
result = refine_entities(detected, "", provider, show_progress=False)
assert result.batches_total == 0
assert result.reclassified == 0
assert result.merged == detected
def test_refine_entities_handles_batch_error_gracefully():
provider = FakeProvider(should_raise=LLMError("transport broke"))
result = refine_entities(
_sample_detected(),
corpus_text="",
provider=provider,
show_progress=False,
)
assert result.errors
assert "transport broke" in result.errors[0]
# Detected unchanged (no successful decisions)
assert result.reclassified == 0
assert result.cancelled is False
def test_refine_entities_ctrl_c_returns_partial():
"""Ctrl-C during refinement marks cancelled=True and returns partial result."""
# Two batches' worth of candidates
detected = {
"people": [],
"projects": [],
"uncertain": [
{
"name": f"Cand{i}",
"type": "uncertain",
"confidence": 0.4,
"frequency": 3,
"signals": [],
}
for i in range(50)
],
}
provider = FakeProvider(
response_text='{"classifications": []}',
interrupt_on_call=2, # interrupt on second batch
)
result = refine_entities(detected, "", provider, batch_size=25, show_progress=False)
assert result.cancelled is True
assert result.batches_completed == 1 # first batch finished; second interrupted
assert result.batches_total == 2
def test_refine_entities_malformed_response_recorded_as_error():
provider = FakeProvider(response_text="not json")
result = refine_entities(_sample_detected(), "", provider, show_progress=False)
assert any("could not parse" in e for e in result.errors)
# ── collect_corpus_text ─────────────────────────────────────────────────
def test_collect_corpus_text_reads_prose_files(tmp_path):
(tmp_path / "a.md").write_text("hello world")
(tmp_path / "b.txt").write_text("more prose")
(tmp_path / "c.py").write_text("import os") # not prose, skipped
text = collect_corpus_text(str(tmp_path))
assert "hello world" in text
assert "more prose" in text
assert "import os" not in text
def test_collect_corpus_text_prefers_recent(tmp_path):
import os
import time
old = tmp_path / "old.md"
old.write_text("OLD_CONTENT")
time.sleep(0.01)
new = tmp_path / "new.md"
new.write_text("NEW_CONTENT")
# Force old to be older still
old_mtime = old.stat().st_mtime - 3600
os.utime(old, (old_mtime, old_mtime))
text = collect_corpus_text(str(tmp_path), max_files=1)
assert "NEW_CONTENT" in text
assert "OLD_CONTENT" not in text
def test_collect_corpus_text_missing_dir_returns_empty(tmp_path):
assert collect_corpus_text(str(tmp_path / "nope")) == ""
def test_collect_corpus_text_caps_bytes_per_file(tmp_path):
big = tmp_path / "big.md"
big.write_text("x" * 100_000)
text = collect_corpus_text(str(tmp_path), max_files=1, max_bytes_per_file=500)
assert len(text) <= 600 # 500 + newlines
+10
View File
@@ -66,6 +66,16 @@ def test_load_config_uses_defaults_when_yaml_missing():
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
def test_scan_project_skips_mempalace_generated_files():
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir).resolve()
write_file(project_root / "entities.json", '{"people": [], "projects": []}')
write_file(project_root / "mempalace.yaml", "wing: test\nrooms: []\n")
write_file(project_root / "notes.md", "real user content\n" * 10)
assert scanned_files(project_root) == ["notes.md"]
def test_scan_project_respects_gitignore(): def test_scan_project_respects_gitignore():
tmpdir = tempfile.mkdtemp() tmpdir = tempfile.mkdtemp()
try: try:
+44
View File
@@ -5,6 +5,7 @@ import os
import shutil import shutil
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
import pytest import pytest
@@ -480,6 +481,49 @@ def test_discover_entities_prefers_real_signal_over_prose(tmp_path):
assert "realproj" in proj_names assert "realproj" in proj_names
def test_discover_entities_keeps_uncertain_for_llm_when_real_signal(tmp_path):
"""With --llm, regex-uncertain prose candidates should reach refinement."""
(tmp_path / "package.json").write_text(json.dumps({"name": "realproj"}))
_init_git_repo(tmp_path)
(tmp_path / "doc.md").write_text("Noise appeared. Noise repeated. Noise again.")
class FakeProvider:
def __init__(self):
self.prompts = []
def classify(self, _system, user, json_mode=True):
self.prompts.append(user)
return SimpleNamespace(
text='{"classifications": [{"name": "Noise", "label": "COMMON_WORD"}]}'
)
provider = FakeProvider()
d = discover_entities(str(tmp_path), llm_provider=provider, show_progress=False)
assert len(provider.prompts) == 1
assert "Noise" in provider.prompts[0]
assert "Noise" not in [e["name"] for cat in d.values() for e in cat]
def test_discover_entities_keeps_llm_only_project_uncertain_when_real_signal(tmp_path):
"""Repo roots should not auto-promote LLM-only tools/topics into projects."""
(tmp_path / "package.json").write_text(json.dumps({"name": "realproj"}))
_init_git_repo(tmp_path)
(tmp_path / "doc.md").write_text("Terraform shipped. Terraform changed. Terraform runs.")
class FakeProvider:
def classify(self, _system, _user, json_mode=True):
return SimpleNamespace(
text='{"classifications": [{"name": "Terraform", "label": "PROJECT"}]}'
)
d = discover_entities(str(tmp_path), llm_provider=FakeProvider(), show_progress=False)
assert "realproj" in [e["name"] for e in d["projects"]]
assert "Terraform" not in [e["name"] for e in d["projects"]]
assert "Terraform" in [e["name"] for e in d["uncertain"]]
# ── _UnionFind basics ────────────────────────────────────────────────── # ── _UnionFind basics ──────────────────────────────────────────────────
Generated
+2
View File
@@ -1174,6 +1174,7 @@ source = { editable = "." }
dependencies = [ dependencies = [
{ name = "chromadb" }, { name = "chromadb" },
{ name = "pyyaml" }, { name = "pyyaml" },
{ name = "tomli", marker = "python_full_version < '3.11'" },
] ]
[package.optional-dependencies] [package.optional-dependencies]
@@ -1206,6 +1207,7 @@ requires-dist = [
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },
{ name = "pyyaml", specifier = ">=6.0,<7" }, { name = "pyyaml", specifier = ">=6.0,<7" },
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.4.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.4.0" },
{ name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0.0" },
] ]
provides-extras = ["dev", "spellcheck"] provides-extras = ["dev", "spellcheck"]