fix(search): hybrid closet+drawer retrieval — closets boost, never gate (#795)
This commit is contained in:
committed by
GitHub
parent
4d581cbb73
commit
8e446f904c
+129
-111
@@ -183,138 +183,156 @@ def search_memories(
|
|||||||
|
|
||||||
where = build_where_filter(wing, room)
|
where = build_where_filter(wing, room)
|
||||||
|
|
||||||
# Try closet-first search: search the compact index, then hydrate drawers
|
# Hybrid retrieval: always query drawers directly (the floor), then use
|
||||||
closet_hits = []
|
# closet hits to boost rankings. Closets are a ranking SIGNAL, never a
|
||||||
|
# GATE — direct drawer search is always the baseline.
|
||||||
|
#
|
||||||
|
# This avoids the "weak-closets regression" where narrative content
|
||||||
|
# produces low-signal closets (regex extraction matches few topics)
|
||||||
|
# and closet-first routing hides drawers that direct search would find.
|
||||||
|
try:
|
||||||
|
dkwargs = {
|
||||||
|
"query_texts": [query],
|
||||||
|
"n_results": n_results * 3, # over-fetch for re-ranking
|
||||||
|
"include": ["documents", "metadatas", "distances"],
|
||||||
|
}
|
||||||
|
if where:
|
||||||
|
dkwargs["where"] = where
|
||||||
|
drawer_results = drawers_col.query(**dkwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"Search error: {e}"}
|
||||||
|
|
||||||
|
# Gather closet hits (best-per-source) to build a boost lookup.
|
||||||
|
closet_boost_by_source = {} # source_file -> (rank, closet_dist, preview)
|
||||||
try:
|
try:
|
||||||
closets_col = get_closets_collection(palace_path, create=False)
|
closets_col = get_closets_collection(palace_path, create=False)
|
||||||
ckwargs = {
|
ckwargs = {
|
||||||
"query_texts": [query],
|
"query_texts": [query],
|
||||||
"n_results": n_results * 2, # over-fetch closets to find best drawers
|
"n_results": n_results * 2,
|
||||||
"include": ["documents", "metadatas", "distances"],
|
"include": ["documents", "metadatas", "distances"],
|
||||||
}
|
}
|
||||||
if where:
|
if where:
|
||||||
ckwargs["where"] = where
|
ckwargs["where"] = where
|
||||||
closet_results = closets_col.query(**ckwargs)
|
closet_results = closets_col.query(**ckwargs)
|
||||||
if closet_results["documents"][0]:
|
for rank, (doc, meta, dist) in enumerate(
|
||||||
closet_hits = list(zip(
|
zip(
|
||||||
closet_results["documents"][0],
|
closet_results["documents"][0],
|
||||||
closet_results["metadatas"][0],
|
closet_results["metadatas"][0],
|
||||||
closet_results["distances"][0],
|
closet_results["distances"][0],
|
||||||
))
|
)
|
||||||
|
):
|
||||||
|
source = meta.get("source_file", "")
|
||||||
|
if source and source not in closet_boost_by_source:
|
||||||
|
closet_boost_by_source[source] = (rank, dist, doc[:200])
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # no closets yet — fall through to direct drawer search
|
pass # no closets yet — hybrid degrades to pure drawer search
|
||||||
|
|
||||||
# If closets found results, hydrate the referenced drawers
|
# Rank-based boost. Ordinal signal (which closet matched best) is more
|
||||||
MAX_HYDRATION_CHARS = 10000 # cap to prevent blowup on large source files
|
# reliable than absolute distance on narrative content.
|
||||||
|
CLOSET_RANK_BOOSTS = [0.40, 0.25, 0.15, 0.08, 0.04]
|
||||||
|
CLOSET_DISTANCE_CAP = 1.5 # cosine dist > 1.5 = too weak to use as signal
|
||||||
|
|
||||||
if closet_hits:
|
scored = []
|
||||||
import re
|
for doc, meta, dist in zip(
|
||||||
seen_sources = set()
|
drawer_results["documents"][0],
|
||||||
hits = []
|
drawer_results["metadatas"][0],
|
||||||
for closet_doc, closet_meta, closet_dist in closet_hits:
|
drawer_results["distances"][0],
|
||||||
source = closet_meta.get("source_file", "")
|
):
|
||||||
if source in seen_sources:
|
|
||||||
continue
|
|
||||||
seen_sources.add(source)
|
|
||||||
|
|
||||||
# Find drawers for this source file, grep for most relevant chunk
|
|
||||||
try:
|
|
||||||
drawer_results = drawers_col.get(
|
|
||||||
where={"source_file": source},
|
|
||||||
include=["documents", "metadatas"],
|
|
||||||
)
|
|
||||||
if drawer_results.get("ids"):
|
|
||||||
# Drawer-grep: score each chunk against the query,
|
|
||||||
# return the best-matching chunk first + surrounding context
|
|
||||||
query_terms = set(re.findall(r'\w{2,}', query.lower()))
|
|
||||||
best_idx = 0
|
|
||||||
best_score = -1
|
|
||||||
for idx, doc in enumerate(drawer_results["documents"]):
|
|
||||||
doc_lower = doc.lower()
|
|
||||||
score = sum(1 for t in query_terms if t in doc_lower)
|
|
||||||
if score > best_score:
|
|
||||||
best_score = score
|
|
||||||
best_idx = idx
|
|
||||||
|
|
||||||
# Build result: best chunk first, then neighbors
|
|
||||||
docs = drawer_results["documents"]
|
|
||||||
n_docs = len(docs)
|
|
||||||
# Include best chunk + 1 before + 1 after for context
|
|
||||||
start = max(0, best_idx - 1)
|
|
||||||
end = min(n_docs, best_idx + 2)
|
|
||||||
relevant_text = "\n\n".join(docs[start:end])
|
|
||||||
|
|
||||||
if len(relevant_text) > MAX_HYDRATION_CHARS:
|
|
||||||
relevant_text = relevant_text[:MAX_HYDRATION_CHARS] + f"\n\n[...truncated. {n_docs} total drawers. Use mempalace_get_drawer for full content.]"
|
|
||||||
|
|
||||||
meta = drawer_results["metadatas"][best_idx]
|
|
||||||
hits.append({
|
|
||||||
"text": relevant_text,
|
|
||||||
"wing": meta.get("wing", "unknown"),
|
|
||||||
"room": meta.get("room", "unknown"),
|
|
||||||
"source_file": Path(source).name,
|
|
||||||
"similarity": round(max(0.0, 1 - closet_dist), 3),
|
|
||||||
"distance": round(closet_dist, 4),
|
|
||||||
"matched_via": "closet",
|
|
||||||
"closet_preview": closet_doc[:200],
|
|
||||||
"drawer_index": best_idx,
|
|
||||||
"total_drawers": n_docs,
|
|
||||||
})
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if len(hits) >= n_results:
|
|
||||||
break
|
|
||||||
|
|
||||||
if hits:
|
|
||||||
# Re-rank with BM25 hybrid scoring
|
|
||||||
hits = _hybrid_rank(hits, query)
|
|
||||||
return {
|
|
||||||
"query": query,
|
|
||||||
"filters": {"wing": wing, "room": room},
|
|
||||||
"total_before_filter": len(closet_hits),
|
|
||||||
"results": hits,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Fallback: direct drawer search (no closets yet, or closets empty)
|
|
||||||
try:
|
|
||||||
kwargs = {
|
|
||||||
"query_texts": [query],
|
|
||||||
"n_results": n_results,
|
|
||||||
"include": ["documents", "metadatas", "distances"],
|
|
||||||
}
|
|
||||||
if where:
|
|
||||||
kwargs["where"] = where
|
|
||||||
|
|
||||||
results = drawers_col.query(**kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": f"Search error: {e}"}
|
|
||||||
|
|
||||||
docs = results["documents"][0]
|
|
||||||
metas = results["metadatas"][0]
|
|
||||||
dists = results["distances"][0]
|
|
||||||
|
|
||||||
hits = []
|
|
||||||
for doc, meta, dist in zip(docs, metas, dists):
|
|
||||||
# Filter on raw distance before rounding to avoid precision loss
|
|
||||||
if max_distance > 0.0 and dist > max_distance:
|
if max_distance > 0.0 and dist > max_distance:
|
||||||
continue
|
continue
|
||||||
hits.append(
|
|
||||||
{
|
|
||||||
"text": doc,
|
|
||||||
"wing": meta.get("wing", "unknown"),
|
|
||||||
"room": meta.get("room", "unknown"),
|
|
||||||
"source_file": Path(meta.get("source_file", "?")).name,
|
|
||||||
"similarity": round(max(0.0, 1 - dist), 3),
|
|
||||||
"distance": round(dist, 4),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Re-rank with BM25 hybrid scoring
|
source = meta.get("source_file", "")
|
||||||
|
boost = 0.0
|
||||||
|
matched_via = "drawer"
|
||||||
|
closet_preview = None
|
||||||
|
if source in closet_boost_by_source:
|
||||||
|
c_rank, c_dist, c_preview = closet_boost_by_source[source]
|
||||||
|
if c_dist <= CLOSET_DISTANCE_CAP and c_rank < len(CLOSET_RANK_BOOSTS):
|
||||||
|
boost = CLOSET_RANK_BOOSTS[c_rank]
|
||||||
|
matched_via = "drawer+closet"
|
||||||
|
closet_preview = c_preview
|
||||||
|
|
||||||
|
effective_dist = dist - boost
|
||||||
|
entry = {
|
||||||
|
"text": doc,
|
||||||
|
"wing": meta.get("wing", "unknown"),
|
||||||
|
"room": meta.get("room", "unknown"),
|
||||||
|
"source_file": Path(source).name if source else "?",
|
||||||
|
"similarity": round(max(0.0, 1 - effective_dist), 3),
|
||||||
|
"distance": round(dist, 4),
|
||||||
|
"effective_distance": round(effective_dist, 4),
|
||||||
|
"closet_boost": round(boost, 3),
|
||||||
|
"matched_via": matched_via,
|
||||||
|
"_sort_key": effective_dist,
|
||||||
|
}
|
||||||
|
if closet_preview:
|
||||||
|
entry["closet_preview"] = closet_preview
|
||||||
|
scored.append(entry)
|
||||||
|
|
||||||
|
scored.sort(key=lambda h: h["_sort_key"])
|
||||||
|
hits = scored[:n_results]
|
||||||
|
|
||||||
|
# Drawer-grep enrichment: for top hits whose source file has multiple
|
||||||
|
# drawers, return the best-matching chunk + its immediate neighbors
|
||||||
|
# instead of just the single drawer. Preserves the chunk-expansion
|
||||||
|
# behavior users relied on in the closet-first path.
|
||||||
|
MAX_HYDRATION_CHARS = 10000
|
||||||
|
import re as _re
|
||||||
|
|
||||||
|
for h in hits:
|
||||||
|
if h["matched_via"] == "drawer":
|
||||||
|
continue
|
||||||
|
# Only enrich closet-matched hits (cheap: we already know source matters)
|
||||||
|
source_name = h["source_file"]
|
||||||
|
# Look up full source_file by matching suffix in candidate pool
|
||||||
|
full_source = next(
|
||||||
|
(
|
||||||
|
m.get("source_file", "")
|
||||||
|
for m in drawer_results["metadatas"][0]
|
||||||
|
if m.get("source_file", "").endswith(source_name)
|
||||||
|
),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
if not full_source:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
source_drawers = drawers_col.get(
|
||||||
|
where={"source_file": full_source}, include=["documents"]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
docs = source_drawers.get("documents") or []
|
||||||
|
if len(docs) <= 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
query_terms = set(_re.findall(r"\w{2,}", query.lower()))
|
||||||
|
best_idx, best_score = 0, -1
|
||||||
|
for idx, d in enumerate(docs):
|
||||||
|
d_lower = d.lower()
|
||||||
|
s = sum(1 for t in query_terms if t in d_lower)
|
||||||
|
if s > best_score:
|
||||||
|
best_score, best_idx = s, idx
|
||||||
|
|
||||||
|
start = max(0, best_idx - 1)
|
||||||
|
end = min(len(docs), best_idx + 2)
|
||||||
|
expanded = "\n\n".join(docs[start:end])
|
||||||
|
if len(expanded) > MAX_HYDRATION_CHARS:
|
||||||
|
expanded = (
|
||||||
|
expanded[:MAX_HYDRATION_CHARS]
|
||||||
|
+ f"\n\n[...truncated. {len(docs)} total drawers. Use mempalace_get_drawer for full content.]"
|
||||||
|
)
|
||||||
|
h["text"] = expanded
|
||||||
|
h["drawer_index"] = best_idx
|
||||||
|
h["total_drawers"] = len(docs)
|
||||||
|
|
||||||
|
# BM25 hybrid re-rank within the final candidate set
|
||||||
hits = _hybrid_rank(hits, query)
|
hits = _hybrid_rank(hits, query)
|
||||||
|
for h in hits:
|
||||||
|
h.pop("_sort_key", None)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"query": query,
|
"query": query,
|
||||||
"filters": {"wing": wing, "room": room},
|
"filters": {"wing": wing, "room": room},
|
||||||
"total_before_filter": len(docs),
|
"total_before_filter": len(drawer_results["documents"][0]),
|
||||||
"results": hits,
|
"results": hits,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,141 @@
|
|||||||
|
"""Tests for the hybrid closet+drawer retrieval in search_memories.
|
||||||
|
|
||||||
|
The hybrid path queries drawers directly (the floor) AND closets, applying a
|
||||||
|
rank-based boost to drawers whose source_file appears in top closet hits.
|
||||||
|
This avoids the "weak-closets regression" where low-signal closets (from
|
||||||
|
regex extraction on narrative content) could hide drawers that direct
|
||||||
|
search would have found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mempalace.palace import (
|
||||||
|
get_collection,
|
||||||
|
get_closets_collection,
|
||||||
|
upsert_closet_lines,
|
||||||
|
)
|
||||||
|
from mempalace.searcher import search_memories
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_drawers(palace_path):
|
||||||
|
"""Insert 4 short drawers with deterministic content."""
|
||||||
|
col = get_collection(palace_path, create=True)
|
||||||
|
col.upsert(
|
||||||
|
ids=["D1", "D2", "D3", "D4"],
|
||||||
|
documents=[
|
||||||
|
"We switched the auth service to use JWT tokens with a 24h expiry.",
|
||||||
|
"Database migration to PostgreSQL 15 completed last Tuesday.",
|
||||||
|
"The frontend team is debating whether to adopt TanStack Query.",
|
||||||
|
"Kafka consumer rebalance timeout set to 45 seconds after incident.",
|
||||||
|
],
|
||||||
|
metadatas=[
|
||||||
|
{"wing": "backend", "room": "auth", "source_file": "fixture_D1.md"},
|
||||||
|
{"wing": "backend", "room": "db", "source_file": "fixture_D2.md"},
|
||||||
|
{"wing": "frontend", "room": "state", "source_file": "fixture_D3.md"},
|
||||||
|
{"wing": "backend", "room": "queue", "source_file": "fixture_D4.md"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_strong_closet_for(palace_path, drawer_id, source_file, topics):
|
||||||
|
"""Insert a closet whose content strongly overlaps the query keywords."""
|
||||||
|
col = get_closets_collection(palace_path)
|
||||||
|
lines = [f"{t}||→{drawer_id}" for t in topics]
|
||||||
|
upsert_closet_lines(
|
||||||
|
col,
|
||||||
|
closet_id_base=f"closet_{drawer_id}",
|
||||||
|
lines=lines,
|
||||||
|
metadata={
|
||||||
|
"wing": "backend",
|
||||||
|
"room": "auth",
|
||||||
|
"source_file": source_file,
|
||||||
|
"generated_by": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── core invariant: closets can only HELP, never HIDE ─────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestHybridInvariant:
|
||||||
|
def test_no_closets_degrades_to_direct_drawer_search(self, tmp_path):
|
||||||
|
palace = str(tmp_path / "palace")
|
||||||
|
_seed_drawers(palace)
|
||||||
|
# No closets created.
|
||||||
|
result = search_memories("Kafka rebalance timeout", palace, n_results=3)
|
||||||
|
ids = [h["source_file"] for h in result["results"]]
|
||||||
|
assert ids, "should return results"
|
||||||
|
assert "fixture_D4.md" in ids, (
|
||||||
|
"direct drawer search alone should surface the Kafka drawer"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_weak_closets_do_not_hide_direct_drawer_hits(self, tmp_path):
|
||||||
|
"""A closet that points at a wrong drawer must NOT suppress the
|
||||||
|
drawer that direct search would have ranked first."""
|
||||||
|
palace = str(tmp_path / "palace")
|
||||||
|
_seed_drawers(palace)
|
||||||
|
# Seed a misleading closet: it matches a generic phrase but points at D3.
|
||||||
|
_seed_strong_closet_for(
|
||||||
|
palace,
|
||||||
|
drawer_id="D3",
|
||||||
|
source_file="fixture_D3.md",
|
||||||
|
topics=["Kafka queue tuning", "consumer rebalance config"],
|
||||||
|
)
|
||||||
|
result = search_memories("Kafka consumer rebalance timeout", palace, n_results=5)
|
||||||
|
ids = [h["source_file"] for h in result["results"]]
|
||||||
|
assert "fixture_D4.md" in ids, (
|
||||||
|
"D4 must appear — direct drawer search alone would rank it first. "
|
||||||
|
"Closet pointing to D3 should only boost D3, never hide D4."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_closet_boost_lifts_matching_drawer(self, tmp_path):
|
||||||
|
"""When a closet agrees with direct search, the matching drawer
|
||||||
|
should be boosted to rank 1."""
|
||||||
|
palace = str(tmp_path / "palace")
|
||||||
|
_seed_drawers(palace)
|
||||||
|
_seed_strong_closet_for(
|
||||||
|
palace,
|
||||||
|
drawer_id="D1",
|
||||||
|
source_file="fixture_D1.md",
|
||||||
|
topics=["JWT auth tokens", "session expiry", "authentication service"],
|
||||||
|
)
|
||||||
|
result = search_memories("JWT auth tokens expiry", palace, n_results=3)
|
||||||
|
ids = [h["source_file"] for h in result["results"]]
|
||||||
|
assert ids[0] == "fixture_D1.md"
|
||||||
|
top = result["results"][0]
|
||||||
|
assert top["matched_via"] == "drawer+closet"
|
||||||
|
assert top["closet_boost"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── closet_boost metadata ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestClosetMetadata:
|
||||||
|
def test_closet_preview_exposed_when_boosted(self, tmp_path):
|
||||||
|
palace = str(tmp_path / "palace")
|
||||||
|
_seed_drawers(palace)
|
||||||
|
_seed_strong_closet_for(
|
||||||
|
palace,
|
||||||
|
drawer_id="D1",
|
||||||
|
source_file="fixture_D1.md",
|
||||||
|
topics=["JWT auth tokens", "24h expiry", "authentication"],
|
||||||
|
)
|
||||||
|
result = search_memories("JWT authentication", palace, n_results=2)
|
||||||
|
top = result["results"][0]
|
||||||
|
assert top["source_file"] == "fixture_D1.md"
|
||||||
|
assert "closet_preview" in top
|
||||||
|
|
||||||
|
def test_drawer_only_hits_have_no_closet_preview(self, tmp_path):
|
||||||
|
palace = str(tmp_path / "palace")
|
||||||
|
_seed_drawers(palace)
|
||||||
|
# No closets
|
||||||
|
result = search_memories("TanStack Query", palace, n_results=2)
|
||||||
|
assert result["results"]
|
||||||
|
for h in result["results"]:
|
||||||
|
assert h["matched_via"] == "drawer"
|
||||||
|
assert "closet_preview" not in h
|
||||||
|
assert h["closet_boost"] == 0.0
|
||||||
Reference in New Issue
Block a user