feat(llm): interactive entity refinement with batching and cancellation
Takes the candidate set produced by phase-1 detection (manifests, git authors, regex on prose) and asks an LLM to reclassify each candidate as PERSON / PROJECT / TOPIC / COMMON_WORD / AMBIGUOUS. Scale approach: never feed the raw corpus to the LLM. For each candidate, collect up to 3 context lines from sampled prose, cap each at 240 chars, batch 25 candidates per call. Keeps total input around 50-100K tokens even on large corpora and completes in a few minutes on a 4B local model. Interactive UX: - Stderr progress bar with the current candidate name, updates per-batch. - Ctrl-C interrupts cleanly: returns a RefineResult with `cancelled=True` and whatever was classified before the interrupt. The partial result is safe to pass straight to confirm_entities. - Per-batch errors (transport, parse) are recorded in `errors` and don't abort the whole run. Refinement scope: only `uncertain` and low-confidence `projects` entries are sent. Manifest-backed projects (conf >= 0.95) and git- authored people are already authoritative and skip the LLM. Response parser is defensive — accepts `label` or `type` keys, lowercase/uppercase variants, top-level list or wrapped object, and strips markdown code fences. Unknown labels become AMBIGUOUS so the user reviews them rather than silently accepting a bad classification. `collect_corpus_text` provides a simple stratified prose sampler (recent first, capped per-file) so callers don't need to build their own corpus window. 28 tests with a FakeProvider (no network). Covers context collection, prompt building, response parsing variants, classification apply, end-to-end refine, and Ctrl-C partial-result behavior.
This commit is contained in:
@@ -0,0 +1,368 @@
|
|||||||
|
"""
|
||||||
|
llm_refine.py — Optional LLM refinement of regex-detected entities.
|
||||||
|
|
||||||
|
Takes the candidate set produced by phase-1 detection (manifests, git
|
||||||
|
authors, regex on prose) and asks an LLM to reclassify each candidate as
|
||||||
|
PERSON / PROJECT / TOPIC / COMMON_WORD / AMBIGUOUS.
|
||||||
|
|
||||||
|
Design constraints:
|
||||||
|
- Opt-in. Default init path never imports this module.
|
||||||
|
- Local-first by default (Ollama).
|
||||||
|
- Interactive UX: visible progress, clean cancellation (Ctrl-C returns
|
||||||
|
whatever was classified before the interrupt).
|
||||||
|
- Don't feed the raw corpus to the LLM — feed candidates + a few sampled
|
||||||
|
context lines each. Keeps total input to ~50-100K tokens even for huge
|
||||||
|
prose corpora.
|
||||||
|
|
||||||
|
Public:
|
||||||
|
refine_entities(detected, corpus_text, provider, ...) -> dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from mempalace.llm_client import LLMError, LLMProvider
|
||||||
|
|
||||||
|
|
||||||
|
BATCH_SIZE = 25 # candidates per LLM call; tuned for 4B local models
|
||||||
|
CONTEXT_LINES_PER_CANDIDATE = 3
|
||||||
|
CONTEXT_WINDOW_CHARS = 240 # max chars per context line to keep tokens bounded
|
||||||
|
|
||||||
|
# Valid labels the LLM is allowed to return. Anything else is treated as
|
||||||
|
# AMBIGUOUS so the user reviews it.
|
||||||
|
VALID_LABELS = {"PERSON", "PROJECT", "TOPIC", "COMMON_WORD", "AMBIGUOUS"}
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """You are helping organize a user's memory palace by classifying capitalized tokens found in their files.
|
||||||
|
|
||||||
|
For each candidate, pick exactly ONE label:
|
||||||
|
- PERSON: a specific real person the user knows (colleague, family, character they write about)
|
||||||
|
- PROJECT: a named product, codebase, or effort the user works on
|
||||||
|
- TOPIC: a recurring theme or subject (not a person, not a project) — cities, technologies, concepts
|
||||||
|
- COMMON_WORD: an English word, verb, or fragment that isn't a named entity at all (e.g. "Created", "Before", "Never")
|
||||||
|
- AMBIGUOUS: context is insufficient to decide between two of the above
|
||||||
|
|
||||||
|
Use the provided context lines to disambiguate. A capitalized word that only appears in metadata ("Created: 2026-04-24") is COMMON_WORD. A name that appears with pronouns and dialogue is PERSON.
|
||||||
|
|
||||||
|
Respond with JSON only. Schema:
|
||||||
|
{"classifications": [{"name": "<exact candidate name>", "label": "<LABEL>", "reason": "<one short sentence>"}]}
|
||||||
|
|
||||||
|
One entry per candidate, same order as the input."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RefineResult:
|
||||||
|
merged: dict # updated detected dict
|
||||||
|
reclassified: int # entries whose type changed
|
||||||
|
dropped: int # entries moved out (COMMON_WORD, or AMBIGUOUS sent to uncertain)
|
||||||
|
errors: list[str] # per-batch error messages (transport/parse failures)
|
||||||
|
batches_completed: int
|
||||||
|
batches_total: int
|
||||||
|
cancelled: bool
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_contexts(
|
||||||
|
corpus_lines: list[str], name: str, max_lines: int = CONTEXT_LINES_PER_CANDIDATE
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return up to `max_lines` distinct lines from the corpus that mention `name`.
|
||||||
|
|
||||||
|
Case-insensitive substring match. Lines are truncated to
|
||||||
|
CONTEXT_WINDOW_CHARS chars to keep token usage bounded.
|
||||||
|
"""
|
||||||
|
needle = name.lower()
|
||||||
|
seen: set[str] = set()
|
||||||
|
out: list[str] = []
|
||||||
|
for line in corpus_lines:
|
||||||
|
if needle not in line.lower():
|
||||||
|
continue
|
||||||
|
trimmed = line.strip()[:CONTEXT_WINDOW_CHARS]
|
||||||
|
if not trimmed or trimmed in seen:
|
||||||
|
continue
|
||||||
|
seen.add(trimmed)
|
||||||
|
out.append(trimmed)
|
||||||
|
if len(out) >= max_lines:
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _build_user_prompt(candidates_with_contexts: list[tuple[str, str, list[str]]]) -> str:
|
||||||
|
"""Shape: for each candidate, list its current type guess + sampled contexts."""
|
||||||
|
parts: list[str] = ["CANDIDATES:"]
|
||||||
|
for i, (name, current_type, contexts) in enumerate(candidates_with_contexts, 1):
|
||||||
|
parts.append(f"\n{i}. {name} (currently: {current_type})")
|
||||||
|
if contexts:
|
||||||
|
for c in contexts:
|
||||||
|
parts.append(f" > {c}")
|
||||||
|
else:
|
||||||
|
parts.append(" > (no context available)")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_response(text: str, expected_names: list[str]) -> dict[str, tuple[str, str]]:
|
||||||
|
"""Parse the LLM's JSON response into {name: (label, reason)}.
|
||||||
|
|
||||||
|
Robust to the model occasionally wrapping JSON in text or returning
|
||||||
|
slight schema variations. Falls back to matching by candidate name.
|
||||||
|
"""
|
||||||
|
# Strip any surrounding fences or prose
|
||||||
|
text = text.strip()
|
||||||
|
if text.startswith("```"):
|
||||||
|
text = re.sub(r"^```(?:json)?\s*", "", text)
|
||||||
|
text = re.sub(r"\s*```\s*$", "", text)
|
||||||
|
try:
|
||||||
|
data = json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
entries = data.get("classifications") if isinstance(data, dict) else data
|
||||||
|
if not isinstance(entries, list):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
name_to_label: dict[str, tuple[str, str]] = {}
|
||||||
|
expected_set = {n.lower(): n for n in expected_names}
|
||||||
|
for entry in entries:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
name = entry.get("name") or entry.get("candidate")
|
||||||
|
label = entry.get("label") or entry.get("type") or entry.get("classification")
|
||||||
|
reason = entry.get("reason") or ""
|
||||||
|
if not isinstance(name, str) or not isinstance(label, str):
|
||||||
|
continue
|
||||||
|
# Restore canonical casing from expected_names
|
||||||
|
canonical = expected_set.get(name.lower(), name)
|
||||||
|
lbl = label.strip().upper()
|
||||||
|
if lbl not in VALID_LABELS:
|
||||||
|
lbl = "AMBIGUOUS"
|
||||||
|
name_to_label[canonical] = (lbl, reason.strip()[:120])
|
||||||
|
return name_to_label
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_classifications(
|
||||||
|
detected: dict, decisions: dict[str, tuple[str, str]]
|
||||||
|
) -> tuple[dict, int, int]:
|
||||||
|
"""Merge LLM decisions back into the detected dict.
|
||||||
|
|
||||||
|
Returns (new_detected, reclassified_count, dropped_count).
|
||||||
|
"""
|
||||||
|
label_to_bucket = {
|
||||||
|
"PERSON": "people",
|
||||||
|
"PROJECT": "projects",
|
||||||
|
"TOPIC": "uncertain",
|
||||||
|
"AMBIGUOUS": "uncertain",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Index every entity by name for in-place update
|
||||||
|
all_entries: list[tuple[str, dict]] = []
|
||||||
|
for bucket, items in detected.items():
|
||||||
|
for e in items:
|
||||||
|
all_entries.append((bucket, e))
|
||||||
|
|
||||||
|
reclassified = 0
|
||||||
|
dropped = 0
|
||||||
|
new_detected: dict[str, list[dict]] = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for old_bucket, entry in all_entries:
|
||||||
|
decision = decisions.get(entry["name"])
|
||||||
|
if decision is None:
|
||||||
|
# No LLM opinion — keep as-is
|
||||||
|
new_detected[old_bucket].append(entry)
|
||||||
|
continue
|
||||||
|
|
||||||
|
label, reason = decision
|
||||||
|
if label == "COMMON_WORD":
|
||||||
|
dropped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_bucket = label_to_bucket[label]
|
||||||
|
updated = dict(entry)
|
||||||
|
# Append the LLM's reason as a new signal so the user sees why it moved
|
||||||
|
signals = list(updated.get("signals", []))
|
||||||
|
signals.append(f"LLM: {label.lower()} — {reason}" if reason else f"LLM: {label.lower()}")
|
||||||
|
updated["signals"] = signals
|
||||||
|
if target_bucket != old_bucket:
|
||||||
|
reclassified += 1
|
||||||
|
updated["type"] = (
|
||||||
|
"person"
|
||||||
|
if target_bucket == "people"
|
||||||
|
else "project"
|
||||||
|
if target_bucket == "projects"
|
||||||
|
else "uncertain"
|
||||||
|
)
|
||||||
|
new_detected[target_bucket].append(updated)
|
||||||
|
|
||||||
|
return new_detected, reclassified, dropped
|
||||||
|
|
||||||
|
|
||||||
|
def _print_progress(batch_idx: int, total: int, current_name: str) -> None:
|
||||||
|
"""Overwrite-line progress indicator."""
|
||||||
|
width = 40
|
||||||
|
filled = int(width * batch_idx / total) if total else 0
|
||||||
|
bar = "█" * filled + "░" * (width - filled)
|
||||||
|
msg = f"\r LLM refine: [{bar}] batch {batch_idx}/{total} current: {current_name[:30]:<30}"
|
||||||
|
sys.stderr.write(msg)
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def refine_entities(
|
||||||
|
detected: dict,
|
||||||
|
corpus_text: str,
|
||||||
|
provider: LLMProvider,
|
||||||
|
batch_size: int = BATCH_SIZE,
|
||||||
|
show_progress: bool = True,
|
||||||
|
) -> RefineResult:
|
||||||
|
"""Reclassify detected entities using the LLM provider.
|
||||||
|
|
||||||
|
Only candidates in the ``uncertain`` and ``projects`` buckets are sent for
|
||||||
|
refinement — ``people`` entries from git authorship are already
|
||||||
|
high-confidence and don't benefit from LLM second-guessing.
|
||||||
|
|
||||||
|
Ctrl-C during refinement: cancels the remaining batches, returns a
|
||||||
|
RefineResult with ``cancelled=True`` and whatever was classified before
|
||||||
|
the interrupt. The partial result is safe to pass straight to
|
||||||
|
``confirm_entities``.
|
||||||
|
|
||||||
|
Transport or parse failures in individual batches are recorded in
|
||||||
|
``errors`` and do not abort the run.
|
||||||
|
"""
|
||||||
|
# Only refine buckets that actually benefit — keep `people` as-is
|
||||||
|
# (git-authored people are already authoritative).
|
||||||
|
candidates: list[tuple[str, str]] = []
|
||||||
|
for bucket in ("projects", "uncertain"):
|
||||||
|
for e in detected.get(bucket, []):
|
||||||
|
# Skip already-high-confidence entries (manifest-backed projects etc.)
|
||||||
|
if e.get("confidence", 0) >= 0.95 and bucket == "projects":
|
||||||
|
continue
|
||||||
|
candidates.append((e["name"], bucket.rstrip("s"))) # "projects" -> "project"
|
||||||
|
|
||||||
|
corpus_lines = corpus_text.splitlines() if corpus_text else []
|
||||||
|
|
||||||
|
# Deduplicate candidate names while preserving order
|
||||||
|
seen: set[str] = set()
|
||||||
|
unique: list[tuple[str, str]] = []
|
||||||
|
for name, kind in candidates:
|
||||||
|
if name not in seen:
|
||||||
|
seen.add(name)
|
||||||
|
unique.append((name, kind))
|
||||||
|
|
||||||
|
if not unique:
|
||||||
|
return RefineResult(
|
||||||
|
merged=detected,
|
||||||
|
reclassified=0,
|
||||||
|
dropped=0,
|
||||||
|
errors=[],
|
||||||
|
batches_completed=0,
|
||||||
|
batches_total=0,
|
||||||
|
cancelled=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build batches
|
||||||
|
batches: list[list[tuple[str, str, list[str]]]] = []
|
||||||
|
for i in range(0, len(unique), batch_size):
|
||||||
|
chunk = unique[i : i + batch_size]
|
||||||
|
enriched = [(name, kind, _collect_contexts(corpus_lines, name)) for name, kind in chunk]
|
||||||
|
batches.append(enriched)
|
||||||
|
|
||||||
|
all_decisions: dict[str, tuple[str, str]] = {}
|
||||||
|
errors: list[str] = []
|
||||||
|
completed = 0
|
||||||
|
cancelled = False
|
||||||
|
|
||||||
|
for idx, batch in enumerate(batches, 1):
|
||||||
|
if show_progress and batch:
|
||||||
|
_print_progress(idx - 1, len(batches), batch[0][0])
|
||||||
|
user_prompt = _build_user_prompt(batch)
|
||||||
|
try:
|
||||||
|
resp = provider.classify(SYSTEM_PROMPT, user_prompt, json_mode=True)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
cancelled = True
|
||||||
|
break
|
||||||
|
except LLMError as e:
|
||||||
|
errors.append(f"batch {idx}: {e}")
|
||||||
|
continue
|
||||||
|
names_in_batch = [name for name, _, _ in batch]
|
||||||
|
decisions = _parse_response(resp.text, names_in_batch)
|
||||||
|
if not decisions:
|
||||||
|
errors.append(f"batch {idx}: could not parse response")
|
||||||
|
all_decisions.update(decisions)
|
||||||
|
completed += 1
|
||||||
|
if show_progress:
|
||||||
|
_print_progress(idx, len(batches), batch[-1][0])
|
||||||
|
|
||||||
|
if show_progress:
|
||||||
|
sys.stderr.write("\n")
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
merged, reclassified, dropped = _apply_classifications(detected, all_decisions)
|
||||||
|
|
||||||
|
return RefineResult(
|
||||||
|
merged=merged,
|
||||||
|
reclassified=reclassified,
|
||||||
|
dropped=dropped,
|
||||||
|
errors=errors,
|
||||||
|
batches_completed=completed,
|
||||||
|
batches_total=len(batches),
|
||||||
|
cancelled=cancelled,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def collect_corpus_text(
|
||||||
|
project_dir: str,
|
||||||
|
max_files: int = 30,
|
||||||
|
max_bytes_per_file: int = 20_000,
|
||||||
|
) -> str:
|
||||||
|
"""Gather prose text from ``project_dir`` for use as LLM context source.
|
||||||
|
|
||||||
|
Stratified: reads up to ``max_files`` prose files (``.md``, ``.txt``,
|
||||||
|
``.rst``), preferring recently-modified. Each file capped at
|
||||||
|
``max_bytes_per_file`` to bound total input.
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from mempalace.entity_detector import PROSE_EXTENSIONS, SKIP_DIRS
|
||||||
|
|
||||||
|
root = Path(project_dir).expanduser().resolve()
|
||||||
|
if not root.is_dir():
|
||||||
|
return ""
|
||||||
|
candidates: list[tuple[float, Path]] = []
|
||||||
|
for dirpath, dirs, files in _walk_prose(root, SKIP_DIRS):
|
||||||
|
for fname in files:
|
||||||
|
p = dirpath / fname
|
||||||
|
if p.suffix.lower() not in PROSE_EXTENSIONS:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
mtime = p.stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
candidates.append((mtime, p))
|
||||||
|
candidates.sort(reverse=True)
|
||||||
|
selected = [p for _, p in candidates[:max_files]]
|
||||||
|
chunks: list[str] = []
|
||||||
|
for p in selected:
|
||||||
|
try:
|
||||||
|
with open(p, encoding="utf-8", errors="replace") as f:
|
||||||
|
chunks.append(f.read(max_bytes_per_file))
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
return "\n".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def _walk_prose(root, skip_dirs):
|
||||||
|
"""Walk a directory yielding (Path, dirs, files), pruning skip_dirs.
|
||||||
|
|
||||||
|
Inlined from ``project_scanner._walk`` to avoid a private-name import
|
||||||
|
coupling. Functionality is intentionally narrow: prose collection only.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
for dirpath, dirs, files in os.walk(root):
|
||||||
|
dirs[:] = [d for d in dirs if d not in skip_dirs and not d.startswith(".")]
|
||||||
|
yield Path(dirpath), dirs, files
|
||||||
@@ -0,0 +1,446 @@
|
|||||||
|
"""Tests for mempalace.llm_refine.
|
||||||
|
|
||||||
|
Uses a fake provider for deterministic, offline tests. No network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
from mempalace.llm_client import LLMError, LLMResponse
|
||||||
|
from mempalace.llm_refine import (
|
||||||
|
_apply_classifications,
|
||||||
|
_build_user_prompt,
|
||||||
|
_collect_contexts,
|
||||||
|
_parse_response,
|
||||||
|
collect_corpus_text,
|
||||||
|
refine_entities,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── fake provider ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeProvider:
|
||||||
|
"""Returns a caller-supplied JSON string on every classify call."""
|
||||||
|
|
||||||
|
response_text: str = ""
|
||||||
|
should_raise: Exception = None
|
||||||
|
call_count: int = 0
|
||||||
|
interrupt_on_call: int = -1
|
||||||
|
|
||||||
|
def classify(self, system, user, json_mode=True):
|
||||||
|
self.call_count += 1
|
||||||
|
if self.call_count == self.interrupt_on_call:
|
||||||
|
raise KeyboardInterrupt()
|
||||||
|
if self.should_raise is not None:
|
||||||
|
raise self.should_raise
|
||||||
|
return LLMResponse(text=self.response_text, model="fake", provider="fake", raw={})
|
||||||
|
|
||||||
|
def check_available(self):
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _collect_contexts ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_contexts_finds_matches():
|
||||||
|
lines = [
|
||||||
|
"Something about Alice",
|
||||||
|
"Bob said hello",
|
||||||
|
"Alice was here",
|
||||||
|
"Alice walked by",
|
||||||
|
]
|
||||||
|
out = _collect_contexts(lines, "Alice", max_lines=2)
|
||||||
|
assert len(out) == 2
|
||||||
|
assert all("alice" in line.lower() for line in out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_contexts_case_insensitive():
|
||||||
|
lines = ["lowercase alice mention"]
|
||||||
|
out = _collect_contexts(lines, "Alice")
|
||||||
|
assert out == ["lowercase alice mention"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_contexts_dedupes_identical_lines():
|
||||||
|
lines = ["Alice", "Alice", "Alice was here"]
|
||||||
|
out = _collect_contexts(lines, "Alice", max_lines=5)
|
||||||
|
# two unique lines, not three
|
||||||
|
assert len(out) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_contexts_truncates_long_lines():
|
||||||
|
lines = ["Alice " + ("x" * 1000)]
|
||||||
|
out = _collect_contexts(lines, "Alice")
|
||||||
|
assert len(out[0]) <= 240
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_contexts_no_matches():
|
||||||
|
assert _collect_contexts(["nothing here"], "Alice") == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _build_user_prompt ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_prompt_numbers_and_includes_contexts():
|
||||||
|
prompt = _build_user_prompt(
|
||||||
|
[
|
||||||
|
("Alice", "uncertain", ["Alice said hi"]),
|
||||||
|
("Bob", "project", []),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert "1. Alice" in prompt
|
||||||
|
assert "2. Bob" in prompt
|
||||||
|
assert "Alice said hi" in prompt
|
||||||
|
assert "(no context available)" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
# ── _parse_response ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_canonicalizes_label():
|
||||||
|
text = '{"classifications": [{"name": "Alice", "label": "person", "reason": "x"}]}'
|
||||||
|
out = _parse_response(text, ["Alice"])
|
||||||
|
assert out["Alice"] == ("PERSON", "x")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_accepts_type_alias():
|
||||||
|
"""LLMs may return 'type' instead of 'label'."""
|
||||||
|
text = '{"classifications": [{"name": "Bob", "type": "PROJECT"}]}'
|
||||||
|
out = _parse_response(text, ["Bob"])
|
||||||
|
assert out["Bob"][0] == "PROJECT"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_maps_unknown_label_to_ambiguous():
|
||||||
|
text = '{"classifications": [{"name": "X", "label": "WEIRD"}]}'
|
||||||
|
out = _parse_response(text, ["X"])
|
||||||
|
assert out["X"][0] == "AMBIGUOUS"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_restores_canonical_casing():
|
||||||
|
"""Model may lowercase the name; we restore against the expected set."""
|
||||||
|
text = '{"classifications": [{"name": "mempalace", "label": "PROJECT"}]}'
|
||||||
|
out = _parse_response(text, ["MemPalace"])
|
||||||
|
assert "MemPalace" in out
|
||||||
|
assert out["MemPalace"][0] == "PROJECT"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_strips_code_fences():
|
||||||
|
text = '```json\n{"classifications": [{"name": "X", "label": "TOPIC"}]}\n```'
|
||||||
|
out = _parse_response(text, ["X"])
|
||||||
|
assert out["X"][0] == "TOPIC"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_malformed_returns_empty():
|
||||||
|
out = _parse_response("not json at all", ["X"])
|
||||||
|
assert out == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_accepts_top_level_list():
|
||||||
|
"""Some models skip the wrapping object and return the list directly."""
|
||||||
|
text = '[{"name": "Y", "label": "PERSON"}]'
|
||||||
|
out = _parse_response(text, ["Y"])
|
||||||
|
assert out["Y"][0] == "PERSON"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _apply_classifications ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_moves_to_correct_bucket():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [
|
||||||
|
{
|
||||||
|
"name": "Foo",
|
||||||
|
"type": "project",
|
||||||
|
"confidence": 0.8,
|
||||||
|
"frequency": 3,
|
||||||
|
"signals": ["old"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"uncertain": [
|
||||||
|
{"name": "Alice", "type": "uncertain", "confidence": 0.4, "frequency": 5, "signals": []}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
decisions = {
|
||||||
|
"Foo": ("PROJECT", "real project name"),
|
||||||
|
"Alice": ("PERSON", "clearly a person"),
|
||||||
|
}
|
||||||
|
new, reclass, dropped = _apply_classifications(detected, decisions)
|
||||||
|
assert len(new["people"]) == 1
|
||||||
|
assert new["people"][0]["name"] == "Alice"
|
||||||
|
assert new["people"][0]["type"] == "person"
|
||||||
|
assert reclass == 1 # Alice moved uncertain -> people
|
||||||
|
assert dropped == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_drops_common_word():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [
|
||||||
|
{
|
||||||
|
"name": "Never",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 20,
|
||||||
|
"signals": [],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
decisions = {"Never": ("COMMON_WORD", "adverb")}
|
||||||
|
new, _, dropped = _apply_classifications(detected, decisions)
|
||||||
|
assert dropped == 1
|
||||||
|
assert new["uncertain"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_keeps_unvisited_entries():
|
||||||
|
detected = {
|
||||||
|
"people": [
|
||||||
|
{
|
||||||
|
"name": "Igor",
|
||||||
|
"type": "person",
|
||||||
|
"confidence": 0.99,
|
||||||
|
"frequency": 100,
|
||||||
|
"signals": ["git"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
# No decision for Igor — should stay untouched
|
||||||
|
new, reclass, dropped = _apply_classifications(detected, {})
|
||||||
|
assert new["people"][0]["name"] == "Igor"
|
||||||
|
assert reclass == 0
|
||||||
|
assert dropped == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_appends_reason_signal():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [
|
||||||
|
{
|
||||||
|
"name": "Foo",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": ["regex"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
decisions = {"Foo": ("PERSON", "spoken of by name")}
|
||||||
|
new, _, _ = _apply_classifications(detected, decisions)
|
||||||
|
assert any("LLM: person" in s for s in new["people"][0]["signals"])
|
||||||
|
assert any("spoken of by name" in s for s in new["people"][0]["signals"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_classifications_topic_goes_to_uncertain():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [
|
||||||
|
{
|
||||||
|
"name": "Paris",
|
||||||
|
"type": "project",
|
||||||
|
"confidence": 0.7,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": ["regex"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
decisions = {"Paris": ("TOPIC", "city, not a project")}
|
||||||
|
new, reclass, _ = _apply_classifications(detected, decisions)
|
||||||
|
assert len(new["projects"]) == 0
|
||||||
|
assert len(new["uncertain"]) == 1
|
||||||
|
assert new["uncertain"][0]["name"] == "Paris"
|
||||||
|
assert reclass == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── refine_entities ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_detected():
|
||||||
|
return {
|
||||||
|
"people": [
|
||||||
|
{
|
||||||
|
"name": "Igor",
|
||||||
|
"type": "person",
|
||||||
|
"confidence": 0.99,
|
||||||
|
"frequency": 100,
|
||||||
|
"signals": ["git"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"projects": [
|
||||||
|
{
|
||||||
|
"name": "Foo",
|
||||||
|
"type": "project",
|
||||||
|
"confidence": 0.7,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": ["regex"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"uncertain": [
|
||||||
|
{
|
||||||
|
"name": "Never",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 10,
|
||||||
|
"signals": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Alice",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 5,
|
||||||
|
"signals": [],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_end_to_end_with_fake_provider():
|
||||||
|
provider = FakeProvider(
|
||||||
|
response_text=(
|
||||||
|
'{"classifications": ['
|
||||||
|
'{"name": "Foo", "label": "PROJECT", "reason": "real"},'
|
||||||
|
'{"name": "Never", "label": "COMMON_WORD"},'
|
||||||
|
'{"name": "Alice", "label": "PERSON", "reason": "name"}'
|
||||||
|
"]}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = refine_entities(
|
||||||
|
_sample_detected(),
|
||||||
|
corpus_text="Alice said hi. Foo was shipped. Never gonna.",
|
||||||
|
provider=provider,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
assert result.batches_total == 1
|
||||||
|
assert result.batches_completed == 1
|
||||||
|
assert not result.cancelled
|
||||||
|
# Alice → people, Never → dropped, Foo stays in projects
|
||||||
|
names_in_people = [e["name"] for e in result.merged["people"]]
|
||||||
|
assert "Alice" in names_in_people
|
||||||
|
assert "Igor" in names_in_people # untouched
|
||||||
|
assert "Never" not in [e["name"] for e in result.merged["uncertain"]]
|
||||||
|
assert result.dropped == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_skips_high_confidence_projects():
|
||||||
|
"""Manifest-backed projects (conf >= 0.95) aren't sent to the LLM."""
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [
|
||||||
|
{
|
||||||
|
"name": "manifest-backed",
|
||||||
|
"type": "project",
|
||||||
|
"confidence": 0.99,
|
||||||
|
"frequency": 50,
|
||||||
|
"signals": ["pyproject.toml"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
provider = FakeProvider(response_text='{"classifications": []}')
|
||||||
|
refine_entities(detected, "", provider, show_progress=False)
|
||||||
|
# Should not have called the LLM at all
|
||||||
|
assert provider.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_empty_candidates_returns_noop():
|
||||||
|
detected = {"people": [], "projects": [], "uncertain": []}
|
||||||
|
provider = FakeProvider()
|
||||||
|
result = refine_entities(detected, "", provider, show_progress=False)
|
||||||
|
assert result.batches_total == 0
|
||||||
|
assert result.reclassified == 0
|
||||||
|
assert result.merged == detected
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_handles_batch_error_gracefully():
|
||||||
|
provider = FakeProvider(should_raise=LLMError("transport broke"))
|
||||||
|
result = refine_entities(
|
||||||
|
_sample_detected(),
|
||||||
|
corpus_text="",
|
||||||
|
provider=provider,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
assert result.errors
|
||||||
|
assert "transport broke" in result.errors[0]
|
||||||
|
# Detected unchanged (no successful decisions)
|
||||||
|
assert result.reclassified == 0
|
||||||
|
assert result.cancelled is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_ctrl_c_returns_partial():
|
||||||
|
"""Ctrl-C during refinement marks cancelled=True and returns partial result."""
|
||||||
|
# Two batches' worth of candidates
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [
|
||||||
|
{
|
||||||
|
"name": f"Cand{i}",
|
||||||
|
"type": "uncertain",
|
||||||
|
"confidence": 0.4,
|
||||||
|
"frequency": 3,
|
||||||
|
"signals": [],
|
||||||
|
}
|
||||||
|
for i in range(50)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
provider = FakeProvider(
|
||||||
|
response_text='{"classifications": []}',
|
||||||
|
interrupt_on_call=2, # interrupt on second batch
|
||||||
|
)
|
||||||
|
result = refine_entities(detected, "", provider, batch_size=25, show_progress=False)
|
||||||
|
assert result.cancelled is True
|
||||||
|
assert result.batches_completed == 1 # first batch finished; second interrupted
|
||||||
|
assert result.batches_total == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_entities_malformed_response_recorded_as_error():
|
||||||
|
provider = FakeProvider(response_text="not json")
|
||||||
|
result = refine_entities(_sample_detected(), "", provider, show_progress=False)
|
||||||
|
assert any("could not parse" in e for e in result.errors)
|
||||||
|
|
||||||
|
|
||||||
|
# ── collect_corpus_text ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_corpus_text_reads_prose_files(tmp_path):
|
||||||
|
(tmp_path / "a.md").write_text("hello world")
|
||||||
|
(tmp_path / "b.txt").write_text("more prose")
|
||||||
|
(tmp_path / "c.py").write_text("import os") # not prose, skipped
|
||||||
|
text = collect_corpus_text(str(tmp_path))
|
||||||
|
assert "hello world" in text
|
||||||
|
assert "more prose" in text
|
||||||
|
assert "import os" not in text
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_corpus_text_prefers_recent(tmp_path):
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
old = tmp_path / "old.md"
|
||||||
|
old.write_text("OLD_CONTENT")
|
||||||
|
time.sleep(0.01)
|
||||||
|
new = tmp_path / "new.md"
|
||||||
|
new.write_text("NEW_CONTENT")
|
||||||
|
# Force old to be older still
|
||||||
|
old_mtime = old.stat().st_mtime - 3600
|
||||||
|
os.utime(old, (old_mtime, old_mtime))
|
||||||
|
|
||||||
|
text = collect_corpus_text(str(tmp_path), max_files=1)
|
||||||
|
assert "NEW_CONTENT" in text
|
||||||
|
assert "OLD_CONTENT" not in text
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_corpus_text_missing_dir_returns_empty(tmp_path):
|
||||||
|
assert collect_corpus_text(str(tmp_path / "nope")) == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_corpus_text_caps_bytes_per_file(tmp_path):
|
||||||
|
big = tmp_path / "big.md"
|
||||||
|
big.write_text("x" * 100_000)
|
||||||
|
text = collect_corpus_text(str(tmp_path), max_files=1, max_bytes_per_file=500)
|
||||||
|
assert len(text) <= 600 # 500 + newlines
|
||||||
Reference in New Issue
Block a user