264 lines
8.3 KiB
Python
264 lines
8.3 KiB
Python
|
|
"""Tests for mempalace.entity_detector."""
|
||
|
|
|
||
|
|
import os
|
||
|
|
|
||
|
|
from mempalace.entity_detector import (
|
||
|
|
PROSE_EXTENSIONS,
|
||
|
|
STOPWORDS,
|
||
|
|
classify_entity,
|
||
|
|
detect_entities,
|
||
|
|
extract_candidates,
|
||
|
|
scan_for_detection,
|
||
|
|
score_entity,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ── extract_candidates ──────────────────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
def test_extract_candidates_finds_frequent_names():
|
||
|
|
text = "Riley said hello. Riley laughed. Riley smiled. Riley waved."
|
||
|
|
result = extract_candidates(text)
|
||
|
|
assert "Riley" in result
|
||
|
|
assert result["Riley"] >= 3
|
||
|
|
|
||
|
|
|
||
|
|
def test_extract_candidates_ignores_stopwords():
|
||
|
|
# "The" appears many times but is a stopword
|
||
|
|
text = "The The The The The The"
|
||
|
|
result = extract_candidates(text)
|
||
|
|
assert "The" not in result
|
||
|
|
|
||
|
|
|
||
|
|
def test_extract_candidates_requires_min_frequency():
|
||
|
|
text = "Riley said hi. Devon waved."
|
||
|
|
result = extract_candidates(text)
|
||
|
|
# Each name appears only once, below the threshold of 3
|
||
|
|
assert "Riley" not in result
|
||
|
|
assert "Devon" not in result
|
||
|
|
|
||
|
|
|
||
|
|
def test_extract_candidates_finds_multi_word_names():
|
||
|
|
# Multi-word names need 3+ occurrences and no stopwords
|
||
|
|
text = (
|
||
|
|
"Claude Code is great. Claude Code rocks. "
|
||
|
|
"Claude Code works. Claude Code rules."
|
||
|
|
)
|
||
|
|
result = extract_candidates(text)
|
||
|
|
assert "Claude Code" in result
|
||
|
|
|
||
|
|
|
||
|
|
def test_extract_candidates_empty_text():
|
||
|
|
result = extract_candidates("")
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
|
||
|
|
# ── score_entity ────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
def test_score_entity_person_verbs():
|
||
|
|
text = "Riley said hello. Riley asked why. Riley told me."
|
||
|
|
lines = text.splitlines()
|
||
|
|
result = score_entity("Riley", text, lines)
|
||
|
|
assert result["person_score"] > 0
|
||
|
|
assert len(result["person_signals"]) > 0
|
||
|
|
|
||
|
|
|
||
|
|
def test_score_entity_project_verbs():
|
||
|
|
text = "We are building ChromaDB. We deployed ChromaDB. Install ChromaDB."
|
||
|
|
lines = text.splitlines()
|
||
|
|
result = score_entity("ChromaDB", text, lines)
|
||
|
|
assert result["project_score"] > 0
|
||
|
|
assert len(result["project_signals"]) > 0
|
||
|
|
|
||
|
|
|
||
|
|
def test_score_entity_dialogue_markers():
|
||
|
|
text = "Riley: Hey, how are you?\nRiley: I'm fine."
|
||
|
|
lines = text.splitlines()
|
||
|
|
result = score_entity("Riley", text, lines)
|
||
|
|
assert result["person_score"] > 0
|
||
|
|
|
||
|
|
|
||
|
|
def test_score_entity_code_ref():
|
||
|
|
text = "Check out ChromaDB.py for details. Also ChromaDB.js is good."
|
||
|
|
lines = text.splitlines()
|
||
|
|
result = score_entity("ChromaDB", text, lines)
|
||
|
|
assert result["project_score"] > 0
|
||
|
|
|
||
|
|
|
||
|
|
def test_score_entity_no_signals():
|
||
|
|
text = "Nothing interesting here at all."
|
||
|
|
lines = text.splitlines()
|
||
|
|
result = score_entity("Riley", text, lines)
|
||
|
|
assert result["person_score"] == 0
|
||
|
|
assert result["project_score"] == 0
|
||
|
|
|
||
|
|
|
||
|
|
# ── classify_entity ─────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
def test_classify_entity_no_signals_gives_uncertain():
|
||
|
|
scores = {
|
||
|
|
"person_score": 0,
|
||
|
|
"project_score": 0,
|
||
|
|
"person_signals": [],
|
||
|
|
"project_signals": [],
|
||
|
|
}
|
||
|
|
result = classify_entity("Foo", 10, scores)
|
||
|
|
assert result["type"] == "uncertain"
|
||
|
|
assert result["name"] == "Foo"
|
||
|
|
|
||
|
|
|
||
|
|
def test_classify_entity_strong_project():
|
||
|
|
scores = {
|
||
|
|
"person_score": 0,
|
||
|
|
"project_score": 10,
|
||
|
|
"person_signals": [],
|
||
|
|
"project_signals": ["project verb (5x)", "code file reference (2x)"],
|
||
|
|
}
|
||
|
|
result = classify_entity("ChromaDB", 5, scores)
|
||
|
|
assert result["type"] == "project"
|
||
|
|
|
||
|
|
|
||
|
|
def test_classify_entity_strong_person_needs_two_signal_types():
|
||
|
|
scores = {
|
||
|
|
"person_score": 10,
|
||
|
|
"project_score": 0,
|
||
|
|
"person_signals": [
|
||
|
|
"dialogue marker (3x)",
|
||
|
|
"'Riley ...' action (4x)",
|
||
|
|
],
|
||
|
|
"project_signals": [],
|
||
|
|
}
|
||
|
|
result = classify_entity("Riley", 8, scores)
|
||
|
|
assert result["type"] == "person"
|
||
|
|
|
||
|
|
|
||
|
|
def test_classify_entity_pronoun_only_is_uncertain():
|
||
|
|
scores = {
|
||
|
|
"person_score": 8,
|
||
|
|
"project_score": 0,
|
||
|
|
"person_signals": ["pronoun nearby (4x)"],
|
||
|
|
"project_signals": [],
|
||
|
|
}
|
||
|
|
result = classify_entity("Riley", 5, scores)
|
||
|
|
assert result["type"] == "uncertain"
|
||
|
|
|
||
|
|
|
||
|
|
def test_classify_entity_mixed_signals():
|
||
|
|
scores = {
|
||
|
|
"person_score": 5,
|
||
|
|
"project_score": 5,
|
||
|
|
"person_signals": ["pronoun nearby (2x)"],
|
||
|
|
"project_signals": ["project verb (2x)"],
|
||
|
|
}
|
||
|
|
result = classify_entity("Lantern", 5, scores)
|
||
|
|
assert result["type"] == "uncertain"
|
||
|
|
assert "mixed signals" in result["signals"][-1]
|
||
|
|
|
||
|
|
|
||
|
|
# ── detect_entities (integration) ───────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
def test_detect_entities_with_person_file(tmp_path):
|
||
|
|
f = tmp_path / "notes.txt"
|
||
|
|
content = "\n".join(
|
||
|
|
[
|
||
|
|
"Riley said hello today.",
|
||
|
|
"Riley asked about the project.",
|
||
|
|
"Riley told me she was happy.",
|
||
|
|
"Riley: I think we should go.",
|
||
|
|
"Hey Riley, thanks for the help.",
|
||
|
|
"Riley laughed and smiled.",
|
||
|
|
"Riley decided to join.",
|
||
|
|
"Riley pushed the change.",
|
||
|
|
]
|
||
|
|
)
|
||
|
|
f.write_text(content)
|
||
|
|
result = detect_entities([f])
|
||
|
|
all_names = [e["name"] for cat in result.values() for e in cat]
|
||
|
|
assert "Riley" in all_names
|
||
|
|
|
||
|
|
|
||
|
|
def test_detect_entities_with_project_file(tmp_path):
|
||
|
|
f = tmp_path / "readme.txt"
|
||
|
|
# "ChromaDB" has uppercase+lowercase mix but extract_candidates looks
|
||
|
|
# for /[A-Z][a-z]{1,19}/ — so we need a name that matches that regex.
|
||
|
|
# Use "Lantern" which matches the capitalized-word pattern.
|
||
|
|
content = "\n".join(
|
||
|
|
[
|
||
|
|
"The Lantern project is great.",
|
||
|
|
"Building Lantern was fun.",
|
||
|
|
"We deployed Lantern today.",
|
||
|
|
"Install Lantern with pip install Lantern.",
|
||
|
|
"Check Lantern.py for the source.",
|
||
|
|
"Lantern v2 is faster.",
|
||
|
|
]
|
||
|
|
)
|
||
|
|
f.write_text(content)
|
||
|
|
result = detect_entities([f])
|
||
|
|
all_names = [e["name"] for cat in result.values() for e in cat]
|
||
|
|
assert "Lantern" in all_names
|
||
|
|
|
||
|
|
|
||
|
|
def test_detect_entities_empty_files(tmp_path):
|
||
|
|
f = tmp_path / "empty.txt"
|
||
|
|
f.write_text("")
|
||
|
|
result = detect_entities([f])
|
||
|
|
assert result == {"people": [], "projects": [], "uncertain": []}
|
||
|
|
|
||
|
|
|
||
|
|
def test_detect_entities_handles_missing_file(tmp_path):
|
||
|
|
missing = tmp_path / "nonexistent.txt"
|
||
|
|
result = detect_entities([missing])
|
||
|
|
assert result == {"people": [], "projects": [], "uncertain": []}
|
||
|
|
|
||
|
|
|
||
|
|
def test_detect_entities_respects_max_files(tmp_path):
|
||
|
|
files = []
|
||
|
|
for i in range(5):
|
||
|
|
f = tmp_path / f"file{i}.txt"
|
||
|
|
f.write_text("Riley said hello. " * 10)
|
||
|
|
files.append(f)
|
||
|
|
# max_files=2 should only read 2 files
|
||
|
|
result = detect_entities(files, max_files=2)
|
||
|
|
# Should still work without error
|
||
|
|
assert isinstance(result, dict)
|
||
|
|
|
||
|
|
|
||
|
|
# ── scan_for_detection ──────────────────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
def test_scan_for_detection_finds_prose(tmp_path):
|
||
|
|
(tmp_path / "notes.md").write_text("hello")
|
||
|
|
(tmp_path / "data.txt").write_text("world")
|
||
|
|
(tmp_path / "code.py").write_text("import os")
|
||
|
|
files = scan_for_detection(str(tmp_path))
|
||
|
|
extensions = {os.path.splitext(str(f))[1] for f in files}
|
||
|
|
# Prose files should be found
|
||
|
|
assert ".md" in extensions or ".txt" in extensions
|
||
|
|
|
||
|
|
|
||
|
|
def test_scan_for_detection_skips_git_dir(tmp_path):
|
||
|
|
git_dir = tmp_path / ".git"
|
||
|
|
git_dir.mkdir()
|
||
|
|
(git_dir / "config.txt").write_text("git config")
|
||
|
|
(tmp_path / "readme.md").write_text("hello")
|
||
|
|
files = scan_for_detection(str(tmp_path))
|
||
|
|
file_strs = [str(f) for f in files]
|
||
|
|
assert not any(".git" in f for f in file_strs)
|
||
|
|
|
||
|
|
|
||
|
|
# ── module-level constants ──────────────────────────────────────────────
|
||
|
|
|
||
|
|
|
||
|
|
def test_stopwords_contains_common_words():
|
||
|
|
assert "the" in STOPWORDS
|
||
|
|
assert "import" in STOPWORDS
|
||
|
|
assert "class" in STOPWORDS
|
||
|
|
|
||
|
|
|
||
|
|
def test_prose_extensions():
|
||
|
|
assert ".txt" in PROSE_EXTENSIONS
|
||
|
|
assert ".md" in PROSE_EXTENSIONS
|