#1148, #1150, and #1157 were reviewed and merged on GitHub, but the two stacked children landed on their parent feature branches (now stale) rather than on develop. Only #1148's commits reached develop via the direct merge. Release PR #1159 (develop → main for v3.3.3) is therefore missing the LLM refinement, Claude-conversation scanner, and miner- registry wire-up that were ostensibly part of the release. This merge brings the stale `feat/llm-entity-refine` branch (which contains the rolled-up merge commit for #1157 → #1150 → everything below) into develop so the release tag includes it. No code changes here — only history recovery.
This commit is contained in:
@@ -0,0 +1,218 @@
|
||||
"""Tests for mempalace.convo_scanner."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from mempalace.convo_scanner import (
|
||||
_decode_slug_fallback,
|
||||
_extract_cwd_from_session,
|
||||
_resolve_project_name,
|
||||
_safe_mtime,
|
||||
is_claude_projects_root,
|
||||
scan_claude_projects,
|
||||
)
|
||||
|
||||
|
||||
# ── is_claude_projects_root ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_is_claude_projects_root_true(tmp_path):
|
||||
project_dir = tmp_path / "-home-user-dev-foo"
|
||||
project_dir.mkdir()
|
||||
(project_dir / "abc.jsonl").write_text("{}\n")
|
||||
assert is_claude_projects_root(tmp_path)
|
||||
|
||||
|
||||
def test_is_claude_projects_root_false_no_dash_prefix(tmp_path):
|
||||
project_dir = tmp_path / "normal-folder"
|
||||
project_dir.mkdir()
|
||||
(project_dir / "abc.jsonl").write_text("{}\n")
|
||||
assert not is_claude_projects_root(tmp_path)
|
||||
|
||||
|
||||
def test_is_claude_projects_root_false_no_jsonl(tmp_path):
|
||||
project_dir = tmp_path / "-home-user-foo"
|
||||
project_dir.mkdir()
|
||||
(project_dir / "other.txt").write_text("hello")
|
||||
assert not is_claude_projects_root(tmp_path)
|
||||
|
||||
|
||||
def test_is_claude_projects_root_false_empty(tmp_path):
|
||||
assert not is_claude_projects_root(tmp_path)
|
||||
|
||||
|
||||
def test_is_claude_projects_root_false_nonexistent(tmp_path):
|
||||
assert not is_claude_projects_root(tmp_path / "does-not-exist")
|
||||
|
||||
|
||||
# ── cwd extraction ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_cwd_from_session(tmp_path):
|
||||
f = tmp_path / "session.jsonl"
|
||||
lines = [
|
||||
json.dumps({"type": "file-history-snapshot", "messageId": "x"}),
|
||||
json.dumps({"type": "user", "cwd": "/home/user/dev/myproj", "content": "hi"}),
|
||||
]
|
||||
f.write_text("\n".join(lines) + "\n")
|
||||
assert _extract_cwd_from_session(f) == "/home/user/dev/myproj"
|
||||
|
||||
|
||||
def test_extract_cwd_from_session_skips_malformed(tmp_path):
|
||||
f = tmp_path / "session.jsonl"
|
||||
f.write_text(
|
||||
"{not valid json\n" + json.dumps({"type": "user", "cwd": "/home/user/dev/good"}) + "\n"
|
||||
)
|
||||
assert _extract_cwd_from_session(f) == "/home/user/dev/good"
|
||||
|
||||
|
||||
def test_extract_cwd_from_session_none_if_absent(tmp_path):
|
||||
f = tmp_path / "session.jsonl"
|
||||
f.write_text(json.dumps({"type": "x", "messageId": "y"}) + "\n")
|
||||
assert _extract_cwd_from_session(f) is None
|
||||
|
||||
|
||||
def test_extract_cwd_from_session_none_if_file_missing(tmp_path):
|
||||
assert _extract_cwd_from_session(tmp_path / "missing.jsonl") is None
|
||||
|
||||
|
||||
# ── slug fallback ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_slug_fallback_last_segment():
|
||||
assert _decode_slug_fallback("-home-user-dev-foo") == "foo"
|
||||
|
||||
|
||||
def test_decode_slug_fallback_double_dash():
|
||||
assert _decode_slug_fallback("-home-user--bentokit") == "bentokit"
|
||||
|
||||
|
||||
def test_decode_slug_fallback_empty():
|
||||
assert _decode_slug_fallback("") == ""
|
||||
|
||||
|
||||
def test_decode_slug_fallback_only_dashes():
|
||||
assert _decode_slug_fallback("---") == "---"
|
||||
|
||||
|
||||
# ── safe metadata helpers ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_safe_mtime_returns_zero_on_stat_error(tmp_path, monkeypatch):
|
||||
f = tmp_path / "session.jsonl"
|
||||
f.write_text("{}\n")
|
||||
original_stat = Path.stat
|
||||
|
||||
def fail_stat(self):
|
||||
if self == f:
|
||||
raise OSError("permission denied")
|
||||
return original_stat(self)
|
||||
|
||||
monkeypatch.setattr(Path, "stat", fail_stat)
|
||||
assert _safe_mtime(f) == 0.0
|
||||
|
||||
|
||||
# ── _resolve_project_name ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_resolve_project_name_uses_cwd(tmp_path):
|
||||
pdir = tmp_path / "-home-user-dev-coolproj"
|
||||
pdir.mkdir()
|
||||
session = pdir / "a.jsonl"
|
||||
session.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/cool-proj-real"}) + "\n")
|
||||
assert _resolve_project_name(pdir) == "cool-proj-real"
|
||||
|
||||
|
||||
def test_resolve_project_name_falls_back_when_no_cwd(tmp_path):
|
||||
pdir = tmp_path / "-home-user-dev-foo"
|
||||
pdir.mkdir()
|
||||
(pdir / "a.jsonl").write_text(json.dumps({"type": "x"}) + "\n")
|
||||
assert _resolve_project_name(pdir) == "foo"
|
||||
|
||||
|
||||
def test_resolve_project_name_prefers_newer_session(tmp_path):
|
||||
"""Newest session's cwd wins — covers the case where user renamed the
|
||||
project directory between sessions."""
|
||||
|
||||
pdir = tmp_path / "-home-user-dev-old"
|
||||
pdir.mkdir()
|
||||
old = pdir / "old.jsonl"
|
||||
old.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/old"}) + "\n")
|
||||
# Ensure distinguishable mtimes
|
||||
old_mtime = old.stat().st_mtime - 100
|
||||
import os
|
||||
|
||||
os.utime(old, (old_mtime, old_mtime))
|
||||
|
||||
new = pdir / "new.jsonl"
|
||||
new.write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/new-name"}) + "\n")
|
||||
assert _resolve_project_name(pdir) == "new-name"
|
||||
|
||||
|
||||
# ── scan_claude_projects ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_scan_claude_projects_empty_dir(tmp_path):
|
||||
assert scan_claude_projects(tmp_path) == []
|
||||
|
||||
|
||||
def test_scan_claude_projects_not_a_projects_root(tmp_path):
|
||||
"""Returns empty list if the dir doesn't look like .claude/projects/."""
|
||||
(tmp_path / "some-folder").mkdir()
|
||||
(tmp_path / "some-folder" / "readme.md").write_text("hi")
|
||||
assert scan_claude_projects(tmp_path) == []
|
||||
|
||||
|
||||
def test_scan_claude_projects_finds_projects(tmp_path):
|
||||
p1 = tmp_path / "-home-user-dev-alpha"
|
||||
p1.mkdir()
|
||||
(p1 / "a.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/alpha"}) + "\n")
|
||||
(p1 / "b.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/alpha"}) + "\n")
|
||||
|
||||
p2 = tmp_path / "-home-user-dev-beta"
|
||||
p2.mkdir()
|
||||
(p2 / "x.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/beta"}) + "\n")
|
||||
|
||||
result = scan_claude_projects(tmp_path)
|
||||
names = [p.name for p in result]
|
||||
assert "alpha" in names
|
||||
assert "beta" in names
|
||||
# alpha has 2 sessions, beta has 1 — alpha ranks higher
|
||||
alpha = next(p for p in result if p.name == "alpha")
|
||||
beta = next(p for p in result if p.name == "beta")
|
||||
assert alpha.user_commits == 2
|
||||
assert beta.user_commits == 1
|
||||
|
||||
|
||||
def test_scan_claude_projects_ignores_dirs_without_jsonl(tmp_path):
|
||||
empty_proj = tmp_path / "-home-user-dev-empty"
|
||||
empty_proj.mkdir()
|
||||
(empty_proj / "notes.md").write_text("hi")
|
||||
assert scan_claude_projects(tmp_path) == []
|
||||
|
||||
|
||||
def test_scan_claude_projects_marks_as_mine(tmp_path):
|
||||
p = tmp_path / "-home-user-dev-owned"
|
||||
p.mkdir()
|
||||
(p / "s.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/dev/owned"}) + "\n")
|
||||
result = scan_claude_projects(tmp_path)
|
||||
assert len(result) == 1
|
||||
assert result[0].is_mine is True
|
||||
|
||||
|
||||
def test_scan_claude_projects_dedup_by_name(tmp_path):
|
||||
"""Two encoded dirs resolving to the same project name collapse to one."""
|
||||
p1 = tmp_path / "-home-user-a-proj"
|
||||
p1.mkdir()
|
||||
(p1 / "s.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/a/proj"}) + "\n")
|
||||
(p1 / "t.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/a/proj"}) + "\n")
|
||||
|
||||
p2 = tmp_path / "-home-user-b-proj"
|
||||
p2.mkdir()
|
||||
(p2 / "u.jsonl").write_text(json.dumps({"type": "user", "cwd": "/home/user/b/proj"}) + "\n")
|
||||
|
||||
result = scan_claude_projects(tmp_path)
|
||||
# Both decode to "proj"; only one remains — the one with more sessions wins
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "proj"
|
||||
assert result[0].user_commits == 2
|
||||
@@ -0,0 +1,208 @@
|
||||
"""Tests for mempalace.miner.add_to_known_entities.
|
||||
|
||||
Covers the init → miner wire-up: init's confirmed entities merged into
|
||||
``~/.mempalace/known_entities.json`` so the miner's drawer-tagging path
|
||||
recognizes them at mine time.
|
||||
|
||||
Every test redirects the registry path to a tmp_path to avoid touching
|
||||
the real ~/.mempalace/ on the developer's machine.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from mempalace import miner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_registry(tmp_path, monkeypatch):
|
||||
"""Redirect the module-level registry path to a tmp file and reset cache."""
|
||||
registry = tmp_path / "known_entities.json"
|
||||
monkeypatch.setattr(miner, "_ENTITY_REGISTRY_PATH", str(registry))
|
||||
miner._ENTITY_REGISTRY_CACHE.update({"mtime": None, "names": frozenset(), "raw": {}})
|
||||
return registry
|
||||
|
||||
|
||||
# ── fresh-file cases ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_creates_registry_when_absent(temp_registry):
|
||||
assert not temp_registry.exists()
|
||||
miner.add_to_known_entities({"people": ["Alice", "Bob"], "projects": ["foo"]})
|
||||
assert temp_registry.exists()
|
||||
data = json.loads(temp_registry.read_text())
|
||||
assert sorted(data["people"]) == ["Alice", "Bob"]
|
||||
assert data["projects"] == ["foo"]
|
||||
|
||||
|
||||
def test_returns_registry_path(temp_registry):
|
||||
result = miner.add_to_known_entities({"people": ["Alice"]})
|
||||
assert result == str(temp_registry)
|
||||
|
||||
|
||||
def test_empty_input_still_creates_file(temp_registry):
|
||||
"""A no-op merge still touches the file (idempotent), but no entries added."""
|
||||
miner.add_to_known_entities({})
|
||||
# File may or may not be written for a truly empty call — tolerate either.
|
||||
if temp_registry.exists():
|
||||
data = json.loads(temp_registry.read_text())
|
||||
assert data == {} or all(not v for v in data.values())
|
||||
|
||||
|
||||
def test_skips_empty_name_strings(temp_registry):
|
||||
miner.add_to_known_entities({"people": ["Alice", "", None]})
|
||||
data = json.loads(temp_registry.read_text())
|
||||
assert data["people"] == ["Alice"]
|
||||
|
||||
|
||||
# ── union / dedup cases ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_unions_with_existing_list_category(temp_registry):
|
||||
temp_registry.write_text(json.dumps({"people": ["Alice", "Bob"]}))
|
||||
miner.add_to_known_entities({"people": ["Bob", "Carol"]})
|
||||
data = json.loads(temp_registry.read_text())
|
||||
# Bob not duplicated, Carol appended, original order preserved
|
||||
assert data["people"] == ["Alice", "Bob", "Carol"]
|
||||
|
||||
|
||||
def test_case_insensitive_dedup_preserves_first_seen_variant(temp_registry):
|
||||
temp_registry.write_text(json.dumps({"people": ["Alice"]}))
|
||||
miner.add_to_known_entities({"people": ["alice", "ALICE", "Bob"]})
|
||||
data = json.loads(temp_registry.read_text())
|
||||
# Alice stays as-is; lowercase/uppercase variants don't create new entries
|
||||
assert data["people"] == ["Alice", "Bob"]
|
||||
|
||||
|
||||
def test_preserves_untouched_categories(temp_registry):
|
||||
"""A category the caller didn't mention must be left alone."""
|
||||
temp_registry.write_text(json.dumps({"people": ["Alice"], "places": ["Paris", "Tokyo"]}))
|
||||
miner.add_to_known_entities({"people": ["Bob"]})
|
||||
data = json.loads(temp_registry.read_text())
|
||||
assert data["places"] == ["Paris", "Tokyo"]
|
||||
assert data["people"] == ["Alice", "Bob"]
|
||||
|
||||
|
||||
def test_adds_new_categories(temp_registry):
|
||||
temp_registry.write_text(json.dumps({"people": ["Alice"]}))
|
||||
miner.add_to_known_entities({"projects": ["foo", "bar"]})
|
||||
data = json.loads(temp_registry.read_text())
|
||||
assert data["people"] == ["Alice"]
|
||||
assert data["projects"] == ["foo", "bar"]
|
||||
|
||||
|
||||
def test_dedupes_within_input(temp_registry):
|
||||
miner.add_to_known_entities({"people": ["Alice", "alice", "Alice"]})
|
||||
data = json.loads(temp_registry.read_text())
|
||||
assert data["people"] == ["Alice"]
|
||||
|
||||
|
||||
# ── dict-format existing registry ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_dict_format_existing_category_gets_new_keys(temp_registry):
|
||||
"""Miner supports {name: code} dict categories (alternate registry shape).
|
||||
New names are added as keys without overwriting existing codes."""
|
||||
temp_registry.write_text(json.dumps({"people": {"Alice": "ALC", "Bob": "BOB"}}))
|
||||
miner.add_to_known_entities({"people": ["Alice", "Carol"]})
|
||||
data = json.loads(temp_registry.read_text())
|
||||
# Alice's code survives; Carol added with None; Bob untouched
|
||||
assert data["people"]["Alice"] == "ALC"
|
||||
assert data["people"]["Bob"] == "BOB"
|
||||
assert "Carol" in data["people"]
|
||||
assert data["people"]["Carol"] is None
|
||||
|
||||
|
||||
def test_dict_format_dedupes_case_insensitively_and_stringifies_new_names(temp_registry):
|
||||
temp_registry.write_text(json.dumps({"people": {"Alice": "ALC"}}))
|
||||
miner.add_to_known_entities({"people": ["alice", 123]})
|
||||
data = json.loads(temp_registry.read_text())
|
||||
assert data["people"] == {"Alice": "ALC", "123": None}
|
||||
|
||||
|
||||
# ── error tolerance ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_malformed_existing_registry_starts_fresh(temp_registry):
|
||||
temp_registry.write_text("{ not valid json")
|
||||
miner.add_to_known_entities({"people": ["Alice"]})
|
||||
data = json.loads(temp_registry.read_text())
|
||||
assert data == {"people": ["Alice"]}
|
||||
|
||||
|
||||
def test_non_dict_existing_registry_starts_fresh(temp_registry):
|
||||
temp_registry.write_text(json.dumps(["unexpected", "array"]))
|
||||
miner.add_to_known_entities({"people": ["Alice"]})
|
||||
data = json.loads(temp_registry.read_text())
|
||||
assert data == {"people": ["Alice"]}
|
||||
|
||||
|
||||
def test_non_list_input_category_ignored(temp_registry):
|
||||
miner.add_to_known_entities({"people": ["Alice"], "weird": "not a list"})
|
||||
data = json.loads(temp_registry.read_text())
|
||||
assert "weird" not in data or data.get("weird") == "not a list"
|
||||
assert data["people"] == ["Alice"]
|
||||
|
||||
|
||||
# ── cache invalidation ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_cache_invalidated_so_subsequent_load_sees_write(temp_registry):
|
||||
"""cmd_init → cmd_mine runs in the same process; the load path must
|
||||
see what init just wrote without a process restart."""
|
||||
# Prime the cache with an empty state
|
||||
miner._load_known_entities()
|
||||
assert miner._load_known_entities() == frozenset()
|
||||
|
||||
miner.add_to_known_entities({"people": ["Alice", "Bob"], "projects": ["foo"]})
|
||||
|
||||
loaded = miner._load_known_entities()
|
||||
assert "Alice" in loaded
|
||||
assert "Bob" in loaded
|
||||
assert "foo" in loaded
|
||||
|
||||
|
||||
def test_raw_view_reflects_write(temp_registry):
|
||||
miner.add_to_known_entities({"people": ["Alice"]})
|
||||
raw = miner._load_known_entities_raw()
|
||||
assert raw.get("people") == ["Alice"]
|
||||
|
||||
|
||||
# ── Unicode round-trip ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_unicode_names_written_literally_not_escaped(temp_registry):
|
||||
"""`ensure_ascii=False` so non-ASCII names stay readable on disk."""
|
||||
miner.add_to_known_entities({"people": ["Gergő Móricz", "Arturo Domínguez"]})
|
||||
raw_text = temp_registry.read_text(encoding="utf-8")
|
||||
assert "Gergő" in raw_text
|
||||
assert "Móricz" in raw_text
|
||||
# Round-trips through JSON
|
||||
data = json.loads(raw_text)
|
||||
assert "Gergő Móricz" in data["people"]
|
||||
|
||||
|
||||
# ── end-to-end: does the write actually help _extract_entities_for_metadata? ──
|
||||
|
||||
|
||||
def test_populated_registry_improves_miner_recall(temp_registry):
|
||||
"""The whole point of the wire-up: names written via add_to_known_entities
|
||||
must be recognized by the miner's entity-extraction metadata pass."""
|
||||
miner.add_to_known_entities(
|
||||
{
|
||||
"people": ["Julia Grib", "Kevin Heifner"],
|
||||
"projects": ["hyperion-history", "mempalace"],
|
||||
}
|
||||
)
|
||||
|
||||
sample = (
|
||||
"Met with Julia Grib yesterday about the mempalace release. "
|
||||
"Kevin Heifner pushed the hyperion-history fix."
|
||||
)
|
||||
result = miner._extract_entities_for_metadata(sample)
|
||||
tagged = set(result.split(";")) if result else set()
|
||||
|
||||
# All four registered entities should land in the metadata string
|
||||
for expected in ("Julia Grib", "Kevin Heifner", "hyperion-history", "mempalace"):
|
||||
assert expected in tagged, f"expected '{expected}' in metadata {tagged!r}"
|
||||
@@ -0,0 +1,327 @@
|
||||
"""Tests for mempalace.llm_client.
|
||||
|
||||
HTTP is mocked throughout — these tests do not require a running Ollama
|
||||
or network access. Live-provider smoke tests live outside the unit-test
|
||||
suite.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from mempalace.llm_client import (
|
||||
AnthropicProvider,
|
||||
LLMError,
|
||||
OllamaProvider,
|
||||
OpenAICompatProvider,
|
||||
_http_post_json,
|
||||
get_provider,
|
||||
)
|
||||
|
||||
|
||||
# ── factory ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_provider_ollama():
|
||||
p = get_provider("ollama", "gemma4:e4b")
|
||||
assert isinstance(p, OllamaProvider)
|
||||
assert p.model == "gemma4:e4b"
|
||||
assert p.endpoint == OllamaProvider.DEFAULT_ENDPOINT
|
||||
|
||||
|
||||
def test_get_provider_openai_compat():
|
||||
p = get_provider("openai-compat", "foo", endpoint="http://localhost:1234")
|
||||
assert isinstance(p, OpenAICompatProvider)
|
||||
|
||||
|
||||
def test_get_provider_anthropic():
|
||||
p = get_provider("anthropic", "claude-haiku", api_key="sk-xxx")
|
||||
assert isinstance(p, AnthropicProvider)
|
||||
assert p.api_key == "sk-xxx"
|
||||
|
||||
|
||||
def test_get_provider_unknown_raises():
|
||||
with pytest.raises(LLMError, match="Unknown provider"):
|
||||
get_provider("nonsense", "x")
|
||||
|
||||
|
||||
# ── _http_post_json ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_http_post_json_success():
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = b'{"ok": true}'
|
||||
mock_resp.__enter__.return_value = mock_resp
|
||||
mock_resp.__exit__.return_value = False
|
||||
with patch("mempalace.llm_client.urlopen", return_value=mock_resp):
|
||||
result = _http_post_json("http://x/y", {"a": 1}, {}, timeout=5)
|
||||
assert result == {"ok": True}
|
||||
|
||||
|
||||
def test_http_post_json_http_error_wraps_as_llm_error():
|
||||
from urllib.error import HTTPError
|
||||
import io
|
||||
|
||||
err = HTTPError("http://x", 404, "Not Found", {}, io.BytesIO(b"model missing"))
|
||||
with patch("mempalace.llm_client.urlopen", side_effect=err):
|
||||
with pytest.raises(LLMError, match="HTTP 404"):
|
||||
_http_post_json("http://x", {}, {}, timeout=5)
|
||||
|
||||
|
||||
def test_http_post_json_url_error_wraps_as_llm_error():
|
||||
from urllib.error import URLError
|
||||
|
||||
with patch("mempalace.llm_client.urlopen", side_effect=URLError("conn refused")):
|
||||
with pytest.raises(LLMError, match="Cannot reach"):
|
||||
_http_post_json("http://x", {}, {}, timeout=5)
|
||||
|
||||
|
||||
def test_http_post_json_malformed_response():
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = b"not json"
|
||||
mock_resp.__enter__.return_value = mock_resp
|
||||
mock_resp.__exit__.return_value = False
|
||||
with patch("mempalace.llm_client.urlopen", return_value=mock_resp):
|
||||
with pytest.raises(LLMError, match="Malformed"):
|
||||
_http_post_json("http://x", {}, {}, timeout=5)
|
||||
|
||||
|
||||
# ── OllamaProvider ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_ollama_chat_response(content: str):
|
||||
mock = MagicMock()
|
||||
mock.read.return_value = json.dumps({"message": {"content": content}}).encode()
|
||||
mock.__enter__.return_value = mock
|
||||
mock.__exit__.return_value = False
|
||||
return mock
|
||||
|
||||
|
||||
def test_ollama_check_available_finds_model():
|
||||
tags = {"models": [{"name": "gemma4:e4b"}, {"name": "other:latest"}]}
|
||||
mock = MagicMock()
|
||||
mock.read.return_value = json.dumps(tags).encode()
|
||||
mock.__enter__.return_value = mock
|
||||
mock.__exit__.return_value = False
|
||||
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||
p = OllamaProvider(model="gemma4:e4b")
|
||||
ok, msg = p.check_available()
|
||||
assert ok
|
||||
assert msg == "ok"
|
||||
|
||||
|
||||
def test_ollama_check_available_accepts_latest_suffix():
|
||||
tags = {"models": [{"name": "mymodel:latest"}]}
|
||||
mock = MagicMock()
|
||||
mock.read.return_value = json.dumps(tags).encode()
|
||||
mock.__enter__.return_value = mock
|
||||
mock.__exit__.return_value = False
|
||||
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||
p = OllamaProvider(model="mymodel")
|
||||
ok, _ = p.check_available()
|
||||
assert ok
|
||||
|
||||
|
||||
def test_ollama_check_available_missing_model():
|
||||
tags = {"models": [{"name": "other:latest"}]}
|
||||
mock = MagicMock()
|
||||
mock.read.return_value = json.dumps(tags).encode()
|
||||
mock.__enter__.return_value = mock
|
||||
mock.__exit__.return_value = False
|
||||
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||
p = OllamaProvider(model="absent")
|
||||
ok, msg = p.check_available()
|
||||
assert not ok
|
||||
assert "ollama pull absent" in msg
|
||||
|
||||
|
||||
def test_ollama_check_available_unreachable():
|
||||
from urllib.error import URLError
|
||||
|
||||
with patch("mempalace.llm_client.urlopen", side_effect=URLError("refused")):
|
||||
p = OllamaProvider(model="gemma4:e4b")
|
||||
ok, msg = p.check_available()
|
||||
assert not ok
|
||||
assert "Cannot reach Ollama" in msg
|
||||
|
||||
|
||||
def test_ollama_classify_sends_json_format():
|
||||
captured = {}
|
||||
|
||||
def fake_urlopen(req, *, timeout):
|
||||
captured["url"] = req.full_url
|
||||
captured["body"] = json.loads(req.data.decode())
|
||||
return _mock_ollama_chat_response('{"classifications": []}')
|
||||
|
||||
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||
p = OllamaProvider(model="gemma4:e4b")
|
||||
resp = p.classify("sys", "user", json_mode=True)
|
||||
|
||||
assert captured["body"]["format"] == "json"
|
||||
assert captured["body"]["model"] == "gemma4:e4b"
|
||||
assert captured["url"].endswith("/api/chat")
|
||||
assert resp.provider == "ollama"
|
||||
assert resp.text == '{"classifications": []}'
|
||||
|
||||
|
||||
def test_ollama_classify_empty_content_raises():
|
||||
with patch("mempalace.llm_client.urlopen", return_value=_mock_ollama_chat_response("")):
|
||||
p = OllamaProvider(model="x")
|
||||
with pytest.raises(LLMError, match="Empty response"):
|
||||
p.classify("s", "u")
|
||||
|
||||
|
||||
# ── OpenAICompatProvider ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_openai_response(content: str):
|
||||
mock = MagicMock()
|
||||
payload = {"choices": [{"message": {"content": content}}]}
|
||||
mock.read.return_value = json.dumps(payload).encode()
|
||||
mock.__enter__.return_value = mock
|
||||
mock.__exit__.return_value = False
|
||||
return mock
|
||||
|
||||
|
||||
def test_openai_compat_resolves_url_with_v1_suffix():
|
||||
captured = {}
|
||||
|
||||
def fake_urlopen(req, *, timeout):
|
||||
captured["url"] = req.full_url
|
||||
return _mock_openai_response('{"ok": true}')
|
||||
|
||||
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||
p = OpenAICompatProvider(model="x", endpoint="http://h:1234")
|
||||
p.classify("s", "u")
|
||||
assert captured["url"] == "http://h:1234/v1/chat/completions"
|
||||
|
||||
|
||||
def test_openai_compat_resolves_url_with_existing_v1():
|
||||
captured = {}
|
||||
|
||||
def fake_urlopen(req, *, timeout):
|
||||
captured["url"] = req.full_url
|
||||
return _mock_openai_response('{"ok": true}')
|
||||
|
||||
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||
p = OpenAICompatProvider(model="x", endpoint="http://h:1234/v1")
|
||||
p.classify("s", "u")
|
||||
assert captured["url"] == "http://h:1234/v1/chat/completions"
|
||||
|
||||
|
||||
def test_openai_compat_requires_endpoint():
|
||||
p = OpenAICompatProvider(model="x")
|
||||
with pytest.raises(LLMError, match="requires --llm-endpoint"):
|
||||
p.classify("s", "u")
|
||||
|
||||
|
||||
def test_openai_compat_sends_authorization_when_key_present():
|
||||
captured = {}
|
||||
|
||||
def fake_urlopen(req, *, timeout):
|
||||
captured["auth"] = req.get_header("Authorization")
|
||||
return _mock_openai_response('{"ok": true}')
|
||||
|
||||
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||
p = OpenAICompatProvider(model="x", endpoint="http://h", api_key="sk-aaa")
|
||||
p.classify("s", "u")
|
||||
assert captured["auth"] == "Bearer sk-aaa"
|
||||
|
||||
|
||||
def test_openai_compat_uses_env_var_fallback(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-from-env")
|
||||
p = OpenAICompatProvider(model="x", endpoint="http://h")
|
||||
assert p.api_key == "sk-from-env"
|
||||
|
||||
|
||||
def test_openai_compat_sends_response_format_json():
|
||||
captured = {}
|
||||
|
||||
def fake_urlopen(req, *, timeout):
|
||||
captured["body"] = json.loads(req.data.decode())
|
||||
return _mock_openai_response('{"ok": true}')
|
||||
|
||||
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||
p = OpenAICompatProvider(model="x", endpoint="http://h")
|
||||
p.classify("s", "u", json_mode=True)
|
||||
assert captured["body"]["response_format"] == {"type": "json_object"}
|
||||
|
||||
|
||||
def test_openai_compat_unexpected_shape_raises():
|
||||
mock = MagicMock()
|
||||
mock.read.return_value = b'{"nothing": "here"}'
|
||||
mock.__enter__.return_value = mock
|
||||
mock.__exit__.return_value = False
|
||||
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||
p = OpenAICompatProvider(model="x", endpoint="http://h")
|
||||
with pytest.raises(LLMError, match="Unexpected response shape"):
|
||||
p.classify("s", "u")
|
||||
|
||||
|
||||
# ── AnthropicProvider ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_anthropic_response(text: str):
|
||||
mock = MagicMock()
|
||||
payload = {"content": [{"type": "text", "text": text}]}
|
||||
mock.read.return_value = json.dumps(payload).encode()
|
||||
mock.__enter__.return_value = mock
|
||||
mock.__exit__.return_value = False
|
||||
return mock
|
||||
|
||||
|
||||
def test_anthropic_requires_api_key(monkeypatch):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
p = AnthropicProvider(model="claude-haiku")
|
||||
ok, msg = p.check_available()
|
||||
assert not ok
|
||||
assert "ANTHROPIC_API_KEY" in msg
|
||||
|
||||
|
||||
def test_anthropic_reads_env_key(monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-env")
|
||||
p = AnthropicProvider(model="claude-haiku")
|
||||
assert p.api_key == "sk-ant-env"
|
||||
ok, _ = p.check_available()
|
||||
assert ok
|
||||
|
||||
|
||||
def test_anthropic_classify_sends_version_and_key():
|
||||
captured = {}
|
||||
|
||||
def fake_urlopen(req, *, timeout):
|
||||
captured["api_key"] = req.get_header("X-api-key")
|
||||
captured["version"] = req.get_header("Anthropic-version")
|
||||
return _mock_anthropic_response('{"ok": true}')
|
||||
|
||||
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||
p = AnthropicProvider(model="claude-haiku", api_key="sk-ant-abc")
|
||||
resp = p.classify("s", "u")
|
||||
assert captured["api_key"] == "sk-ant-abc"
|
||||
assert captured["version"] == AnthropicProvider.API_VERSION
|
||||
assert resp.text == '{"ok": true}'
|
||||
|
||||
|
||||
def test_anthropic_joins_multiple_text_blocks():
|
||||
mock = MagicMock()
|
||||
payload = {
|
||||
"content": [
|
||||
{"type": "text", "text": "part one. "},
|
||||
{"type": "text", "text": "part two."},
|
||||
]
|
||||
}
|
||||
mock.read.return_value = json.dumps(payload).encode()
|
||||
mock.__enter__.return_value = mock
|
||||
mock.__exit__.return_value = False
|
||||
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||
p = AnthropicProvider(model="claude-haiku", api_key="sk-ant")
|
||||
resp = p.classify("s", "u")
|
||||
assert resp.text == "part one. part two."
|
||||
|
||||
|
||||
def test_anthropic_no_key_raises_on_classify(monkeypatch):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
p = AnthropicProvider(model="claude-haiku")
|
||||
with pytest.raises(LLMError, match="requires ANTHROPIC_API_KEY"):
|
||||
p.classify("s", "u")
|
||||
@@ -0,0 +1,631 @@
|
||||
"""Tests for mempalace.llm_refine.
|
||||
|
||||
Uses a fake provider for deterministic, offline tests. No network.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
from mempalace.llm_client import LLMError, LLMResponse
|
||||
from mempalace.llm_refine import (
|
||||
_apply_classifications,
|
||||
_build_user_prompt,
|
||||
_collect_contexts,
|
||||
_extract_json_candidates,
|
||||
_is_authoritative_person,
|
||||
_is_authoritative_project,
|
||||
_parse_response,
|
||||
collect_corpus_text,
|
||||
refine_entities,
|
||||
)
|
||||
|
||||
|
||||
# ── fake provider ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class FakeProvider:
|
||||
"""Returns a caller-supplied JSON string on every classify call."""
|
||||
|
||||
response_text: str = ""
|
||||
should_raise: Exception = None
|
||||
call_count: int = 0
|
||||
interrupt_on_call: int = -1
|
||||
|
||||
def classify(self, system, user, json_mode=True):
|
||||
self.call_count += 1
|
||||
if self.call_count == self.interrupt_on_call:
|
||||
raise KeyboardInterrupt()
|
||||
if self.should_raise is not None:
|
||||
raise self.should_raise
|
||||
return LLMResponse(text=self.response_text, model="fake", provider="fake", raw={})
|
||||
|
||||
def check_available(self):
|
||||
return True, "ok"
|
||||
|
||||
|
||||
# ── _collect_contexts ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_collect_contexts_finds_matches():
|
||||
lines = [
|
||||
"Something about Alice",
|
||||
"Bob said hello",
|
||||
"Alice was here",
|
||||
"Alice walked by",
|
||||
]
|
||||
out = _collect_contexts(lines, "Alice", max_lines=2)
|
||||
assert len(out) == 2
|
||||
assert all("alice" in line.lower() for line in out)
|
||||
|
||||
|
||||
def test_collect_contexts_case_insensitive():
|
||||
lines = ["lowercase alice mention"]
|
||||
out = _collect_contexts(lines, "Alice")
|
||||
assert out == ["lowercase alice mention"]
|
||||
|
||||
|
||||
def test_collect_contexts_uses_token_boundaries():
|
||||
lines = [
|
||||
"forgot should not match",
|
||||
"Go is a language.",
|
||||
"go-v1 shipped.",
|
||||
]
|
||||
out = _collect_contexts(lines, "Go", max_lines=5)
|
||||
assert out == ["Go is a language.", "go-v1 shipped."]
|
||||
|
||||
|
||||
def test_collect_contexts_dedupes_identical_lines():
|
||||
lines = ["Alice", "Alice", "Alice was here"]
|
||||
out = _collect_contexts(lines, "Alice", max_lines=5)
|
||||
# two unique lines, not three
|
||||
assert len(out) == 2
|
||||
|
||||
|
||||
def test_collect_contexts_truncates_long_lines():
|
||||
lines = ["Alice " + ("x" * 1000)]
|
||||
out = _collect_contexts(lines, "Alice")
|
||||
assert len(out[0]) <= 240
|
||||
|
||||
|
||||
def test_collect_contexts_no_matches():
|
||||
assert _collect_contexts(["nothing here"], "Alice") == []
|
||||
|
||||
|
||||
# ── _build_user_prompt ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_build_user_prompt_numbers_and_includes_contexts():
|
||||
prompt = _build_user_prompt(
|
||||
[
|
||||
("Alice", "uncertain", ["Alice said hi"]),
|
||||
("Bob", "project", []),
|
||||
]
|
||||
)
|
||||
assert "1. Alice" in prompt
|
||||
assert "2. Bob" in prompt
|
||||
assert "Alice said hi" in prompt
|
||||
assert "(no context available)" in prompt
|
||||
|
||||
|
||||
# ── _parse_response ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_parse_response_canonicalizes_label():
|
||||
text = '{"classifications": [{"name": "Alice", "label": "person", "reason": "x"}]}'
|
||||
out = _parse_response(text, ["Alice"])
|
||||
assert out["Alice"] == ("PERSON", "x")
|
||||
|
||||
|
||||
def test_parse_response_accepts_type_alias():
|
||||
"""LLMs may return 'type' instead of 'label'."""
|
||||
text = '{"classifications": [{"name": "Bob", "type": "PROJECT"}]}'
|
||||
out = _parse_response(text, ["Bob"])
|
||||
assert out["Bob"][0] == "PROJECT"
|
||||
|
||||
|
||||
def test_parse_response_maps_unknown_label_to_ambiguous():
|
||||
text = '{"classifications": [{"name": "X", "label": "WEIRD"}]}'
|
||||
out = _parse_response(text, ["X"])
|
||||
assert out["X"][0] == "AMBIGUOUS"
|
||||
|
||||
|
||||
def test_parse_response_restores_canonical_casing():
|
||||
"""Model may lowercase the name; we restore against the expected set."""
|
||||
text = '{"classifications": [{"name": "mempalace", "label": "PROJECT"}]}'
|
||||
out = _parse_response(text, ["MemPalace"])
|
||||
assert "MemPalace" in out
|
||||
assert out["MemPalace"][0] == "PROJECT"
|
||||
|
||||
|
||||
def test_parse_response_strips_code_fences():
|
||||
text = '```json\n{"classifications": [{"name": "X", "label": "TOPIC"}]}\n```'
|
||||
out = _parse_response(text, ["X"])
|
||||
assert out["X"][0] == "TOPIC"
|
||||
|
||||
|
||||
def test_parse_response_extracts_json_after_prose():
|
||||
text = 'Sure, here is the JSON: {"classifications": [{"name": "X", "label": "TOPIC"}]}'
|
||||
out = _parse_response(text, ["X"])
|
||||
assert out["X"][0] == "TOPIC"
|
||||
|
||||
|
||||
def test_parse_response_extracts_fenced_json_after_prose():
|
||||
text = 'Sure:\n```json\n{"classifications": [{"name": "X", "label": "PROJECT"}]}\n```'
|
||||
out = _parse_response(text, ["X"])
|
||||
assert out["X"][0] == "PROJECT"
|
||||
|
||||
|
||||
def test_extract_json_candidates_handles_embedded_array():
|
||||
text = 'prefix [{"name": "Y", "label": "PERSON"}] suffix'
|
||||
candidates = _extract_json_candidates(text)
|
||||
assert '[{"name": "Y", "label": "PERSON"}]' in candidates
|
||||
|
||||
|
||||
def test_parse_response_ignores_non_json_brackets_before_payload():
|
||||
text = 'See [note] first. JSON: {"classifications": [{"name": "X", "label": "TOPIC"}]}'
|
||||
out = _parse_response(text, ["X"])
|
||||
assert out["X"][0] == "TOPIC"
|
||||
|
||||
|
||||
def test_parse_response_malformed_returns_empty():
|
||||
out = _parse_response("not json at all", ["X"])
|
||||
assert out == {}
|
||||
|
||||
|
||||
def test_parse_response_accepts_top_level_list():
|
||||
"""Some models skip the wrapping object and return the list directly."""
|
||||
text = '[{"name": "Y", "label": "PERSON"}]'
|
||||
out = _parse_response(text, ["Y"])
|
||||
assert out["Y"][0] == "PERSON"
|
||||
|
||||
|
||||
# ── _apply_classifications ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_apply_classifications_moves_to_correct_bucket():
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [
|
||||
{
|
||||
"name": "Foo",
|
||||
"type": "project",
|
||||
"confidence": 0.8,
|
||||
"frequency": 3,
|
||||
"signals": ["old"],
|
||||
}
|
||||
],
|
||||
"uncertain": [
|
||||
{"name": "Alice", "type": "uncertain", "confidence": 0.4, "frequency": 5, "signals": []}
|
||||
],
|
||||
}
|
||||
decisions = {
|
||||
"Foo": ("PROJECT", "real project name"),
|
||||
"Alice": ("PERSON", "clearly a person"),
|
||||
}
|
||||
new, reclass, dropped = _apply_classifications(detected, decisions)
|
||||
assert len(new["people"]) == 1
|
||||
assert new["people"][0]["name"] == "Alice"
|
||||
assert new["people"][0]["type"] == "person"
|
||||
assert reclass == 1 # Alice moved uncertain -> people
|
||||
assert dropped == 0
|
||||
|
||||
|
||||
def test_apply_classifications_drops_common_word():
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [],
|
||||
"uncertain": [
|
||||
{
|
||||
"name": "Never",
|
||||
"type": "uncertain",
|
||||
"confidence": 0.4,
|
||||
"frequency": 20,
|
||||
"signals": [],
|
||||
}
|
||||
],
|
||||
}
|
||||
decisions = {"Never": ("COMMON_WORD", "adverb")}
|
||||
new, _, dropped = _apply_classifications(detected, decisions)
|
||||
assert dropped == 1
|
||||
assert new["uncertain"] == []
|
||||
|
||||
|
||||
def test_apply_classifications_keeps_unvisited_entries():
|
||||
detected = {
|
||||
"people": [
|
||||
{
|
||||
"name": "Igor",
|
||||
"type": "person",
|
||||
"confidence": 0.99,
|
||||
"frequency": 100,
|
||||
"signals": ["git"],
|
||||
}
|
||||
],
|
||||
"projects": [],
|
||||
"uncertain": [],
|
||||
}
|
||||
# No decision for Igor — should stay untouched
|
||||
new, reclass, dropped = _apply_classifications(detected, {})
|
||||
assert new["people"][0]["name"] == "Igor"
|
||||
assert reclass == 0
|
||||
assert dropped == 0
|
||||
|
||||
|
||||
def test_apply_classifications_appends_reason_signal():
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [],
|
||||
"uncertain": [
|
||||
{
|
||||
"name": "Foo",
|
||||
"type": "uncertain",
|
||||
"confidence": 0.4,
|
||||
"frequency": 5,
|
||||
"signals": ["regex"],
|
||||
}
|
||||
],
|
||||
}
|
||||
decisions = {"Foo": ("PERSON", "spoken of by name")}
|
||||
new, _, _ = _apply_classifications(detected, decisions)
|
||||
assert any("LLM: person" in s for s in new["people"][0]["signals"])
|
||||
assert any("spoken of by name" in s for s in new["people"][0]["signals"])
|
||||
|
||||
|
||||
def test_apply_classifications_topic_goes_to_uncertain():
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [
|
||||
{
|
||||
"name": "Paris",
|
||||
"type": "project",
|
||||
"confidence": 0.7,
|
||||
"frequency": 5,
|
||||
"signals": ["regex"],
|
||||
}
|
||||
],
|
||||
"uncertain": [],
|
||||
}
|
||||
decisions = {"Paris": ("TOPIC", "city, not a project")}
|
||||
new, reclass, _ = _apply_classifications(detected, decisions)
|
||||
assert len(new["projects"]) == 0
|
||||
assert len(new["uncertain"]) == 1
|
||||
assert new["uncertain"][0]["name"] == "Paris"
|
||||
assert reclass == 1
|
||||
|
||||
|
||||
def test_apply_classifications_can_block_llm_only_project_promotion():
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [],
|
||||
"uncertain": [
|
||||
{
|
||||
"name": "Terraform",
|
||||
"type": "uncertain",
|
||||
"confidence": 0.4,
|
||||
"frequency": 5,
|
||||
"signals": ["regex"],
|
||||
}
|
||||
],
|
||||
}
|
||||
decisions = {"Terraform": ("PROJECT", "tool")}
|
||||
new, reclass, _ = _apply_classifications(
|
||||
detected,
|
||||
decisions,
|
||||
allow_project_promotions=False,
|
||||
)
|
||||
assert new["projects"] == []
|
||||
assert new["uncertain"][0]["name"] == "Terraform"
|
||||
assert new["uncertain"][0]["type"] == "uncertain"
|
||||
assert reclass == 0
|
||||
|
||||
|
||||
def test_apply_classifications_allows_project_promotion_for_prose_only_mode():
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [],
|
||||
"uncertain": [
|
||||
{
|
||||
"name": "Project Aurora",
|
||||
"type": "uncertain",
|
||||
"confidence": 0.4,
|
||||
"frequency": 5,
|
||||
"signals": ["regex"],
|
||||
}
|
||||
],
|
||||
}
|
||||
decisions = {"Project Aurora": ("PROJECT", "user effort")}
|
||||
new, reclass, _ = _apply_classifications(detected, decisions)
|
||||
assert new["projects"][0]["name"] == "Project Aurora"
|
||||
assert new["projects"][0]["type"] == "project"
|
||||
assert reclass == 1
|
||||
|
||||
|
||||
# ── authoritative source filters ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_is_authoritative_person_requires_git_signal():
|
||||
assert _is_authoritative_person({"signals": ["5 commits across 2 repos"]})
|
||||
assert not _is_authoritative_person({"signals": ["pronoun nearby (5x)"]})
|
||||
|
||||
|
||||
def test_is_authoritative_project_requires_manifest_or_git_signal():
|
||||
assert _is_authoritative_project({"signals": ["package.json, 12 of your commits"]})
|
||||
assert _is_authoritative_project({"signals": ["57 commits (none by you)"]})
|
||||
assert not _is_authoritative_project({"signals": ["code file reference (5x)"]})
|
||||
|
||||
|
||||
# ── refine_entities ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _sample_detected():
|
||||
return {
|
||||
"people": [
|
||||
{
|
||||
"name": "Igor",
|
||||
"type": "person",
|
||||
"confidence": 0.99,
|
||||
"frequency": 100,
|
||||
"signals": ["git"],
|
||||
}
|
||||
],
|
||||
"projects": [
|
||||
{
|
||||
"name": "Foo",
|
||||
"type": "project",
|
||||
"confidence": 0.7,
|
||||
"frequency": 5,
|
||||
"signals": ["regex"],
|
||||
}
|
||||
],
|
||||
"uncertain": [
|
||||
{
|
||||
"name": "Never",
|
||||
"type": "uncertain",
|
||||
"confidence": 0.4,
|
||||
"frequency": 10,
|
||||
"signals": [],
|
||||
},
|
||||
{
|
||||
"name": "Alice",
|
||||
"type": "uncertain",
|
||||
"confidence": 0.4,
|
||||
"frequency": 5,
|
||||
"signals": [],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_refine_entities_end_to_end_with_fake_provider():
|
||||
provider = FakeProvider(
|
||||
response_text=(
|
||||
'{"classifications": ['
|
||||
'{"name": "Foo", "label": "PROJECT", "reason": "real"},'
|
||||
'{"name": "Never", "label": "COMMON_WORD"},'
|
||||
'{"name": "Alice", "label": "PERSON", "reason": "name"}'
|
||||
"]}"
|
||||
)
|
||||
)
|
||||
result = refine_entities(
|
||||
_sample_detected(),
|
||||
corpus_text="Alice said hi. Foo was shipped. Never gonna.",
|
||||
provider=provider,
|
||||
show_progress=False,
|
||||
)
|
||||
assert result.batches_total == 1
|
||||
assert result.batches_completed == 1
|
||||
assert not result.cancelled
|
||||
# Alice → people, Never → dropped, Foo stays in projects
|
||||
names_in_people = [e["name"] for e in result.merged["people"]]
|
||||
assert "Alice" in names_in_people
|
||||
assert "Igor" in names_in_people # untouched
|
||||
assert "Never" not in [e["name"] for e in result.merged["uncertain"]]
|
||||
assert result.dropped == 1
|
||||
|
||||
|
||||
def test_refine_entities_skips_high_confidence_projects():
|
||||
"""Manifest-backed projects (conf >= 0.95) aren't sent to the LLM."""
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [
|
||||
{
|
||||
"name": "manifest-backed",
|
||||
"type": "project",
|
||||
"confidence": 0.99,
|
||||
"frequency": 50,
|
||||
"signals": ["pyproject.toml"],
|
||||
}
|
||||
],
|
||||
"uncertain": [],
|
||||
}
|
||||
provider = FakeProvider(response_text='{"classifications": []}')
|
||||
refine_entities(detected, "", provider, show_progress=False)
|
||||
# Should not have called the LLM at all
|
||||
assert provider.call_count == 0
|
||||
|
||||
|
||||
def test_refine_entities_refines_high_confidence_regex_projects():
|
||||
"""High-confidence regex projects still need LLM review without source signal."""
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [
|
||||
{
|
||||
"name": "OpenAPI",
|
||||
"type": "project",
|
||||
"confidence": 0.99,
|
||||
"frequency": 5,
|
||||
"signals": ["code file reference (5x)"],
|
||||
}
|
||||
],
|
||||
"uncertain": [],
|
||||
}
|
||||
provider = FakeProvider(
|
||||
response_text=(
|
||||
'{"classifications": [{"name": "OpenAPI", "label": "TOPIC", "reason": "technology"}]}'
|
||||
)
|
||||
)
|
||||
result = refine_entities(detected, "OpenAPI schemas", provider, show_progress=False)
|
||||
assert provider.call_count == 1
|
||||
assert result.reclassified == 1
|
||||
assert result.merged["projects"] == []
|
||||
assert result.merged["uncertain"][0]["name"] == "OpenAPI"
|
||||
|
||||
|
||||
def test_refine_entities_refines_regex_people_but_skips_git_people():
|
||||
detected = {
|
||||
"people": [
|
||||
{
|
||||
"name": "Igor Lins e Silva",
|
||||
"type": "person",
|
||||
"confidence": 0.99,
|
||||
"frequency": 100,
|
||||
"signals": ["100 commits across 3 repos"],
|
||||
},
|
||||
{
|
||||
"name": "Tool",
|
||||
"type": "person",
|
||||
"confidence": 0.99,
|
||||
"frequency": 5,
|
||||
"signals": ["pronoun nearby (5x)"],
|
||||
},
|
||||
],
|
||||
"projects": [],
|
||||
"uncertain": [],
|
||||
}
|
||||
provider = FakeProvider(
|
||||
response_text='{"classifications": [{"name": "Tool", "label": "COMMON_WORD"}]}'
|
||||
)
|
||||
result = refine_entities(detected, "Tool is a common noun.", provider, show_progress=False)
|
||||
assert provider.call_count == 1
|
||||
names = [e["name"] for e in result.merged["people"]]
|
||||
assert names == ["Igor Lins e Silva"]
|
||||
assert result.dropped == 1
|
||||
|
||||
|
||||
def test_refine_entities_can_keep_llm_only_project_in_uncertain():
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [],
|
||||
"uncertain": [
|
||||
{
|
||||
"name": "Terraform",
|
||||
"type": "uncertain",
|
||||
"confidence": 0.4,
|
||||
"frequency": 9,
|
||||
"signals": ["regex"],
|
||||
}
|
||||
],
|
||||
}
|
||||
provider = FakeProvider(
|
||||
response_text='{"classifications": [{"name": "Terraform", "label": "PROJECT"}]}'
|
||||
)
|
||||
result = refine_entities(
|
||||
detected,
|
||||
"Terraform config",
|
||||
provider,
|
||||
show_progress=False,
|
||||
allow_project_promotions=False,
|
||||
)
|
||||
assert result.merged["projects"] == []
|
||||
assert result.merged["uncertain"][0]["name"] == "Terraform"
|
||||
assert any("LLM: project" in s for s in result.merged["uncertain"][0]["signals"])
|
||||
|
||||
|
||||
def test_refine_entities_empty_candidates_returns_noop():
|
||||
detected = {"people": [], "projects": [], "uncertain": []}
|
||||
provider = FakeProvider()
|
||||
result = refine_entities(detected, "", provider, show_progress=False)
|
||||
assert result.batches_total == 0
|
||||
assert result.reclassified == 0
|
||||
assert result.merged == detected
|
||||
|
||||
|
||||
def test_refine_entities_handles_batch_error_gracefully():
|
||||
provider = FakeProvider(should_raise=LLMError("transport broke"))
|
||||
result = refine_entities(
|
||||
_sample_detected(),
|
||||
corpus_text="",
|
||||
provider=provider,
|
||||
show_progress=False,
|
||||
)
|
||||
assert result.errors
|
||||
assert "transport broke" in result.errors[0]
|
||||
# Detected unchanged (no successful decisions)
|
||||
assert result.reclassified == 0
|
||||
assert result.cancelled is False
|
||||
|
||||
|
||||
def test_refine_entities_ctrl_c_returns_partial():
|
||||
"""Ctrl-C during refinement marks cancelled=True and returns partial result."""
|
||||
# Two batches' worth of candidates
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [],
|
||||
"uncertain": [
|
||||
{
|
||||
"name": f"Cand{i}",
|
||||
"type": "uncertain",
|
||||
"confidence": 0.4,
|
||||
"frequency": 3,
|
||||
"signals": [],
|
||||
}
|
||||
for i in range(50)
|
||||
],
|
||||
}
|
||||
provider = FakeProvider(
|
||||
response_text='{"classifications": []}',
|
||||
interrupt_on_call=2, # interrupt on second batch
|
||||
)
|
||||
result = refine_entities(detected, "", provider, batch_size=25, show_progress=False)
|
||||
assert result.cancelled is True
|
||||
assert result.batches_completed == 1 # first batch finished; second interrupted
|
||||
assert result.batches_total == 2
|
||||
|
||||
|
||||
def test_refine_entities_malformed_response_recorded_as_error():
|
||||
provider = FakeProvider(response_text="not json")
|
||||
result = refine_entities(_sample_detected(), "", provider, show_progress=False)
|
||||
assert any("could not parse" in e for e in result.errors)
|
||||
|
||||
|
||||
# ── collect_corpus_text ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_collect_corpus_text_reads_prose_files(tmp_path):
|
||||
(tmp_path / "a.md").write_text("hello world")
|
||||
(tmp_path / "b.txt").write_text("more prose")
|
||||
(tmp_path / "c.py").write_text("import os") # not prose, skipped
|
||||
text = collect_corpus_text(str(tmp_path))
|
||||
assert "hello world" in text
|
||||
assert "more prose" in text
|
||||
assert "import os" not in text
|
||||
|
||||
|
||||
def test_collect_corpus_text_prefers_recent(tmp_path):
|
||||
import os
|
||||
import time
|
||||
|
||||
old = tmp_path / "old.md"
|
||||
old.write_text("OLD_CONTENT")
|
||||
time.sleep(0.01)
|
||||
new = tmp_path / "new.md"
|
||||
new.write_text("NEW_CONTENT")
|
||||
# Force old to be older still
|
||||
old_mtime = old.stat().st_mtime - 3600
|
||||
os.utime(old, (old_mtime, old_mtime))
|
||||
|
||||
text = collect_corpus_text(str(tmp_path), max_files=1)
|
||||
assert "NEW_CONTENT" in text
|
||||
assert "OLD_CONTENT" not in text
|
||||
|
||||
|
||||
def test_collect_corpus_text_missing_dir_returns_empty(tmp_path):
|
||||
assert collect_corpus_text(str(tmp_path / "nope")) == ""
|
||||
|
||||
|
||||
def test_collect_corpus_text_caps_bytes_per_file(tmp_path):
|
||||
big = tmp_path / "big.md"
|
||||
big.write_text("x" * 100_000)
|
||||
text = collect_corpus_text(str(tmp_path), max_files=1, max_bytes_per_file=500)
|
||||
assert len(text) <= 600 # 500 + newlines
|
||||
@@ -66,6 +66,16 @@ def test_load_config_uses_defaults_when_yaml_missing():
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
def test_scan_project_skips_mempalace_generated_files():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir).resolve()
|
||||
write_file(project_root / "entities.json", '{"people": [], "projects": []}')
|
||||
write_file(project_root / "mempalace.yaml", "wing: test\nrooms: []\n")
|
||||
write_file(project_root / "notes.md", "real user content\n" * 10)
|
||||
|
||||
assert scanned_files(project_root) == ["notes.md"]
|
||||
|
||||
|
||||
def test_scan_project_respects_gitignore():
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
try:
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -480,6 +481,49 @@ def test_discover_entities_prefers_real_signal_over_prose(tmp_path):
|
||||
assert "realproj" in proj_names
|
||||
|
||||
|
||||
def test_discover_entities_keeps_uncertain_for_llm_when_real_signal(tmp_path):
|
||||
"""With --llm, regex-uncertain prose candidates should reach refinement."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "realproj"}))
|
||||
_init_git_repo(tmp_path)
|
||||
(tmp_path / "doc.md").write_text("Noise appeared. Noise repeated. Noise again.")
|
||||
|
||||
class FakeProvider:
|
||||
def __init__(self):
|
||||
self.prompts = []
|
||||
|
||||
def classify(self, _system, user, json_mode=True):
|
||||
self.prompts.append(user)
|
||||
return SimpleNamespace(
|
||||
text='{"classifications": [{"name": "Noise", "label": "COMMON_WORD"}]}'
|
||||
)
|
||||
|
||||
provider = FakeProvider()
|
||||
d = discover_entities(str(tmp_path), llm_provider=provider, show_progress=False)
|
||||
|
||||
assert len(provider.prompts) == 1
|
||||
assert "Noise" in provider.prompts[0]
|
||||
assert "Noise" not in [e["name"] for cat in d.values() for e in cat]
|
||||
|
||||
|
||||
def test_discover_entities_keeps_llm_only_project_uncertain_when_real_signal(tmp_path):
|
||||
"""Repo roots should not auto-promote LLM-only tools/topics into projects."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({"name": "realproj"}))
|
||||
_init_git_repo(tmp_path)
|
||||
(tmp_path / "doc.md").write_text("Terraform shipped. Terraform changed. Terraform runs.")
|
||||
|
||||
class FakeProvider:
|
||||
def classify(self, _system, _user, json_mode=True):
|
||||
return SimpleNamespace(
|
||||
text='{"classifications": [{"name": "Terraform", "label": "PROJECT"}]}'
|
||||
)
|
||||
|
||||
d = discover_entities(str(tmp_path), llm_provider=FakeProvider(), show_progress=False)
|
||||
|
||||
assert "realproj" in [e["name"] for e in d["projects"]]
|
||||
assert "Terraform" not in [e["name"] for e in d["projects"]]
|
||||
assert "Terraform" in [e["name"] for e in d["uncertain"]]
|
||||
|
||||
|
||||
# ── _UnionFind basics ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user