diff --git a/mempalace/entity_detector.py b/mempalace/entity_detector.py index a20e2af..754c65d 100644 --- a/mempalace/entity_detector.py +++ b/mempalace/entity_detector.py @@ -134,10 +134,10 @@ def extract_candidates(text: str, languages=("en",)) -> dict: counts: defaultdict = defaultdict(int) - # Single-word candidates — one pattern per language - for raw_pat in patterns["candidate_patterns"]: + # Single-word candidates — one pre-wrapped pattern per language + for wrapped_pat in patterns["candidate_patterns"]: try: - rx = re.compile(rf"\b({raw_pat})\b") + rx = re.compile(wrapped_pat) except re.error: continue for word in rx.findall(text): @@ -147,10 +147,10 @@ def extract_candidates(text: str, languages=("en",)) -> dict: continue counts[word] += 1 - # Multi-word candidates — one pattern per language - for raw_pat in patterns["multi_word_patterns"]: + # Multi-word candidates — one pre-wrapped pattern per language + for wrapped_pat in patterns["multi_word_patterns"]: try: - rx = re.compile(rf"\b({raw_pat})\b") + rx = re.compile(wrapped_pat) except re.error: continue for phrase in rx.findall(text): diff --git a/mempalace/i18n/__init__.py b/mempalace/i18n/__init__.py index 671e0a1..f7b55a0 100644 --- a/mempalace/i18n/__init__.py +++ b/mempalace/i18n/__init__.py @@ -91,6 +91,90 @@ def _load_entity_section(lang: str) -> dict: return data.get("entity", {}) or {} +def _script_boundary(chars: str) -> str: + """Build a lookaround-based word boundary expression. + + Python's built-in ``\\b`` is a transition between ``\\w`` and non-``\\w``. + ``\\w`` covers Unicode Letter and Number categories but NOT Marks (category + Mc/Mn), so for scripts whose words contain combining vowel signs — Devanagari + (ा ी ु), Arabic (ـَ ـِ ـُ), Hebrew (ִ ֵ), Thai, Tamil, Burmese, Khmer — the + default ``\\b`` drops the trailing mark, truncating names like ``अनीता`` to + ``अनीत`` and failing to match ``\\bकहा\\b`` because the trailing matra is + not a word character. + + Locales with such scripts declare ``boundary_chars`` in their entity section + (e.g. ``"\\\\w\\\\u0900-\\\\u097F"`` for Hindi). This function returns a + regex fragment equivalent to ``\\b`` but where the "word" side is defined + as any char matching ``[chars]`` rather than just ``\\w``. + """ + return ( + rf"(?:(?<=[{chars}])(?=[^{chars}])" + rf"|(?<=[^{chars}])(?=[{chars}])" + rf"|^(?=[{chars}])" + rf"|(?<=[{chars}])$)" + ) + + +def _expand_b(pattern: str, boundary_chars: str) -> str: + """Replace every literal ``\\b`` in ``pattern`` with a script-aware boundary. + + ``boundary_chars`` is the inside-word character class (without brackets). + If it's falsy, the pattern is returned unchanged so ``\\b`` keeps its + default Python ``re`` semantics. + """ + if not boundary_chars: + return pattern + return pattern.replace(r"\b", _script_boundary(boundary_chars)) + + +def _wrap_candidate(raw_pat: str, boundary_chars: str) -> str: + """Wrap a candidate/multi-word extraction pattern with a capture group + and word boundaries appropriate for its locale. + + Default: ``\\b(raw)\\b``. With ``boundary_chars``: the script-aware + equivalent, so names ending in combining marks are matched in full. + """ + if boundary_chars: + b = _script_boundary(boundary_chars) + return f"{b}({raw_pat}){b}" + return rf"\b({raw_pat})\b" + + +def _collect_entity_section(section: dict, acc: dict) -> None: + """Merge one language's entity section into the running accumulator. + + Handles boundary expansion in-place so the caller merges already-expanded + strings: `candidate_patterns` and `multi_word_patterns` are pre-wrapped + with the locale's boundary (capture group included, ready to compile); + every ``\\b`` inside person/pronoun/dialogue/project/direct patterns is + replaced with the locale's script-aware boundary. + """ + boundary_chars = section.get("boundary_chars") + if section.get("candidate_pattern"): + acc["candidate_patterns"].append( + _wrap_candidate(section["candidate_pattern"], boundary_chars) + ) + if section.get("multi_word_pattern"): + acc["multi_word_patterns"].append( + _wrap_candidate(section["multi_word_pattern"], boundary_chars) + ) + if section.get("direct_address_pattern"): + acc["direct_address"].append(_expand_b(section["direct_address_pattern"], boundary_chars)) + acc["person_verbs"].extend( + _expand_b(p, boundary_chars) for p in section.get("person_verb_patterns", []) + ) + acc["pronouns"].extend( + _expand_b(p, boundary_chars) for p in section.get("pronoun_patterns", []) + ) + acc["dialogue"].extend( + _expand_b(p, boundary_chars) for p in section.get("dialogue_patterns", []) + ) + acc["project_verbs"].extend( + _expand_b(p, boundary_chars) for p in section.get("project_verb_patterns", []) + ) + acc["stopwords"].update(w.lower() for w in section.get("stopwords", [])) + + def get_entity_patterns(languages=("en",)) -> dict: """Return merged entity detection patterns for the requested languages. @@ -105,11 +189,17 @@ def get_entity_patterns(languages=("en",)) -> dict: - ``stopwords`` is the set union across all languages, returned as a sorted list. - ``candidate_patterns`` and ``multi_word_patterns`` are returned as - lists (one per language) since they use different character classes; - callers run each pattern independently and union the matches. + **fully-wrapped regex strings** (boundary + capture group applied); + the consumer compiles them directly with no further wrapping. - ``direct_address_pattern`` is returned as a list of per-language alternation patterns (not concatenated — each is applied separately). + Locales with combining-mark scripts can declare ``boundary_chars`` in + their entity section (e.g. ``"\\\\w\\\\u0900-\\\\u097F"`` for Hindi); + every ``\\b`` inside that locale's patterns — plus the candidate/multi- + word wrapping — is expanded to a script-aware lookaround boundary that + treats the declared characters as "inside-word". + If ``languages`` is empty or no requested language declares entity data, English is used as a fallback so callers always get a working config. """ @@ -119,14 +209,16 @@ def get_entity_patterns(languages=("en",)) -> dict: if key in _entity_cache: return _entity_cache[key] - candidate_patterns: list[str] = [] - multi_word_patterns: list[str] = [] - person_verbs: list[str] = [] - pronouns: list[str] = [] - dialogue: list[str] = [] - direct_address: list[str] = [] - project_verbs: list[str] = [] - stopwords: set = set() + acc = { + "candidate_patterns": [], + "multi_word_patterns": [], + "person_verbs": [], + "pronouns": [], + "dialogue": [], + "direct_address": [], + "project_verbs": [], + "stopwords": set(), + } found_any = False for lang in languages: @@ -134,42 +226,21 @@ def get_entity_patterns(languages=("en",)) -> dict: if not section: continue found_any = True - if section.get("candidate_pattern"): - candidate_patterns.append(section["candidate_pattern"]) - if section.get("multi_word_pattern"): - multi_word_patterns.append(section["multi_word_pattern"]) - if section.get("direct_address_pattern"): - direct_address.append(section["direct_address_pattern"]) - person_verbs.extend(section.get("person_verb_patterns", [])) - pronouns.extend(section.get("pronoun_patterns", [])) - dialogue.extend(section.get("dialogue_patterns", [])) - project_verbs.extend(section.get("project_verb_patterns", [])) - stopwords.update(w.lower() for w in section.get("stopwords", [])) + _collect_entity_section(section, acc) if not found_any: - # Fallback: load English directly - section = _load_entity_section("en") - if section.get("candidate_pattern"): - candidate_patterns.append(section["candidate_pattern"]) - if section.get("multi_word_pattern"): - multi_word_patterns.append(section["multi_word_pattern"]) - if section.get("direct_address_pattern"): - direct_address.append(section["direct_address_pattern"]) - person_verbs.extend(section.get("person_verb_patterns", [])) - pronouns.extend(section.get("pronoun_patterns", [])) - dialogue.extend(section.get("dialogue_patterns", [])) - project_verbs.extend(section.get("project_verb_patterns", [])) - stopwords.update(w.lower() for w in section.get("stopwords", [])) + # Fallback: load English directly so callers always get a working config. + _collect_entity_section(_load_entity_section("en"), acc) merged = { - "candidate_patterns": candidate_patterns, - "multi_word_patterns": multi_word_patterns, - "person_verb_patterns": _dedupe(person_verbs), - "pronoun_patterns": _dedupe(pronouns), - "dialogue_patterns": _dedupe(dialogue), - "direct_address_patterns": direct_address, - "project_verb_patterns": _dedupe(project_verbs), - "stopwords": sorted(stopwords), + "candidate_patterns": acc["candidate_patterns"], + "multi_word_patterns": acc["multi_word_patterns"], + "person_verb_patterns": _dedupe(acc["person_verbs"]), + "pronoun_patterns": _dedupe(acc["pronouns"]), + "dialogue_patterns": _dedupe(acc["dialogue"]), + "direct_address_patterns": acc["direct_address"], + "project_verb_patterns": _dedupe(acc["project_verbs"]), + "stopwords": sorted(acc["stopwords"]), } _entity_cache[key] = merged return merged diff --git a/tests/test_entity_detector.py b/tests/test_entity_detector.py index 50cb7d1..05a0923 100644 --- a/tests/test_entity_detector.py +++ b/tests/test_entity_detector.py @@ -589,3 +589,75 @@ def test_config_set_entity_languages_empty_falls_back_to_english(tmp_path, monke result = cfg.set_entity_languages([]) assert result == ["en"] assert cfg.entity_languages == ["en"] + + +# ── boundary_chars for combining-mark scripts ───────────────────────── + +# Devanagari vowel signs (matras) are Unicode Mc — not matched by \w. +# Without boundary_chars, \b truncates names like अनीता → अनीत and +# person_verb patterns never fire. With boundary_chars, the i18n loader +# replaces \b with a script-aware lookaround, fixing both. + +_DEVANAGARI_ENTITY = { + "boundary_chars": "\\w\\u0900-\\u097F", + "candidate_pattern": "[\\u0900-\\u097F]{2,20}", + "multi_word_pattern": "[\\u0900-\\u097F]+(?:\\s+[\\u0900-\\u097F]+)+", + "person_verb_patterns": [ + "\\b{name}\\s+ने\\s+कहा\\b", + "\\b{name}\\s+हँसा\\b", + ], + "pronoun_patterns": ["\\bवह\\b", "\\bउसने\\b"], + "dialogue_patterns": ["^{name}:\\s"], + "direct_address_pattern": "\\bनमस्ते\\s+{name}\\b", + "project_verb_patterns": [], + "stopwords": ["यह", "वह", "और", "का", "के", "की"], +} + + +def test_devanagari_candidate_extraction_with_boundary_chars(): + """Names ending in matras are extracted in full with boundary_chars.""" + with _temp_locale("zz-test-hindi", _DEVANAGARI_ENTITY): + text = "अनीता ने कहा। अनीता हँसा। अनीता सोचा। अनीता बोला।" + result = extract_candidates(text, languages=("en", "zz-test-hindi")) + assert "अनीता" in result, f"expected अनीता in {result}" + assert result["अनीता"] >= 3 + + +def test_devanagari_candidate_without_boundary_chars_truncates(): + """Without boundary_chars, a matra-ending name gets truncated.""" + locale_no_boundary = dict(_DEVANAGARI_ENTITY) + del locale_no_boundary["boundary_chars"] + with _temp_locale("zz-test-hindi-no-b", locale_no_boundary): + text = "अनीता ने कहा। अनीता हँसा। अनीता सोचा।" + result = extract_candidates(text, languages=("en", "zz-test-hindi-no-b")) + # Without boundary_chars, \b splits on the matra — full name won't appear + assert "अनीता" not in result + + +def test_devanagari_person_verb_fires_with_boundary_chars(): + """Hindi person-verb patterns fire when boundary_chars extends \\b.""" + with _temp_locale("zz-test-hindi", _DEVANAGARI_ENTITY): + text = "राज ने कहा कुछ। राज हँसा।" + lines = text.splitlines() + scores = score_entity("राज", text, lines, languages=("en", "zz-test-hindi")) + assert scores["person_score"] > 0, f"expected person_score > 0, got {scores}" + assert any("action" in s for s in scores["person_signals"]) + + +def test_devanagari_person_verb_silent_without_boundary_chars(): + """Without boundary_chars, Hindi person verbs don't fire.""" + locale_no_boundary = dict(_DEVANAGARI_ENTITY) + del locale_no_boundary["boundary_chars"] + with _temp_locale("zz-test-hindi-no-b", locale_no_boundary): + text = "राज ने कहा कुछ। राज हँसा।" + lines = text.splitlines() + scores = score_entity("राज", text, lines, languages=("en", "zz-test-hindi-no-b")) + assert scores["person_score"] == 0 + + +def test_boundary_chars_english_regression(): + """English patterns (no boundary_chars) still work identically.""" + text = "Riley said hello. Riley laughed. Riley smiled. Riley waved." + result = extract_candidates(text, languages=("en",)) + assert "Riley" in result + assert result["Riley"] >= 3