f895bc58e6
Python's \b is a \w/non-\w transition. Devanagari vowel signs (matras) like ा ी ु are Unicode category Mc (Mark, Spacing Combining) — not \w. This means \b splits mid-word on every matra: names like अनीता (Anita) truncate to अनीत, and person-verb patterns like \bराज\s+ने\s+कहा\b never match because \b fails after the final matra of कहा. Same issue affects Arabic, Hebrew, Thai, Tamil, and every other script whose words contain combining marks. Fix: locales with combining-mark scripts declare a boundary_chars field in their entity section (e.g. "\\w\\u0900-\\u097F" for Hindi). The i18n loader replaces every \b in that locale's patterns with a script-aware lookaround that treats the declared characters as "inside-word", and pre-wraps candidate/multi_word patterns with the same boundary. Default behavior (no boundary_chars) keeps standard \b — en, pt-br, ru, it are unchanged. Changes: - mempalace/i18n/__init__.py: add _script_boundary, _expand_b, _wrap_candidate, _collect_entity_section; candidate_patterns are now returned fully-wrapped (boundary + capture group applied) - mempalace/entity_detector.py: extract_candidates compiles pre-wrapped candidate patterns directly instead of re-wrapping with \b - tests/test_entity_detector.py: 5 new tests for Devanagari boundaries (name extraction with/without boundary_chars, person-verb firing, English regression)
591 lines
20 KiB
Python
591 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
entity_detector.py — Auto-detect people and projects from file content.
|
|
|
|
Two-pass approach:
|
|
Pass 1: scan files, extract entity candidates with signal counts
|
|
Pass 2: score and classify each candidate as person, project, or uncertain
|
|
|
|
Used by mempalace init before mining begins.
|
|
The confirmed entity map feeds the miner as the taxonomy.
|
|
|
|
Multi-language support:
|
|
All lexical patterns (person verbs, pronouns, dialogue markers, project
|
|
verbs, stopwords, and the candidate-extraction character class) live in
|
|
the ``entity`` section of ``mempalace/i18n/<lang>.json``. Every public
|
|
function accepts a ``languages`` tuple and applies the union of the
|
|
requested locales' patterns. The default is ``("en",)`` — existing
|
|
English-only callers behave exactly as before.
|
|
|
|
To add a new language: add an ``entity`` section to that locale's JSON.
|
|
No code changes required.
|
|
|
|
Usage:
|
|
from mempalace.entity_detector import detect_entities, confirm_entities
|
|
candidates = detect_entities(file_paths) # English only
|
|
candidates = detect_entities(paths, languages=("en", "pt-br"))
|
|
confirmed = confirm_entities(candidates) # interactive review
|
|
"""
|
|
|
|
import re
|
|
import os
|
|
import functools
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
|
|
from mempalace.i18n import get_entity_patterns
|
|
|
|
|
|
# ==================== LANGUAGE-AWARE PATTERN LOADING ====================
|
|
|
|
|
|
def _normalize_langs(languages) -> tuple:
|
|
"""Coerce a language input into a non-empty hashable tuple."""
|
|
if not languages:
|
|
return ("en",)
|
|
if isinstance(languages, str):
|
|
return (languages,)
|
|
return tuple(languages)
|
|
|
|
|
|
@functools.lru_cache(maxsize=32)
|
|
def _get_stopwords(languages: tuple) -> frozenset:
|
|
"""Return the union of stopwords across the given languages."""
|
|
patterns = get_entity_patterns(languages)
|
|
return frozenset(patterns["stopwords"])
|
|
|
|
|
|
# ==================== BACKWARD-COMPAT MODULE CONSTANTS ====================
|
|
#
|
|
# These mirror the old module-level constants so existing imports keep working.
|
|
# They reflect the English defaults and are populated at import time from
|
|
# ``mempalace/i18n/en.json``. Callers that need multi-language behavior should
|
|
# pass the ``languages`` parameter to the public functions below.
|
|
|
|
_EN = get_entity_patterns(("en",))
|
|
|
|
PERSON_VERB_PATTERNS = list(_EN["person_verb_patterns"])
|
|
PRONOUN_PATTERNS = list(_EN["pronoun_patterns"])
|
|
PRONOUN_RE = re.compile("|".join(PRONOUN_PATTERNS), re.IGNORECASE) if PRONOUN_PATTERNS else None
|
|
DIALOGUE_PATTERNS = list(_EN["dialogue_patterns"])
|
|
PROJECT_VERB_PATTERNS = list(_EN["project_verb_patterns"])
|
|
STOPWORDS = set(_EN["stopwords"])
|
|
|
|
|
|
# ==================== EXTENSION POINTS (not language-scoped) ====================
|
|
|
|
# For entity detection — prose only, no code files
|
|
# Code files have too many capitalized names (classes, functions) that aren't entities
|
|
PROSE_EXTENSIONS = {
|
|
".txt",
|
|
".md",
|
|
".rst",
|
|
".csv",
|
|
}
|
|
|
|
READABLE_EXTENSIONS = {
|
|
".txt",
|
|
".md",
|
|
".py",
|
|
".js",
|
|
".ts",
|
|
".json",
|
|
".yaml",
|
|
".yml",
|
|
".csv",
|
|
".rst",
|
|
".toml",
|
|
".sh",
|
|
".rb",
|
|
".go",
|
|
".rs",
|
|
}
|
|
|
|
SKIP_DIRS = {
|
|
".git",
|
|
"node_modules",
|
|
"__pycache__",
|
|
".venv",
|
|
"venv",
|
|
"env",
|
|
"dist",
|
|
"build",
|
|
".next",
|
|
"coverage",
|
|
".mempalace",
|
|
}
|
|
|
|
|
|
# ==================== CANDIDATE EXTRACTION ====================
|
|
|
|
|
|
def extract_candidates(text: str, languages=("en",)) -> dict:
|
|
"""
|
|
Extract all capitalized proper noun candidates from text.
|
|
Returns {name: frequency} for names appearing 3+ times.
|
|
|
|
Each language contributes its own character-class pattern (e.g. ASCII
|
|
for English, Latin+diacritics for pt-br, Cyrillic for Russian,
|
|
Devanagari for Hindi). Matches from all languages are unioned.
|
|
"""
|
|
langs = _normalize_langs(languages)
|
|
patterns = get_entity_patterns(langs)
|
|
stopwords = _get_stopwords(langs)
|
|
|
|
counts: defaultdict = defaultdict(int)
|
|
|
|
# Single-word candidates — one pre-wrapped pattern per language
|
|
for wrapped_pat in patterns["candidate_patterns"]:
|
|
try:
|
|
rx = re.compile(wrapped_pat)
|
|
except re.error:
|
|
continue
|
|
for word in rx.findall(text):
|
|
if word.lower() in stopwords:
|
|
continue
|
|
if len(word) < 2:
|
|
continue
|
|
counts[word] += 1
|
|
|
|
# Multi-word candidates — one pre-wrapped pattern per language
|
|
for wrapped_pat in patterns["multi_word_patterns"]:
|
|
try:
|
|
rx = re.compile(wrapped_pat)
|
|
except re.error:
|
|
continue
|
|
for phrase in rx.findall(text):
|
|
if any(w.lower() in stopwords for w in phrase.split()):
|
|
continue
|
|
counts[phrase] += 1
|
|
|
|
return {name: count for name, count in counts.items() if count >= 3}
|
|
|
|
|
|
# ==================== SIGNAL SCORING ====================
|
|
|
|
|
|
@functools.lru_cache(maxsize=256)
|
|
def _build_patterns(name: str, languages: tuple = ("en",)) -> dict:
|
|
"""Pre-compile all regex patterns for a single entity name, per language set."""
|
|
n = re.escape(name)
|
|
langs = _normalize_langs(languages)
|
|
sources = get_entity_patterns(langs)
|
|
|
|
def _compile_each(raw_patterns, flags=re.IGNORECASE):
|
|
compiled = []
|
|
for p in raw_patterns:
|
|
try:
|
|
compiled.append(re.compile(p.format(name=n), flags))
|
|
except (re.error, KeyError, IndexError):
|
|
continue
|
|
return compiled
|
|
|
|
direct_sources = sources.get("direct_address_patterns") or []
|
|
direct_compiled = []
|
|
for raw in direct_sources:
|
|
try:
|
|
direct_compiled.append(re.compile(raw.format(name=n), re.IGNORECASE))
|
|
except (re.error, KeyError, IndexError):
|
|
continue
|
|
|
|
return {
|
|
"dialogue": _compile_each(sources["dialogue_patterns"], re.MULTILINE | re.IGNORECASE),
|
|
"person_verbs": _compile_each(sources["person_verb_patterns"]),
|
|
"project_verbs": _compile_each(sources["project_verb_patterns"]),
|
|
"direct": direct_compiled,
|
|
"versioned": re.compile(rf"\b{n}[-v]\w+", re.IGNORECASE),
|
|
"code_ref": re.compile(rf"\b{n}\.(py|js|ts|yaml|yml|json|sh)\b", re.IGNORECASE),
|
|
}
|
|
|
|
|
|
@functools.lru_cache(maxsize=32)
|
|
def _pronoun_re(languages: tuple):
|
|
"""Compile a combined pronoun regex for the given languages."""
|
|
langs = _normalize_langs(languages)
|
|
patterns = get_entity_patterns(langs)
|
|
pronouns = patterns.get("pronoun_patterns") or []
|
|
if not pronouns:
|
|
return None
|
|
try:
|
|
return re.compile("|".join(pronouns), re.IGNORECASE)
|
|
except re.error:
|
|
return None
|
|
|
|
|
|
def score_entity(name: str, text: str, lines: list, languages=("en",)) -> dict:
|
|
"""
|
|
Score a candidate entity as person vs project.
|
|
Returns scores and the signals that fired.
|
|
"""
|
|
langs = _normalize_langs(languages)
|
|
patterns = _build_patterns(name, langs)
|
|
pronoun_re = _pronoun_re(langs)
|
|
person_score = 0
|
|
project_score = 0
|
|
person_signals = []
|
|
project_signals = []
|
|
|
|
# --- Person signals ---
|
|
|
|
# Dialogue markers (strong signal)
|
|
for rx in patterns["dialogue"]:
|
|
matches = len(rx.findall(text))
|
|
if matches > 0:
|
|
person_score += matches * 3
|
|
person_signals.append(f"dialogue marker ({matches}x)")
|
|
|
|
# Person verbs
|
|
for rx in patterns["person_verbs"]:
|
|
matches = len(rx.findall(text))
|
|
if matches > 0:
|
|
person_score += matches * 2
|
|
person_signals.append(f"'{name} ...' action ({matches}x)")
|
|
|
|
# Pronoun proximity — pronouns within 3 lines of the name
|
|
if pronoun_re is not None:
|
|
name_lower = name.lower()
|
|
name_line_indices = [i for i, line in enumerate(lines) if name_lower in line.lower()]
|
|
pronoun_hits = 0
|
|
for idx in name_line_indices:
|
|
window_text = " ".join(lines[max(0, idx - 2) : idx + 3])
|
|
if pronoun_re.search(window_text):
|
|
pronoun_hits += 1
|
|
if pronoun_hits > 0:
|
|
person_score += pronoun_hits * 2
|
|
person_signals.append(f"pronoun nearby ({pronoun_hits}x)")
|
|
|
|
# Direct address
|
|
direct_hits = 0
|
|
for rx in patterns["direct"]:
|
|
direct_hits += len(rx.findall(text))
|
|
if direct_hits > 0:
|
|
person_score += direct_hits * 4
|
|
person_signals.append(f"addressed directly ({direct_hits}x)")
|
|
|
|
# --- Project signals ---
|
|
|
|
for rx in patterns["project_verbs"]:
|
|
matches = len(rx.findall(text))
|
|
if matches > 0:
|
|
project_score += matches * 2
|
|
project_signals.append(f"project verb ({matches}x)")
|
|
|
|
versioned = len(patterns["versioned"].findall(text))
|
|
if versioned > 0:
|
|
project_score += versioned * 3
|
|
project_signals.append(f"versioned/hyphenated ({versioned}x)")
|
|
|
|
code_ref = len(patterns["code_ref"].findall(text))
|
|
if code_ref > 0:
|
|
project_score += code_ref * 3
|
|
project_signals.append(f"code file reference ({code_ref}x)")
|
|
|
|
return {
|
|
"person_score": person_score,
|
|
"project_score": project_score,
|
|
"person_signals": person_signals[:3],
|
|
"project_signals": project_signals[:3],
|
|
}
|
|
|
|
|
|
# ==================== CLASSIFY ====================
|
|
|
|
|
|
def classify_entity(name: str, frequency: int, scores: dict) -> dict:
|
|
"""
|
|
Given scores, classify as person / project / uncertain.
|
|
Returns entity dict with confidence.
|
|
"""
|
|
ps = scores["person_score"]
|
|
prs = scores["project_score"]
|
|
total = ps + prs
|
|
|
|
if total == 0:
|
|
# No strong signals — frequency-only candidate, uncertain
|
|
confidence = min(0.4, frequency / 50)
|
|
return {
|
|
"name": name,
|
|
"type": "uncertain",
|
|
"confidence": round(confidence, 2),
|
|
"frequency": frequency,
|
|
"signals": [f"appears {frequency}x, no strong type signals"],
|
|
}
|
|
|
|
person_ratio = ps / total if total > 0 else 0
|
|
|
|
# Require TWO different signal categories to confidently classify as a person.
|
|
# One signal type with many hits (e.g. "Click, click, click...") is not enough —
|
|
# it just means that word appears often in a particular syntactic position.
|
|
signal_categories = set()
|
|
for s in scores["person_signals"]:
|
|
if "dialogue" in s:
|
|
signal_categories.add("dialogue")
|
|
elif "action" in s:
|
|
signal_categories.add("action")
|
|
elif "pronoun" in s:
|
|
signal_categories.add("pronoun")
|
|
elif "addressed" in s:
|
|
signal_categories.add("addressed")
|
|
|
|
has_two_signal_types = len(signal_categories) >= 2
|
|
_ = signal_categories - {"pronoun"} # reserved for future thresholds
|
|
|
|
if person_ratio >= 0.7 and has_two_signal_types and ps >= 5:
|
|
entity_type = "person"
|
|
confidence = min(0.99, 0.5 + person_ratio * 0.5)
|
|
signals = scores["person_signals"] or [f"appears {frequency}x"]
|
|
elif person_ratio >= 0.7 and (not has_two_signal_types or ps < 5):
|
|
# Pronoun-only match — downgrade to uncertain
|
|
entity_type = "uncertain"
|
|
confidence = 0.4
|
|
signals = scores["person_signals"] + [f"appears {frequency}x — pronoun-only match"]
|
|
elif person_ratio <= 0.3:
|
|
entity_type = "project"
|
|
confidence = min(0.99, 0.5 + (1 - person_ratio) * 0.5)
|
|
signals = scores["project_signals"] or [f"appears {frequency}x"]
|
|
else:
|
|
entity_type = "uncertain"
|
|
confidence = 0.5
|
|
signals = (scores["person_signals"] + scores["project_signals"])[:3]
|
|
signals.append("mixed signals — needs review")
|
|
|
|
return {
|
|
"name": name,
|
|
"type": entity_type,
|
|
"confidence": round(confidence, 2),
|
|
"frequency": frequency,
|
|
"signals": signals,
|
|
}
|
|
|
|
|
|
# ==================== MAIN DETECT ====================
|
|
|
|
|
|
def detect_entities(file_paths: list, max_files: int = 10, languages=("en",)) -> dict:
|
|
"""
|
|
Scan files and detect entity candidates.
|
|
|
|
Args:
|
|
file_paths: List of Path objects to scan
|
|
max_files: Max files to read (for speed)
|
|
languages: Tuple of language codes whose entity patterns should be
|
|
applied (union). Defaults to ``("en",)``.
|
|
|
|
Returns:
|
|
{
|
|
"people": [...entity dicts...],
|
|
"projects": [...entity dicts...],
|
|
"uncertain":[...entity dicts...],
|
|
}
|
|
"""
|
|
langs = _normalize_langs(languages)
|
|
|
|
# Collect text from files
|
|
all_text = []
|
|
all_lines = []
|
|
files_read = 0
|
|
|
|
MAX_BYTES_PER_FILE = 5_000 # first 5KB per file — enough to catch recurring entities
|
|
|
|
for filepath in file_paths:
|
|
if files_read >= max_files:
|
|
break
|
|
try:
|
|
with open(filepath, encoding="utf-8", errors="replace") as f:
|
|
content = f.read(MAX_BYTES_PER_FILE)
|
|
all_text.append(content)
|
|
all_lines.extend(content.splitlines())
|
|
files_read += 1
|
|
except OSError:
|
|
continue
|
|
|
|
combined_text = "\n".join(all_text)
|
|
|
|
# Extract candidates
|
|
candidates = extract_candidates(combined_text, languages=langs)
|
|
|
|
if not candidates:
|
|
return {"people": [], "projects": [], "uncertain": []}
|
|
|
|
# Score and classify each candidate
|
|
people = []
|
|
projects = []
|
|
uncertain = []
|
|
|
|
for name, frequency in sorted(candidates.items(), key=lambda x: x[1], reverse=True):
|
|
scores = score_entity(name, combined_text, all_lines, languages=langs)
|
|
entity = classify_entity(name, frequency, scores)
|
|
|
|
if entity["type"] == "person":
|
|
people.append(entity)
|
|
elif entity["type"] == "project":
|
|
projects.append(entity)
|
|
else:
|
|
uncertain.append(entity)
|
|
|
|
# Sort by confidence descending
|
|
people.sort(key=lambda x: x["confidence"], reverse=True)
|
|
projects.sort(key=lambda x: x["confidence"], reverse=True)
|
|
uncertain.sort(key=lambda x: x["frequency"], reverse=True)
|
|
|
|
# Cap results to most relevant
|
|
return {
|
|
"people": people[:15],
|
|
"projects": projects[:10],
|
|
"uncertain": uncertain[:8],
|
|
}
|
|
|
|
|
|
# ==================== INTERACTIVE CONFIRM ====================
|
|
|
|
|
|
def _print_entity_list(entities: list, label: str):
|
|
print(f"\n {label}:")
|
|
if not entities:
|
|
print(" (none detected)")
|
|
return
|
|
for i, e in enumerate(entities):
|
|
confidence_bar = "●" * int(e["confidence"] * 5) + "○" * (5 - int(e["confidence"] * 5))
|
|
signals_str = ", ".join(e["signals"][:2]) if e["signals"] else ""
|
|
print(f" {i + 1:2}. {e['name']:20} [{confidence_bar}] {signals_str}")
|
|
|
|
|
|
def confirm_entities(detected: dict, yes: bool = False) -> dict:
|
|
"""
|
|
Interactive confirmation step.
|
|
User reviews detected entities, removes wrong ones, adds missing ones.
|
|
Returns confirmed {people: [names], projects: [names]}
|
|
|
|
Pass yes=True to auto-accept all detected entities without prompting.
|
|
"""
|
|
print(f"\n{'=' * 58}")
|
|
print(" MemPalace — Entity Detection")
|
|
print(f"{'=' * 58}")
|
|
print("\n Scanned your files. Here's what we found:\n")
|
|
|
|
_print_entity_list(detected["people"], "PEOPLE")
|
|
_print_entity_list(detected["projects"], "PROJECTS")
|
|
|
|
if detected["uncertain"]:
|
|
_print_entity_list(detected["uncertain"], "UNCERTAIN (need your call)")
|
|
|
|
confirmed_people = [e["name"] for e in detected["people"]]
|
|
confirmed_projects = [e["name"] for e in detected["projects"]]
|
|
|
|
if yes:
|
|
# Auto-accept: include all detected (skip uncertain — ambiguous without user input)
|
|
print(
|
|
f"\n Auto-accepting {len(confirmed_people)} people, {len(confirmed_projects)} projects."
|
|
)
|
|
return {"people": confirmed_people, "projects": confirmed_projects}
|
|
|
|
print(f"\n{'─' * 58}")
|
|
print(" Options:")
|
|
print(" [enter] Accept all")
|
|
print(" [edit] Remove wrong entries or reclassify uncertain")
|
|
print(" [add] Add missing people or projects")
|
|
print()
|
|
|
|
choice = input(" Your choice [enter/edit/add]: ").strip().lower()
|
|
|
|
confirmed_people = [e["name"] for e in detected["people"]]
|
|
confirmed_projects = [e["name"] for e in detected["projects"]]
|
|
|
|
if choice == "edit":
|
|
# Handle uncertain first
|
|
if detected["uncertain"]:
|
|
print("\n Uncertain entities — classify each:")
|
|
for e in detected["uncertain"]:
|
|
ans = input(f" {e['name']} — (p)erson, (r)project, or (s)kip? ").strip().lower()
|
|
if ans == "p":
|
|
confirmed_people.append(e["name"])
|
|
elif ans == "r":
|
|
confirmed_projects.append(e["name"])
|
|
|
|
# Remove wrong people
|
|
print(f"\n Current people: {', '.join(confirmed_people) or '(none)'}")
|
|
remove = input(
|
|
" Numbers to REMOVE from people (comma-separated, or enter to skip): "
|
|
).strip()
|
|
if remove:
|
|
to_remove = {int(x.strip()) - 1 for x in remove.split(",") if x.strip().isdigit()}
|
|
confirmed_people = [p for i, p in enumerate(confirmed_people) if i not in to_remove]
|
|
|
|
# Remove wrong projects
|
|
print(f"\n Current projects: {', '.join(confirmed_projects) or '(none)'}")
|
|
remove = input(
|
|
" Numbers to REMOVE from projects (comma-separated, or enter to skip): "
|
|
).strip()
|
|
if remove:
|
|
to_remove = {int(x.strip()) - 1 for x in remove.split(",") if x.strip().isdigit()}
|
|
confirmed_projects = [p for i, p in enumerate(confirmed_projects) if i not in to_remove]
|
|
|
|
if choice == "add" or input("\n Add any missing? [y/N]: ").strip().lower() == "y":
|
|
while True:
|
|
name = input(" Name (or enter to stop): ").strip()
|
|
if not name:
|
|
break
|
|
kind = input(f" Is '{name}' a (p)erson or p(r)oject? ").strip().lower()
|
|
if kind == "p":
|
|
confirmed_people.append(name)
|
|
elif kind == "r":
|
|
confirmed_projects.append(name)
|
|
|
|
print(f"\n{'=' * 58}")
|
|
print(" Confirmed:")
|
|
print(f" People: {', '.join(confirmed_people) or '(none)'}")
|
|
print(f" Projects: {', '.join(confirmed_projects) or '(none)'}")
|
|
print(f"{'=' * 58}\n")
|
|
|
|
return {
|
|
"people": confirmed_people,
|
|
"projects": confirmed_projects,
|
|
}
|
|
|
|
|
|
# ==================== SCAN HELPER ====================
|
|
|
|
|
|
def scan_for_detection(project_dir: str, max_files: int = 10) -> list:
|
|
"""
|
|
Collect prose file paths for entity detection.
|
|
Prose only (.txt, .md, .rst, .csv) — code files produce too many false positives.
|
|
Falls back to all readable files if no prose found.
|
|
"""
|
|
project_path = Path(project_dir).expanduser().resolve()
|
|
prose_files = []
|
|
all_files = []
|
|
|
|
for root, dirs, filenames in os.walk(project_path):
|
|
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
|
|
for filename in filenames:
|
|
filepath = Path(root) / filename
|
|
ext = filepath.suffix.lower()
|
|
if ext in PROSE_EXTENSIONS:
|
|
prose_files.append(filepath)
|
|
elif ext in READABLE_EXTENSIONS:
|
|
all_files.append(filepath)
|
|
|
|
# Prefer prose files — fall back to all readable if too few prose files
|
|
files = prose_files if len(prose_files) >= 3 else prose_files + all_files
|
|
return files[:max_files]
|
|
|
|
|
|
# ==================== CLI ====================
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
if len(sys.argv) < 2:
|
|
print("Usage: python entity_detector.py <directory> [lang1,lang2,...]")
|
|
sys.exit(1)
|
|
|
|
project_dir = sys.argv[1]
|
|
langs = tuple(sys.argv[2].split(",")) if len(sys.argv) >= 3 else ("en",)
|
|
print(f"Scanning: {project_dir} (languages: {', '.join(langs)})")
|
|
files = scan_for_detection(project_dir)
|
|
print(f"Reading {len(files)} files...")
|
|
detected = detect_entities(files, languages=langs)
|
|
confirmed = confirm_entities(detected)
|
|
print("Confirmed entities:", confirmed)
|