#1148, #1150, and #1157 were reviewed and merged on GitHub, but the two stacked children landed on their parent feature branches (now stale) rather than on develop. Only #1148's commits reached develop via the direct merge. Release PR #1159 (develop → main for v3.3.3) is therefore missing the LLM refinement, Claude-conversation scanner, and miner- registry wire-up that were ostensibly part of the release. This merge brings the stale `feat/llm-entity-refine` branch (which contains the rolled-up merge commit for #1157 → #1150 → everything below) into develop so the release tag includes it. No code changes here — only history recovery.
This commit is contained in:
@@ -120,8 +120,7 @@ def quarantine_stale_hnsw(palace_path: str, stale_seconds: float = 3600.0) -> li
|
|||||||
os.rename(seg_dir, target)
|
os.rename(seg_dir, target)
|
||||||
moved.append(target)
|
moved.append(target)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Quarantined stale HNSW segment %s "
|
"Quarantined stale HNSW segment %s (sqlite %.0fs newer than HNSW); renamed to %s",
|
||||||
"(sqlite %.0fs newer than HNSW); renamed to %s",
|
|
||||||
seg_dir,
|
seg_dir,
|
||||||
sqlite_mtime - hnsw_mtime,
|
sqlite_mtime - hnsw_mtime,
|
||||||
target,
|
target,
|
||||||
|
|||||||
+74
-5
@@ -86,21 +86,53 @@ def cmd_init(args):
|
|||||||
languages = cfg.entity_languages
|
languages = cfg.entity_languages
|
||||||
languages_tuple = tuple(languages)
|
languages_tuple = tuple(languages)
|
||||||
|
|
||||||
|
# Optional phase-2 LLM provider (opt-in via --llm).
|
||||||
|
llm_provider = None
|
||||||
|
if getattr(args, "llm", False):
|
||||||
|
from .llm_client import LLMError, get_provider
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm_provider = get_provider(
|
||||||
|
name=args.llm_provider,
|
||||||
|
model=args.llm_model,
|
||||||
|
endpoint=args.llm_endpoint,
|
||||||
|
api_key=args.llm_api_key,
|
||||||
|
)
|
||||||
|
except LLMError as e:
|
||||||
|
print(f" ERROR: {e}", file=sys.stderr)
|
||||||
|
sys.exit(2)
|
||||||
|
ok, msg = llm_provider.check_available()
|
||||||
|
if not ok:
|
||||||
|
print(
|
||||||
|
f" ERROR: LLM provider '{args.llm_provider}' unavailable: {msg}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(2)
|
||||||
|
print(f" LLM refinement enabled: {args.llm_provider}/{args.llm_model}")
|
||||||
|
|
||||||
# Pass 1: discover entities — manifests + git authors first, prose detection
|
# Pass 1: discover entities — manifests + git authors first, prose detection
|
||||||
# as supplement for names mentioned only in docs/notes.
|
# as supplement for names mentioned only in docs/notes. Optional phase-2
|
||||||
|
# LLM refinement runs inside discover_entities when llm_provider is given.
|
||||||
print(f"\n Scanning for entities in: {args.dir}")
|
print(f"\n Scanning for entities in: {args.dir}")
|
||||||
if languages_tuple != ("en",):
|
if languages_tuple != ("en",):
|
||||||
print(f" Languages: {', '.join(languages_tuple)}")
|
print(f" Languages: {', '.join(languages_tuple)}")
|
||||||
detected = discover_entities(args.dir, languages=languages_tuple)
|
detected = discover_entities(args.dir, languages=languages_tuple, llm_provider=llm_provider)
|
||||||
total = len(detected["people"]) + len(detected["projects"]) + len(detected["uncertain"])
|
total = len(detected["people"]) + len(detected["projects"]) + len(detected["uncertain"])
|
||||||
if total > 0:
|
if total > 0:
|
||||||
confirmed = confirm_entities(detected, yes=getattr(args, "yes", False))
|
confirmed = confirm_entities(detected, yes=getattr(args, "yes", False))
|
||||||
# Save confirmed entities to <project>/entities.json for the miner
|
# Save confirmed entities to <project>/entities.json (per-project
|
||||||
|
# audit trail — user can inspect or hand-edit) AND merge into the
|
||||||
|
# global registry the miner reads at mine time.
|
||||||
if confirmed["people"] or confirmed["projects"]:
|
if confirmed["people"] or confirmed["projects"]:
|
||||||
entities_path = Path(args.dir).expanduser().resolve() / "entities.json"
|
entities_path = Path(args.dir).expanduser().resolve() / "entities.json"
|
||||||
with open(entities_path, "w") as f:
|
with open(entities_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(confirmed, f, indent=2)
|
json.dump(confirmed, f, indent=2, ensure_ascii=False)
|
||||||
print(f" Entities saved: {entities_path}")
|
print(f" Entities saved: {entities_path}")
|
||||||
|
|
||||||
|
from .miner import add_to_known_entities
|
||||||
|
|
||||||
|
registry_path = add_to_known_entities(confirmed)
|
||||||
|
print(f" Registry updated: {registry_path}")
|
||||||
else:
|
else:
|
||||||
print(" No entities detected — proceeding with directory-based rooms.")
|
print(" No entities detected — proceeding with directory-based rooms.")
|
||||||
|
|
||||||
@@ -550,6 +582,43 @@ def main():
|
|||||||
"When given, the value is also persisted to config.json."
|
"When given, the value is also persisted to config.json."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
p_init.add_argument(
|
||||||
|
"--llm",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Enable LLM-assisted entity refinement (opt-in, local-first). "
|
||||||
|
"Runs after manifest/git/regex detection, asking the configured "
|
||||||
|
"provider to reclassify ambiguous candidates. "
|
||||||
|
"Ctrl-C during refinement returns partial results."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p_init.add_argument(
|
||||||
|
"--llm-provider",
|
||||||
|
default="ollama",
|
||||||
|
choices=["ollama", "openai-compat", "anthropic"],
|
||||||
|
help="LLM provider (default: ollama). Use --llm to enable.",
|
||||||
|
)
|
||||||
|
p_init.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
default="gemma4:e4b",
|
||||||
|
help="Model name for the chosen provider (default: gemma4:e4b for Ollama).",
|
||||||
|
)
|
||||||
|
p_init.add_argument(
|
||||||
|
"--llm-endpoint",
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Provider endpoint URL. Default for Ollama: http://localhost:11434. "
|
||||||
|
"Required for openai-compat."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p_init.add_argument(
|
||||||
|
"--llm-api-key",
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"API key for the provider. For anthropic, defaults to $ANTHROPIC_API_KEY; "
|
||||||
|
"for openai-compat, defaults to $OPENAI_API_KEY."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# mine
|
# mine
|
||||||
p_mine = sub.add_parser("mine", help="Mine files into the palace")
|
p_mine = sub.add_parser("mine", help="Mine files into the palace")
|
||||||
|
|||||||
@@ -0,0 +1,160 @@
|
|||||||
|
"""
|
||||||
|
convo_scanner.py — Parse Claude Code conversation directories into ProjectInfo.
|
||||||
|
|
||||||
|
Claude Code stores sessions under ``~/.claude/projects/<slug>/<id>.jsonl``,
|
||||||
|
where the ``<slug>`` is the original CWD with ``/`` replaced by ``-``. That
|
||||||
|
encoding is lossy: we can't tell whether ``foo-bar`` in a slug is the
|
||||||
|
literal project name ``foo-bar`` or two path segments ``foo/bar``.
|
||||||
|
|
||||||
|
Fortunately, every message record in the JSONL carries a ``cwd`` field with
|
||||||
|
the true path. This scanner reads one record per session to recover the
|
||||||
|
accurate project name, falling back to slug-decoding only if the JSONL
|
||||||
|
is malformed or empty.
|
||||||
|
|
||||||
|
Output is the same ``ProjectInfo`` shape used by ``project_scanner``, so the
|
||||||
|
``discover_entities`` orchestrator can mix-and-match sources.
|
||||||
|
|
||||||
|
Public:
|
||||||
|
is_claude_projects_root(path) -> bool
|
||||||
|
scan_claude_projects(path) -> list[ProjectInfo]
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mempalace.project_scanner import ProjectInfo
|
||||||
|
|
||||||
|
|
||||||
|
MAX_HEADER_LINES = 20 # lines to read per session looking for `cwd`
|
||||||
|
|
||||||
|
|
||||||
|
def is_claude_projects_root(path: Path) -> bool:
|
||||||
|
"""Return True if path looks like `.claude/projects/`.
|
||||||
|
|
||||||
|
Heuristic: at least one child dir whose name starts with ``-`` and which
|
||||||
|
contains at least one ``.jsonl`` file.
|
||||||
|
"""
|
||||||
|
if not path.is_dir():
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
children = list(path.iterdir())
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
for child in children:
|
||||||
|
if not (child.is_dir() and child.name.startswith("-")):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if any(p.suffix == ".jsonl" for p in child.iterdir() if p.is_file()):
|
||||||
|
return True
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_cwd_from_session(session_file: Path) -> Optional[str]:
|
||||||
|
"""Return the ``cwd`` from the first message record that carries one.
|
||||||
|
|
||||||
|
Returns None if the file can't be read, has no JSON, or no record has cwd.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(session_file, encoding="utf-8", errors="replace") as f:
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
if i >= MAX_HEADER_LINES:
|
||||||
|
break
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
obj = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
cwd = obj.get("cwd")
|
||||||
|
if isinstance(cwd, str) and cwd:
|
||||||
|
return cwd
|
||||||
|
except OSError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_slug_fallback(slug: str) -> str:
|
||||||
|
"""Best-effort project name from slug when cwd is unavailable.
|
||||||
|
|
||||||
|
The slug is lossy (`/` and `-` both become `-`). Last non-empty segment
|
||||||
|
is the closest guess at the project name, preserving kebab-case is
|
||||||
|
impossible without cwd.
|
||||||
|
"""
|
||||||
|
stripped = slug.lstrip("-")
|
||||||
|
parts = [p for p in stripped.split("-") if p]
|
||||||
|
return parts[-1] if parts else slug
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_mtime(path: Path) -> float:
|
||||||
|
"""Return file mtime, defaulting old on permission or filesystem errors."""
|
||||||
|
try:
|
||||||
|
return path.stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_project_name(project_dir: Path) -> str:
|
||||||
|
"""Read one session's cwd to recover the original project name.
|
||||||
|
|
||||||
|
Falls back to slug-decoding if no session has a readable cwd.
|
||||||
|
"""
|
||||||
|
sessions = sorted(
|
||||||
|
(p for p in project_dir.iterdir() if p.is_file() and p.suffix == ".jsonl"),
|
||||||
|
key=_safe_mtime,
|
||||||
|
reverse=True, # newest first — most likely to be well-formed
|
||||||
|
)
|
||||||
|
for session in sessions:
|
||||||
|
cwd = _extract_cwd_from_session(session)
|
||||||
|
if cwd:
|
||||||
|
return Path(cwd).name or cwd
|
||||||
|
return _decode_slug_fallback(project_dir.name)
|
||||||
|
|
||||||
|
|
||||||
|
def scan_claude_projects(path: str | Path) -> list[ProjectInfo]:
|
||||||
|
"""Scan a ``.claude/projects/`` directory for Claude Code conversations.
|
||||||
|
|
||||||
|
One ProjectInfo per subdir. ``has_git`` is False (the directory isn't a
|
||||||
|
repo itself) but ``total_commits`` is repurposed here as session count so
|
||||||
|
the UX surfaces a density signal for ranking.
|
||||||
|
"""
|
||||||
|
root = Path(path).expanduser().resolve()
|
||||||
|
if not is_claude_projects_root(root):
|
||||||
|
return []
|
||||||
|
|
||||||
|
projects: dict[str, ProjectInfo] = {}
|
||||||
|
for sub in sorted(root.iterdir()):
|
||||||
|
if not (sub.is_dir() and sub.name.startswith("-")):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
sessions = [p for p in sub.iterdir() if p.is_file() and p.suffix == ".jsonl"]
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
if not sessions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = _resolve_project_name(sub)
|
||||||
|
session_count = len(sessions)
|
||||||
|
|
||||||
|
proj = ProjectInfo(
|
||||||
|
name=name,
|
||||||
|
repo_root=sub,
|
||||||
|
manifest=None,
|
||||||
|
has_git=False,
|
||||||
|
total_commits=session_count,
|
||||||
|
user_commits=session_count,
|
||||||
|
is_mine=True, # Claude Code sessions are authored by the user
|
||||||
|
)
|
||||||
|
existing = projects.get(name)
|
||||||
|
if existing is None or session_count > existing.user_commits:
|
||||||
|
projects[name] = proj
|
||||||
|
|
||||||
|
return sorted(
|
||||||
|
projects.values(),
|
||||||
|
key=lambda p: (-p.user_commits, p.name),
|
||||||
|
)
|
||||||
@@ -0,0 +1,305 @@
|
|||||||
|
"""
|
||||||
|
llm_client.py — Minimal provider abstraction for LLM-assisted entity refinement.
|
||||||
|
|
||||||
|
Three providers cover the useful space:
|
||||||
|
|
||||||
|
- ``ollama`` (default): local models via http://localhost:11434. Works fully
|
||||||
|
offline. Honors MemPalace's "zero-API required" principle.
|
||||||
|
- ``openai-compat``: any OpenAI-compatible ``/v1/chat/completions`` endpoint.
|
||||||
|
Covers OpenRouter, LM Studio, llama.cpp server, vLLM, Groq, Fireworks,
|
||||||
|
Together, and most self-hosted setups.
|
||||||
|
- ``anthropic``: the official Messages API. Opt-in for users who want Haiku
|
||||||
|
quality without setting up a local model.
|
||||||
|
|
||||||
|
All providers expose the same ``classify(system, user, json_mode)`` method and
|
||||||
|
the same ``check_available()`` probe. No external SDK dependencies — stdlib
|
||||||
|
``urllib`` only.
|
||||||
|
|
||||||
|
JSON mode matters here: we always ask for structured output. Providers
|
||||||
|
differ on how to request it (Ollama: ``format: json``; OpenAI-compat:
|
||||||
|
``response_format``; Anthropic: prompt-level instruction) and this module
|
||||||
|
normalizes that away from the caller.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.error import HTTPError, URLError
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
|
||||||
|
class LLMError(RuntimeError):
|
||||||
|
"""Raised for any provider failure — transport, parse, auth, missing model."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
text: str
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
raw: dict
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== BASE ====================
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProvider:
|
||||||
|
name: str = "base"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: int = 120,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.api_key = api_key
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def check_available(self) -> tuple[bool, str]:
|
||||||
|
"""Return ``(ok, message)``. Fast probe that the provider is reachable."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def _http_post_json(url: str, body: dict, headers: dict, timeout: int) -> dict:
|
||||||
|
"""POST JSON and return the parsed response. Raises LLMError on any failure."""
|
||||||
|
req = Request(
|
||||||
|
url,
|
||||||
|
data=json.dumps(body).encode("utf-8"),
|
||||||
|
headers={"Content-Type": "application/json", **headers},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with urlopen(req, timeout=timeout) as resp:
|
||||||
|
return json.loads(resp.read())
|
||||||
|
except HTTPError as e:
|
||||||
|
detail = ""
|
||||||
|
try:
|
||||||
|
detail = e.read().decode("utf-8", errors="replace")[:500]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise LLMError(f"HTTP {e.code} from {url}: {detail or e.reason}") from e
|
||||||
|
except (URLError, OSError) as e:
|
||||||
|
raise LLMError(f"Cannot reach {url}: {e}") from e
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise LLMError(f"Malformed response from {url}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== OLLAMA ====================
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaProvider(LLMProvider):
|
||||||
|
name = "ollama"
|
||||||
|
DEFAULT_ENDPOINT = "http://localhost:11434"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
timeout: int = 180,
|
||||||
|
**_: object,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
endpoint=endpoint or self.DEFAULT_ENDPOINT,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_available(self) -> tuple[bool, str]:
|
||||||
|
try:
|
||||||
|
with urlopen(f"{self.endpoint}/api/tags", timeout=5) as resp:
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
except (URLError, HTTPError, OSError, json.JSONDecodeError) as e:
|
||||||
|
return False, f"Cannot reach Ollama at {self.endpoint}: {e}"
|
||||||
|
names = {m.get("name", "") for m in data.get("models", []) or []}
|
||||||
|
# Ollama tags may or may not include ':latest' — accept either form
|
||||||
|
wanted = {self.model, f"{self.model}:latest"}
|
||||||
|
if not names & wanted:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Model '{self.model}' not loaded in Ollama. Run: ollama pull {self.model}",
|
||||||
|
)
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
|
||||||
|
body: dict = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system},
|
||||||
|
{"role": "user", "content": user},
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
"options": {"temperature": 0.1},
|
||||||
|
}
|
||||||
|
if json_mode:
|
||||||
|
body["format"] = "json"
|
||||||
|
data = _http_post_json(f"{self.endpoint}/api/chat", body, headers={}, timeout=self.timeout)
|
||||||
|
text = (data.get("message") or {}).get("content", "")
|
||||||
|
if not text:
|
||||||
|
raise LLMError(f"Empty response from Ollama (model={self.model})")
|
||||||
|
return LLMResponse(text=text, model=self.model, provider=self.name, raw=data)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== OPENAI-COMPAT ====================
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatProvider(LLMProvider):
|
||||||
|
"""Any OpenAI-compatible ``/v1/chat/completions`` endpoint.
|
||||||
|
|
||||||
|
Supply ``--llm-endpoint http://host:port`` (with or without ``/v1``).
|
||||||
|
API key via ``--llm-api-key`` or the ``OPENAI_API_KEY`` env var.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "openai-compat"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: int = 120,
|
||||||
|
**_: object,
|
||||||
|
):
|
||||||
|
resolved_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
|
super().__init__(model=model, endpoint=endpoint, api_key=resolved_key, timeout=timeout)
|
||||||
|
|
||||||
|
def _resolve_url(self) -> str:
|
||||||
|
if not self.endpoint:
|
||||||
|
raise LLMError("openai-compat provider requires --llm-endpoint")
|
||||||
|
url = self.endpoint.rstrip("/")
|
||||||
|
if url.endswith("/chat/completions"):
|
||||||
|
return url
|
||||||
|
if not url.endswith("/v1"):
|
||||||
|
url = f"{url}/v1"
|
||||||
|
return f"{url}/chat/completions"
|
||||||
|
|
||||||
|
def check_available(self) -> tuple[bool, str]:
|
||||||
|
if not self.endpoint:
|
||||||
|
return False, "no --llm-endpoint configured"
|
||||||
|
base = self.endpoint.rstrip("/")
|
||||||
|
base = base.removesuffix("/chat/completions").removesuffix("/v1")
|
||||||
|
try:
|
||||||
|
req = Request(f"{base}/v1/models")
|
||||||
|
if self.api_key:
|
||||||
|
req.add_header("Authorization", f"Bearer {self.api_key}")
|
||||||
|
with urlopen(req, timeout=5):
|
||||||
|
pass
|
||||||
|
except (URLError, HTTPError, OSError) as e:
|
||||||
|
return False, f"Cannot reach {self.endpoint}: {e}"
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
|
||||||
|
body: dict = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system},
|
||||||
|
{"role": "user", "content": user},
|
||||||
|
],
|
||||||
|
"temperature": 0.1,
|
||||||
|
}
|
||||||
|
if json_mode:
|
||||||
|
body["response_format"] = {"type": "json_object"}
|
||||||
|
headers = {}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
data = _http_post_json(self._resolve_url(), body, headers=headers, timeout=self.timeout)
|
||||||
|
try:
|
||||||
|
text = data["choices"][0]["message"]["content"]
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
raise LLMError(f"Unexpected response shape: {e}") from e
|
||||||
|
if not text:
|
||||||
|
raise LLMError(f"Empty response from {self.name} (model={self.model})")
|
||||||
|
return LLMResponse(text=text, model=self.model, provider=self.name, raw=data)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== ANTHROPIC ====================
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(LLMProvider):
|
||||||
|
name = "anthropic"
|
||||||
|
DEFAULT_ENDPOINT = "https://api.anthropic.com"
|
||||||
|
API_VERSION = "2023-06-01"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
timeout: int = 120,
|
||||||
|
**_: object,
|
||||||
|
):
|
||||||
|
key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
endpoint=endpoint or self.DEFAULT_ENDPOINT,
|
||||||
|
api_key=key,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_available(self) -> tuple[bool, str]:
|
||||||
|
if not self.api_key:
|
||||||
|
return False, "ANTHROPIC_API_KEY not set (use --llm-api-key or env)"
|
||||||
|
# Don't probe — a live request would cost money. First real call will
|
||||||
|
# surface auth errors if the key is invalid.
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
|
||||||
|
if not self.api_key:
|
||||||
|
raise LLMError("Anthropic provider requires ANTHROPIC_API_KEY env or --llm-api-key")
|
||||||
|
sys_prompt = system
|
||||||
|
if json_mode:
|
||||||
|
sys_prompt += "\n\nRespond with valid JSON only, no prose."
|
||||||
|
body = {
|
||||||
|
"model": self.model,
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"temperature": 0.1,
|
||||||
|
"system": sys_prompt,
|
||||||
|
"messages": [{"role": "user", "content": user}],
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"X-API-Key": self.api_key,
|
||||||
|
"anthropic-version": self.API_VERSION,
|
||||||
|
}
|
||||||
|
data = _http_post_json(
|
||||||
|
f"{self.endpoint}/v1/messages", body, headers=headers, timeout=self.timeout
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
text = "".join(
|
||||||
|
b.get("text", "") for b in data.get("content", []) or [] if b.get("type") == "text"
|
||||||
|
)
|
||||||
|
except (AttributeError, TypeError) as e:
|
||||||
|
raise LLMError(f"Unexpected response shape: {e}") from e
|
||||||
|
if not text:
|
||||||
|
raise LLMError(f"Empty response from Anthropic (model={self.model})")
|
||||||
|
return LLMResponse(text=text, model=self.model, provider=self.name, raw=data)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== FACTORY ====================
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDERS: dict[str, type[LLMProvider]] = {
|
||||||
|
"ollama": OllamaProvider,
|
||||||
|
"openai-compat": OpenAICompatProvider,
|
||||||
|
"anthropic": AnthropicProvider,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(
|
||||||
|
name: str,
|
||||||
|
model: str,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: int = 120,
|
||||||
|
) -> LLMProvider:
|
||||||
|
"""Build a provider by name. Raises LLMError on unknown provider."""
|
||||||
|
cls = PROVIDERS.get(name)
|
||||||
|
if cls is None:
|
||||||
|
raise LLMError(f"Unknown provider '{name}'. Choices: {sorted(PROVIDERS.keys())}")
|
||||||
|
return cls(model=model, endpoint=endpoint, api_key=api_key, timeout=timeout)
|
||||||
@@ -0,0 +1,446 @@
|
|||||||
|
"""
|
||||||
|
llm_refine.py — Optional LLM refinement of regex-detected entities.
|
||||||
|
|
||||||
|
Takes the candidate set produced by phase-1 detection (manifests, git
|
||||||
|
authors, regex on prose) and asks an LLM to reclassify each candidate as
|
||||||
|
PERSON / PROJECT / TOPIC / COMMON_WORD / AMBIGUOUS.
|
||||||
|
|
||||||
|
Design constraints:
|
||||||
|
- Opt-in. Default init path never imports this module.
|
||||||
|
- Local-first by default (Ollama).
|
||||||
|
- Interactive UX: visible progress, clean cancellation (Ctrl-C returns
|
||||||
|
whatever was classified before the interrupt).
|
||||||
|
- Don't feed the raw corpus to the LLM — feed candidates + a few sampled
|
||||||
|
context lines each. Keeps total input to ~50-100K tokens even for huge
|
||||||
|
prose corpora.
|
||||||
|
|
||||||
|
Public:
|
||||||
|
refine_entities(detected, corpus_text, provider, ...) -> dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from mempalace.llm_client import LLMError, LLMProvider
|
||||||
|
|
||||||
|
|
||||||
|
BATCH_SIZE = 25 # candidates per LLM call; tuned for 4B local models
|
||||||
|
CONTEXT_LINES_PER_CANDIDATE = 3
|
||||||
|
CONTEXT_WINDOW_CHARS = 240 # max chars per context line to keep tokens bounded
|
||||||
|
|
||||||
|
# Valid labels the LLM is allowed to return. Anything else is treated as
|
||||||
|
# AMBIGUOUS so the user reviews it.
|
||||||
|
VALID_LABELS = {"PERSON", "PROJECT", "TOPIC", "COMMON_WORD", "AMBIGUOUS"}
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """You are helping organize a user's memory palace by classifying capitalized tokens found in their files.
|
||||||
|
|
||||||
|
For each candidate, pick exactly ONE label:
|
||||||
|
- PERSON: a specific real person the user knows (colleague, family, character they write about)
|
||||||
|
- PROJECT: a named product, codebase, or effort the user works on
|
||||||
|
- TOPIC: a recurring theme or subject (not a person, not a project) — cities, technologies, concepts
|
||||||
|
- COMMON_WORD: an English word, verb, or fragment that isn't a named entity at all (e.g. "Created", "Before", "Never")
|
||||||
|
- AMBIGUOUS: context is insufficient to decide between two of the above
|
||||||
|
|
||||||
|
Frameworks, runtimes, APIs, cloud services, vendors, and third-party products
|
||||||
|
(e.g. Angular, OpenAPI, Terraform, Bun, Google) are TOPIC unless the context
|
||||||
|
clearly says this is the user's own named codebase, product, or active effort.
|
||||||
|
|
||||||
|
Use the provided context lines to disambiguate. A capitalized word that only appears in metadata ("Created: 2026-04-24") is COMMON_WORD. A name that appears with pronouns and dialogue is PERSON.
|
||||||
|
|
||||||
|
Respond with JSON only. Schema:
|
||||||
|
{"classifications": [{"name": "<exact candidate name>", "label": "<LABEL>", "reason": "<one short sentence>"}]}
|
||||||
|
|
||||||
|
One entry per candidate, same order as the input."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RefineResult:
|
||||||
|
merged: dict # updated detected dict
|
||||||
|
reclassified: int # entries whose type changed
|
||||||
|
dropped: int # entries removed from the merged result (COMMON_WORD only)
|
||||||
|
errors: list[str] # per-batch error messages (transport/parse failures)
|
||||||
|
batches_completed: int
|
||||||
|
batches_total: int
|
||||||
|
cancelled: bool
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_contexts(
|
||||||
|
corpus_lines: list[str], name: str, max_lines: int = CONTEXT_LINES_PER_CANDIDATE
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return up to `max_lines` distinct lines from the corpus that mention `name`.
|
||||||
|
|
||||||
|
Case-insensitive token-boundary match. Lines are truncated to
|
||||||
|
CONTEXT_WINDOW_CHARS chars to keep token usage bounded.
|
||||||
|
"""
|
||||||
|
needle = re.compile(rf"(?<!\w){re.escape(name)}(?!\w)", re.IGNORECASE)
|
||||||
|
seen: set[str] = set()
|
||||||
|
out: list[str] = []
|
||||||
|
for line in corpus_lines:
|
||||||
|
if not needle.search(line):
|
||||||
|
continue
|
||||||
|
trimmed = line.strip()[:CONTEXT_WINDOW_CHARS]
|
||||||
|
if not trimmed or trimmed in seen:
|
||||||
|
continue
|
||||||
|
seen.add(trimmed)
|
||||||
|
out.append(trimmed)
|
||||||
|
if len(out) >= max_lines:
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _build_user_prompt(candidates_with_contexts: list[tuple[str, str, list[str]]]) -> str:
|
||||||
|
"""Shape: for each candidate, list its current type guess + sampled contexts."""
|
||||||
|
parts: list[str] = ["CANDIDATES:"]
|
||||||
|
for i, (name, current_type, contexts) in enumerate(candidates_with_contexts, 1):
|
||||||
|
parts.append(f"\n{i}. {name} (currently: {current_type})")
|
||||||
|
if contexts:
|
||||||
|
for c in contexts:
|
||||||
|
parts.append(f" > {c}")
|
||||||
|
else:
|
||||||
|
parts.append(" > (no context available)")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_candidates(text: str) -> list[str]:
|
||||||
|
"""Return plausible JSON payloads extracted from an LLM response."""
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
candidates: list[str] = [text]
|
||||||
|
|
||||||
|
for match in re.finditer(r"```(?:json)?\s*([\s\S]*?)\s*```", text, re.IGNORECASE):
|
||||||
|
candidate = match.group(1).strip()
|
||||||
|
if candidate and candidate not in candidates:
|
||||||
|
candidates.append(candidate)
|
||||||
|
|
||||||
|
for start, opener in ((i, ch) for i, ch in enumerate(text) if ch in "{["):
|
||||||
|
closer = "}" if opener == "{" else "]"
|
||||||
|
depth = 0
|
||||||
|
in_string = False
|
||||||
|
escaped = False
|
||||||
|
for i in range(start, len(text)):
|
||||||
|
ch = text[i]
|
||||||
|
if in_string:
|
||||||
|
if escaped:
|
||||||
|
escaped = False
|
||||||
|
elif ch == "\\":
|
||||||
|
escaped = True
|
||||||
|
elif ch == '"':
|
||||||
|
in_string = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ch == '"':
|
||||||
|
in_string = True
|
||||||
|
elif ch == opener:
|
||||||
|
depth += 1
|
||||||
|
elif ch == closer:
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
candidate = text[start : i + 1].strip()
|
||||||
|
if candidate and candidate not in candidates:
|
||||||
|
candidates.append(candidate)
|
||||||
|
break
|
||||||
|
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_response(text: str, expected_names: list[str]) -> dict[str, tuple[str, str]]:
|
||||||
|
"""Parse the LLM's JSON response into {name: (label, reason)}.
|
||||||
|
|
||||||
|
Robust to the model occasionally wrapping JSON in text or returning
|
||||||
|
slight schema variations. Falls back to matching by candidate name.
|
||||||
|
"""
|
||||||
|
data = None
|
||||||
|
for candidate in _extract_json_candidates(text):
|
||||||
|
try:
|
||||||
|
data = json.loads(candidate)
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if data is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
entries = data.get("classifications") if isinstance(data, dict) else data
|
||||||
|
if not isinstance(entries, list):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
name_to_label: dict[str, tuple[str, str]] = {}
|
||||||
|
expected_set = {n.lower(): n for n in expected_names}
|
||||||
|
for entry in entries:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
name = entry.get("name") or entry.get("candidate")
|
||||||
|
label = entry.get("label") or entry.get("type") or entry.get("classification")
|
||||||
|
reason = entry.get("reason") or ""
|
||||||
|
if not isinstance(name, str) or not isinstance(label, str):
|
||||||
|
continue
|
||||||
|
# Restore canonical casing from expected_names
|
||||||
|
canonical = expected_set.get(name.lower(), name)
|
||||||
|
lbl = label.strip().upper()
|
||||||
|
if lbl not in VALID_LABELS:
|
||||||
|
lbl = "AMBIGUOUS"
|
||||||
|
name_to_label[canonical] = (lbl, reason.strip()[:120])
|
||||||
|
return name_to_label
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_classifications(
|
||||||
|
detected: dict,
|
||||||
|
decisions: dict[str, tuple[str, str]],
|
||||||
|
allow_project_promotions: bool = True,
|
||||||
|
) -> tuple[dict, int, int]:
|
||||||
|
"""Merge LLM decisions back into the detected dict.
|
||||||
|
|
||||||
|
Returns (new_detected, reclassified_count, dropped_count).
|
||||||
|
"""
|
||||||
|
label_to_bucket = {
|
||||||
|
"PERSON": "people",
|
||||||
|
"PROJECT": "projects",
|
||||||
|
"TOPIC": "uncertain",
|
||||||
|
"AMBIGUOUS": "uncertain",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Index every entity by name for in-place update
|
||||||
|
all_entries: list[tuple[str, dict]] = []
|
||||||
|
for bucket, items in detected.items():
|
||||||
|
for e in items:
|
||||||
|
all_entries.append((bucket, e))
|
||||||
|
|
||||||
|
reclassified = 0
|
||||||
|
dropped = 0
|
||||||
|
new_detected: dict[str, list[dict]] = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for old_bucket, entry in all_entries:
|
||||||
|
decision = decisions.get(entry["name"])
|
||||||
|
if decision is None:
|
||||||
|
# No LLM opinion — keep as-is
|
||||||
|
new_detected[old_bucket].append(entry)
|
||||||
|
continue
|
||||||
|
|
||||||
|
label, reason = decision
|
||||||
|
if label == "COMMON_WORD":
|
||||||
|
dropped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_bucket = label_to_bucket[label]
|
||||||
|
if (
|
||||||
|
label == "PROJECT"
|
||||||
|
and not allow_project_promotions
|
||||||
|
and not _is_authoritative_project(entry)
|
||||||
|
):
|
||||||
|
target_bucket = "uncertain"
|
||||||
|
updated = dict(entry)
|
||||||
|
# Append the LLM's reason as a new signal so the user sees why it moved
|
||||||
|
signals = list(updated.get("signals", []))
|
||||||
|
signals.append(f"LLM: {label.lower()} — {reason}" if reason else f"LLM: {label.lower()}")
|
||||||
|
updated["signals"] = signals
|
||||||
|
if target_bucket != old_bucket:
|
||||||
|
reclassified += 1
|
||||||
|
updated["type"] = (
|
||||||
|
"person"
|
||||||
|
if target_bucket == "people"
|
||||||
|
else "project"
|
||||||
|
if target_bucket == "projects"
|
||||||
|
else "uncertain"
|
||||||
|
)
|
||||||
|
new_detected[target_bucket].append(updated)
|
||||||
|
|
||||||
|
return new_detected, reclassified, dropped
|
||||||
|
|
||||||
|
|
||||||
|
def _is_authoritative_person(entry: dict) -> bool:
|
||||||
|
"""Return True for git-author people that should not be second-guessed."""
|
||||||
|
signals = " ".join(entry.get("signals", [])).lower()
|
||||||
|
return "commit" in signals and "repo" in signals
|
||||||
|
|
||||||
|
|
||||||
|
def _is_authoritative_project(entry: dict) -> bool:
|
||||||
|
"""Return True for manifest/git-backed projects that are already source-backed."""
|
||||||
|
signals = " ".join(entry.get("signals", [])).lower()
|
||||||
|
manifest_markers = ("package.json", "pyproject.toml", "cargo.toml", "go.mod")
|
||||||
|
return any(marker in signals for marker in manifest_markers) or "commit" in signals
|
||||||
|
|
||||||
|
|
||||||
|
def _print_progress(batch_idx: int, total: int, current_name: str) -> None:
|
||||||
|
"""Overwrite-line progress indicator."""
|
||||||
|
width = 40
|
||||||
|
filled = int(width * batch_idx / total) if total else 0
|
||||||
|
bar = "█" * filled + "░" * (width - filled)
|
||||||
|
msg = f"\r LLM refine: [{bar}] batch {batch_idx}/{total} current: {current_name[:30]:<30}"
|
||||||
|
sys.stderr.write(msg)
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def refine_entities(
|
||||||
|
detected: dict,
|
||||||
|
corpus_text: str,
|
||||||
|
provider: LLMProvider,
|
||||||
|
batch_size: int = BATCH_SIZE,
|
||||||
|
show_progress: bool = True,
|
||||||
|
allow_project_promotions: bool = True,
|
||||||
|
) -> RefineResult:
|
||||||
|
"""Reclassify detected entities using the LLM provider.
|
||||||
|
|
||||||
|
Only regex-derived candidates are sent for refinement. Git authors and
|
||||||
|
manifest/git-backed projects are already source-backed and don't benefit
|
||||||
|
from LLM second-guessing.
|
||||||
|
|
||||||
|
Ctrl-C during refinement: cancels the remaining batches, returns a
|
||||||
|
RefineResult with ``cancelled=True`` and whatever was classified before
|
||||||
|
the interrupt. The partial result is safe to pass straight to
|
||||||
|
``confirm_entities``.
|
||||||
|
|
||||||
|
Transport or parse failures in individual batches are recorded in
|
||||||
|
``errors`` and do not abort the run.
|
||||||
|
|
||||||
|
``allow_project_promotions=False`` keeps LLM-only project guesses in the
|
||||||
|
uncertain bucket. This is useful when manifest/git signal already supplied
|
||||||
|
canonical projects and regex/LLM hits are likely tools, vendors, or topics.
|
||||||
|
"""
|
||||||
|
candidates: list[tuple[str, str]] = []
|
||||||
|
current_type = {"people": "person", "projects": "project", "uncertain": "uncertain"}
|
||||||
|
for bucket in ("people", "projects", "uncertain"):
|
||||||
|
for e in detected.get(bucket, []):
|
||||||
|
if bucket == "people" and _is_authoritative_person(e):
|
||||||
|
continue
|
||||||
|
if bucket == "projects" and _is_authoritative_project(e):
|
||||||
|
continue
|
||||||
|
candidates.append((e["name"], current_type[bucket]))
|
||||||
|
|
||||||
|
corpus_lines = corpus_text.splitlines() if corpus_text else []
|
||||||
|
|
||||||
|
# Deduplicate candidate names while preserving order
|
||||||
|
seen: set[str] = set()
|
||||||
|
unique: list[tuple[str, str]] = []
|
||||||
|
for name, kind in candidates:
|
||||||
|
if name not in seen:
|
||||||
|
seen.add(name)
|
||||||
|
unique.append((name, kind))
|
||||||
|
|
||||||
|
if not unique:
|
||||||
|
return RefineResult(
|
||||||
|
merged=detected,
|
||||||
|
reclassified=0,
|
||||||
|
dropped=0,
|
||||||
|
errors=[],
|
||||||
|
batches_completed=0,
|
||||||
|
batches_total=0,
|
||||||
|
cancelled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build batches
|
||||||
|
batches: list[list[tuple[str, str, list[str]]]] = []
|
||||||
|
for i in range(0, len(unique), batch_size):
|
||||||
|
chunk = unique[i : i + batch_size]
|
||||||
|
enriched = [(name, kind, _collect_contexts(corpus_lines, name)) for name, kind in chunk]
|
||||||
|
batches.append(enriched)
|
||||||
|
|
||||||
|
all_decisions: dict[str, tuple[str, str]] = {}
|
||||||
|
errors: list[str] = []
|
||||||
|
completed = 0
|
||||||
|
cancelled = False
|
||||||
|
|
||||||
|
for idx, batch in enumerate(batches, 1):
|
||||||
|
if show_progress and batch:
|
||||||
|
_print_progress(idx - 1, len(batches), batch[0][0])
|
||||||
|
user_prompt = _build_user_prompt(batch)
|
||||||
|
try:
|
||||||
|
resp = provider.classify(SYSTEM_PROMPT, user_prompt, json_mode=True)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
cancelled = True
|
||||||
|
break
|
||||||
|
except LLMError as e:
|
||||||
|
errors.append(f"batch {idx}: {e}")
|
||||||
|
continue
|
||||||
|
names_in_batch = [name for name, _, _ in batch]
|
||||||
|
decisions = _parse_response(resp.text, names_in_batch)
|
||||||
|
if not decisions:
|
||||||
|
errors.append(f"batch {idx}: could not parse response")
|
||||||
|
all_decisions.update(decisions)
|
||||||
|
completed += 1
|
||||||
|
if show_progress:
|
||||||
|
_print_progress(idx, len(batches), batch[-1][0])
|
||||||
|
|
||||||
|
if show_progress:
|
||||||
|
sys.stderr.write("\n")
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
merged, reclassified, dropped = _apply_classifications(
|
||||||
|
detected,
|
||||||
|
all_decisions,
|
||||||
|
allow_project_promotions=allow_project_promotions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RefineResult(
|
||||||
|
merged=merged,
|
||||||
|
reclassified=reclassified,
|
||||||
|
dropped=dropped,
|
||||||
|
errors=errors,
|
||||||
|
batches_completed=completed,
|
||||||
|
batches_total=len(batches),
|
||||||
|
cancelled=cancelled,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def collect_corpus_text(
|
||||||
|
project_dir: str,
|
||||||
|
max_files: int = 30,
|
||||||
|
max_bytes_per_file: int = 20_000,
|
||||||
|
) -> str:
|
||||||
|
"""Gather prose text from ``project_dir`` for use as LLM context source.
|
||||||
|
|
||||||
|
Stratified: reads up to ``max_files`` prose files (``.md``, ``.txt``,
|
||||||
|
``.rst``), preferring recently-modified. Each file capped at
|
||||||
|
``max_bytes_per_file`` to bound total input.
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from mempalace.entity_detector import PROSE_EXTENSIONS, SKIP_DIRS
|
||||||
|
|
||||||
|
root = Path(project_dir).expanduser().resolve()
|
||||||
|
if not root.is_dir():
|
||||||
|
return ""
|
||||||
|
candidates: list[tuple[float, Path]] = []
|
||||||
|
for dirpath, dirs, files in _walk_prose(root, SKIP_DIRS):
|
||||||
|
for fname in files:
|
||||||
|
p = dirpath / fname
|
||||||
|
if p.suffix.lower() not in PROSE_EXTENSIONS:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
mtime = p.stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
candidates.append((mtime, p))
|
||||||
|
candidates.sort(reverse=True)
|
||||||
|
selected = [p for _, p in candidates[:max_files]]
|
||||||
|
chunks: list[str] = []
|
||||||
|
for p in selected:
|
||||||
|
try:
|
||||||
|
with open(p, encoding="utf-8", errors="replace") as f:
|
||||||
|
chunks.append(f.read(max_bytes_per_file))
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
return "\n".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def _walk_prose(root, skip_dirs):
|
||||||
|
"""Walk a directory yielding (Path, dirs, files), pruning skip_dirs.
|
||||||
|
|
||||||
|
Inlined from ``project_scanner._walk`` to avoid a private-name import
|
||||||
|
coupling. Functionality is intentionally narrow: prose collection only.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
for dirpath, dirs, files in os.walk(root):
|
||||||
|
dirs[:] = [d for d in dirs if d not in skip_dirs and not d.startswith(".")]
|
||||||
|
yield Path(dirpath), dirs, files
|
||||||
@@ -52,6 +52,7 @@ READABLE_EXTENSIONS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SKIP_FILENAMES = {
|
SKIP_FILENAMES = {
|
||||||
|
"entities.json",
|
||||||
"mempalace.yaml",
|
"mempalace.yaml",
|
||||||
"mempalace.yml",
|
"mempalace.yml",
|
||||||
"mempal.yaml",
|
"mempal.yaml",
|
||||||
@@ -471,6 +472,97 @@ def _load_known_entities_raw() -> dict:
|
|||||||
return dict(_ENTITY_REGISTRY_CACHE["raw"])
|
return dict(_ENTITY_REGISTRY_CACHE["raw"])
|
||||||
|
|
||||||
|
|
||||||
|
def add_to_known_entities(entities_by_category: dict) -> str:
|
||||||
|
"""Union ``entities_by_category`` into ``~/.mempalace/known_entities.json``.
|
||||||
|
|
||||||
|
Accepts ``{category: [names]}`` shape as produced by ``mempalace init``
|
||||||
|
and merges into the registry the miner reads at mine time. Existing
|
||||||
|
categories are preserved untouched unless also present in the input;
|
||||||
|
for categories present in both, entries are unioned case-insensitively
|
||||||
|
without changing the on-disk ordering of pre-existing names.
|
||||||
|
|
||||||
|
If a category is stored on-disk as ``{name: code}`` (the alternate
|
||||||
|
miner-supported shape, used by dialect-style configs), new names are
|
||||||
|
added as keys with ``None`` values so existing code mappings aren't
|
||||||
|
overwritten. A later compress pass can assign codes.
|
||||||
|
|
||||||
|
The in-process cache is invalidated on write so same-process callers
|
||||||
|
(notably ``cmd_init`` → ``cmd_mine`` in sequence) see the update
|
||||||
|
immediately instead of waiting for a mtime re-check.
|
||||||
|
|
||||||
|
Returns the registry path as a string for logging.
|
||||||
|
"""
|
||||||
|
import json as _json
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
|
||||||
|
registry_path = _Path(_ENTITY_REGISTRY_PATH)
|
||||||
|
registry_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
existing: dict = {}
|
||||||
|
if registry_path.exists():
|
||||||
|
try:
|
||||||
|
loaded = _json.loads(registry_path.read_text(encoding="utf-8"))
|
||||||
|
if isinstance(loaded, dict):
|
||||||
|
existing = loaded
|
||||||
|
except (_json.JSONDecodeError, OSError):
|
||||||
|
existing = {}
|
||||||
|
|
||||||
|
def _coerce_name(value):
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
name = str(value)
|
||||||
|
return name if name else None
|
||||||
|
|
||||||
|
for category, names in entities_by_category.items():
|
||||||
|
if not isinstance(names, list) or not names:
|
||||||
|
continue
|
||||||
|
current = existing.get(category)
|
||||||
|
if isinstance(current, list):
|
||||||
|
seen_lower = {str(n).lower() for n in current}
|
||||||
|
for n in names:
|
||||||
|
name = _coerce_name(n)
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
if name.lower() not in seen_lower:
|
||||||
|
current.append(name)
|
||||||
|
seen_lower.add(name.lower())
|
||||||
|
elif isinstance(current, dict):
|
||||||
|
seen_lower = {str(name).lower() for name in current}
|
||||||
|
for n in names:
|
||||||
|
name = _coerce_name(n)
|
||||||
|
if not name or name.lower() in seen_lower:
|
||||||
|
continue
|
||||||
|
current[name] = None
|
||||||
|
seen_lower.add(name.lower())
|
||||||
|
else:
|
||||||
|
# Missing or unrecognized shape — seed as a fresh list, deduped
|
||||||
|
seen: set = set()
|
||||||
|
ordered: list = []
|
||||||
|
for n in names:
|
||||||
|
name = _coerce_name(n)
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
key = name.lower()
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
ordered.append(name)
|
||||||
|
existing[category] = ordered
|
||||||
|
|
||||||
|
registry_path.write_text(_json.dumps(existing, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||||
|
try:
|
||||||
|
registry_path.chmod(0o600)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Invalidate in-process cache so later calls in the same run see the write.
|
||||||
|
_ENTITY_REGISTRY_CACHE["mtime"] = None
|
||||||
|
_ENTITY_REGISTRY_CACHE["names"] = frozenset()
|
||||||
|
_ENTITY_REGISTRY_CACHE["raw"] = {}
|
||||||
|
|
||||||
|
return str(registry_path)
|
||||||
|
|
||||||
|
|
||||||
_HALL_KEYWORDS_CACHE = None
|
_HALL_KEYWORDS_CACHE = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -594,6 +594,8 @@ def discover_entities(
|
|||||||
prose_file_cap: int = 10,
|
prose_file_cap: int = 10,
|
||||||
project_cap: int = 15,
|
project_cap: int = 15,
|
||||||
people_cap: int = 15,
|
people_cap: int = 15,
|
||||||
|
llm_provider: object = None,
|
||||||
|
show_progress: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Top-level entity discovery: real signals first, prose detection second.
|
"""Top-level entity discovery: real signals first, prose detection second.
|
||||||
|
|
||||||
@@ -604,10 +606,39 @@ def discover_entities(
|
|||||||
1. Package manifests (package.json, pyproject.toml, Cargo.toml, go.mod)
|
1. Package manifests (package.json, pyproject.toml, Cargo.toml, go.mod)
|
||||||
→ canonical project names
|
→ canonical project names
|
||||||
2. Git commit authors → real people with real commit counts
|
2. Git commit authors → real people with real commit counts
|
||||||
3. Regex entity detection on prose files → supplementary names only
|
3. Claude Code conversation dirs (~/.claude/projects/) → per-session
|
||||||
|
project names (pulled from each session's ``cwd`` metadata)
|
||||||
|
4. Regex entity detection on prose files → supplementary names only
|
||||||
mentioned in docs/notes (not code)
|
mentioned in docs/notes (not code)
|
||||||
|
5. Optional LLM refinement pass — reclassifies ambiguous candidates
|
||||||
|
using the caller-supplied provider
|
||||||
|
|
||||||
|
Passing ``llm_provider`` enables phase-2 refinement. The caller is
|
||||||
|
responsible for constructing the provider (``llm_client.get_provider``)
|
||||||
|
and confirming availability. Refinement is blocking-interactive:
|
||||||
|
progress prints to stderr; Ctrl-C returns partial results.
|
||||||
"""
|
"""
|
||||||
projects, people = scan(project_dir)
|
projects, people = scan(project_dir)
|
||||||
|
|
||||||
|
# If the target is a Claude Code conversations root, extract per-project
|
||||||
|
# entries from there too. Same ProjectInfo shape, so dedup logic works.
|
||||||
|
from mempalace.convo_scanner import is_claude_projects_root, scan_claude_projects
|
||||||
|
|
||||||
|
root_path = Path(project_dir).expanduser().resolve()
|
||||||
|
if is_claude_projects_root(root_path):
|
||||||
|
convo_projects = scan_claude_projects(root_path)
|
||||||
|
# Dedup by name against the git-manifest list, preferring entries with
|
||||||
|
# more user_commits as signal strength.
|
||||||
|
by_name: dict[str, ProjectInfo] = {p.name: p for p in projects}
|
||||||
|
for cp in convo_projects:
|
||||||
|
existing = by_name.get(cp.name)
|
||||||
|
if existing is None or cp.user_commits > existing.user_commits:
|
||||||
|
by_name[cp.name] = cp
|
||||||
|
projects = sorted(
|
||||||
|
by_name.values(),
|
||||||
|
key=lambda p: (not p.is_mine, -p.user_commits, -p.total_commits, p.name),
|
||||||
|
)
|
||||||
|
|
||||||
real_signal = to_detected_dict(projects, people, project_cap=project_cap, people_cap=people_cap)
|
real_signal = to_detected_dict(projects, people, project_cap=project_cap, people_cap=people_cap)
|
||||||
|
|
||||||
# Secondary pass: prose-only extraction catches names mentioned in docs
|
# Secondary pass: prose-only extraction catches names mentioned in docs
|
||||||
@@ -621,11 +652,45 @@ def discover_entities(
|
|||||||
else {"people": [], "projects": [], "uncertain": []}
|
else {"people": [], "projects": [], "uncertain": []}
|
||||||
)
|
)
|
||||||
|
|
||||||
# If git/manifests gave us real projects, suppress the regex "uncertain" bucket.
|
# Without LLM refinement, suppress regex "uncertain" noise when real
|
||||||
# That bucket is mostly noise (common words, CamelCase tech terms, etc.) and
|
# manifest/git signal exists. With LLM refinement enabled, keep those
|
||||||
# adding it to the review flow just makes the user do triage we can skip.
|
# candidates so the model can promote real entities or drop common words.
|
||||||
has_real_signal = bool(projects) or bool(people)
|
has_real_signal = bool(projects) or bool(people)
|
||||||
return _merge_detected(real_signal, prose_detected, drop_secondary_uncertain=has_real_signal)
|
merged = _merge_detected(
|
||||||
|
real_signal,
|
||||||
|
prose_detected,
|
||||||
|
drop_secondary_uncertain=has_real_signal and llm_provider is None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional phase 2: LLM refinement.
|
||||||
|
if llm_provider is not None:
|
||||||
|
from mempalace.llm_refine import collect_corpus_text, refine_entities
|
||||||
|
|
||||||
|
corpus = collect_corpus_text(str(project_dir))
|
||||||
|
result = refine_entities(
|
||||||
|
merged,
|
||||||
|
corpus,
|
||||||
|
llm_provider,
|
||||||
|
show_progress=show_progress,
|
||||||
|
allow_project_promotions=not has_real_signal,
|
||||||
|
)
|
||||||
|
if show_progress:
|
||||||
|
status_bits = []
|
||||||
|
if result.cancelled:
|
||||||
|
status_bits.append("cancelled")
|
||||||
|
if result.reclassified:
|
||||||
|
status_bits.append(f"reclassified {result.reclassified}")
|
||||||
|
if result.dropped:
|
||||||
|
status_bits.append(f"dropped {result.dropped}")
|
||||||
|
if result.errors:
|
||||||
|
status_bits.append(f"{len(result.errors)} batch error(s)")
|
||||||
|
if status_bits:
|
||||||
|
import sys as _sys
|
||||||
|
|
||||||
|
print(f" LLM refine: {', '.join(status_bits)}", file=_sys.stderr)
|
||||||
|
merged = result.merged
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
# ==================== CLI ====================
|
# ==================== CLI ====================
|
||||||
|
|||||||
@@ -0,0 +1,218 @@
|
|||||||
|
"""Tests for mempalace.convo_scanner."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from mempalace.convo_scanner import (
|
||||||
|
_decode_slug_fallback,
|
||||||
|
_extract_cwd_from_session,
|
||||||
|
_resolve_project_name,
|
||||||
|
_safe_mtime,
|
||||||
|
is_claude_projects_root,
|
||||||
|
scan_claude_projects,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── is_claude_projects_root ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_claude_projects_root_true(tmp_path):
|
||||||
|
project_dir = tmp_path / "-home-user-dev-foo"
|
||||||
|
project_dir.mkdir()
|
||||||
|
(project_dir / "abc.jsonl").write_text("{}\n")
|
||||||
|
assert is_claude_projects_root(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_claude_projects_root_false_no_dash_prefix(tmp_path):
|
||||||
|
project_dir = tmp_path / "normal-folder"
|
||||||
|
project_dir.mkdir()
|
||||||
|
(project_dir / "abc.jsonl").write_text("{}\n")
|
||||||
|
assert not is_claude_projects_root(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_claude_projects_root_false_no_jsonl(tmp_path):
|
||||||
|
project_dir = tmp_path / "-home-user-foo"
|
||||||
|
project_dir.mkdir()
|
||||||
|
(project_dir / "other.txt").write_text("hello")
|
||||||
|
assert not is_claude_projects_root(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_claude_projects_root_false_empty(tmp_path):
|
||||||
|
assert not is_claude_projects_root(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_claude_projects_root_false_nonexistent(tmp_path):
|
||||||
|
assert not is_claude_projects_root(tmp_path / "does-not-exist")
|
||||||
|
|
||||||
|
|
||||||
|
# ── cwd extraction ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_cwd_from_session(tmp_path):
|
||||||
|
f = tmp_path / "session.jsonl"
|
||||||
|
lines = [
|
||||||
|
json.dumps({"type": "file-history-snapshot", "messageId": "x"}),
|
||||||
|
json.dumps({"type": "user", "cwd": "/home/user/dev/myproj", "content": "hi"}),
|
||||||
|
]
|
||||||
|
f.write_text("\n".join(lines) + "\n")
|
||||||
|
assert _extract_cwd_from_session(f) == "/home/user/dev/myproj"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_cwd_from_session_skips_malformed(tmp_path):
|
||||||
|
f = tmp_path / "session.jsonl"
|
||||||
|
f.write_text(
|
||||||
|
"{not valid json\n" + json.dumps({"type": "user", "cwd": "/home/user/dev/good"}) + "\n"
|
||||||
|
)
|
||||||
|
assert _extract_cwd_from_session(f) == "/home/user/dev/good"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_cwd_from_session_none_if_absent(tmp_path):
|
||||||
|
f = tmp_path / "session.jsonl"
|
||||||
|
f.write_text(json.dumps({"type": "x", "messageId": "y"}) + "\n")
|
||||||
|
assert _extract_cwd_from_session(f) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_cwd_from_session_none_if_file_missing(tmp_path):
|
||||||
|
assert _extract_cwd_from_session(tmp_path / "missing.jsonl") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── slug fallback ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_slug_fallback_last_segment():
|
||||||
|
assert _decode_slug_fallback("-home-user-dev-foo") == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_slug_fallback_double_dash():
|
||||||
|
assert _decode_slug_fallback("-home-user--bentokit") == "bentokit"
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_slug_fallback_empty():
|
||||||
|
assert _decode_slug_fallback("") == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_slug_fallback_only_dashes():
|
||||||
|
assert _decode_slug_fallback("---") == "---"
|
||||||
|
|
||||||
|
|
||||||
|
# ── safe metadata helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_safe_mtime_returns_zero_on_stat_error(tmp_path, monkeypatch):
|
||||||
|
f = tmp_path / "session.jsonl"
|
||||||
|
f.write_text("{}\n")
|
||||||
|
original_stat = Path.stat
|
||||||
|
|
||||||
|
def fail_stat(self):
|
||||||
|
if self == f:
|
||||||
|
raise OSError("permission denied")
|
||||||
|
return original_stat(self)
|
||||||
|
|
||||||
|
monkeypatch.setattr(Path, "stat", fail_stat)
|
||||||
|
assert _safe_mtime(f) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── _resolve_project_name ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_project_name_uses_cwd(tmp_path):
|
||||||
|
pdir = tmp_path / "-home-user-dev-coolproj"
|
||||||
|
pdir.mkdir()
|
||||||
|
session = pdir / "a.jsonl"
|
||||||
|
session.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/cool-proj-real"}) + "\n")
|
||||||
|
assert _resolve_project_name(pdir) == "cool-proj-real"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_project_name_falls_back_when_no_cwd(tmp_path):
|
||||||
|
pdir = tmp_path / "-home-user-dev-foo"
|
||||||
|
pdir.mkdir()
|
||||||
|
(pdir / "a.jsonl").write_text(json.dumps({"type": "x"}) + "\n")
|
||||||
|
assert _resolve_project_name(pdir) == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_project_name_prefers_newer_session(tmp_path):
|
||||||
|
"""Newest session's cwd wins — covers the case where user renamed the
|
||||||
|
project directory between sessions."""
|
||||||
|
|
||||||
|
pdir = tmp_path / "-home-user-dev-old"
|
||||||
|
pdir.mkdir()
|
||||||
|
old = pdir / "old.jsonl"
|
||||||
|
old.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/old"}) + "\n")
|
||||||
|
# Ensure distinguishable mtimes
|
||||||
|
old_mtime = old.stat().st_mtime - 100
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.utime(old, (old_mtime, old_mtime))
|
||||||
|
|
||||||
|
new = pdir / "new.jsonl"
|
||||||
|
new.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/new-name"}) + "\n")
|
||||||
|
assert _resolve_project_name(pdir) == "new-name"
|
||||||
|
|
||||||
|
|
||||||
|
# ── scan_claude_projects ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_claude_projects_empty_dir(tmp_path):
|
||||||
|
assert scan_claude_projects(tmp_path) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_claude_projects_not_a_projects_root(tmp_path):
|
||||||
|
"""Returns empty list if the dir doesn't look like .claude/projects/."""
|
||||||
|
(tmp_path / "some-folder").mkdir()
|
||||||
|
(tmp_path / "some-folder" / "readme.md").write_text("hi")
|
||||||
|
assert scan_claude_projects(tmp_path) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_claude_projects_finds_projects(tmp_path):
|
||||||
|
p1 = tmp_path / "-home-user-dev-alpha"
|
||||||
|
p1.mkdir()
|
||||||
|
(p1 / "a.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/alpha"}) + "\n")
|
||||||
|
(p1 / "b.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/alpha"}) + "\n")
|
||||||
|
|
||||||
|
p2 = tmp_path / "-home-user-dev-beta"
|
||||||
|
p2.mkdir()
|
||||||
|
(p2 / "x.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/beta"}) + "\n")
|
||||||
|
|
||||||
|
result = scan_claude_projects(tmp_path)
|
||||||
|
names = [p.name for p in result]
|
||||||
|
assert "alpha" in names
|
||||||
|
assert "beta" in names
|
||||||
|
# alpha has 2 sessions, beta has 1 — alpha ranks higher
|
||||||
|
alpha = next(p for p in result if p.name == "alpha")
|
||||||
|
beta = next(p for p in result if p.name == "beta")
|
||||||
|
assert alpha.user_commits == 2
|
||||||
|
assert beta.user_commits == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_claude_projects_ignores_dirs_without_jsonl(tmp_path):
|
||||||
|
empty_proj = tmp_path / "-home-user-dev-empty"
|
||||||
|
empty_proj.mkdir()
|
||||||
|
(empty_proj / "notes.md").write_text("hi")
|
||||||
|
assert scan_claude_projects(tmp_path) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_claude_projects_marks_as_mine(tmp_path):
|
||||||
|
p = tmp_path / "-home-user-dev-owned"
|
||||||
|
p.mkdir()
|
||||||
|
(p / "s.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/owned"}) + "\n")
|
||||||
|
result = scan_claude_projects(tmp_path)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].is_mine is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_claude_projects_dedup_by_name(tmp_path):
|
||||||
|
"""Two encoded dirs resolving to the same project name collapse to one."""
|
||||||
|
p1 = tmp_path / "-home-user-a-proj"
|
||||||
|
p1.mkdir()
|
||||||
|
(p1 / "s.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/a/proj"}) + "\n")
|
||||||
|
(p1 / "t.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/a/proj"}) + "\n")
|
||||||
|
|
||||||
|
p2 = tmp_path / "-home-user-b-proj"
|
||||||
|
p2.mkdir()
|
||||||
|
(p2 / "u.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/b/proj"}) + "\n")
|
||||||
|
|
||||||
|
result = scan_claude_projects(tmp_path)
|
||||||
|
# Both decode to "proj"; only one remains — the one with more sessions wins
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].name == "proj"
|
||||||
|
assert result[0].user_commits == 2
|
||||||
@@ -0,0 +1,208 @@
|
|||||||
|
"""Tests for mempalace.miner.add_to_known_entities.
|
||||||
|
|
||||||
|
Covers the init → miner wire-up: init's confirmed entities merged into
|
||||||
|
``~/.mempalace/known_entities.json`` so the miner's drawer-tagging path
|
||||||
|
recognizes them at mine time.
|
||||||
|
|
||||||
|
Every test redirects the registry path to a tmp_path to avoid touching
|
||||||
|
the real ~/.mempalace/ on the developer's machine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mempalace import miner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_registry(tmp_path, monkeypatch):
|
||||||
|
"""Redirect the module-level registry path to a tmp file and reset cache."""
|
||||||
|
registry = tmp_path / "known_entities.json"
|
||||||
|
monkeypatch.setattr(miner, "_ENTITY_REGISTRY_PATH", str(registry))
|
||||||
|
miner._ENTITY_REGISTRY_CACHE.update({"mtime": None, "names": frozenset(), "raw": {}})
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
# ── fresh-file cases ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_creates_registry_when_absent(temp_registry):
|
||||||
|
assert not temp_registry.exists()
|
||||||
|
miner.add_to_known_entities({"people": ["Alice", "Bob"], "projects": ["foo"]})
|
||||||
|
assert temp_registry.exists()
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
assert sorted(data["people"]) == ["Alice", "Bob"]
|
||||||
|
assert data["projects"] == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_returns_registry_path(temp_registry):
|
||||||
|
result = miner.add_to_known_entities({"people": ["Alice"]})
|
||||||
|
assert result == str(temp_registry)
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_input_still_creates_file(temp_registry):
|
||||||
|
"""A no-op merge still touches the file (idempotent), but no entries added."""
|
||||||
|
miner.add_to_known_entities({})
|
||||||
|
# File may or may not be written for a truly empty call — tolerate either.
|
||||||
|
if temp_registry.exists():
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
assert data == {} or all(not v for v in data.values())
|
||||||
|
|
||||||
|
|
||||||
|
def test_skips_empty_name_strings(temp_registry):
|
||||||
|
miner.add_to_known_entities({"people": ["Alice", "", None]})
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
assert data["people"] == ["Alice"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── union / dedup cases ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_unions_with_existing_list_category(temp_registry):
|
||||||
|
temp_registry.write_text(json.dumps({"people": ["Alice", "Bob"]}))
|
||||||
|
miner.add_to_known_entities({"people": ["Bob", "Carol"]})
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
# Bob not duplicated, Carol appended, original order preserved
|
||||||
|
assert data["people"] == ["Alice", "Bob", "Carol"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_insensitive_dedup_preserves_first_seen_variant(temp_registry):
|
||||||
|
temp_registry.write_text(json.dumps({"people": ["Alice"]}))
|
||||||
|
miner.add_to_known_entities({"people": ["alice", "ALICE", "Bob"]})
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
# Alice stays as-is; lowercase/uppercase variants don't create new entries
|
||||||
|
assert data["people"] == ["Alice", "Bob"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_preserves_untouched_categories(temp_registry):
|
||||||
|
"""A category the caller didn't mention must be left alone."""
|
||||||
|
temp_registry.write_text(json.dumps({"people": ["Alice"], "places": ["Paris", "Tokyo"]}))
|
||||||
|
miner.add_to_known_entities({"people": ["Bob"]})
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
assert data["places"] == ["Paris", "Tokyo"]
|
||||||
|
assert data["people"] == ["Alice", "Bob"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_adds_new_categories(temp_registry):
|
||||||
|
temp_registry.write_text(json.dumps({"people": ["Alice"]}))
|
||||||
|
miner.add_to_known_entities({"projects": ["foo", "bar"]})
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
assert data["people"] == ["Alice"]
|
||||||
|
assert data["projects"] == ["foo", "bar"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_dedupes_within_input(temp_registry):
|
||||||
|
miner.add_to_known_entities({"people": ["Alice", "alice", "Alice"]})
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
assert data["people"] == ["Alice"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── dict-format existing registry ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_format_existing_category_gets_new_keys(temp_registry):
|
||||||
|
"""Miner supports {name: code} dict categories (alternate registry shape).
|
||||||
|
New names are added as keys without overwriting existing codes."""
|
||||||
|
temp_registry.write_text(json.dumps({"people": {"Alice": "ALC", "Bob": "BOB"}}))
|
||||||
|
miner.add_to_known_entities({"people": ["Alice", "Carol"]})
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
# Alice's code survives; Carol added with None; Bob untouched
|
||||||
|
assert data["people"]["Alice"] == "ALC"
|
||||||
|
assert data["people"]["Bob"] == "BOB"
|
||||||
|
assert "Carol" in data["people"]
|
||||||
|
assert data["people"]["Carol"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_format_dedupes_case_insensitively_and_stringifies_new_names(temp_registry):
|
||||||
|
temp_registry.write_text(json.dumps({"people": {"Alice": "ALC"}}))
|
||||||
|
miner.add_to_known_entities({"people": ["alice", 123]})
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
assert data["people"] == {"Alice": "ALC", "123": None}
|
||||||
|
|
||||||
|
|
||||||
|
# ── error tolerance ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_malformed_existing_registry_starts_fresh(temp_registry):
|
||||||
|
temp_registry.write_text("{ not valid json")
|
||||||
|
miner.add_to_known_entities({"people": ["Alice"]})
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
assert data == {"people": ["Alice"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_dict_existing_registry_starts_fresh(temp_registry):
|
||||||
|
temp_registry.write_text(json.dumps(["unexpected", "array"]))
|
||||||
|
miner.add_to_known_entities({"people": ["Alice"]})
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
assert data == {"people": ["Alice"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_list_input_category_ignored(temp_registry):
|
||||||
|
miner.add_to_known_entities({"people": ["Alice"], "weird": "not a list"})
|
||||||
|
data = json.loads(temp_registry.read_text())
|
||||||
|
assert "weird" not in data or data.get("weird") == "not a list"
|
||||||
|
assert data["people"] == ["Alice"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── cache invalidation ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_invalidated_so_subsequent_load_sees_write(temp_registry):
|
||||||
|
"""cmd_init → cmd_mine runs in the same process; the load path must
|
||||||
|
see what init just wrote without a process restart."""
|
||||||
|
# Prime the cache with an empty state
|
||||||
|
miner._load_known_entities()
|
||||||
|
assert miner._load_known_entities() == frozenset()
|
||||||
|
|
||||||
|
miner.add_to_known_entities({"people": ["Alice", "Bob"], "projects": ["foo"]})
|
||||||
|
|
||||||
|
loaded = miner._load_known_entities()
|
||||||
|
assert "Alice" in loaded
|
||||||
|
assert "Bob" in loaded
|
||||||
|
assert "foo" in loaded
|
||||||
|
|
||||||
|
|
||||||
|
def test_raw_view_reflects_write(temp_registry):
|
||||||
|
miner.add_to_known_entities({"people": ["Alice"]})
|
||||||
|
raw = miner._load_known_entities_raw()
|
||||||
|
assert raw.get("people") == ["Alice"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unicode round-trip ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_unicode_names_written_literally_not_escaped(temp_registry):
|
||||||
|
"""`ensure_ascii=False` so non-ASCII names stay readable on disk."""
|
||||||
|
miner.add_to_known_entities({"people": ["Gergő Móricz", "Arturo Domínguez"]})
|
||||||
|
raw_text = temp_registry.read_text(encoding="utf-8")
|
||||||
|
assert "Gergő" in raw_text
|
||||||
|
assert "Móricz" in raw_text
|
||||||
|
# Round-trips through JSON
|
||||||
|
data = json.loads(raw_text)
|
||||||
|
assert "Gergő Móricz" in data["people"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── end-to-end: does the write actually help _extract_entities_for_metadata? ──
|
||||||
|
|
||||||
|
|
||||||
|
def test_populated_registry_improves_miner_recall(temp_registry):
|
||||||
|
"""The whole point of the wire-up: names written via add_to_known_entities
|
||||||
|
must be recognized by the miner's entity-extraction metadata pass."""
|
||||||
|
miner.add_to_known_entities(
|
||||||
|
{
|
||||||
|
"people": ["Julia Grib", "Kevin Heifner"],
|
||||||
|
"projects": ["hyperion-history", "mempalace"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
sample = (
|
||||||
|
"Met with Julia Grib yesterday about the mempalace release. "
|
||||||
|
"Kevin Heifner pushed the hyperion-history fix."
|
||||||
|
)
|
||||||
|
result = miner._extract_entities_for_metadata(sample)
|
||||||
|
tagged = set(result.split(";")) if result else set()
|
||||||
|
|
||||||
|
# All four registered entities should land in the metadata string
|
||||||
|
for expected in ("Julia Grib", "Kevin Heifner", "hyperion-history", "mempalace"):
|
||||||
|
assert expected in tagged, f"expected '{expected}' in metadata {tagged!r}"
|
||||||
@@ -0,0 +1,327 @@
|
|||||||
|
"""Tests for mempalace.llm_client.
|
||||||
|
|
||||||
|
HTTP is mocked throughout — these tests do not require a running Ollama
|
||||||
|
or network access. Live-provider smoke tests live outside the unit-test
|
||||||
|
suite.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mempalace.llm_client import (
|
||||||
|
AnthropicProvider,
|
||||||
|
LLMError,
|
||||||
|
OllamaProvider,
|
||||||
|
OpenAICompatProvider,
|
||||||
|
_http_post_json,
|
||||||
|
get_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── factory ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_provider_ollama():
|
||||||
|
p = get_provider("ollama", "gemma4:e4b")
|
||||||
|
assert isinstance(p, OllamaProvider)
|
||||||
|
assert p.model == "gemma4:e4b"
|
||||||
|
assert p.endpoint == OllamaProvider.DEFAULT_ENDPOINT
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_provider_openai_compat():
|
||||||
|
p = get_provider("openai-compat", "foo", endpoint="http://localhost:1234")
|
||||||
|
assert isinstance(p, OpenAICompatProvider)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_provider_anthropic():
|
||||||
|
p = get_provider("anthropic", "claude-haiku", api_key="sk-xxx")
|
||||||
|
assert isinstance(p, AnthropicProvider)
|
||||||
|
assert p.api_key == "sk-xxx"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_provider_unknown_raises():
|
||||||
|
with pytest.raises(LLMError, match="Unknown provider"):
|
||||||
|
get_provider("nonsense", "x")
|
||||||
|
|
||||||
|
|
||||||
|
# ── _http_post_json ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_post_json_success():
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.read.return_value = b'{"ok": true}'
|
||||||
|
mock_resp.__enter__.return_value = mock_resp
|
||||||
|
mock_resp.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock_resp):
|
||||||
|
result = _http_post_json("http://x/y", {"a": 1}, {}, timeout=5)
|
||||||
|
assert result == {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_post_json_http_error_wraps_as_llm_error():
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
import io
|
||||||
|
|
||||||
|
err = HTTPError("http://x", 404, "Not Found", {}, io.BytesIO(b"model missing"))
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=err):
|
||||||
|
with pytest.raises(LLMError, match="HTTP 404"):
|
||||||
|
_http_post_json("http://x", {}, {}, timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_post_json_url_error_wraps_as_llm_error():
|
||||||
|
from urllib.error import URLError
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=URLError("conn refused")):
|
||||||
|
with pytest.raises(LLMError, match="Cannot reach"):
|
||||||
|
_http_post_json("http://x", {}, {}, timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_post_json_malformed_response():
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.read.return_value = b"not json"
|
||||||
|
mock_resp.__enter__.return_value = mock_resp
|
||||||
|
mock_resp.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock_resp):
|
||||||
|
with pytest.raises(LLMError, match="Malformed"):
|
||||||
|
_http_post_json("http://x", {}, {}, timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
# ── OllamaProvider ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_ollama_chat_response(content: str):
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.read.return_value = json.dumps({"message": {"content": content}}).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_check_available_finds_model():
|
||||||
|
tags = {"models": [{"name": "gemma4:e4b"}, {"name": "other:latest"}]}
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.read.return_value = json.dumps(tags).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||||
|
p = OllamaProvider(model="gemma4:e4b")
|
||||||
|
ok, msg = p.check_available()
|
||||||
|
assert ok
|
||||||
|
assert msg == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_check_available_accepts_latest_suffix():
|
||||||
|
tags = {"models": [{"name": "mymodel:latest"}]}
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.read.return_value = json.dumps(tags).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||||
|
p = OllamaProvider(model="mymodel")
|
||||||
|
ok, _ = p.check_available()
|
||||||
|
assert ok
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_check_available_missing_model():
|
||||||
|
tags = {"models": [{"name": "other:latest"}]}
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.read.return_value = json.dumps(tags).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||||
|
p = OllamaProvider(model="absent")
|
||||||
|
ok, msg = p.check_available()
|
||||||
|
assert not ok
|
||||||
|
assert "ollama pull absent" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_check_available_unreachable():
|
||||||
|
from urllib.error import URLError
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=URLError("refused")):
|
||||||
|
p = OllamaProvider(model="gemma4:e4b")
|
||||||
|
ok, msg = p.check_available()
|
||||||
|
assert not ok
|
||||||
|
assert "Cannot reach Ollama" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_classify_sends_json_format():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["url"] = req.full_url
|
||||||
|
captured["body"] = json.loads(req.data.decode())
|
||||||
|
return _mock_ollama_chat_response('{"classifications": []}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = OllamaProvider(model="gemma4:e4b")
|
||||||
|
resp = p.classify("sys", "user", json_mode=True)
|
||||||
|
|
||||||
|
assert captured["body"]["format"] == "json"
|
||||||
|
assert captured["body"]["model"] == "gemma4:e4b"
|
||||||
|
assert captured["url"].endswith("/api/chat")
|
||||||
|
assert resp.provider == "ollama"
|
||||||
|
assert resp.text == '{"classifications": []}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_classify_empty_content_raises():
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=_mock_ollama_chat_response("")):
|
||||||
|
p = OllamaProvider(model="x")
|
||||||
|
with pytest.raises(LLMError, match="Empty response"):
|
||||||
|
p.classify("s", "u")
|
||||||
|
|
||||||
|
|
||||||
|
# ── OpenAICompatProvider ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_openai_response(content: str):
|
||||||
|
mock = MagicMock()
|
||||||
|
payload = {"choices": [{"message": {"content": content}}]}
|
||||||
|
mock.read.return_value = json.dumps(payload).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_resolves_url_with_v1_suffix():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["url"] = req.full_url
|
||||||
|
return _mock_openai_response('{"ok": true}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h:1234")
|
||||||
|
p.classify("s", "u")
|
||||||
|
assert captured["url"] == "http://h:1234/v1/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_resolves_url_with_existing_v1():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["url"] = req.full_url
|
||||||
|
return _mock_openai_response('{"ok": true}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h:1234/v1")
|
||||||
|
p.classify("s", "u")
|
||||||
|
assert captured["url"] == "http://h:1234/v1/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_requires_endpoint():
|
||||||
|
p = OpenAICompatProvider(model="x")
|
||||||
|
with pytest.raises(LLMError, match="requires --llm-endpoint"):
|
||||||
|
p.classify("s", "u")
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_sends_authorization_when_key_present():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["auth"] = req.get_header("Authorization")
|
||||||
|
return _mock_openai_response('{"ok": true}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h", api_key="sk-aaa")
|
||||||
|
p.classify("s", "u")
|
||||||
|
assert captured["auth"] == "Bearer sk-aaa"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_uses_env_var_fallback(monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-from-env")
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h")
|
||||||
|
assert p.api_key == "sk-from-env"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_sends_response_format_json():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["body"] = json.loads(req.data.decode())
|
||||||
|
return _mock_openai_response('{"ok": true}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h")
|
||||||
|
p.classify("s", "u", json_mode=True)
|
||||||
|
assert captured["body"]["response_format"] == {"type": "json_object"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_unexpected_shape_raises():
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.read.return_value = b'{"nothing": "here"}'
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h")
|
||||||
|
with pytest.raises(LLMError, match="Unexpected response shape"):
|
||||||
|
p.classify("s", "u")
|
||||||
|
|
||||||
|
|
||||||
|
# ── AnthropicProvider ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_anthropic_response(text: str):
|
||||||
|
mock = MagicMock()
|
||||||
|
payload = {"content": [{"type": "text", "text": text}]}
|
||||||
|
mock.read.return_value = json.dumps(payload).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_requires_api_key(monkeypatch):
|
||||||
|
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||||
|
p = AnthropicProvider(model="claude-haiku")
|
||||||
|
ok, msg = p.check_available()
|
||||||
|
assert not ok
|
||||||
|
assert "ANTHROPIC_API_KEY" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_reads_env_key(monkeypatch):
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-env")
|
||||||
|
p = AnthropicProvider(model="claude-haiku")
|
||||||
|
assert p.api_key == "sk-ant-env"
|
||||||
|
ok, _ = p.check_available()
|
||||||
|
assert ok
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_classify_sends_version_and_key():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["api_key"] = req.get_header("X-api-key")
|
||||||
|
captured["version"] = req.get_header("Anthropic-version")
|
||||||
|
return _mock_anthropic_response('{"ok": true}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = AnthropicProvider(model="claude-haiku", api_key="sk-ant-abc")
|
||||||
|
resp = p.classify("s", "u")
|
||||||
|
assert captured["api_key"] == "sk-ant-abc"
|
||||||
|
assert captured["version"] == AnthropicProvider.API_VERSION
|
||||||
|
assert resp.text == '{"ok": true}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_joins_multiple_text_blocks():
|
||||||
|
mock = MagicMock()
|
||||||
|
payload = {
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "part one. "},
|
||||||
|
{"type": "text", "text": "part two."},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock.read.return_value = json.dumps(payload).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||||
|
p = AnthropicProvider(model="claude-haiku", api_key="sk-ant")
|
||||||
|
resp = p.classify("s", "u")
|
||||||
|
assert resp.text == "part one. part two."
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_no_key_raises_on_classify(monkeypatch):
|
||||||
|
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||||
|
p = AnthropicProvider(model="claude-haiku")
|
||||||
|
with pytest.raises(LLMError, match="requires ANTHROPIC_API_KEY"):
|
||||||
|
p.classify("s", "u")
|
||||||
@@ -0,0 +1,631 @@
|
|||||||
|
"""Tests for mempalace.llm_refine.
|
||||||
|
|
||||||
|
Uses a fake provider for deterministic, offline tests. No network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
from mempalace.llm_client import LLMError, LLMResponse
|
||||||
|
from mempalace.llm_refine import (
|
||||||
|
_apply_classifications,
|
||||||
|
_build_user_prompt,
|
||||||
|
_collect_contexts,
|
||||||
|
_extract_json_candidates,
|
||||||
|
_is_authoritative_person,
|
||||||
|
_is_authoritative_project,
|
||||||
|
_parse_response,
|
||||||
|
collect_corpus_text,
|
||||||
|
refine_entities,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── fake provider ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeProvider:
|
||||||
|
"""Returns a caller-supplied JSON string on every classify call."""
|
||||||
|
|
||||||
|
response_text: str = ""
|
||||||
|
should_raise: Exception = None
|
||||||
|
call_count: int = 0
|
||||||
|
interrupt_on_call: int = -1
|
||||||
|
|
||||||
|
def classify(self, system, user, json_mode=True):
|
||||||
|
self.call_count += 1
|
||||||
|
if self.call_count == self.interrupt_on_call:
|
||||||
|
raise KeyboardInterrupt()
|
||||||
|
if self.should_raise is not None:
|
||||||
|
raise self.should_raise
|
||||||
|
return LLMResponse(text=self.response_text, model="fake", provider="fake", raw={})
|
||||||
|
|
||||||
|
def check_available(self):
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _collect_contexts ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_contexts_finds_matches():
|
||||||
|
lines = [
|
||||||
|
"Something about Alice",
|
||||||
|
"Bob said hello",
|
||||||
|
"Alice was here",
|
||||||
|
"Alice walked by",
|
||||||
|
]
|
||||||
|
out = _collect_contexts(lines, "Alice", max_lines=2)
|
||||||
|
assert len(out) == 2
|
||||||
|
assert all("alice" in line.lower() for line in out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_contexts_case_insensitive():
|
||||||
|
lines = ["lowercase alice mention"]
|
||||||
|
out = _collect_contexts(lines, "Alice")
|
||||||
|
assert out == ["lowercase alice mention"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_contexts_uses_token_boundaries():
|
||||||
|
lines = [
|
||||||
|
"forgot should not match",
|
||||||
|
"Go is a language.",
|
||||||
|
"go-v1 shipped.",
|
||||||
|
]
|
||||||
|
out = _collect_contexts(lines, "Go", max_lines=5)
|
||||||
|
assert out == ["Go is a language.", "go-v1 shipped."]
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_contexts_dedupes_identical_lines():
|
||||||
|
lines = ["Alice", "Alice", "Alice was here"]
|
||||||
|
out = _collect_contexts(lines, "Alice", max_lines=5)
|
||||||
|
# two unique lines, not three
|
||||||
|
assert len(out) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_contexts_truncates_long_lines():
|
||||||
|
lines = ["Alice " + ("x" * 1000)]
|
||||||
|
out = _collect_contexts(lines, "Alice")
|
||||||
|
assert len(out[0]) <= 240
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_contexts_no_matches():
|
||||||
|
assert _collect_contexts(["nothing here"], "Alice") == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _build_user_prompt ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_prompt_numbers_and_includes_contexts():
|
||||||
|
prompt = _build_user_prompt(
|
||||||
|
[
|
||||||
|
("Alice", "uncertain", ["Alice said hi"]),
|
||||||
|
("Bob", "project", []),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert "1. Alice" in prompt
|
||||||
|
assert "2. Bob" in prompt
|
||||||
|
assert "Alice said hi" in prompt
|
||||||
|
assert "(no context available)" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
# ── _parse_response ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_canonicalizes_label():
|
||||||
|
text = '{"classifications": [{"name": "Alice", "label": "person", "reason": "x"}]}'
|
||||||
|
out = _parse_response(text, ["Alice"])
|
||||||
|
assert out["Alice"] == ("PERSON", "x")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_accepts_type_alias():
|
||||||
|
"""LLMs may return 'type' instead of 'label'."""
|
||||||
|
text = '{"classifications": [{"name": "Bob", "type": "PROJECT"}]}'
|
||||||
|
out = _parse_response(text, ["Bob"])
|
||||||
|
assert out["Bob"][0] == "PROJECT"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_maps_unknown_label_to_ambiguous():
|
||||||
|
text = '{"classifications": [{"name": "X", "label": "WEIRD"}]}'
|
||||||
|
out = _parse_response(text, ["X"])
|
||||||
|
assert out["X"][0] == "AMBIGUOUS"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_restores_canonical_casing():
|
||||||
|
"""Model may lowercase the name; we restore against the expected set."""
|
||||||
|
text = '{"classifications": [{"name": "mempalace", "label": "PROJECT"}]}'
|
||||||
|
out = _parse_response(text, ["MemPalace"])
|
||||||
|
assert "MemPalace" in out
|
||||||
|
assert out["MemPalace"][0] == "PROJECT"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_strips_code_fences():
|
||||||
|
text = '```json\n{"classifications": [{"name": "X", "label": "TOPIC"}]}\n```'
|
||||||
|
out = _parse_response(text, ["X"])
|
||||||
|
assert out["X"][0] == "TOPIC"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_extracts_json_after_prose():
|
||||||
|
text = 'Sure, here is the JSON: {"classifications": [{"name": "X", "label": "TOPIC"}]}'
|
||||||
|
out = _parse_response(text, ["X"])
|
||||||
|
assert out["X"][0] == "TOPIC"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_extracts_fenced_json_after_prose():
|
||||||
|
text = 'Sure:\n```json\n{"classifications": [{"name": "X", "label": "PROJECT"}]}\n```'
|
||||||
|
out = _parse_response(text, ["X"])
|
||||||
|
assert out["X"][0] == "PROJECT"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_json_candidates_handles_embedded_array():
|
||||||
|
text = 'prefix [{"name": "Y", "label": "PERSON"}] suffix'
|
||||||
|
candidates = _extract_json_candidates(text)
|
||||||
|
assert '[{"name": "Y", "label": "PERSON"}]' in candidates
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_ignores_non_json_brackets_before_payload():
|
||||||
|
text = 'See [note] first. JSON: {"classifications": [{"name": "X", "label": "TOPIC"}]}'
|
||||||
|
out = _parse_response(text, ["X"])
|
||||||
|
assert out["X"][0] == "TOPIC"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_malformed_returns_empty():
|
||||||
|
out = _parse_response("not json at all", ["X"])
|
||||||
|
assert out == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_accepts_top_level_list():
|
||||||
|
"""Some models skip the wrapping object and return the list directly."""
|
||||||
|
text = '[{"name": "Y", "label": "PERSON"}]'
|
||||||
|
out = _parse_response(text, ["Y"])
|
||||||
|
assert out["Y"][0] == "PERSON"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _apply_classifications ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_moves_to_correct_bucket():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [
|
||||||
|
{
|
||||||
|
"name": "Foo",
|
||||||
|
"type": "project",
|
||||||
|
"confidence": 0.8,
|
||||||
|
"frequency": 3,
|
||||||
|
"signals": ["old"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"uncertain": [
|
||||||
|
{"name": "Alice", "type": "uncertain", "confidence": 0.4, "frequency": 5, "signals": []}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
decisions = {
|
||||||
|
"Foo": ("PROJECT", "real project name"),
|
||||||
|
"Alice": ("PERSON", "clearly a person"),
|
||||||
|
}
|
||||||
|
new, reclass, dropped = _apply_classifications(detected, decisions)
|
||||||
|
assert len(new["people"]) == 1
|
||||||
|
assert new["people"][0]["name"] == "Alice"
|
||||||
|
assert new["people"][0]["type"] == "person"
|
||||||
|
assert reclass == 1 # Alice moved uncertain -> people
|
||||||
|
assert dropped == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_drops_common_word():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [
|
||||||
|
{
|
||||||
|
"name": "Never",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 20,
|
||||||
|
"signals": [],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
decisions = {"Never": ("COMMON_WORD", "adverb")}
|
||||||
|
new, _, dropped = _apply_classifications(detected, decisions)
|
||||||
|
assert dropped == 1
|
||||||
|
assert new["uncertain"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_keeps_unvisited_entries():
|
||||||
|
detected = {
|
||||||
|
"people": [
|
||||||
|
{
|
||||||
|
"name": "Igor",
|
||||||
|
"type": "person",
|
||||||
|
"confidence": 0.99,
|
||||||
|
"frequency": 100,
|
||||||
|
"signals": ["git"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
# No decision for Igor — should stay untouched
|
||||||
|
new, reclass, dropped = _apply_classifications(detected, {})
|
||||||
|
assert new["people"][0]["name"] == "Igor"
|
||||||
|
assert reclass == 0
|
||||||
|
assert dropped == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_appends_reason_signal():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [
|
||||||
|
{
|
||||||
|
"name": "Foo",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": ["regex"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
decisions = {"Foo": ("PERSON", "spoken of by name")}
|
||||||
|
new, _, _ = _apply_classifications(detected, decisions)
|
||||||
|
assert any("LLM: person" in s for s in new["people"][0]["signals"])
|
||||||
|
assert any("spoken of by name" in s for s in new["people"][0]["signals"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_topic_goes_to_uncertain():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [
|
||||||
|
{
|
||||||
|
"name": "Paris",
|
||||||
|
"type": "project",
|
||||||
|
"confidence": 0.7,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": ["regex"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
decisions = {"Paris": ("TOPIC", "city, not a project")}
|
||||||
|
new, reclass, _ = _apply_classifications(detected, decisions)
|
||||||
|
assert len(new["projects"]) == 0
|
||||||
|
assert len(new["uncertain"]) == 1
|
||||||
|
assert new["uncertain"][0]["name"] == "Paris"
|
||||||
|
assert reclass == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_can_block_llm_only_project_promotion():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [
|
||||||
|
{
|
||||||
|
"name": "Terraform",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": ["regex"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
decisions = {"Terraform": ("PROJECT", "tool")}
|
||||||
|
new, reclass, _ = _apply_classifications(
|
||||||
|
detected,
|
||||||
|
decisions,
|
||||||
|
allow_project_promotions=False,
|
||||||
|
)
|
||||||
|
assert new["projects"] == []
|
||||||
|
assert new["uncertain"][0]["name"] == "Terraform"
|
||||||
|
assert new["uncertain"][0]["type"] == "uncertain"
|
||||||
|
assert reclass == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_allows_project_promotion_for_prose_only_mode():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [
|
||||||
|
{
|
||||||
|
"name": "Project Aurora",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": ["regex"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
decisions = {"Project Aurora": ("PROJECT", "user effort")}
|
||||||
|
new, reclass, _ = _apply_classifications(detected, decisions)
|
||||||
|
assert new["projects"][0]["name"] == "Project Aurora"
|
||||||
|
assert new["projects"][0]["type"] == "project"
|
||||||
|
assert reclass == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── authoritative source filters ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_authoritative_person_requires_git_signal():
|
||||||
|
assert _is_authoritative_person({"signals": ["5 commits across 2 repos"]})
|
||||||
|
assert not _is_authoritative_person({"signals": ["pronoun nearby (5x)"]})
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_authoritative_project_requires_manifest_or_git_signal():
|
||||||
|
assert _is_authoritative_project({"signals": ["package.json, 12 of your commits"]})
|
||||||
|
assert _is_authoritative_project({"signals": ["57 commits (none by you)"]})
|
||||||
|
assert not _is_authoritative_project({"signals": ["code file reference (5x)"]})
|
||||||
|
|
||||||
|
|
||||||
|
# ── refine_entities ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_detected():
|
||||||
|
return {
|
||||||
|
"people": [
|
||||||
|
{
|
||||||
|
"name": "Igor",
|
||||||
|
"type": "person",
|
||||||
|
"confidence": 0.99,
|
||||||
|
"frequency": 100,
|
||||||
|
"signals": ["git"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"projects": [
|
||||||
|
{
|
||||||
|
"name": "Foo",
|
||||||
|
"type": "project",
|
||||||
|
"confidence": 0.7,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": ["regex"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"uncertain": [
|
||||||
|
{
|
||||||
|
"name": "Never",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 10,
|
||||||
|
"signals": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Alice",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": [],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_end_to_end_with_fake_provider():
|
||||||
|
provider = FakeProvider(
|
||||||
|
response_text=(
|
||||||
|
'{"classifications": ['
|
||||||
|
'{"name": "Foo", "label": "PROJECT", "reason": "real"},'
|
||||||
|
'{"name": "Never", "label": "COMMON_WORD"},'
|
||||||
|
'{"name": "Alice", "label": "PERSON", "reason": "name"}'
|
||||||
|
"]}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = refine_entities(
|
||||||
|
_sample_detected(),
|
||||||
|
corpus_text="Alice said hi. Foo was shipped. Never gonna.",
|
||||||
|
provider=provider,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
assert result.batches_total == 1
|
||||||
|
assert result.batches_completed == 1
|
||||||
|
assert not result.cancelled
|
||||||
|
# Alice → people, Never → dropped, Foo stays in projects
|
||||||
|
names_in_people = [e["name"] for e in result.merged["people"]]
|
||||||
|
assert "Alice" in names_in_people
|
||||||
|
assert "Igor" in names_in_people # untouched
|
||||||
|
assert "Never" not in [e["name"] for e in result.merged["uncertain"]]
|
||||||
|
assert result.dropped == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_skips_high_confidence_projects():
|
||||||
|
"""Manifest-backed projects (conf >= 0.95) aren't sent to the LLM."""
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [
|
||||||
|
{
|
||||||
|
"name": "manifest-backed",
|
||||||
|
"type": "project",
|
||||||
|
"confidence": 0.99,
|
||||||
|
"frequency": 50,
|
||||||
|
"signals": ["pyproject.toml"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
provider = FakeProvider(response_text='{"classifications": []}')
|
||||||
|
refine_entities(detected, "", provider, show_progress=False)
|
||||||
|
# Should not have called the LLM at all
|
||||||
|
assert provider.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_refines_high_confidence_regex_projects():
|
||||||
|
"""High-confidence regex projects still need LLM review without source signal."""
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [
|
||||||
|
{
|
||||||
|
"name": "OpenAPI",
|
||||||
|
"type": "project",
|
||||||
|
"confidence": 0.99,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": ["code file reference (5x)"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
provider = FakeProvider(
|
||||||
|
response_text=(
|
||||||
|
'{"classifications": [{"name": "OpenAPI", "label": "TOPIC", "reason": "technology"}]}'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = refine_entities(detected, "OpenAPI schemas", provider, show_progress=False)
|
||||||
|
assert provider.call_count == 1
|
||||||
|
assert result.reclassified == 1
|
||||||
|
assert result.merged["projects"] == []
|
||||||
|
assert result.merged["uncertain"][0]["name"] == "OpenAPI"
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_refines_regex_people_but_skips_git_people():
|
||||||
|
detected = {
|
||||||
|
"people": [
|
||||||
|
{
|
||||||
|
"name": "Igor Lins e Silva",
|
||||||
|
"type": "person",
|
||||||
|
"confidence": 0.99,
|
||||||
|
"frequency": 100,
|
||||||
|
"signals": ["100 commits across 3 repos"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Tool",
|
||||||
|
"type": "person",
|
||||||
|
"confidence": 0.99,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": ["pronoun nearby (5x)"],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
provider = FakeProvider(
|
||||||
|
response_text='{"classifications": [{"name": "Tool", "label": "COMMON_WORD"}]}'
|
||||||
|
)
|
||||||
|
result = refine_entities(detected, "Tool is a common noun.", provider, show_progress=False)
|
||||||
|
assert provider.call_count == 1
|
||||||
|
names = [e["name"] for e in result.merged["people"]]
|
||||||
|
assert names == ["Igor Lins e Silva"]
|
||||||
|
assert result.dropped == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_can_keep_llm_only_project_in_uncertain():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [
|
||||||
|
{
|
||||||
|
"name": "Terraform",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 9,
|
||||||
|
"signals": ["regex"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
provider = FakeProvider(
|
||||||
|
response_text='{"classifications": [{"name": "Terraform", "label": "PROJECT"}]}'
|
||||||
|
)
|
||||||
|
result = refine_entities(
|
||||||
|
detected,
|
||||||
|
"Terraform config",
|
||||||
|
provider,
|
||||||
|
show_progress=False,
|
||||||
|
allow_project_promotions=False,
|
||||||
|
)
|
||||||
|
assert result.merged["projects"] == []
|
||||||
|
assert result.merged["uncertain"][0]["name"] == "Terraform"
|
||||||
|
assert any("LLM: project" in s for s in result.merged["uncertain"][0]["signals"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_empty_candidates_returns_noop():
|
||||||
|
detected = {"people": [], "projects": [], "uncertain": []}
|
||||||
|
provider = FakeProvider()
|
||||||
|
result = refine_entities(detected, "", provider, show_progress=False)
|
||||||
|
assert result.batches_total == 0
|
||||||
|
assert result.reclassified == 0
|
||||||
|
assert result.merged == detected
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_handles_batch_error_gracefully():
|
||||||
|
provider = FakeProvider(should_raise=LLMError("transport broke"))
|
||||||
|
result = refine_entities(
|
||||||
|
_sample_detected(),
|
||||||
|
corpus_text="",
|
||||||
|
provider=provider,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
assert result.errors
|
||||||
|
assert "transport broke" in result.errors[0]
|
||||||
|
# Detected unchanged (no successful decisions)
|
||||||
|
assert result.reclassified == 0
|
||||||
|
assert result.cancelled is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_ctrl_c_returns_partial():
|
||||||
|
"""Ctrl-C during refinement marks cancelled=True and returns partial result."""
|
||||||
|
# Two batches' worth of candidates
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [
|
||||||
|
{
|
||||||
|
"name": f"Cand{i}",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 3,
|
||||||
|
"signals": [],
|
||||||
|
}
|
||||||
|
for i in range(50)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
provider = FakeProvider(
|
||||||
|
response_text='{"classifications": []}',
|
||||||
|
interrupt_on_call=2, # interrupt on second batch
|
||||||
|
)
|
||||||
|
result = refine_entities(detected, "", provider, batch_size=25, show_progress=False)
|
||||||
|
assert result.cancelled is True
|
||||||
|
assert result.batches_completed == 1 # first batch finished; second interrupted
|
||||||
|
assert result.batches_total == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_malformed_response_recorded_as_error():
|
||||||
|
provider = FakeProvider(response_text="not json")
|
||||||
|
result = refine_entities(_sample_detected(), "", provider, show_progress=False)
|
||||||
|
assert any("could not parse" in e for e in result.errors)
|
||||||
|
|
||||||
|
|
||||||
|
# ── collect_corpus_text ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_corpus_text_reads_prose_files(tmp_path):
|
||||||
|
(tmp_path / "a.md").write_text("hello world")
|
||||||
|
(tmp_path / "b.txt").write_text("more prose")
|
||||||
|
(tmp_path / "c.py").write_text("import os") # not prose, skipped
|
||||||
|
text = collect_corpus_text(str(tmp_path))
|
||||||
|
assert "hello world" in text
|
||||||
|
assert "more prose" in text
|
||||||
|
assert "import os" not in text
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_corpus_text_prefers_recent(tmp_path):
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
old = tmp_path / "old.md"
|
||||||
|
old.write_text("OLD_CONTENT")
|
||||||
|
time.sleep(0.01)
|
||||||
|
new = tmp_path / "new.md"
|
||||||
|
new.write_text("NEW_CONTENT")
|
||||||
|
# Force old to be older still
|
||||||
|
old_mtime = old.stat().st_mtime - 3600
|
||||||
|
os.utime(old, (old_mtime, old_mtime))
|
||||||
|
|
||||||
|
text = collect_corpus_text(str(tmp_path), max_files=1)
|
||||||
|
assert "NEW_CONTENT" in text
|
||||||
|
assert "OLD_CONTENT" not in text
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_corpus_text_missing_dir_returns_empty(tmp_path):
|
||||||
|
assert collect_corpus_text(str(tmp_path / "nope")) == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_corpus_text_caps_bytes_per_file(tmp_path):
|
||||||
|
big = tmp_path / "big.md"
|
||||||
|
big.write_text("x" * 100_000)
|
||||||
|
text = collect_corpus_text(str(tmp_path), max_files=1, max_bytes_per_file=500)
|
||||||
|
assert len(text) <= 600 # 500 + newlines
|
||||||
@@ -66,6 +66,16 @@ def test_load_config_uses_defaults_when_yaml_missing():
|
|||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_project_skips_mempalace_generated_files():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
project_root = Path(tmpdir).resolve()
|
||||||
|
write_file(project_root / "entities.json", '{"people": [], "projects": []}')
|
||||||
|
write_file(project_root / "mempalace.yaml", "wing: test\nrooms: []\n")
|
||||||
|
write_file(project_root / "notes.md", "real user content\n" * 10)
|
||||||
|
|
||||||
|
assert scanned_files(project_root) == ["notes.md"]
|
||||||
|
|
||||||
|
|
||||||
def test_scan_project_respects_gitignore():
|
def test_scan_project_respects_gitignore():
|
||||||
tmpdir = tempfile.mkdtemp()
|
tmpdir = tempfile.mkdtemp()
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -480,6 +481,49 @@ def test_discover_entities_prefers_real_signal_over_prose(tmp_path):
|
|||||||
assert "realproj" in proj_names
|
assert "realproj" in proj_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_discover_entities_keeps_uncertain_for_llm_when_real_signal(tmp_path):
|
||||||
|
"""With --llm, regex-uncertain prose candidates should reach refinement."""
|
||||||
|
(tmp_path / "package.json").write_text(json.dumps({"name": "realproj"}))
|
||||||
|
_init_git_repo(tmp_path)
|
||||||
|
(tmp_path / "doc.md").write_text("Noise appeared. Noise repeated. Noise again.")
|
||||||
|
|
||||||
|
class FakeProvider:
|
||||||
|
def __init__(self):
|
||||||
|
self.prompts = []
|
||||||
|
|
||||||
|
def classify(self, _system, user, json_mode=True):
|
||||||
|
self.prompts.append(user)
|
||||||
|
return SimpleNamespace(
|
||||||
|
text='{"classifications": [{"name": "Noise", "label": "COMMON_WORD"}]}'
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = FakeProvider()
|
||||||
|
d = discover_entities(str(tmp_path), llm_provider=provider, show_progress=False)
|
||||||
|
|
||||||
|
assert len(provider.prompts) == 1
|
||||||
|
assert "Noise" in provider.prompts[0]
|
||||||
|
assert "Noise" not in [e["name"] for cat in d.values() for e in cat]
|
||||||
|
|
||||||
|
|
||||||
|
def test_discover_entities_keeps_llm_only_project_uncertain_when_real_signal(tmp_path):
|
||||||
|
"""Repo roots should not auto-promote LLM-only tools/topics into projects."""
|
||||||
|
(tmp_path / "package.json").write_text(json.dumps({"name": "realproj"}))
|
||||||
|
_init_git_repo(tmp_path)
|
||||||
|
(tmp_path / "doc.md").write_text("Terraform shipped. Terraform changed. Terraform runs.")
|
||||||
|
|
||||||
|
class FakeProvider:
|
||||||
|
def classify(self, _system, _user, json_mode=True):
|
||||||
|
return SimpleNamespace(
|
||||||
|
text='{"classifications": [{"name": "Terraform", "label": "PROJECT"}]}'
|
||||||
|
)
|
||||||
|
|
||||||
|
d = discover_entities(str(tmp_path), llm_provider=FakeProvider(), show_progress=False)
|
||||||
|
|
||||||
|
assert "realproj" in [e["name"] for e in d["projects"]]
|
||||||
|
assert "Terraform" not in [e["name"] for e in d["projects"]]
|
||||||
|
assert "Terraform" in [e["name"] for e in d["uncertain"]]
|
||||||
|
|
||||||
|
|
||||||
# ── _UnionFind basics ──────────────────────────────────────────────────
|
# ── _UnionFind basics ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1174,6 +1174,7 @@ source = { editable = "." }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "chromadb" },
|
{ name = "chromadb" },
|
||||||
{ name = "pyyaml" },
|
{ name = "pyyaml" },
|
||||||
|
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
@@ -1206,6 +1207,7 @@ requires-dist = [
|
|||||||
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },
|
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },
|
||||||
{ name = "pyyaml", specifier = ">=6.0,<7" },
|
{ name = "pyyaml", specifier = ">=6.0,<7" },
|
||||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.4.0" },
|
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.4.0" },
|
||||||
|
{ name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0.0" },
|
||||||
]
|
]
|
||||||
provides-extras = ["dev", "spellcheck"]
|
provides-extras = ["dev", "spellcheck"]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user