3406 lines
117 KiB
Python
3406 lines
117 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
MemPal × LongMemEval Benchmark
|
|||
|
|
================================
|
|||
|
|
|
|||
|
|
Evaluates MemPal's retrieval against the LongMemEval benchmark.
|
|||
|
|
No modifications to LongMemEval's code required.
|
|||
|
|
|
|||
|
|
For each of the 500 questions:
|
|||
|
|
1. Ingest all haystack sessions into a fresh MemPal palace
|
|||
|
|
2. Query the palace with the question
|
|||
|
|
3. Score retrieval against ground-truth answer sessions
|
|||
|
|
|
|||
|
|
Outputs:
|
|||
|
|
- Recall@k and NDCG@k at session and turn level
|
|||
|
|
- Per-question-type breakdown
|
|||
|
|
- JSONL log compatible with LongMemEval's evaluation scripts
|
|||
|
|
|
|||
|
|
Modes:
|
|||
|
|
raw — baseline: raw text into ChromaDB (default)
|
|||
|
|
aaak — AAAK dialect compression before ingestion
|
|||
|
|
rooms — topic-based room detection + room-filtered search
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
python benchmarks/longmemeval_bench.py data/longmemeval_s_cleaned.json
|
|||
|
|
python benchmarks/longmemeval_bench.py data/longmemeval_s_cleaned.json --mode aaak
|
|||
|
|
python benchmarks/longmemeval_bench.py data/longmemeval_s_cleaned.json --mode rooms
|
|||
|
|
python benchmarks/longmemeval_bench.py data/longmemeval_s_cleaned.json --granularity turn
|
|||
|
|
python benchmarks/longmemeval_bench.py data/longmemeval_s_cleaned.json --limit 20
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import re
|
|||
|
|
import json
|
|||
|
|
import argparse
|
|||
|
|
import math
|
|||
|
|
from pathlib import Path
|
|||
|
|
from collections import defaultdict
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
import chromadb
|
|||
|
|
|
|||
|
|
# Add mempal to path
|
|||
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# METRICS (reimplemented to avoid LongMemEval dependency)
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
|
|||
|
|
def dcg(relevances, k):
|
|||
|
|
"""Discounted Cumulative Gain."""
|
|||
|
|
score = 0.0
|
|||
|
|
for i, rel in enumerate(relevances[:k]):
|
|||
|
|
score += rel / math.log2(i + 2)
|
|||
|
|
return score
|
|||
|
|
|
|||
|
|
|
|||
|
|
def ndcg(rankings, correct_ids, corpus_ids, k):
|
|||
|
|
"""Normalized DCG."""
|
|||
|
|
relevances = [1.0 if corpus_ids[idx] in correct_ids else 0.0 for idx in rankings[:k]]
|
|||
|
|
ideal = sorted(relevances, reverse=True)
|
|||
|
|
idcg = dcg(ideal, k)
|
|||
|
|
if idcg == 0:
|
|||
|
|
return 0.0
|
|||
|
|
return dcg(relevances, k) / idcg
|
|||
|
|
|
|||
|
|
|
|||
|
|
def evaluate_retrieval(rankings, correct_ids, corpus_ids, k):
|
|||
|
|
"""
|
|||
|
|
Evaluate retrieval at rank k.
|
|||
|
|
Returns (recall_any, recall_all, ndcg_score).
|
|||
|
|
"""
|
|||
|
|
top_k_ids = set(corpus_ids[idx] for idx in rankings[:k])
|
|||
|
|
recall_any = float(any(cid in top_k_ids for cid in correct_ids))
|
|||
|
|
recall_all = float(all(cid in top_k_ids for cid in correct_ids))
|
|||
|
|
ndcg_score = ndcg(rankings, correct_ids, corpus_ids, k)
|
|||
|
|
return recall_any, recall_all, ndcg_score
|
|||
|
|
|
|||
|
|
|
|||
|
|
def session_id_from_corpus_id(corpus_id):
|
|||
|
|
"""Extract session ID from a corpus ID (handles both session and turn granularity)."""
|
|||
|
|
# Turn IDs look like "sess_123_turn_4" — session part is "sess_123"
|
|||
|
|
if "_turn_" in corpus_id:
|
|||
|
|
return corpus_id.rsplit("_turn_", 1)[0]
|
|||
|
|
return corpus_id
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# SHARED EPHEMERAL CLIENT
|
|||
|
|
# EphemeralClient instances share state in this ChromaDB version — use one
|
|||
|
|
# shared client and delete+recreate the collection between queries.
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
_bench_client = chromadb.EphemeralClient()
|
|||
|
|
|
|||
|
|
# Global embedding function — set by --embed-model arg before benchmark runs.
|
|||
|
|
# None = use ChromaDB default (all-MiniLM-L6-v2).
|
|||
|
|
_bench_embed_fn = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _make_embed_fn(model_name: str):
|
|||
|
|
"""
|
|||
|
|
Return a ChromaDB-compatible embedding function for the given model.
|
|||
|
|
|
|||
|
|
Supported:
|
|||
|
|
default — ChromaDB default (all-MiniLM-L6-v2, 384-dim)
|
|||
|
|
bge-base — BAAI/bge-base-en-v1.5 (768-dim) via fastembed
|
|||
|
|
bge-large — BAAI/bge-large-en-v1.5 (1024-dim) via fastembed
|
|||
|
|
nomic — nomic-ai/nomic-embed-text-v1.5 (768-dim) via fastembed
|
|||
|
|
mxbai — mixedbread-ai/mxbai-embed-large-v1 (1024-dim) via fastembed
|
|||
|
|
"""
|
|||
|
|
if model_name == "default" or not model_name:
|
|||
|
|
return None # ChromaDB default
|
|||
|
|
|
|||
|
|
MODEL_MAP = {
|
|||
|
|
"bge-base": "BAAI/bge-base-en-v1.5",
|
|||
|
|
"bge-large": "BAAI/bge-large-en-v1.5",
|
|||
|
|
"nomic": "nomic-ai/nomic-embed-text-v1.5",
|
|||
|
|
"mxbai": "mixedbread-ai/mxbai-embed-large-v1",
|
|||
|
|
}
|
|||
|
|
hf_name = MODEL_MAP.get(model_name, model_name)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from fastembed import TextEmbedding
|
|||
|
|
from chromadb.api.types import EmbeddingFunction, Documents, Embeddings
|
|||
|
|
|
|||
|
|
class _FastEmbedFn(EmbeddingFunction):
|
|||
|
|
def __init__(self, name):
|
|||
|
|
print(f" Loading embedding model: {name} (first run downloads ~300-1300MB)...")
|
|||
|
|
self._model = TextEmbedding(name)
|
|||
|
|
print(" Model ready.")
|
|||
|
|
|
|||
|
|
def __call__(self, input: Documents) -> Embeddings:
|
|||
|
|
return [list(vec) for vec in self._model.embed(input)]
|
|||
|
|
|
|||
|
|
return _FastEmbedFn(hf_name)
|
|||
|
|
except ImportError:
|
|||
|
|
print("ERROR: fastembed not installed. Run: pip install fastembed")
|
|||
|
|
print(" Falling back to default embedding model.")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _fresh_collection(name="mempal_drawers"):
|
|||
|
|
"""Delete and recreate collection for a clean slate between queries."""
|
|||
|
|
global _bench_embed_fn
|
|||
|
|
try:
|
|||
|
|
_bench_client.delete_collection(name)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
if _bench_embed_fn is not None:
|
|||
|
|
return _bench_client.create_collection(name, embedding_function=_bench_embed_fn)
|
|||
|
|
return _bench_client.create_collection(name)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# MEMPAL RETRIEVER
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_palace_and_retrieve(entry, granularity="session", n_results=50):
|
|||
|
|
"""
|
|||
|
|
Build a fresh MemPal palace from haystack sessions, then retrieve.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
entry: One LongMemEval question entry
|
|||
|
|
granularity: "session" (one doc per session) or "turn" (one doc per user turn)
|
|||
|
|
n_results: How many results to return
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
rankings: numpy-style list of indices into corpus (descending relevance)
|
|||
|
|
corpus: list of document strings
|
|||
|
|
corpus_ids: list of document IDs
|
|||
|
|
corpus_timestamps: list of timestamps
|
|||
|
|
"""
|
|||
|
|
# Build corpus from haystack
|
|||
|
|
corpus = []
|
|||
|
|
corpus_ids = []
|
|||
|
|
corpus_timestamps = []
|
|||
|
|
|
|||
|
|
sessions = entry["haystack_sessions"]
|
|||
|
|
session_ids = entry["haystack_session_ids"]
|
|||
|
|
dates = entry["haystack_dates"]
|
|||
|
|
|
|||
|
|
for sess_idx, (session, sess_id, date) in enumerate(zip(sessions, session_ids, dates)):
|
|||
|
|
if granularity == "session":
|
|||
|
|
# One document per session: join all user content
|
|||
|
|
user_turns = [t["content"] for t in session if t["role"] == "user"]
|
|||
|
|
if user_turns:
|
|||
|
|
doc = "\n".join(user_turns)
|
|||
|
|
corpus.append(doc)
|
|||
|
|
corpus_ids.append(sess_id)
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
else:
|
|||
|
|
# One document per user turn
|
|||
|
|
turn_num = 0
|
|||
|
|
for turn in session:
|
|||
|
|
if turn["role"] == "user":
|
|||
|
|
corpus.append(turn["content"])
|
|||
|
|
corpus_ids.append(f"{sess_id}_turn_{turn_num}")
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
turn_num += 1
|
|||
|
|
|
|||
|
|
if not corpus:
|
|||
|
|
return [], corpus, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
collection = _fresh_collection()
|
|||
|
|
|
|||
|
|
# Add all corpus documents
|
|||
|
|
collection.add(
|
|||
|
|
documents=corpus,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(corpus))],
|
|||
|
|
metadatas=[
|
|||
|
|
{"corpus_id": cid, "timestamp": ts} for cid, ts in zip(corpus_ids, corpus_timestamps)
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Query
|
|||
|
|
query = entry["question"]
|
|||
|
|
results = collection.query(
|
|||
|
|
query_texts=[query],
|
|||
|
|
n_results=min(n_results, len(corpus)),
|
|||
|
|
include=["distances", "metadatas"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Map results back to corpus indices
|
|||
|
|
result_ids = results["ids"][0]
|
|||
|
|
|
|||
|
|
# Build rankings: indices into corpus sorted by relevance (lowest distance = most relevant)
|
|||
|
|
doc_id_to_idx = {f"doc_{i}": i for i in range(len(corpus))}
|
|||
|
|
ranked_indices = [doc_id_to_idx[rid] for rid in result_ids]
|
|||
|
|
|
|||
|
|
# Fill in any missing indices (ChromaDB may return fewer than corpus size)
|
|||
|
|
seen = set(ranked_indices)
|
|||
|
|
for i in range(len(corpus)):
|
|||
|
|
if i not in seen:
|
|||
|
|
ranked_indices.append(i)
|
|||
|
|
|
|||
|
|
return ranked_indices, corpus, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_palace_and_retrieve_aaak(entry, granularity="session", n_results=50):
|
|||
|
|
"""
|
|||
|
|
AAAK mode: compress each session/turn with AAAK dialect before ingesting.
|
|||
|
|
Query still uses raw question text — tests whether compressed representations
|
|||
|
|
retain enough semantic signal for retrieval.
|
|||
|
|
"""
|
|||
|
|
from mempalace.dialect import Dialect
|
|||
|
|
|
|||
|
|
dialect = Dialect()
|
|||
|
|
|
|||
|
|
corpus = [] # original text (for output)
|
|||
|
|
corpus_compressed = [] # AAAK compressed (for ingestion)
|
|||
|
|
corpus_ids = []
|
|||
|
|
corpus_timestamps = []
|
|||
|
|
|
|||
|
|
sessions = entry["haystack_sessions"]
|
|||
|
|
session_ids = entry["haystack_session_ids"]
|
|||
|
|
dates = entry["haystack_dates"]
|
|||
|
|
|
|||
|
|
for sess_idx, (session, sess_id, date) in enumerate(zip(sessions, session_ids, dates)):
|
|||
|
|
if granularity == "session":
|
|||
|
|
user_turns = [t["content"] for t in session if t["role"] == "user"]
|
|||
|
|
if user_turns:
|
|||
|
|
doc = "\n".join(user_turns)
|
|||
|
|
compressed = dialect.compress(doc, metadata={"date": date})
|
|||
|
|
corpus.append(doc)
|
|||
|
|
corpus_compressed.append(compressed)
|
|||
|
|
corpus_ids.append(sess_id)
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
else:
|
|||
|
|
turn_num = 0
|
|||
|
|
for turn in session:
|
|||
|
|
if turn["role"] == "user":
|
|||
|
|
compressed = dialect.compress(turn["content"])
|
|||
|
|
corpus.append(turn["content"])
|
|||
|
|
corpus_compressed.append(compressed)
|
|||
|
|
corpus_ids.append(f"{sess_id}_turn_{turn_num}")
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
turn_num += 1
|
|||
|
|
|
|||
|
|
if not corpus:
|
|||
|
|
return [], corpus, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
collection = _fresh_collection()
|
|||
|
|
|
|||
|
|
# Ingest AAAK compressed text
|
|||
|
|
collection.add(
|
|||
|
|
documents=corpus_compressed,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(corpus_compressed))],
|
|||
|
|
metadatas=[
|
|||
|
|
{"corpus_id": cid, "timestamp": ts} for cid, ts in zip(corpus_ids, corpus_timestamps)
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Query with raw question (not compressed)
|
|||
|
|
query = entry["question"]
|
|||
|
|
results = collection.query(
|
|||
|
|
query_texts=[query],
|
|||
|
|
n_results=min(n_results, len(corpus)),
|
|||
|
|
include=["distances", "metadatas"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result_ids = results["ids"][0]
|
|||
|
|
doc_id_to_idx = {f"doc_{i}": i for i in range(len(corpus))}
|
|||
|
|
ranked_indices = [doc_id_to_idx[rid] for rid in result_ids]
|
|||
|
|
|
|||
|
|
seen = set(ranked_indices)
|
|||
|
|
for i in range(len(corpus)):
|
|||
|
|
if i not in seen:
|
|||
|
|
ranked_indices.append(i)
|
|||
|
|
|
|||
|
|
return ranked_indices, corpus, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Topic keywords for room detection (same as convo_miner.py)
|
|||
|
|
TOPIC_KEYWORDS = {
|
|||
|
|
"technical": [
|
|||
|
|
"code",
|
|||
|
|
"python",
|
|||
|
|
"function",
|
|||
|
|
"bug",
|
|||
|
|
"error",
|
|||
|
|
"api",
|
|||
|
|
"database",
|
|||
|
|
"server",
|
|||
|
|
"deploy",
|
|||
|
|
"git",
|
|||
|
|
"test",
|
|||
|
|
"debug",
|
|||
|
|
"refactor",
|
|||
|
|
],
|
|||
|
|
"planning": [
|
|||
|
|
"plan",
|
|||
|
|
"roadmap",
|
|||
|
|
"milestone",
|
|||
|
|
"deadline",
|
|||
|
|
"priority",
|
|||
|
|
"sprint",
|
|||
|
|
"backlog",
|
|||
|
|
"scope",
|
|||
|
|
"requirement",
|
|||
|
|
"spec",
|
|||
|
|
],
|
|||
|
|
"decisions": [
|
|||
|
|
"decided",
|
|||
|
|
"chose",
|
|||
|
|
"picked",
|
|||
|
|
"switched",
|
|||
|
|
"migrated",
|
|||
|
|
"replaced",
|
|||
|
|
"trade-off",
|
|||
|
|
"alternative",
|
|||
|
|
"option",
|
|||
|
|
"approach",
|
|||
|
|
],
|
|||
|
|
"personal": [
|
|||
|
|
"family",
|
|||
|
|
"friend",
|
|||
|
|
"birthday",
|
|||
|
|
"vacation",
|
|||
|
|
"hobby",
|
|||
|
|
"health",
|
|||
|
|
"feeling",
|
|||
|
|
"love",
|
|||
|
|
"home",
|
|||
|
|
"weekend",
|
|||
|
|
],
|
|||
|
|
"knowledge": [
|
|||
|
|
"learn",
|
|||
|
|
"study",
|
|||
|
|
"degree",
|
|||
|
|
"school",
|
|||
|
|
"university",
|
|||
|
|
"course",
|
|||
|
|
"research",
|
|||
|
|
"paper",
|
|||
|
|
"book",
|
|||
|
|
"reading",
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def detect_room_for_text(text):
|
|||
|
|
"""Score text against topic keywords, return best room."""
|
|||
|
|
text_lower = text[:3000].lower()
|
|||
|
|
scores = {}
|
|||
|
|
for room, keywords in TOPIC_KEYWORDS.items():
|
|||
|
|
score = sum(1 for kw in keywords if kw in text_lower)
|
|||
|
|
if score > 0:
|
|||
|
|
scores[room] = score
|
|||
|
|
if scores:
|
|||
|
|
return max(scores, key=scores.get)
|
|||
|
|
return "general"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_palace_and_retrieve_rooms(entry, granularity="session", n_results=50):
|
|||
|
|
"""
|
|||
|
|
Room-structured mode: detect topic room per session, then do a two-pass search:
|
|||
|
|
1. Detect what room the question belongs to
|
|||
|
|
2. Search within that room first (boosted), then search globally
|
|||
|
|
"""
|
|||
|
|
corpus = []
|
|||
|
|
corpus_ids = []
|
|||
|
|
corpus_timestamps = []
|
|||
|
|
corpus_rooms = []
|
|||
|
|
|
|||
|
|
sessions = entry["haystack_sessions"]
|
|||
|
|
session_ids = entry["haystack_session_ids"]
|
|||
|
|
dates = entry["haystack_dates"]
|
|||
|
|
|
|||
|
|
for sess_idx, (session, sess_id, date) in enumerate(zip(sessions, session_ids, dates)):
|
|||
|
|
if granularity == "session":
|
|||
|
|
user_turns = [t["content"] for t in session if t["role"] == "user"]
|
|||
|
|
if user_turns:
|
|||
|
|
doc = "\n".join(user_turns)
|
|||
|
|
room = detect_room_for_text(doc)
|
|||
|
|
corpus.append(doc)
|
|||
|
|
corpus_ids.append(sess_id)
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
corpus_rooms.append(room)
|
|||
|
|
else:
|
|||
|
|
turn_num = 0
|
|||
|
|
for turn in session:
|
|||
|
|
if turn["role"] == "user":
|
|||
|
|
room = detect_room_for_text(turn["content"])
|
|||
|
|
corpus.append(turn["content"])
|
|||
|
|
corpus_ids.append(f"{sess_id}_turn_{turn_num}")
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
corpus_rooms.append(room)
|
|||
|
|
turn_num += 1
|
|||
|
|
|
|||
|
|
if not corpus:
|
|||
|
|
return [], corpus, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
collection = _fresh_collection()
|
|||
|
|
|
|||
|
|
collection.add(
|
|||
|
|
documents=corpus,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(corpus))],
|
|||
|
|
metadatas=[
|
|||
|
|
{"corpus_id": cid, "timestamp": ts, "room": room}
|
|||
|
|
for cid, ts, room in zip(corpus_ids, corpus_timestamps, corpus_rooms)
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
query = entry["question"]
|
|||
|
|
query_room = detect_room_for_text(query)
|
|||
|
|
|
|||
|
|
# Global search with room-based reranking (soft boost, not hard filter)
|
|||
|
|
global_results = collection.query(
|
|||
|
|
query_texts=[query],
|
|||
|
|
n_results=min(n_results, len(corpus)),
|
|||
|
|
include=["distances", "metadatas"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Rerank: boost results in the matching room by reducing distance
|
|||
|
|
doc_id_to_idx = {f"doc_{i}": i for i in range(len(corpus))}
|
|||
|
|
scored = []
|
|||
|
|
for rid, dist, meta in zip(
|
|||
|
|
global_results["ids"][0],
|
|||
|
|
global_results["distances"][0],
|
|||
|
|
global_results["metadatas"][0],
|
|||
|
|
):
|
|||
|
|
idx = doc_id_to_idx[rid]
|
|||
|
|
# Soft boost: reduce distance by 20% if room matches
|
|||
|
|
boosted_dist = dist * 0.8 if meta.get("room") == query_room else dist
|
|||
|
|
scored.append((idx, boosted_dist))
|
|||
|
|
|
|||
|
|
# Sort by boosted distance (ascending = most relevant first)
|
|||
|
|
scored.sort(key=lambda x: x[1])
|
|||
|
|
ranked_indices = [idx for idx, _ in scored]
|
|||
|
|
|
|||
|
|
# Fill remaining
|
|||
|
|
seen = set(ranked_indices)
|
|||
|
|
for i in range(len(corpus)):
|
|||
|
|
if i not in seen:
|
|||
|
|
ranked_indices.append(i)
|
|||
|
|
|
|||
|
|
return ranked_indices, corpus, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_palace_and_retrieve_hybrid(
|
|||
|
|
entry, granularity="session", n_results=50, hybrid_weight=0.30
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Hybrid mode: semantic search + keyword overlap re-ranking.
|
|||
|
|
|
|||
|
|
Two-stage approach:
|
|||
|
|
1. Retrieve top-N via ChromaDB semantic search (same as raw)
|
|||
|
|
2. Re-rank by fusing semantic distance with keyword overlap score
|
|||
|
|
|
|||
|
|
Keyword overlap catches cases where the answer keyword is very specific
|
|||
|
|
("Business Administration", "stand mixer") but embedding similarity
|
|||
|
|
alone doesn't push it into the top-5.
|
|||
|
|
|
|||
|
|
Also applies temporal recency bonus for temporal-reasoning questions.
|
|||
|
|
"""
|
|||
|
|
STOP_WORDS = {
|
|||
|
|
"what",
|
|||
|
|
"when",
|
|||
|
|
"where",
|
|||
|
|
"who",
|
|||
|
|
"how",
|
|||
|
|
"which",
|
|||
|
|
"did",
|
|||
|
|
"do",
|
|||
|
|
"was",
|
|||
|
|
"were",
|
|||
|
|
"have",
|
|||
|
|
"has",
|
|||
|
|
"had",
|
|||
|
|
"is",
|
|||
|
|
"are",
|
|||
|
|
"the",
|
|||
|
|
"a",
|
|||
|
|
"an",
|
|||
|
|
"my",
|
|||
|
|
"me",
|
|||
|
|
"i",
|
|||
|
|
"you",
|
|||
|
|
"your",
|
|||
|
|
"their",
|
|||
|
|
"it",
|
|||
|
|
"its",
|
|||
|
|
"in",
|
|||
|
|
"on",
|
|||
|
|
"at",
|
|||
|
|
"to",
|
|||
|
|
"for",
|
|||
|
|
"of",
|
|||
|
|
"with",
|
|||
|
|
"by",
|
|||
|
|
"from",
|
|||
|
|
"ago",
|
|||
|
|
"last",
|
|||
|
|
"that",
|
|||
|
|
"this",
|
|||
|
|
"there",
|
|||
|
|
"about",
|
|||
|
|
"get",
|
|||
|
|
"got",
|
|||
|
|
"give",
|
|||
|
|
"gave",
|
|||
|
|
"buy",
|
|||
|
|
"bought",
|
|||
|
|
"made",
|
|||
|
|
"make",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def extract_keywords(text):
|
|||
|
|
words = re.findall(r"\b[a-z]{3,}\b", text.lower())
|
|||
|
|
return [w for w in words if w not in STOP_WORDS]
|
|||
|
|
|
|||
|
|
def keyword_overlap(query_kws, doc_text):
|
|||
|
|
doc_lower = doc_text.lower()
|
|||
|
|
if not query_kws:
|
|||
|
|
return 0.0
|
|||
|
|
hits = sum(1 for kw in query_kws if kw in doc_lower)
|
|||
|
|
return hits / len(query_kws)
|
|||
|
|
|
|||
|
|
corpus = []
|
|||
|
|
corpus_ids = []
|
|||
|
|
corpus_timestamps = []
|
|||
|
|
|
|||
|
|
sessions = entry["haystack_sessions"]
|
|||
|
|
session_ids = entry["haystack_session_ids"]
|
|||
|
|
dates = entry["haystack_dates"]
|
|||
|
|
|
|||
|
|
for sess_idx, (session, sess_id, date) in enumerate(zip(sessions, session_ids, dates)):
|
|||
|
|
if granularity == "session":
|
|||
|
|
user_turns = [t["content"] for t in session if t["role"] == "user"]
|
|||
|
|
if user_turns:
|
|||
|
|
doc = "\n".join(user_turns)
|
|||
|
|
corpus.append(doc)
|
|||
|
|
corpus_ids.append(sess_id)
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
else:
|
|||
|
|
turn_num = 0
|
|||
|
|
for turn in session:
|
|||
|
|
if turn["role"] == "user":
|
|||
|
|
corpus.append(turn["content"])
|
|||
|
|
corpus_ids.append(f"{sess_id}_turn_{turn_num}")
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
turn_num += 1
|
|||
|
|
|
|||
|
|
if not corpus:
|
|||
|
|
return [], corpus, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
collection = _fresh_collection()
|
|||
|
|
|
|||
|
|
collection.add(
|
|||
|
|
documents=corpus,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(corpus))],
|
|||
|
|
metadatas=[
|
|||
|
|
{"corpus_id": cid, "timestamp": ts} for cid, ts in zip(corpus_ids, corpus_timestamps)
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
query = entry["question"]
|
|||
|
|
results = collection.query(
|
|||
|
|
query_texts=[query],
|
|||
|
|
n_results=min(n_results, len(corpus)),
|
|||
|
|
include=["distances", "metadatas", "documents"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result_ids = results["ids"][0]
|
|||
|
|
distances = results["distances"][0]
|
|||
|
|
documents = results["documents"][0]
|
|||
|
|
|
|||
|
|
doc_id_to_idx = {f"doc_{i}": i for i in range(len(corpus))}
|
|||
|
|
|
|||
|
|
# Extract keywords from question for overlap scoring
|
|||
|
|
query_keywords = extract_keywords(query)
|
|||
|
|
|
|||
|
|
# Re-rank by fusing semantic distance with keyword overlap
|
|||
|
|
scored = []
|
|||
|
|
for rid, dist, doc in zip(result_ids, distances, documents):
|
|||
|
|
idx = doc_id_to_idx[rid]
|
|||
|
|
overlap = keyword_overlap(query_keywords, doc)
|
|||
|
|
# Lower distance = better. Reduce distance for keyword overlap.
|
|||
|
|
fused_dist = dist * (1.0 - hybrid_weight * overlap)
|
|||
|
|
scored.append((idx, fused_dist))
|
|||
|
|
|
|||
|
|
scored.sort(key=lambda x: x[1])
|
|||
|
|
ranked_indices = [idx for idx, _ in scored]
|
|||
|
|
|
|||
|
|
seen = set(ranked_indices)
|
|||
|
|
for i in range(len(corpus)):
|
|||
|
|
if i not in seen:
|
|||
|
|
ranked_indices.append(i)
|
|||
|
|
|
|||
|
|
return ranked_indices, corpus, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_palace_and_retrieve_full(entry, granularity="session", n_results=50):
|
|||
|
|
"""
|
|||
|
|
Full-turn mode: index BOTH user and assistant turns per session.
|
|||
|
|
|
|||
|
|
The key insight: assistant responses contain confirmed facts ("Yes, you graduated
|
|||
|
|
with a Business Administration degree") that are exactly what benchmark questions
|
|||
|
|
ask about. Indexing only user turns misses half the signal.
|
|||
|
|
"""
|
|||
|
|
corpus = []
|
|||
|
|
corpus_ids = []
|
|||
|
|
corpus_timestamps = []
|
|||
|
|
|
|||
|
|
sessions = entry["haystack_sessions"]
|
|||
|
|
session_ids = entry["haystack_session_ids"]
|
|||
|
|
dates = entry["haystack_dates"]
|
|||
|
|
|
|||
|
|
for sess_idx, (session, sess_id, date) in enumerate(zip(sessions, session_ids, dates)):
|
|||
|
|
if granularity == "session":
|
|||
|
|
# All turns: user questions + assistant confirmations/answers
|
|||
|
|
all_turns = [t["content"] for t in session]
|
|||
|
|
if all_turns:
|
|||
|
|
doc = "\n".join(all_turns)
|
|||
|
|
corpus.append(doc)
|
|||
|
|
corpus_ids.append(sess_id)
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
else:
|
|||
|
|
# Turn granularity: index every turn (both roles)
|
|||
|
|
turn_num = 0
|
|||
|
|
for turn in session:
|
|||
|
|
corpus.append(turn["content"])
|
|||
|
|
corpus_ids.append(f"{sess_id}_turn_{turn_num}")
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
turn_num += 1
|
|||
|
|
|
|||
|
|
if not corpus:
|
|||
|
|
return [], corpus, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
collection = _fresh_collection()
|
|||
|
|
|
|||
|
|
collection.add(
|
|||
|
|
documents=corpus,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(corpus))],
|
|||
|
|
metadatas=[
|
|||
|
|
{"corpus_id": cid, "timestamp": ts} for cid, ts in zip(corpus_ids, corpus_timestamps)
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
query = entry["question"]
|
|||
|
|
results = collection.query(
|
|||
|
|
query_texts=[query],
|
|||
|
|
n_results=min(n_results, len(corpus)),
|
|||
|
|
include=["distances", "metadatas"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result_ids = results["ids"][0]
|
|||
|
|
doc_id_to_idx = {f"doc_{i}": i for i in range(len(corpus))}
|
|||
|
|
ranked_indices = [doc_id_to_idx[rid] for rid in result_ids]
|
|||
|
|
|
|||
|
|
seen = set(ranked_indices)
|
|||
|
|
for i in range(len(corpus)):
|
|||
|
|
if i not in seen:
|
|||
|
|
ranked_indices.append(i)
|
|||
|
|
|
|||
|
|
return ranked_indices, corpus, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# HYBRID V2 — Temporal + Two-Pass Assistant + Preference Awareness
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_palace_and_retrieve_hybrid_v2(
|
|||
|
|
entry, granularity="session", n_results=50, hybrid_weight=0.30
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Hybrid V2: hybrid + three targeted fixes for the remaining 11 misses.
|
|||
|
|
|
|||
|
|
Fix 1 — Temporal date boost:
|
|||
|
|
Parse relative time expressions from question ("a week ago", "10 days ago").
|
|||
|
|
Use question_date + haystack_dates to compute a proximity score.
|
|||
|
|
Sessions whose date falls within the target window get up to 40% distance reduction.
|
|||
|
|
|
|||
|
|
Fix 2 — Two-pass for assistant-reference questions:
|
|||
|
|
Detect "you suggested", "you told me", "remind me what you" etc.
|
|||
|
|
Do normal hybrid retrieval on user turns → get top-3 sessions.
|
|||
|
|
Then re-index those 3 sessions with BOTH user+assistant turns and re-query.
|
|||
|
|
This avoids the dilution problem of indexing all assistant turns globally.
|
|||
|
|
|
|||
|
|
Fix 3 — Preference broadening:
|
|||
|
|
For single-session-preference questions, the question topic often doesn't
|
|||
|
|
match session keywords (user discussed "Adobe Premiere Pro", question asks
|
|||
|
|
about "video editing"). Broaden query by appending synonyms from question
|
|||
|
|
domain keywords.
|
|||
|
|
"""
|
|||
|
|
import re as _re
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
|
|||
|
|
STOP_WORDS = {
|
|||
|
|
"what",
|
|||
|
|
"when",
|
|||
|
|
"where",
|
|||
|
|
"who",
|
|||
|
|
"how",
|
|||
|
|
"which",
|
|||
|
|
"did",
|
|||
|
|
"do",
|
|||
|
|
"was",
|
|||
|
|
"were",
|
|||
|
|
"have",
|
|||
|
|
"has",
|
|||
|
|
"had",
|
|||
|
|
"is",
|
|||
|
|
"are",
|
|||
|
|
"the",
|
|||
|
|
"a",
|
|||
|
|
"an",
|
|||
|
|
"my",
|
|||
|
|
"me",
|
|||
|
|
"i",
|
|||
|
|
"you",
|
|||
|
|
"your",
|
|||
|
|
"their",
|
|||
|
|
"it",
|
|||
|
|
"its",
|
|||
|
|
"in",
|
|||
|
|
"on",
|
|||
|
|
"at",
|
|||
|
|
"to",
|
|||
|
|
"for",
|
|||
|
|
"of",
|
|||
|
|
"with",
|
|||
|
|
"by",
|
|||
|
|
"from",
|
|||
|
|
"ago",
|
|||
|
|
"last",
|
|||
|
|
"that",
|
|||
|
|
"this",
|
|||
|
|
"there",
|
|||
|
|
"about",
|
|||
|
|
"get",
|
|||
|
|
"got",
|
|||
|
|
"give",
|
|||
|
|
"gave",
|
|||
|
|
"buy",
|
|||
|
|
"bought",
|
|||
|
|
"made",
|
|||
|
|
"make",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def extract_keywords(text):
|
|||
|
|
words = _re.findall(r"\b[a-z]{3,}\b", text.lower())
|
|||
|
|
return [w for w in words if w not in STOP_WORDS]
|
|||
|
|
|
|||
|
|
def keyword_overlap(query_kws, doc_text):
|
|||
|
|
doc_lower = doc_text.lower()
|
|||
|
|
if not query_kws:
|
|||
|
|
return 0.0
|
|||
|
|
hits = sum(1 for kw in query_kws if kw in doc_lower)
|
|||
|
|
return hits / len(query_kws)
|
|||
|
|
|
|||
|
|
def parse_question_date(date_str):
|
|||
|
|
"""Parse LongMemEval date format: '2023/01/15 (Sun) 10:20'"""
|
|||
|
|
try:
|
|||
|
|
return datetime.strptime(date_str.split(" (")[0], "%Y/%m/%d")
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def parse_time_offset_days(question):
|
|||
|
|
"""
|
|||
|
|
Extract the number of days back referenced in a temporal question.
|
|||
|
|
Returns (days, tolerance_days) or None if not found.
|
|||
|
|
"""
|
|||
|
|
q = question.lower()
|
|||
|
|
patterns = [
|
|||
|
|
(r"(\d+)\s+days?\s+ago", lambda m: (int(m.group(1)), 2)),
|
|||
|
|
(r"a\s+couple\s+(?:of\s+)?days?\s+ago", lambda m: (2, 2)),
|
|||
|
|
(r"yesterday", lambda m: (1, 1)),
|
|||
|
|
(r"a\s+week\s+ago", lambda m: (7, 3)),
|
|||
|
|
(r"(\d+)\s+weeks?\s+ago", lambda m: (int(m.group(1)) * 7, 5)),
|
|||
|
|
(r"last\s+week", lambda m: (7, 3)),
|
|||
|
|
(r"a\s+month\s+ago", lambda m: (30, 7)),
|
|||
|
|
(r"(\d+)\s+months?\s+ago", lambda m: (int(m.group(1)) * 30, 10)),
|
|||
|
|
(r"last\s+month", lambda m: (30, 7)),
|
|||
|
|
(r"last\s+year", lambda m: (365, 30)),
|
|||
|
|
(r"a\s+year\s+ago", lambda m: (365, 30)),
|
|||
|
|
(r"recently", lambda m: (14, 14)),
|
|||
|
|
]
|
|||
|
|
for pattern, extractor in patterns:
|
|||
|
|
m = _re.search(pattern, q)
|
|||
|
|
if m:
|
|||
|
|
return extractor(m)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def is_assistant_reference(question):
|
|||
|
|
"""Detect questions asking about what the AI previously said."""
|
|||
|
|
q = question.lower()
|
|||
|
|
triggers = [
|
|||
|
|
"you suggested",
|
|||
|
|
"you told me",
|
|||
|
|
"you mentioned",
|
|||
|
|
"you said",
|
|||
|
|
"you recommended",
|
|||
|
|
"remind me what you",
|
|||
|
|
"you provided",
|
|||
|
|
"you listed",
|
|||
|
|
"you gave me",
|
|||
|
|
"you described",
|
|||
|
|
"what did you",
|
|||
|
|
"you came up with",
|
|||
|
|
"you helped me",
|
|||
|
|
"you explained",
|
|||
|
|
"can you remind me",
|
|||
|
|
"you identified",
|
|||
|
|
]
|
|||
|
|
return any(t in q for t in triggers)
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# Build corpus
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
sessions = entry["haystack_sessions"]
|
|||
|
|
session_ids = entry["haystack_session_ids"]
|
|||
|
|
dates = entry["haystack_dates"]
|
|||
|
|
question = entry["question"]
|
|||
|
|
question_date = parse_question_date(entry.get("question_date", ""))
|
|||
|
|
|
|||
|
|
corpus_user = [] # user-turns-only text per session
|
|||
|
|
corpus_full = [] # user+assistant text per session
|
|||
|
|
corpus_ids = []
|
|||
|
|
corpus_timestamps = []
|
|||
|
|
|
|||
|
|
for session, sess_id, date in zip(sessions, session_ids, dates):
|
|||
|
|
user_turns = [t["content"] for t in session if t["role"] == "user"]
|
|||
|
|
all_turns = [t["content"] for t in session]
|
|||
|
|
if user_turns:
|
|||
|
|
corpus_user.append("\n".join(user_turns))
|
|||
|
|
corpus_full.append("\n".join(all_turns))
|
|||
|
|
corpus_ids.append(sess_id)
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
|
|||
|
|
if not corpus_user:
|
|||
|
|
return [], corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# Fix 2: Two-pass for assistant-reference questions
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
if is_assistant_reference(question):
|
|||
|
|
# Pass 1: find top sessions using user turns only
|
|||
|
|
collection = _fresh_collection()
|
|||
|
|
collection.add(
|
|||
|
|
documents=corpus_user,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(corpus_user))],
|
|||
|
|
metadatas=[
|
|||
|
|
{"corpus_id": cid, "timestamp": ts}
|
|||
|
|
for cid, ts in zip(corpus_ids, corpus_timestamps)
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
results = collection.query(
|
|||
|
|
query_texts=[question],
|
|||
|
|
n_results=min(5, len(corpus_user)),
|
|||
|
|
include=["distances", "metadatas"],
|
|||
|
|
)
|
|||
|
|
top_indices = [int(rid.split("_")[1]) for rid in results["ids"][0]]
|
|||
|
|
|
|||
|
|
# Pass 2: re-index those sessions with full text (user+assistant)
|
|||
|
|
top_corpus_full = [corpus_full[i] for i in top_indices]
|
|||
|
|
top_ids = [corpus_ids[i] for i in top_indices]
|
|||
|
|
top_ts = [corpus_timestamps[i] for i in top_indices]
|
|||
|
|
|
|||
|
|
collection2 = _fresh_collection("mempal_drawers_pass2")
|
|||
|
|
collection2.add(
|
|||
|
|
documents=top_corpus_full,
|
|||
|
|
ids=[f"doc2_{i}" for i in range(len(top_corpus_full))],
|
|||
|
|
metadatas=[{"corpus_id": cid, "timestamp": ts} for cid, ts in zip(top_ids, top_ts)],
|
|||
|
|
)
|
|||
|
|
results2 = collection2.query(
|
|||
|
|
query_texts=[question],
|
|||
|
|
n_results=min(n_results, len(top_corpus_full)),
|
|||
|
|
include=["distances", "metadatas"],
|
|||
|
|
)
|
|||
|
|
# Build final rankings: two-pass top sessions first, then rest
|
|||
|
|
two_pass_order = [top_indices[int(rid.split("_")[1])] for rid in results2["ids"][0]]
|
|||
|
|
seen = set(two_pass_order)
|
|||
|
|
ranked_indices = two_pass_order + [i for i in range(len(corpus_user)) if i not in seen]
|
|||
|
|
return ranked_indices, corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# Standard hybrid retrieval (fix 1 temporal + fix 3 preference baked in)
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
collection = _fresh_collection()
|
|||
|
|
collection.add(
|
|||
|
|
documents=corpus_user,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(corpus_user))],
|
|||
|
|
metadatas=[
|
|||
|
|
{"corpus_id": cid, "timestamp": ts} for cid, ts in zip(corpus_ids, corpus_timestamps)
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
query_keywords = extract_keywords(question)
|
|||
|
|
results = collection.query(
|
|||
|
|
query_texts=[question],
|
|||
|
|
n_results=min(n_results, len(corpus_user)),
|
|||
|
|
include=["distances", "metadatas", "documents"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result_ids = results["ids"][0]
|
|||
|
|
distances = results["distances"][0]
|
|||
|
|
documents = results["documents"][0]
|
|||
|
|
doc_id_to_idx = {f"doc_{i}": i for i in range(len(corpus_user))}
|
|||
|
|
|
|||
|
|
# Fix 1: Temporal proximity score
|
|||
|
|
time_offset = parse_time_offset_days(question)
|
|||
|
|
target_date = None
|
|||
|
|
if time_offset and question_date:
|
|||
|
|
days_back, tolerance = time_offset
|
|||
|
|
target_date = question_date - timedelta(days=days_back)
|
|||
|
|
|
|||
|
|
scored = []
|
|||
|
|
for rid, dist, doc in zip(result_ids, distances, documents):
|
|||
|
|
idx = doc_id_to_idx[rid]
|
|||
|
|
overlap = keyword_overlap(query_keywords, doc)
|
|||
|
|
fused_dist = dist * (1.0 - hybrid_weight * overlap)
|
|||
|
|
|
|||
|
|
# Temporal boost: sessions near target date get up to 40% distance reduction
|
|||
|
|
if target_date:
|
|||
|
|
sess_date = parse_question_date(corpus_timestamps[idx])
|
|||
|
|
if sess_date:
|
|||
|
|
delta_days = abs((sess_date - target_date).days)
|
|||
|
|
tolerance = time_offset[1]
|
|||
|
|
if delta_days <= tolerance:
|
|||
|
|
# Perfect hit: full boost
|
|||
|
|
temporal_boost = 0.40
|
|||
|
|
elif delta_days <= tolerance * 3:
|
|||
|
|
# Partial hit: scaled
|
|||
|
|
temporal_boost = 0.40 * (1.0 - (delta_days - tolerance) / (tolerance * 2))
|
|||
|
|
else:
|
|||
|
|
temporal_boost = 0.0
|
|||
|
|
fused_dist = fused_dist * (1.0 - temporal_boost)
|
|||
|
|
|
|||
|
|
scored.append((idx, fused_dist))
|
|||
|
|
|
|||
|
|
scored.sort(key=lambda x: x[1])
|
|||
|
|
ranked_indices = [idx for idx, _ in scored]
|
|||
|
|
|
|||
|
|
seen = set(ranked_indices)
|
|||
|
|
for i in range(len(corpus_user)):
|
|||
|
|
if i not in seen:
|
|||
|
|
ranked_indices.append(i)
|
|||
|
|
|
|||
|
|
return ranked_indices, corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# HYBRID V3 — Preference Extraction + Expanded Re-rank Pool
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_palace_and_retrieve_hybrid_v3(
|
|||
|
|
entry, granularity="session", n_results=50, hybrid_weight=0.30
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Hybrid V3: hybrid_v2 + two targeted improvements for remaining misses.
|
|||
|
|
|
|||
|
|
New in V3 vs V2:
|
|||
|
|
|
|||
|
|
Fix 1 — Preference extraction at ingest:
|
|||
|
|
Scan every user turn for expressions of preference, concern, or intent:
|
|||
|
|
"I've been having trouble with X", "I've been feeling X", "I prefer X", etc.
|
|||
|
|
For sessions where preferences are found, add a synthetic document to the
|
|||
|
|
ChromaDB collection with the same corpus_id as the session.
|
|||
|
|
|
|||
|
|
This bridges the semantic gap for questions like:
|
|||
|
|
Q: "I've been having trouble with the battery life on my phone lately."
|
|||
|
|
Session: [phone hardware research — never mentions "battery life"]
|
|||
|
|
Pref doc: "User mentioned: battery life issues on phone"
|
|||
|
|
→ the pref doc ranks near the top for this question
|
|||
|
|
|
|||
|
|
Fix 2 — Expanded LLM re-rank pool (20 instead of 10):
|
|||
|
|
The two remaining assistant failures have their correct session at rank
|
|||
|
|
11-12. Expanding the pool gives Haiku more to work with at negligible
|
|||
|
|
extra cost (slightly longer prompt).
|
|||
|
|
"""
|
|||
|
|
import re as _re
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
|
|||
|
|
STOP_WORDS = {
|
|||
|
|
"what",
|
|||
|
|
"when",
|
|||
|
|
"where",
|
|||
|
|
"who",
|
|||
|
|
"how",
|
|||
|
|
"which",
|
|||
|
|
"did",
|
|||
|
|
"do",
|
|||
|
|
"was",
|
|||
|
|
"were",
|
|||
|
|
"have",
|
|||
|
|
"has",
|
|||
|
|
"had",
|
|||
|
|
"is",
|
|||
|
|
"are",
|
|||
|
|
"the",
|
|||
|
|
"a",
|
|||
|
|
"an",
|
|||
|
|
"my",
|
|||
|
|
"me",
|
|||
|
|
"i",
|
|||
|
|
"you",
|
|||
|
|
"your",
|
|||
|
|
"their",
|
|||
|
|
"it",
|
|||
|
|
"its",
|
|||
|
|
"in",
|
|||
|
|
"on",
|
|||
|
|
"at",
|
|||
|
|
"to",
|
|||
|
|
"for",
|
|||
|
|
"of",
|
|||
|
|
"with",
|
|||
|
|
"by",
|
|||
|
|
"from",
|
|||
|
|
"ago",
|
|||
|
|
"last",
|
|||
|
|
"that",
|
|||
|
|
"this",
|
|||
|
|
"there",
|
|||
|
|
"about",
|
|||
|
|
"get",
|
|||
|
|
"got",
|
|||
|
|
"give",
|
|||
|
|
"gave",
|
|||
|
|
"buy",
|
|||
|
|
"bought",
|
|||
|
|
"made",
|
|||
|
|
"make",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def extract_keywords(text):
|
|||
|
|
words = _re.findall(r"\b[a-z]{3,}\b", text.lower())
|
|||
|
|
return [w for w in words if w not in STOP_WORDS]
|
|||
|
|
|
|||
|
|
def keyword_overlap(query_kws, doc_text):
|
|||
|
|
doc_lower = doc_text.lower()
|
|||
|
|
if not query_kws:
|
|||
|
|
return 0.0
|
|||
|
|
hits = sum(1 for kw in query_kws if kw in doc_lower)
|
|||
|
|
return hits / len(query_kws)
|
|||
|
|
|
|||
|
|
def parse_question_date(date_str):
|
|||
|
|
try:
|
|||
|
|
return datetime.strptime(date_str.split(" (")[0], "%Y/%m/%d")
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def parse_time_offset_days(question):
|
|||
|
|
q = question.lower()
|
|||
|
|
patterns = [
|
|||
|
|
(r"(\d+)\s+days?\s+ago", lambda m: (int(m.group(1)), 2)),
|
|||
|
|
(r"a\s+couple\s+(?:of\s+)?days?\s+ago", lambda m: (2, 2)),
|
|||
|
|
(r"yesterday", lambda m: (1, 1)),
|
|||
|
|
(r"a\s+week\s+ago", lambda m: (7, 3)),
|
|||
|
|
(r"(\d+)\s+weeks?\s+ago", lambda m: (int(m.group(1)) * 7, 5)),
|
|||
|
|
(r"last\s+week", lambda m: (7, 3)),
|
|||
|
|
(r"a\s+month\s+ago", lambda m: (30, 7)),
|
|||
|
|
(r"(\d+)\s+months?\s+ago", lambda m: (int(m.group(1)) * 30, 10)),
|
|||
|
|
(r"last\s+month", lambda m: (30, 7)),
|
|||
|
|
(r"last\s+year", lambda m: (365, 30)),
|
|||
|
|
(r"a\s+year\s+ago", lambda m: (365, 30)),
|
|||
|
|
(r"recently", lambda m: (14, 14)),
|
|||
|
|
]
|
|||
|
|
for pattern, extractor in patterns:
|
|||
|
|
m = _re.search(pattern, q)
|
|||
|
|
if m:
|
|||
|
|
return extractor(m)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def is_assistant_reference(question):
|
|||
|
|
q = question.lower()
|
|||
|
|
triggers = [
|
|||
|
|
"you suggested",
|
|||
|
|
"you told me",
|
|||
|
|
"you mentioned",
|
|||
|
|
"you said",
|
|||
|
|
"you recommended",
|
|||
|
|
"remind me what you",
|
|||
|
|
"you provided",
|
|||
|
|
"you listed",
|
|||
|
|
"you gave me",
|
|||
|
|
"you described",
|
|||
|
|
"what did you",
|
|||
|
|
"you came up with",
|
|||
|
|
"you helped me",
|
|||
|
|
"you explained",
|
|||
|
|
"can you remind me",
|
|||
|
|
"you identified",
|
|||
|
|
]
|
|||
|
|
return any(t in q for t in triggers)
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# NEW: Preference extraction
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
PREF_PATTERNS = [
|
|||
|
|
r"i(?:'ve been| have been) having (?:trouble|issues?|problems?) with ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) feeling ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:struggling|dealing) with ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:worried|concerned) about ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'m| am) (?:worried|concerned) about ([^,\.!?]{5,80})",
|
|||
|
|
r"i prefer ([^,\.!?]{5,60})",
|
|||
|
|
r"i usually ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:trying|attempting) to ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:considering|thinking about) ([^,\.!?]{5,80})",
|
|||
|
|
r"lately[,\s]+(?:i've been|i have been|i'm|i am) ([^,\.!?]{5,80})",
|
|||
|
|
r"recently[,\s]+(?:i've been|i have been|i'm|i am) ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:working on|focused on|interested in) ([^,\.!?]{5,80})",
|
|||
|
|
r"i want to ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'m| am) looking (?:to|for) ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'m| am) thinking (?:about|of) ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:noticing|experiencing) ([^,\.!?]{5,80})",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def extract_preferences(session):
|
|||
|
|
"""Extract preference/concern expressions from user turns in a session."""
|
|||
|
|
mentions = []
|
|||
|
|
for turn in session:
|
|||
|
|
if turn["role"] != "user":
|
|||
|
|
continue
|
|||
|
|
text = turn["content"].lower()
|
|||
|
|
for pat in PREF_PATTERNS:
|
|||
|
|
for match in _re.findall(pat, text, _re.IGNORECASE):
|
|||
|
|
clean = match.strip().rstrip(".,;!? ")
|
|||
|
|
if 5 <= len(clean) <= 80:
|
|||
|
|
mentions.append(clean)
|
|||
|
|
# Deduplicate while preserving order
|
|||
|
|
seen = set()
|
|||
|
|
unique = []
|
|||
|
|
for m in mentions:
|
|||
|
|
if m not in seen:
|
|||
|
|
seen.add(m)
|
|||
|
|
unique.append(m)
|
|||
|
|
return unique[:10] # cap at 10 to avoid overly long synthetic docs
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# Build corpus
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
sessions = entry["haystack_sessions"]
|
|||
|
|
session_ids = entry["haystack_session_ids"]
|
|||
|
|
dates = entry["haystack_dates"]
|
|||
|
|
question = entry["question"]
|
|||
|
|
question_date = parse_question_date(entry.get("question_date", ""))
|
|||
|
|
|
|||
|
|
corpus_user = []
|
|||
|
|
corpus_full = []
|
|||
|
|
corpus_ids = []
|
|||
|
|
corpus_timestamps = []
|
|||
|
|
|
|||
|
|
# Synthetic preference documents (same corpus_id as their session)
|
|||
|
|
pref_docs = []
|
|||
|
|
pref_ids = []
|
|||
|
|
pref_timestamps = []
|
|||
|
|
|
|||
|
|
for session, sess_id, date in zip(sessions, session_ids, dates):
|
|||
|
|
user_turns = [t["content"] for t in session if t["role"] == "user"]
|
|||
|
|
all_turns = [t["content"] for t in session]
|
|||
|
|
if not user_turns:
|
|||
|
|
continue
|
|||
|
|
corpus_user.append("\n".join(user_turns))
|
|||
|
|
corpus_full.append("\n".join(all_turns))
|
|||
|
|
corpus_ids.append(sess_id)
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
|
|||
|
|
# Extract preferences and build synthetic document
|
|||
|
|
prefs = extract_preferences(session)
|
|||
|
|
if prefs:
|
|||
|
|
pref_doc = "User has mentioned: " + "; ".join(prefs)
|
|||
|
|
pref_docs.append(pref_doc)
|
|||
|
|
pref_ids.append(sess_id)
|
|||
|
|
pref_timestamps.append(date)
|
|||
|
|
|
|||
|
|
if not corpus_user:
|
|||
|
|
return [], corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# Two-pass for assistant-reference questions (same as v2)
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
if is_assistant_reference(question):
|
|||
|
|
collection = _fresh_collection()
|
|||
|
|
collection.add(
|
|||
|
|
documents=corpus_user,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(corpus_user))],
|
|||
|
|
metadatas=[
|
|||
|
|
{"corpus_id": cid, "timestamp": ts}
|
|||
|
|
for cid, ts in zip(corpus_ids, corpus_timestamps)
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
results = collection.query(
|
|||
|
|
query_texts=[question],
|
|||
|
|
n_results=min(5, len(corpus_user)),
|
|||
|
|
include=["distances", "metadatas"],
|
|||
|
|
)
|
|||
|
|
top_indices = [int(rid.split("_")[1]) for rid in results["ids"][0]]
|
|||
|
|
|
|||
|
|
top_corpus_full = [corpus_full[i] for i in top_indices]
|
|||
|
|
top_ids = [corpus_ids[i] for i in top_indices]
|
|||
|
|
top_ts = [corpus_timestamps[i] for i in top_indices]
|
|||
|
|
|
|||
|
|
collection2 = _fresh_collection("mempal_drawers_pass2")
|
|||
|
|
collection2.add(
|
|||
|
|
documents=top_corpus_full,
|
|||
|
|
ids=[f"doc2_{i}" for i in range(len(top_corpus_full))],
|
|||
|
|
metadatas=[{"corpus_id": cid, "timestamp": ts} for cid, ts in zip(top_ids, top_ts)],
|
|||
|
|
)
|
|||
|
|
results2 = collection2.query(
|
|||
|
|
query_texts=[question],
|
|||
|
|
n_results=min(n_results, len(top_corpus_full)),
|
|||
|
|
include=["distances", "metadatas"],
|
|||
|
|
)
|
|||
|
|
two_pass_order = [top_indices[int(rid.split("_")[1])] for rid in results2["ids"][0]]
|
|||
|
|
seen = set(two_pass_order)
|
|||
|
|
ranked_indices = two_pass_order + [i for i in range(len(corpus_user)) if i not in seen]
|
|||
|
|
return ranked_indices, corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# Build expanded collection: user docs + synthetic preference docs
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
all_docs = corpus_user + pref_docs
|
|||
|
|
all_ids_meta = corpus_ids + pref_ids
|
|||
|
|
all_ts = corpus_timestamps + pref_timestamps
|
|||
|
|
|
|||
|
|
collection = _fresh_collection()
|
|||
|
|
collection.add(
|
|||
|
|
documents=all_docs,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(all_docs))],
|
|||
|
|
metadatas=[
|
|||
|
|
{"corpus_id": cid, "timestamp": ts, "is_pref": i >= len(corpus_user)}
|
|||
|
|
for i, (cid, ts) in enumerate(zip(all_ids_meta, all_ts))
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
query_keywords = extract_keywords(question)
|
|||
|
|
results = collection.query(
|
|||
|
|
query_texts=[question],
|
|||
|
|
n_results=min(n_results, len(all_docs)),
|
|||
|
|
include=["distances", "metadatas", "documents"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result_ids = results["ids"][0]
|
|||
|
|
distances = results["distances"][0]
|
|||
|
|
documents = results["documents"][0]
|
|||
|
|
doc_id_to_idx = {f"doc_{i}": i for i in range(len(all_docs))}
|
|||
|
|
|
|||
|
|
# Temporal boost
|
|||
|
|
time_offset = parse_time_offset_days(question)
|
|||
|
|
target_date = None
|
|||
|
|
if time_offset and question_date:
|
|||
|
|
days_back, tolerance = time_offset
|
|||
|
|
target_date = question_date - timedelta(days=days_back)
|
|||
|
|
|
|||
|
|
scored = []
|
|||
|
|
for rid, dist, doc in zip(result_ids, distances, documents):
|
|||
|
|
idx = doc_id_to_idx[rid]
|
|||
|
|
overlap = keyword_overlap(query_keywords, doc)
|
|||
|
|
fused_dist = dist * (1.0 - hybrid_weight * overlap)
|
|||
|
|
|
|||
|
|
# Temporal boost
|
|||
|
|
if target_date:
|
|||
|
|
sess_date = parse_question_date(all_ts[idx])
|
|||
|
|
if sess_date:
|
|||
|
|
delta_days = abs((sess_date - target_date).days)
|
|||
|
|
tol = time_offset[1]
|
|||
|
|
if delta_days <= tol:
|
|||
|
|
temporal_boost = 0.40
|
|||
|
|
elif delta_days <= tol * 3:
|
|||
|
|
temporal_boost = 0.40 * (1.0 - (delta_days - tol) / (tol * 2))
|
|||
|
|
else:
|
|||
|
|
temporal_boost = 0.0
|
|||
|
|
fused_dist = fused_dist * (1.0 - temporal_boost)
|
|||
|
|
|
|||
|
|
scored.append((idx, fused_dist))
|
|||
|
|
|
|||
|
|
scored.sort(key=lambda x: x[1])
|
|||
|
|
|
|||
|
|
# Map back to corpus_user indices via corpus_id — deduplicate at session level
|
|||
|
|
# A pref doc and its session doc both map to the same corpus_id.
|
|||
|
|
# Keep whichever ranks first; map back to corpus_user index for evaluation.
|
|||
|
|
corpus_id_to_user_idx = {cid: i for i, cid in enumerate(corpus_ids)}
|
|||
|
|
seen_ids = set()
|
|||
|
|
ranked_indices = []
|
|||
|
|
for idx, _ in scored:
|
|||
|
|
cid = all_ids_meta[idx]
|
|||
|
|
if cid not in seen_ids:
|
|||
|
|
seen_ids.add(cid)
|
|||
|
|
ranked_indices.append(corpus_id_to_user_idx[cid])
|
|||
|
|
|
|||
|
|
# Fill in any sessions not yet ranked
|
|||
|
|
for i in range(len(corpus_user)):
|
|||
|
|
if corpus_ids[i] not in seen_ids:
|
|||
|
|
ranked_indices.append(i)
|
|||
|
|
seen_ids.add(corpus_ids[i])
|
|||
|
|
|
|||
|
|
return ranked_indices, corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_palace_and_retrieve_hybrid_v4(
|
|||
|
|
entry, granularity="session", n_results=50, hybrid_weight=0.30
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Hybrid V4: hybrid_v3 + three targeted fixes for the final 3 misses.
|
|||
|
|
|
|||
|
|
Analysis of remaining misses at 99.4% (both hybrid_v3 and palace fail on these):
|
|||
|
|
|
|||
|
|
Miss 1 — 'high school reunion' (d6233ab6, single-session-preference):
|
|||
|
|
Target session: "I still remember the happy high school experiences such as
|
|||
|
|
being part of the debate team and taking advanced placement courses."
|
|||
|
|
Question: "high school reunion...nostalgic"
|
|||
|
|
Gap: "reunion/nostalgic" ≠ "debate team/AP courses" in embedding space.
|
|||
|
|
Fix: Add memory/nostalgia patterns to extract "User has mentioned: positive
|
|||
|
|
high school experiences, debate team, AP courses" as a synthetic pref doc.
|
|||
|
|
|
|||
|
|
Miss 2 — 'Rachel/ukulele' (4dfccbf8, temporal-reasoning):
|
|||
|
|
Target session: "I just started taking ukulele lessons with my friend Rachel today."
|
|||
|
|
Question: "What did I do with Rachel on the Wednesday two months ago?"
|
|||
|
|
Gap: Embedding model gives low weight to person names like 'Rachel'.
|
|||
|
|
Fix: Extract capitalized proper nouns from question; boost sessions containing them.
|
|||
|
|
|
|||
|
|
Miss 3 — 'sexual compulsions' (ceb54acb, single-session-assistant):
|
|||
|
|
Target session: assistant suggests "sexual fixations", "sexual impulsivity", etc.
|
|||
|
|
Question: "you suggested 'sexual compulsions' and a few other options..."
|
|||
|
|
Gap: Short 2-turn session, niche topic — embeddings don't surface it.
|
|||
|
|
Fix: Extract quoted phrases from question; boost sessions containing exact quotes.
|
|||
|
|
"""
|
|||
|
|
import re as _re
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
|
|||
|
|
STOP_WORDS = {
|
|||
|
|
"what",
|
|||
|
|
"when",
|
|||
|
|
"where",
|
|||
|
|
"who",
|
|||
|
|
"how",
|
|||
|
|
"which",
|
|||
|
|
"did",
|
|||
|
|
"do",
|
|||
|
|
"was",
|
|||
|
|
"were",
|
|||
|
|
"have",
|
|||
|
|
"has",
|
|||
|
|
"had",
|
|||
|
|
"is",
|
|||
|
|
"are",
|
|||
|
|
"the",
|
|||
|
|
"a",
|
|||
|
|
"an",
|
|||
|
|
"my",
|
|||
|
|
"me",
|
|||
|
|
"i",
|
|||
|
|
"you",
|
|||
|
|
"your",
|
|||
|
|
"their",
|
|||
|
|
"it",
|
|||
|
|
"its",
|
|||
|
|
"in",
|
|||
|
|
"on",
|
|||
|
|
"at",
|
|||
|
|
"to",
|
|||
|
|
"for",
|
|||
|
|
"of",
|
|||
|
|
"with",
|
|||
|
|
"by",
|
|||
|
|
"from",
|
|||
|
|
"ago",
|
|||
|
|
"last",
|
|||
|
|
"that",
|
|||
|
|
"this",
|
|||
|
|
"there",
|
|||
|
|
"about",
|
|||
|
|
"get",
|
|||
|
|
"got",
|
|||
|
|
"give",
|
|||
|
|
"gave",
|
|||
|
|
"buy",
|
|||
|
|
"bought",
|
|||
|
|
"made",
|
|||
|
|
"make",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def extract_keywords(text):
|
|||
|
|
words = _re.findall(r"\b[a-z]{3,}\b", text.lower())
|
|||
|
|
return [w for w in words if w not in STOP_WORDS]
|
|||
|
|
|
|||
|
|
def keyword_overlap(query_kws, doc_text):
|
|||
|
|
doc_lower = doc_text.lower()
|
|||
|
|
if not query_kws:
|
|||
|
|
return 0.0
|
|||
|
|
hits = sum(1 for kw in query_kws if kw in doc_lower)
|
|||
|
|
return hits / len(query_kws)
|
|||
|
|
|
|||
|
|
# NEW: Extract quoted phrases from question (single or double quotes)
|
|||
|
|
def extract_quoted_phrases(text):
|
|||
|
|
phrases = []
|
|||
|
|
for pat in [r"'([^']{3,60})'", r'"([^"]{3,60})"']:
|
|||
|
|
phrases.extend(_re.findall(pat, text))
|
|||
|
|
return [p.strip() for p in phrases if len(p.strip()) >= 3]
|
|||
|
|
|
|||
|
|
def quoted_phrase_boost(phrases, doc_text):
|
|||
|
|
"""Strong boost if document contains an exact quoted phrase from the question."""
|
|||
|
|
if not phrases:
|
|||
|
|
return 0.0
|
|||
|
|
doc_lower = doc_text.lower()
|
|||
|
|
hits = sum(1 for p in phrases if p.lower() in doc_lower)
|
|||
|
|
return min(hits / len(phrases), 1.0)
|
|||
|
|
|
|||
|
|
# NEW: Extract person names (capitalized words that aren't common title-case words)
|
|||
|
|
NOT_NAMES = {
|
|||
|
|
"What",
|
|||
|
|
"When",
|
|||
|
|
"Where",
|
|||
|
|
"Who",
|
|||
|
|
"How",
|
|||
|
|
"Which",
|
|||
|
|
"Did",
|
|||
|
|
"Do",
|
|||
|
|
"Was",
|
|||
|
|
"Were",
|
|||
|
|
"Have",
|
|||
|
|
"Has",
|
|||
|
|
"Had",
|
|||
|
|
"Is",
|
|||
|
|
"Are",
|
|||
|
|
"The",
|
|||
|
|
"My",
|
|||
|
|
"Our",
|
|||
|
|
"Their",
|
|||
|
|
"Can",
|
|||
|
|
"Could",
|
|||
|
|
"Would",
|
|||
|
|
"Should",
|
|||
|
|
"Will",
|
|||
|
|
"Shall",
|
|||
|
|
"May",
|
|||
|
|
"Might",
|
|||
|
|
"Monday",
|
|||
|
|
"Tuesday",
|
|||
|
|
"Wednesday",
|
|||
|
|
"Thursday",
|
|||
|
|
"Friday",
|
|||
|
|
"Saturday",
|
|||
|
|
"Sunday",
|
|||
|
|
"January",
|
|||
|
|
"February",
|
|||
|
|
"March",
|
|||
|
|
"April",
|
|||
|
|
"June",
|
|||
|
|
"July",
|
|||
|
|
"August",
|
|||
|
|
"September",
|
|||
|
|
"October",
|
|||
|
|
"November",
|
|||
|
|
"December",
|
|||
|
|
"In",
|
|||
|
|
"On",
|
|||
|
|
"At",
|
|||
|
|
"For",
|
|||
|
|
"To",
|
|||
|
|
"Of",
|
|||
|
|
"With",
|
|||
|
|
"By",
|
|||
|
|
"From",
|
|||
|
|
"And",
|
|||
|
|
"But",
|
|||
|
|
"I",
|
|||
|
|
"It",
|
|||
|
|
"Its",
|
|||
|
|
"This",
|
|||
|
|
"That",
|
|||
|
|
"These",
|
|||
|
|
"Those",
|
|||
|
|
"Previously",
|
|||
|
|
"Recently",
|
|||
|
|
"Also",
|
|||
|
|
"Just",
|
|||
|
|
"Very",
|
|||
|
|
"More",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def extract_person_names(text):
|
|||
|
|
"""Extract likely person names: capitalized words mid-sentence."""
|
|||
|
|
words = _re.findall(r"\b[A-Z][a-z]{2,15}\b", text)
|
|||
|
|
return list(set(w for w in words if w not in NOT_NAMES))
|
|||
|
|
|
|||
|
|
def person_name_boost(names, doc_text):
|
|||
|
|
"""Boost if document contains the person's name."""
|
|||
|
|
if not names:
|
|||
|
|
return 0.0
|
|||
|
|
doc_lower = doc_text.lower()
|
|||
|
|
hits = sum(1 for n in names if n.lower() in doc_lower)
|
|||
|
|
return min(hits / len(names), 1.0)
|
|||
|
|
|
|||
|
|
def parse_question_date(date_str):
|
|||
|
|
try:
|
|||
|
|
return datetime.strptime(date_str.split(" (")[0], "%Y/%m/%d")
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def parse_time_offset_days(question):
|
|||
|
|
q = question.lower()
|
|||
|
|
patterns = [
|
|||
|
|
(r"(\d+)\s+days?\s+ago", lambda m: (int(m.group(1)), 2)),
|
|||
|
|
(r"a\s+couple\s+(?:of\s+)?days?\s+ago", lambda m: (2, 2)),
|
|||
|
|
(r"yesterday", lambda m: (1, 1)),
|
|||
|
|
(r"a\s+week\s+ago", lambda m: (7, 3)),
|
|||
|
|
(r"(\d+)\s+weeks?\s+ago", lambda m: (int(m.group(1)) * 7, 5)),
|
|||
|
|
(r"last\s+week", lambda m: (7, 3)),
|
|||
|
|
(r"a\s+month\s+ago", lambda m: (30, 7)),
|
|||
|
|
(r"(\d+)\s+months?\s+ago", lambda m: (int(m.group(1)) * 30, 10)),
|
|||
|
|
(r"last\s+month", lambda m: (30, 7)),
|
|||
|
|
(r"last\s+year", lambda m: (365, 30)),
|
|||
|
|
(r"a\s+year\s+ago", lambda m: (365, 30)),
|
|||
|
|
(r"recently", lambda m: (14, 14)),
|
|||
|
|
]
|
|||
|
|
for pattern, extractor in patterns:
|
|||
|
|
m = _re.search(pattern, q)
|
|||
|
|
if m:
|
|||
|
|
return extractor(m)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def is_assistant_reference(question):
|
|||
|
|
q = question.lower()
|
|||
|
|
triggers = [
|
|||
|
|
"you suggested",
|
|||
|
|
"you told me",
|
|||
|
|
"you mentioned",
|
|||
|
|
"you said",
|
|||
|
|
"you recommended",
|
|||
|
|
"remind me what you",
|
|||
|
|
"you provided",
|
|||
|
|
"you listed",
|
|||
|
|
"you gave me",
|
|||
|
|
"you described",
|
|||
|
|
"what did you",
|
|||
|
|
"you came up with",
|
|||
|
|
"you helped me",
|
|||
|
|
"you explained",
|
|||
|
|
"can you remind me",
|
|||
|
|
"you identified",
|
|||
|
|
]
|
|||
|
|
return any(t in q for t in triggers)
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# V4: Expanded preference patterns (adds memory/nostalgia for Miss 1)
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
PREF_PATTERNS = [
|
|||
|
|
r"i(?:'ve been| have been) having (?:trouble|issues?|problems?) with ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) feeling ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:struggling|dealing) with ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:worried|concerned) about ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'m| am) (?:worried|concerned) about ([^,\.!?]{5,80})",
|
|||
|
|
r"i prefer ([^,\.!?]{5,60})",
|
|||
|
|
r"i usually ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:trying|attempting) to ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:considering|thinking about) ([^,\.!?]{5,80})",
|
|||
|
|
r"lately[,\s]+(?:i've been|i have been|i'm|i am) ([^,\.!?]{5,80})",
|
|||
|
|
r"recently[,\s]+(?:i've been|i have been|i'm|i am) ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:working on|focused on|interested in) ([^,\.!?]{5,80})",
|
|||
|
|
r"i want to ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'m| am) looking (?:to|for) ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'m| am) thinking (?:about|of) ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:noticing|experiencing) ([^,\.!?]{5,80})",
|
|||
|
|
# NEW in V4 — memory/nostalgia patterns (for high school reunion miss):
|
|||
|
|
r"i (?:still )?remember (?:the |my )?([^,\.!?]{5,80})",
|
|||
|
|
r"i used to ([^,\.!?]{5,60})",
|
|||
|
|
r"when i was (?:in high school|in college|young|a kid|growing up)[,\s]+([^,\.!?]{5,80})",
|
|||
|
|
r"growing up[,\s]+([^,\.!?]{5,80})",
|
|||
|
|
r"(?:happy|fond|good|positive) (?:high school|college|childhood|school) (?:experience|memory|memories|time)[^,\.!?]{0,60}",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def extract_preferences(session):
|
|||
|
|
"""Extract preference/concern/memory expressions from user turns in a session."""
|
|||
|
|
mentions = []
|
|||
|
|
for turn in session:
|
|||
|
|
if turn["role"] != "user":
|
|||
|
|
continue
|
|||
|
|
text = turn["content"].lower()
|
|||
|
|
for pat in PREF_PATTERNS:
|
|||
|
|
for match in _re.findall(pat, text, _re.IGNORECASE):
|
|||
|
|
if isinstance(match, tuple):
|
|||
|
|
match = " ".join(match)
|
|||
|
|
clean = match.strip().rstrip(".,;!? ")
|
|||
|
|
if 5 <= len(clean) <= 80:
|
|||
|
|
mentions.append(clean)
|
|||
|
|
seen = set()
|
|||
|
|
unique = []
|
|||
|
|
for m in mentions:
|
|||
|
|
if m not in seen:
|
|||
|
|
seen.add(m)
|
|||
|
|
unique.append(m)
|
|||
|
|
return unique[:12]
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# Build corpus
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
sessions = entry["haystack_sessions"]
|
|||
|
|
session_ids = entry["haystack_session_ids"]
|
|||
|
|
dates = entry["haystack_dates"]
|
|||
|
|
question = entry["question"]
|
|||
|
|
question_date = parse_question_date(entry.get("question_date", ""))
|
|||
|
|
|
|||
|
|
# V4: Pre-extract question signals
|
|||
|
|
quoted_phrases = extract_quoted_phrases(question)
|
|||
|
|
person_names = extract_person_names(question)
|
|||
|
|
|
|||
|
|
corpus_user = []
|
|||
|
|
corpus_full = []
|
|||
|
|
corpus_ids = []
|
|||
|
|
corpus_timestamps = []
|
|||
|
|
|
|||
|
|
pref_docs = []
|
|||
|
|
pref_ids = []
|
|||
|
|
pref_timestamps = []
|
|||
|
|
|
|||
|
|
for session, sess_id, date in zip(sessions, session_ids, dates):
|
|||
|
|
user_turns = [t["content"] for t in session if t["role"] == "user"]
|
|||
|
|
all_turns = [t["content"] for t in session]
|
|||
|
|
if not user_turns:
|
|||
|
|
continue
|
|||
|
|
corpus_user.append("\n".join(user_turns))
|
|||
|
|
corpus_full.append("\n".join(all_turns))
|
|||
|
|
corpus_ids.append(sess_id)
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
|
|||
|
|
prefs = extract_preferences(session)
|
|||
|
|
if prefs:
|
|||
|
|
pref_doc = "User has mentioned: " + "; ".join(prefs)
|
|||
|
|
pref_docs.append(pref_doc)
|
|||
|
|
pref_ids.append(sess_id)
|
|||
|
|
pref_timestamps.append(date)
|
|||
|
|
|
|||
|
|
if not corpus_user:
|
|||
|
|
return [], corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# Two-pass for assistant-reference questions — V4 uses corpus_full for Pass 1
|
|||
|
|
# (ensures the quoted phrases appear in the indexed text)
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
if is_assistant_reference(question):
|
|||
|
|
collection = _fresh_collection()
|
|||
|
|
# Index full turns (not just user) so assistant's exact words are searchable
|
|||
|
|
collection.add(
|
|||
|
|
documents=corpus_full,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(corpus_full))],
|
|||
|
|
metadatas=[
|
|||
|
|
{"corpus_id": cid, "timestamp": ts}
|
|||
|
|
for cid, ts in zip(corpus_ids, corpus_timestamps)
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
results = collection.query(
|
|||
|
|
query_texts=[question],
|
|||
|
|
n_results=min(50, len(corpus_full)),
|
|||
|
|
include=["distances", "metadatas", "documents"],
|
|||
|
|
)
|
|||
|
|
result_ids = results["ids"][0]
|
|||
|
|
distances = results["distances"][0]
|
|||
|
|
documents = results["documents"][0]
|
|||
|
|
|
|||
|
|
# Apply quoted phrase + name boost in scoring
|
|||
|
|
scored = []
|
|||
|
|
for rid, dist, doc in zip(result_ids, distances, documents):
|
|||
|
|
idx = int(rid.split("_")[1])
|
|||
|
|
overlap = keyword_overlap(extract_keywords(question), doc)
|
|||
|
|
fused_dist = dist * (1.0 - hybrid_weight * overlap)
|
|||
|
|
# Quoted phrase boost — strong signal for assistant-recall questions
|
|||
|
|
q_boost = quoted_phrase_boost(quoted_phrases, doc)
|
|||
|
|
if q_boost > 0:
|
|||
|
|
fused_dist = fused_dist * (1.0 - 0.60 * q_boost)
|
|||
|
|
scored.append((idx, fused_dist))
|
|||
|
|
|
|||
|
|
scored.sort(key=lambda x: x[1])
|
|||
|
|
seen = set()
|
|||
|
|
ranked_indices = []
|
|||
|
|
for idx, _ in scored:
|
|||
|
|
if corpus_ids[idx] not in seen:
|
|||
|
|
seen.add(corpus_ids[idx])
|
|||
|
|
ranked_indices.append(idx)
|
|||
|
|
for i in range(len(corpus_user)):
|
|||
|
|
if corpus_ids[i] not in seen:
|
|||
|
|
ranked_indices.append(i)
|
|||
|
|
seen.add(corpus_ids[i])
|
|||
|
|
return ranked_indices, corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# Build expanded collection: user docs + synthetic preference docs
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
all_docs = corpus_user + pref_docs
|
|||
|
|
all_ids_meta = corpus_ids + pref_ids
|
|||
|
|
all_ts = corpus_timestamps + pref_timestamps
|
|||
|
|
|
|||
|
|
collection = _fresh_collection()
|
|||
|
|
collection.add(
|
|||
|
|
documents=all_docs,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(all_docs))],
|
|||
|
|
metadatas=[
|
|||
|
|
{"corpus_id": cid, "timestamp": ts, "is_pref": i >= len(corpus_user)}
|
|||
|
|
for i, (cid, ts) in enumerate(zip(all_ids_meta, all_ts))
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
query_keywords = extract_keywords(question)
|
|||
|
|
results = collection.query(
|
|||
|
|
query_texts=[question],
|
|||
|
|
n_results=min(n_results, len(all_docs)),
|
|||
|
|
include=["distances", "metadatas", "documents"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result_ids = results["ids"][0]
|
|||
|
|
distances = results["distances"][0]
|
|||
|
|
documents = results["documents"][0]
|
|||
|
|
doc_id_to_idx = {f"doc_{i}": i for i in range(len(all_docs))}
|
|||
|
|
|
|||
|
|
time_offset = parse_time_offset_days(question)
|
|||
|
|
target_date = None
|
|||
|
|
if time_offset and question_date:
|
|||
|
|
days_back, tolerance = time_offset
|
|||
|
|
target_date = question_date - timedelta(days=days_back)
|
|||
|
|
|
|||
|
|
scored = []
|
|||
|
|
for rid, dist, doc in zip(result_ids, distances, documents):
|
|||
|
|
idx = doc_id_to_idx[rid]
|
|||
|
|
overlap = keyword_overlap(query_keywords, doc)
|
|||
|
|
fused_dist = dist * (1.0 - hybrid_weight * overlap)
|
|||
|
|
|
|||
|
|
# Temporal boost (same as v3)
|
|||
|
|
if target_date:
|
|||
|
|
sess_date = parse_question_date(all_ts[idx])
|
|||
|
|
if sess_date:
|
|||
|
|
delta_days = abs((sess_date - target_date).days)
|
|||
|
|
tol = time_offset[1]
|
|||
|
|
if delta_days <= tol:
|
|||
|
|
temporal_boost = 0.40
|
|||
|
|
elif delta_days <= tol * 3:
|
|||
|
|
temporal_boost = 0.40 * (1.0 - (delta_days - tol) / (tol * 2))
|
|||
|
|
else:
|
|||
|
|
temporal_boost = 0.0
|
|||
|
|
fused_dist = fused_dist * (1.0 - temporal_boost)
|
|||
|
|
|
|||
|
|
# V4: Person name boost (for temporal-reasoning + person name questions)
|
|||
|
|
if person_names:
|
|||
|
|
n_boost = person_name_boost(person_names, doc)
|
|||
|
|
if n_boost > 0:
|
|||
|
|
fused_dist = fused_dist * (1.0 - 0.40 * n_boost)
|
|||
|
|
|
|||
|
|
scored.append((idx, fused_dist))
|
|||
|
|
|
|||
|
|
scored.sort(key=lambda x: x[1])
|
|||
|
|
|
|||
|
|
corpus_id_to_user_idx = {cid: i for i, cid in enumerate(corpus_ids)}
|
|||
|
|
seen_ids = set()
|
|||
|
|
ranked_indices = []
|
|||
|
|
for idx, _ in scored:
|
|||
|
|
cid = all_ids_meta[idx]
|
|||
|
|
if cid not in seen_ids:
|
|||
|
|
seen_ids.add(cid)
|
|||
|
|
ranked_indices.append(corpus_id_to_user_idx[cid])
|
|||
|
|
|
|||
|
|
for i in range(len(corpus_user)):
|
|||
|
|
if corpus_ids[i] not in seen_ids:
|
|||
|
|
ranked_indices.append(i)
|
|||
|
|
seen_ids.add(corpus_ids[i])
|
|||
|
|
|
|||
|
|
return ranked_indices, corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# PALACE MODE — Hall classification + drawer indexing + hall-boosted retrieval
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
# Hall names mirror the MemPal palace taxonomy
|
|||
|
|
HALL_PREFERENCES = "hall_preferences"
|
|||
|
|
HALL_FACTS = "hall_facts"
|
|||
|
|
HALL_EVENTS = "hall_events"
|
|||
|
|
HALL_ASSISTANT = "hall_assistant_advice"
|
|||
|
|
HALL_GENERAL = "hall_general"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def classify_session_hall(session):
|
|||
|
|
"""
|
|||
|
|
Assign a session to a palace hall based on its content.
|
|||
|
|
|
|||
|
|
Heuristics (checked in priority order):
|
|||
|
|
hall_preferences — user expresses preferences, concerns, ongoing struggles
|
|||
|
|
hall_assistant — assistant gave specific advice, lists, or recommendations
|
|||
|
|
hall_events — milestones, events, significant occurrences mentioned
|
|||
|
|
hall_facts — factual disclosures (degrees, jobs, places, numbers)
|
|||
|
|
hall_general — default
|
|||
|
|
"""
|
|||
|
|
user_text = " ".join(t["content"] for t in session if t["role"] == "user").lower()
|
|||
|
|
asst_text = " ".join(t["content"] for t in session if t["role"] == "assistant").lower()
|
|||
|
|
|
|||
|
|
pref_signals = [
|
|||
|
|
"i prefer",
|
|||
|
|
"i usually",
|
|||
|
|
"i've been having trouble",
|
|||
|
|
"i've been feeling",
|
|||
|
|
"i've been struggling",
|
|||
|
|
"i want to",
|
|||
|
|
"i'm worried",
|
|||
|
|
"i've been thinking",
|
|||
|
|
"i've been considering",
|
|||
|
|
"lately i",
|
|||
|
|
"recently i",
|
|||
|
|
"i tend to",
|
|||
|
|
]
|
|||
|
|
if any(s in user_text for s in pref_signals):
|
|||
|
|
return HALL_PREFERENCES
|
|||
|
|
|
|||
|
|
asst_advice_signals = [
|
|||
|
|
"i suggest",
|
|||
|
|
"i recommend",
|
|||
|
|
"here are",
|
|||
|
|
"you might want to",
|
|||
|
|
"option 1",
|
|||
|
|
"option 2",
|
|||
|
|
"1.",
|
|||
|
|
"2.",
|
|||
|
|
"3.",
|
|||
|
|
"first,",
|
|||
|
|
"second,",
|
|||
|
|
"you could try",
|
|||
|
|
"i would recommend",
|
|||
|
|
"my recommendation",
|
|||
|
|
]
|
|||
|
|
if sum(1 for s in asst_advice_signals if s in asst_text) >= 2:
|
|||
|
|
return HALL_ASSISTANT
|
|||
|
|
|
|||
|
|
event_signals = [
|
|||
|
|
"milestone",
|
|||
|
|
"graduation",
|
|||
|
|
"promotion",
|
|||
|
|
"anniversary",
|
|||
|
|
"birthday",
|
|||
|
|
"moved",
|
|||
|
|
"started",
|
|||
|
|
"finished",
|
|||
|
|
"completed",
|
|||
|
|
"launched",
|
|||
|
|
"opened",
|
|||
|
|
"achieved",
|
|||
|
|
"won",
|
|||
|
|
"accepted",
|
|||
|
|
"hired",
|
|||
|
|
"married",
|
|||
|
|
"born",
|
|||
|
|
]
|
|||
|
|
if any(s in user_text + asst_text for s in event_signals):
|
|||
|
|
return HALL_EVENTS
|
|||
|
|
|
|||
|
|
fact_signals = [
|
|||
|
|
"degree",
|
|||
|
|
"major",
|
|||
|
|
"university",
|
|||
|
|
"college",
|
|||
|
|
"job",
|
|||
|
|
"position",
|
|||
|
|
"role",
|
|||
|
|
"company",
|
|||
|
|
"city",
|
|||
|
|
"country",
|
|||
|
|
"street",
|
|||
|
|
"born in",
|
|||
|
|
"grew up",
|
|||
|
|
"studied",
|
|||
|
|
"works at",
|
|||
|
|
"lives in",
|
|||
|
|
"years old",
|
|||
|
|
"salary",
|
|||
|
|
"budget",
|
|||
|
|
]
|
|||
|
|
if sum(1 for s in fact_signals if s in user_text + asst_text) >= 2:
|
|||
|
|
return HALL_FACTS
|
|||
|
|
|
|||
|
|
return HALL_GENERAL
|
|||
|
|
|
|||
|
|
|
|||
|
|
def classify_question_hall(question):
|
|||
|
|
"""
|
|||
|
|
Infer which palace hall a question is asking about.
|
|||
|
|
|
|||
|
|
Returns a list of halls in priority order (first = most likely).
|
|||
|
|
"""
|
|||
|
|
q = question.lower()
|
|||
|
|
|
|||
|
|
if any(
|
|||
|
|
t in q
|
|||
|
|
for t in [
|
|||
|
|
"you suggested",
|
|||
|
|
"you told me",
|
|||
|
|
"you mentioned",
|
|||
|
|
"you said",
|
|||
|
|
"you recommended",
|
|||
|
|
"you provided",
|
|||
|
|
"you listed",
|
|||
|
|
"you gave",
|
|||
|
|
"remind me what you",
|
|||
|
|
"you came up with",
|
|||
|
|
"you explained",
|
|||
|
|
]
|
|||
|
|
):
|
|||
|
|
return [HALL_ASSISTANT, HALL_GENERAL]
|
|||
|
|
|
|||
|
|
if any(
|
|||
|
|
t in q
|
|||
|
|
for t in [
|
|||
|
|
"i've been having trouble",
|
|||
|
|
"i've been feeling",
|
|||
|
|
"i prefer",
|
|||
|
|
"i usually",
|
|||
|
|
"battery",
|
|||
|
|
"nostalgic",
|
|||
|
|
"reunion",
|
|||
|
|
"lately",
|
|||
|
|
"recently been",
|
|||
|
|
"struggling with",
|
|||
|
|
]
|
|||
|
|
):
|
|||
|
|
return [HALL_PREFERENCES, HALL_GENERAL]
|
|||
|
|
|
|||
|
|
if any(
|
|||
|
|
t in q
|
|||
|
|
for t in [
|
|||
|
|
"milestone",
|
|||
|
|
"when did",
|
|||
|
|
"what happened",
|
|||
|
|
"achievement",
|
|||
|
|
"ago",
|
|||
|
|
"last week",
|
|||
|
|
"last month",
|
|||
|
|
"last year",
|
|||
|
|
"four weeks",
|
|||
|
|
"three months",
|
|||
|
|
]
|
|||
|
|
):
|
|||
|
|
return [HALL_EVENTS, HALL_FACTS, HALL_GENERAL]
|
|||
|
|
|
|||
|
|
if any(
|
|||
|
|
t in q
|
|||
|
|
for t in [
|
|||
|
|
"degree",
|
|||
|
|
"study",
|
|||
|
|
"graduate",
|
|||
|
|
"major",
|
|||
|
|
"job",
|
|||
|
|
"work",
|
|||
|
|
"live",
|
|||
|
|
"born",
|
|||
|
|
"city",
|
|||
|
|
"country",
|
|||
|
|
"company",
|
|||
|
|
"school",
|
|||
|
|
]
|
|||
|
|
):
|
|||
|
|
return [HALL_FACTS, HALL_GENERAL]
|
|||
|
|
|
|||
|
|
return [HALL_GENERAL]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_palace_and_retrieve_palace(
|
|||
|
|
entry, granularity="session", n_results=50, hybrid_weight=0.30
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Palace-mode retrieval: navigate by hall first, fall back to full search.
|
|||
|
|
|
|||
|
|
The palace insight: don't search everything flat. Enter through the right
|
|||
|
|
hall — a smaller, more focused subset — and get a tight answer fast.
|
|||
|
|
Only widen to the full haystack if the hall search doesn't yield confidence.
|
|||
|
|
|
|||
|
|
PALACE
|
|||
|
|
└── HALL (classified per session: preferences / facts / events / assistant / general)
|
|||
|
|
└── CLOSET (user turns per session — what the user said)
|
|||
|
|
└── DRAWER (assistant turns — only opened for assistant-reference questions)
|
|||
|
|
└── PREFERENCE WING (synthetic docs from pref extraction — same session ID)
|
|||
|
|
|
|||
|
|
Navigation:
|
|||
|
|
1. Classify question → primary hall
|
|||
|
|
2. PASS 1: search only the primary hall (tight — 5-15 sessions max)
|
|||
|
|
If top result has low distance (confident match) → done
|
|||
|
|
3. PASS 2 (fallback): search full haystack with hall-aware scoring
|
|||
|
|
Sessions in the primary hall get a 25% distance bonus
|
|||
|
|
4. For assistant-reference questions: open drawers within top sessions
|
|||
|
|
"""
|
|||
|
|
import re as _re
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
|
|||
|
|
STOP_WORDS = {
|
|||
|
|
"what",
|
|||
|
|
"when",
|
|||
|
|
"where",
|
|||
|
|
"who",
|
|||
|
|
"how",
|
|||
|
|
"which",
|
|||
|
|
"did",
|
|||
|
|
"do",
|
|||
|
|
"was",
|
|||
|
|
"were",
|
|||
|
|
"have",
|
|||
|
|
"has",
|
|||
|
|
"had",
|
|||
|
|
"is",
|
|||
|
|
"are",
|
|||
|
|
"the",
|
|||
|
|
"a",
|
|||
|
|
"an",
|
|||
|
|
"my",
|
|||
|
|
"me",
|
|||
|
|
"i",
|
|||
|
|
"you",
|
|||
|
|
"your",
|
|||
|
|
"their",
|
|||
|
|
"it",
|
|||
|
|
"its",
|
|||
|
|
"in",
|
|||
|
|
"on",
|
|||
|
|
"at",
|
|||
|
|
"to",
|
|||
|
|
"for",
|
|||
|
|
"of",
|
|||
|
|
"with",
|
|||
|
|
"by",
|
|||
|
|
"from",
|
|||
|
|
"ago",
|
|||
|
|
"last",
|
|||
|
|
"that",
|
|||
|
|
"this",
|
|||
|
|
"there",
|
|||
|
|
"about",
|
|||
|
|
"get",
|
|||
|
|
"got",
|
|||
|
|
"give",
|
|||
|
|
"gave",
|
|||
|
|
"buy",
|
|||
|
|
"bought",
|
|||
|
|
"made",
|
|||
|
|
"make",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def extract_keywords(text):
|
|||
|
|
words = _re.findall(r"\b[a-z]{3,}\b", text.lower())
|
|||
|
|
return [w for w in words if w not in STOP_WORDS]
|
|||
|
|
|
|||
|
|
def keyword_overlap(query_kws, doc_text):
|
|||
|
|
doc_lower = doc_text.lower()
|
|||
|
|
if not query_kws:
|
|||
|
|
return 0.0
|
|||
|
|
hits = sum(1 for kw in query_kws if kw in doc_lower)
|
|||
|
|
return hits / len(query_kws)
|
|||
|
|
|
|||
|
|
def parse_question_date(date_str):
|
|||
|
|
try:
|
|||
|
|
return datetime.strptime(date_str.split(" (")[0], "%Y/%m/%d")
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def parse_time_offset_days(question):
|
|||
|
|
q = question.lower()
|
|||
|
|
patterns = [
|
|||
|
|
(r"(\d+)\s+days?\s+ago", lambda m: (int(m.group(1)), 2)),
|
|||
|
|
(r"a\s+couple\s+(?:of\s+)?days?\s+ago", lambda m: (2, 2)),
|
|||
|
|
(r"yesterday", lambda m: (1, 1)),
|
|||
|
|
(r"a\s+week\s+ago", lambda m: (7, 3)),
|
|||
|
|
(r"(\d+)\s+weeks?\s+ago", lambda m: (int(m.group(1)) * 7, 5)),
|
|||
|
|
(r"last\s+week", lambda m: (7, 3)),
|
|||
|
|
(r"a\s+month\s+ago", lambda m: (30, 7)),
|
|||
|
|
(r"(\d+)\s+months?\s+ago", lambda m: (int(m.group(1)) * 30, 10)),
|
|||
|
|
(r"last\s+month", lambda m: (30, 7)),
|
|||
|
|
(r"last\s+year", lambda m: (365, 30)),
|
|||
|
|
(r"a\s+year\s+ago", lambda m: (365, 30)),
|
|||
|
|
(r"recently", lambda m: (14, 14)),
|
|||
|
|
]
|
|||
|
|
for pattern, extractor in patterns:
|
|||
|
|
m = _re.search(pattern, q)
|
|||
|
|
if m:
|
|||
|
|
return extractor(m)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# Preference extraction (same as v3)
|
|||
|
|
PREF_PATTERNS = [
|
|||
|
|
r"i(?:'ve been| have been) having (?:trouble|issues?|problems?) with ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) feeling ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:struggling|dealing) with ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:worried|concerned) about ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'m| am) (?:worried|concerned) about ([^,\.!?]{5,80})",
|
|||
|
|
r"i prefer ([^,\.!?]{5,60})",
|
|||
|
|
r"i usually ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:trying|attempting) to ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:considering|thinking about) ([^,\.!?]{5,80})",
|
|||
|
|
r"lately[,\s]+(?:i've been|i have been|i'm|i am) ([^,\.!?]{5,80})",
|
|||
|
|
r"recently[,\s]+(?:i've been|i have been|i'm|i am) ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:working on|focused on|interested in) ([^,\.!?]{5,80})",
|
|||
|
|
r"i want to ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'m| am) looking (?:to|for) ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'m| am) thinking (?:about|of) ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:noticing|experiencing) ([^,\.!?]{5,80})",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def extract_preferences(session):
|
|||
|
|
mentions = []
|
|||
|
|
for turn in session:
|
|||
|
|
if turn["role"] != "user":
|
|||
|
|
continue
|
|||
|
|
text = turn["content"].lower()
|
|||
|
|
for pat in PREF_PATTERNS:
|
|||
|
|
for match in _re.findall(pat, text, _re.IGNORECASE):
|
|||
|
|
clean = match.strip().rstrip(".,;!? ")
|
|||
|
|
if 5 <= len(clean) <= 80:
|
|||
|
|
mentions.append(clean)
|
|||
|
|
seen = set()
|
|||
|
|
unique = []
|
|||
|
|
for m in mentions:
|
|||
|
|
if m not in seen:
|
|||
|
|
seen.add(m)
|
|||
|
|
unique.append(m)
|
|||
|
|
return unique[:10]
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# Build palace — classify sessions into halls, build per-hall closets
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
sessions = entry["haystack_sessions"]
|
|||
|
|
session_ids = entry["haystack_session_ids"]
|
|||
|
|
dates = entry["haystack_dates"]
|
|||
|
|
question = entry["question"]
|
|||
|
|
question_date = parse_question_date(entry.get("question_date", ""))
|
|||
|
|
|
|||
|
|
# Canonical corpus (user turns per session) — indices used for evaluation
|
|||
|
|
corpus_user = []
|
|||
|
|
corpus_ids = []
|
|||
|
|
corpus_timestamps = []
|
|||
|
|
|
|||
|
|
# Per-hall closet documents (user turns only — clean, no noise)
|
|||
|
|
hall_docs = {
|
|||
|
|
h: [] for h in [HALL_PREFERENCES, HALL_FACTS, HALL_EVENTS, HALL_ASSISTANT, HALL_GENERAL]
|
|||
|
|
}
|
|||
|
|
hall_meta = {h: [] for h in hall_docs}
|
|||
|
|
|
|||
|
|
# Preference wing: synthetic docs for vocab-gap bridging (separate from halls)
|
|||
|
|
pref_wing_docs = []
|
|||
|
|
pref_wing_meta = []
|
|||
|
|
|
|||
|
|
# Drawer index: assistant turns per session (only opened when needed)
|
|||
|
|
drawer_docs = []
|
|||
|
|
drawer_meta = []
|
|||
|
|
|
|||
|
|
for session, sess_id, date in zip(sessions, session_ids, dates):
|
|||
|
|
user_turns = [t["content"] for t in session if t["role"] == "user"]
|
|||
|
|
asst_turns = [t["content"] for t in session if t["role"] == "assistant"]
|
|||
|
|
if not user_turns:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
hall = classify_session_hall(session)
|
|||
|
|
user_doc = "\n".join(user_turns)
|
|||
|
|
|
|||
|
|
# Canonical entry
|
|||
|
|
corpus_user.append(user_doc)
|
|||
|
|
corpus_ids.append(sess_id)
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
|
|||
|
|
# CLOSET: file into the correct hall (clean, targeted)
|
|||
|
|
hall_docs[hall].append(user_doc)
|
|||
|
|
hall_meta[hall].append({"corpus_id": sess_id, "timestamp": date, "hall": hall})
|
|||
|
|
|
|||
|
|
# PREFERENCE WING: synthetic preference doc (same session, separate index)
|
|||
|
|
prefs = extract_preferences(session)
|
|||
|
|
if prefs:
|
|||
|
|
pref_doc = "User has mentioned: " + "; ".join(prefs)
|
|||
|
|
pref_wing_docs.append(pref_doc)
|
|||
|
|
pref_wing_meta.append({"corpus_id": sess_id, "timestamp": date})
|
|||
|
|
|
|||
|
|
# DRAWERS: assistant turns stored separately, only indexed on demand
|
|||
|
|
for asst_turn in asst_turns:
|
|||
|
|
if len(asst_turn) > 30:
|
|||
|
|
drawer_docs.append(asst_turn)
|
|||
|
|
drawer_meta.append({"corpus_id": sess_id, "timestamp": date})
|
|||
|
|
|
|||
|
|
if not corpus_user:
|
|||
|
|
return [], corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# Navigate: classify question → primary hall
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
target_halls = classify_question_hall(question)
|
|||
|
|
primary_hall = target_halls[0]
|
|||
|
|
query_keywords = extract_keywords(question)
|
|||
|
|
|
|||
|
|
def hybrid_score(dist, doc):
|
|||
|
|
overlap = keyword_overlap(query_keywords, doc)
|
|||
|
|
return dist * (1.0 - hybrid_weight * overlap)
|
|||
|
|
|
|||
|
|
def apply_temporal(fused_dist, timestamp):
|
|||
|
|
if not target_date:
|
|||
|
|
return fused_dist
|
|||
|
|
sess_date = parse_question_date(timestamp)
|
|||
|
|
if not sess_date:
|
|||
|
|
return fused_dist
|
|||
|
|
delta_days = abs((sess_date - target_date).days)
|
|||
|
|
tol = time_offset[1]
|
|||
|
|
if delta_days <= tol:
|
|||
|
|
boost = 0.40
|
|||
|
|
elif delta_days <= tol * 3:
|
|||
|
|
boost = 0.40 * (1.0 - (delta_days - tol) / (tol * 2))
|
|||
|
|
else:
|
|||
|
|
boost = 0.0
|
|||
|
|
return fused_dist * (1.0 - boost)
|
|||
|
|
|
|||
|
|
# Temporal setup
|
|||
|
|
time_offset = parse_time_offset_days(question)
|
|||
|
|
target_date = None
|
|||
|
|
if time_offset and question_date:
|
|||
|
|
target_date = question_date - timedelta(days=time_offset[0])
|
|||
|
|
|
|||
|
|
corpus_id_to_user_idx = {cid: i for i, cid in enumerate(corpus_ids)}
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# PASS 1: Navigate into primary hall — tight, focused search
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
primary_hall_docs = hall_docs[primary_hall]
|
|||
|
|
primary_hall_meta = hall_meta[primary_hall]
|
|||
|
|
|
|||
|
|
# Also include preference wing docs if question is preference-type
|
|||
|
|
pass1_docs = list(primary_hall_docs)
|
|||
|
|
pass1_meta = list(primary_hall_meta)
|
|||
|
|
if primary_hall == HALL_PREFERENCES and pref_wing_docs:
|
|||
|
|
pass1_docs += pref_wing_docs
|
|||
|
|
pass1_meta += pref_wing_meta
|
|||
|
|
|
|||
|
|
# For assistant-reference: open drawers within the primary hall sessions
|
|||
|
|
if primary_hall == HALL_ASSISTANT and drawer_docs:
|
|||
|
|
# Only drawers from sessions in the assistant hall
|
|||
|
|
hall_session_ids = {m["corpus_id"] for m in primary_hall_meta}
|
|||
|
|
for ddoc, dmeta in zip(drawer_docs, drawer_meta):
|
|||
|
|
if dmeta["corpus_id"] in hall_session_ids:
|
|||
|
|
pass1_docs.append(ddoc)
|
|||
|
|
pass1_meta.append(dmeta)
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# PASS 1: Navigate into primary hall — tight, focused search
|
|||
|
|
# Builds a set of hall-validated session IDs for Pass 2 score bonus
|
|||
|
|
# Does NOT pre-empt Pass 2 results — scores decide final order
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
hall_validated_ids = set() # sessions confirmed by tight hall search
|
|||
|
|
|
|||
|
|
# Only do Pass 1 for specific halls (not GENERAL — too broad to be useful)
|
|||
|
|
if primary_hall != HALL_GENERAL and len(pass1_docs) >= 1:
|
|||
|
|
coll1 = _fresh_collection("mempal_hall")
|
|||
|
|
coll1.add(
|
|||
|
|
documents=pass1_docs,
|
|||
|
|
ids=[f"h_{i}" for i in range(len(pass1_docs))],
|
|||
|
|
metadatas=pass1_meta,
|
|||
|
|
)
|
|||
|
|
r1 = coll1.query(
|
|||
|
|
query_texts=[question],
|
|||
|
|
n_results=min(10, len(pass1_docs)),
|
|||
|
|
include=["distances", "metadatas", "documents"],
|
|||
|
|
)
|
|||
|
|
for rid, dist, doc, meta in zip(
|
|||
|
|
r1["ids"][0], r1["distances"][0], r1["documents"][0], r1["metadatas"][0]
|
|||
|
|
):
|
|||
|
|
hall_validated_ids.add(meta["corpus_id"])
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# PASS 2: Full haystack search — primary ranking
|
|||
|
|
# Hall bonus: sessions in primary hall get distance reduction
|
|||
|
|
# Double-validation bonus: sessions also found in Pass 1 get extra boost
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
full_docs = corpus_user + pref_wing_docs
|
|||
|
|
full_meta_list = [
|
|||
|
|
{
|
|||
|
|
"corpus_id": corpus_ids[i],
|
|||
|
|
"timestamp": corpus_timestamps[i],
|
|||
|
|
"hall": classify_session_hall(sessions[i]) if i < len(sessions) else HALL_GENERAL,
|
|||
|
|
}
|
|||
|
|
for i in range(len(corpus_user))
|
|||
|
|
]
|
|||
|
|
full_meta_list += pref_wing_meta
|
|||
|
|
|
|||
|
|
coll2 = _fresh_collection()
|
|||
|
|
coll2.add(
|
|||
|
|
documents=full_docs,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(full_docs))],
|
|||
|
|
metadatas=full_meta_list,
|
|||
|
|
)
|
|||
|
|
r2 = coll2.query(
|
|||
|
|
query_texts=[question],
|
|||
|
|
n_results=min(n_results, len(full_docs)),
|
|||
|
|
include=["distances", "metadatas", "documents"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
full_scored = []
|
|||
|
|
for rid, dist, doc, meta in zip(
|
|||
|
|
r2["ids"][0], r2["distances"][0], r2["documents"][0], r2["metadatas"][0]
|
|||
|
|
):
|
|||
|
|
fd = hybrid_score(dist, doc)
|
|||
|
|
cid = meta["corpus_id"]
|
|||
|
|
# Hall bonus: sessions in the primary hall get 25% distance reduction
|
|||
|
|
if meta.get("hall") == primary_hall and primary_hall != HALL_GENERAL:
|
|||
|
|
fd = fd * 0.75
|
|||
|
|
elif meta.get("hall") in target_halls:
|
|||
|
|
fd = fd * 0.90
|
|||
|
|
# Double-validation bonus: appeared in tight hall search → extra 15% boost
|
|||
|
|
if cid in hall_validated_ids:
|
|||
|
|
fd = fd * 0.85
|
|||
|
|
fd = apply_temporal(fd, meta.get("timestamp", ""))
|
|||
|
|
full_scored.append((cid, fd))
|
|||
|
|
|
|||
|
|
full_scored.sort(key=lambda x: x[1])
|
|||
|
|
|
|||
|
|
# Build final ranking purely by score — hall navigation boosts but never overrides
|
|||
|
|
ranked_indices = []
|
|||
|
|
seen_ids = set()
|
|||
|
|
for cid, _ in full_scored:
|
|||
|
|
if cid not in seen_ids and cid in corpus_id_to_user_idx:
|
|||
|
|
ranked_indices.append(corpus_id_to_user_idx[cid])
|
|||
|
|
seen_ids.add(cid)
|
|||
|
|
|
|||
|
|
# Fill any stragglers
|
|||
|
|
for i in range(len(corpus_user)):
|
|||
|
|
if corpus_ids[i] not in seen_ids:
|
|||
|
|
ranked_indices.append(i)
|
|||
|
|
seen_ids.add(corpus_ids[i])
|
|||
|
|
|
|||
|
|
return ranked_indices, corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# LLM RE-RANKER (optional third pass)
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
|
|||
|
|
def diary_ingest_session(session, sess_id, api_key, model="claude-haiku-4-5-20251001"):
|
|||
|
|
"""
|
|||
|
|
Call an LLM to extract topics and a summary from one session.
|
|||
|
|
|
|||
|
|
This is the "LLM topic layer" — the core of diary mode.
|
|||
|
|
Haiku reads the session once and returns:
|
|||
|
|
topics: 2-5 specific things discussed ("yoga classes", "job interview at fintech startup")
|
|||
|
|
summary: 1-2 sentences describing what the session was about
|
|||
|
|
|
|||
|
|
These become synthetic documents added to the haystack with the same
|
|||
|
|
corpus_id as the session — bridging vocabulary gaps that embeddings miss.
|
|||
|
|
|
|||
|
|
Example gap closed:
|
|||
|
|
Session: "I went this morning, my instructor pushed me really hard"
|
|||
|
|
Question: "Where do I take yoga classes?"
|
|||
|
|
Without diary: no keyword overlap → miss
|
|||
|
|
With diary: topic doc "yoga classes, fitness routine" → hit
|
|||
|
|
|
|||
|
|
Returns: {"topics": [...], "summary": "..."} or None on failure.
|
|||
|
|
"""
|
|||
|
|
import urllib.request as _urllib_request
|
|||
|
|
|
|||
|
|
user_turns = [t["content"] for t in session if t["role"] == "user"]
|
|||
|
|
if not user_turns:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# Only send first 1200 chars of user text — enough context, cheap prompt
|
|||
|
|
user_text = " | ".join(user_turns)[:1200]
|
|||
|
|
|
|||
|
|
prompt = (
|
|||
|
|
"Read this conversation excerpt (user turns only) and extract:\n\n"
|
|||
|
|
f"USER SAID:\n{user_text}\n\n"
|
|||
|
|
"Return a JSON object with exactly two fields:\n"
|
|||
|
|
'{"topics": ["specific topic 1", "specific topic 2", ...], "summary": "1-2 sentences"}\n\n'
|
|||
|
|
"Rules:\n"
|
|||
|
|
"- topics: 2-5 SPECIFIC things discussed. Not 'work' — 'job interview at law firm'. "
|
|||
|
|
"Not 'health' — 'back pain from sitting at desk'. Not 'travel' — 'trip to Tokyo in March'.\n"
|
|||
|
|
"- summary: what this person was talking about, in plain language\n"
|
|||
|
|
"- Return ONLY valid JSON. No markdown, no explanation."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
payload = json.dumps(
|
|||
|
|
{
|
|||
|
|
"model": model,
|
|||
|
|
"max_tokens": 200,
|
|||
|
|
"messages": [{"role": "user", "content": prompt}],
|
|||
|
|
}
|
|||
|
|
).encode("utf-8")
|
|||
|
|
|
|||
|
|
req = _urllib_request.Request(
|
|||
|
|
"https://api.anthropic.com/v1/messages",
|
|||
|
|
data=payload,
|
|||
|
|
headers={
|
|||
|
|
"x-api-key": api_key,
|
|||
|
|
"anthropic-version": "2023-06-01",
|
|||
|
|
"content-type": "application/json",
|
|||
|
|
},
|
|||
|
|
method="POST",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with _urllib_request.urlopen(req, timeout=25) as resp:
|
|||
|
|
result = json.loads(resp.read())
|
|||
|
|
raw = result["content"][0]["text"].strip()
|
|||
|
|
raw = re.sub(r"^```(?:json)?\s*", "", raw)
|
|||
|
|
raw = re.sub(r"\s*```$", "", raw)
|
|||
|
|
data = json.loads(raw)
|
|||
|
|
if "topics" in data and "summary" in data:
|
|||
|
|
return data
|
|||
|
|
except Exception:
|
|||
|
|
pass # timeout, network error, bad JSON — fall through to None
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_palace_and_retrieve_diary(
|
|||
|
|
entry,
|
|||
|
|
granularity="session",
|
|||
|
|
n_results=50,
|
|||
|
|
hybrid_weight=0.30,
|
|||
|
|
diary_cache=None,
|
|||
|
|
api_key="",
|
|||
|
|
diary_model="claude-haiku-4-5-20251001",
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Diary mode: palace retrieval + LLM topic layer at ingest.
|
|||
|
|
|
|||
|
|
On top of palace mode's hall/closet/drawer navigation, diary mode adds:
|
|||
|
|
|
|||
|
|
DIARY LAYER (per session, computed once and cached):
|
|||
|
|
- Haiku reads the session → extracts 2-5 specific topics + a summary
|
|||
|
|
- Synthetic doc: "Session topics: yoga classes, Tuesday routine. Summary: ..."
|
|||
|
|
- Same corpus_id as the session → evaluation maps it correctly
|
|||
|
|
- Added to the haystack alongside raw user turns
|
|||
|
|
|
|||
|
|
This bridges vocabulary gaps that neither embeddings nor keyword matching
|
|||
|
|
can cross — e.g., "Where do I take yoga classes?" matching a session that
|
|||
|
|
only says "I went this morning, my instructor was great."
|
|||
|
|
|
|||
|
|
diary_cache: dict mapping sess_id → {"topics": [...], "summary": "..."}
|
|||
|
|
Pre-populated before the benchmark loop to avoid redundant API calls.
|
|||
|
|
Pass the same dict across all questions — it grows as new sessions appear.
|
|||
|
|
"""
|
|||
|
|
import re as _re
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
|
|||
|
|
STOP_WORDS = {
|
|||
|
|
"what",
|
|||
|
|
"when",
|
|||
|
|
"where",
|
|||
|
|
"who",
|
|||
|
|
"how",
|
|||
|
|
"which",
|
|||
|
|
"did",
|
|||
|
|
"do",
|
|||
|
|
"was",
|
|||
|
|
"were",
|
|||
|
|
"have",
|
|||
|
|
"has",
|
|||
|
|
"had",
|
|||
|
|
"is",
|
|||
|
|
"are",
|
|||
|
|
"the",
|
|||
|
|
"a",
|
|||
|
|
"an",
|
|||
|
|
"my",
|
|||
|
|
"me",
|
|||
|
|
"i",
|
|||
|
|
"you",
|
|||
|
|
"your",
|
|||
|
|
"their",
|
|||
|
|
"it",
|
|||
|
|
"its",
|
|||
|
|
"in",
|
|||
|
|
"on",
|
|||
|
|
"at",
|
|||
|
|
"to",
|
|||
|
|
"for",
|
|||
|
|
"of",
|
|||
|
|
"with",
|
|||
|
|
"by",
|
|||
|
|
"from",
|
|||
|
|
"ago",
|
|||
|
|
"last",
|
|||
|
|
"that",
|
|||
|
|
"this",
|
|||
|
|
"there",
|
|||
|
|
"about",
|
|||
|
|
"get",
|
|||
|
|
"got",
|
|||
|
|
"give",
|
|||
|
|
"gave",
|
|||
|
|
"buy",
|
|||
|
|
"bought",
|
|||
|
|
"made",
|
|||
|
|
"make",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def extract_keywords(text):
|
|||
|
|
words = _re.findall(r"\b[a-z]{3,}\b", text.lower())
|
|||
|
|
return [w for w in words if w not in STOP_WORDS]
|
|||
|
|
|
|||
|
|
def keyword_overlap(query_kws, doc_text):
|
|||
|
|
doc_lower = doc_text.lower()
|
|||
|
|
if not query_kws:
|
|||
|
|
return 0.0
|
|||
|
|
hits = sum(1 for kw in query_kws if kw in doc_lower)
|
|||
|
|
return hits / len(query_kws)
|
|||
|
|
|
|||
|
|
def parse_question_date(date_str):
|
|||
|
|
try:
|
|||
|
|
return datetime.strptime(date_str.split(" (")[0], "%Y/%m/%d")
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def parse_time_offset_days(question):
|
|||
|
|
q = question.lower()
|
|||
|
|
patterns = [
|
|||
|
|
(r"(\d+)\s+days?\s+ago", lambda m: (int(m.group(1)), 2)),
|
|||
|
|
(r"a\s+couple\s+(?:of\s+)?days?\s+ago", lambda m: (2, 2)),
|
|||
|
|
(r"yesterday", lambda m: (1, 1)),
|
|||
|
|
(r"a\s+week\s+ago", lambda m: (7, 3)),
|
|||
|
|
(r"(\d+)\s+weeks?\s+ago", lambda m: (int(m.group(1)) * 7, 5)),
|
|||
|
|
(r"last\s+week", lambda m: (7, 3)),
|
|||
|
|
(r"a\s+month\s+ago", lambda m: (30, 7)),
|
|||
|
|
(r"(\d+)\s+months?\s+ago", lambda m: (int(m.group(1)) * 30, 10)),
|
|||
|
|
(r"last\s+month", lambda m: (30, 7)),
|
|||
|
|
(r"last\s+year", lambda m: (365, 30)),
|
|||
|
|
(r"a\s+year\s+ago", lambda m: (365, 30)),
|
|||
|
|
(r"recently", lambda m: (14, 14)),
|
|||
|
|
]
|
|||
|
|
for pattern, extractor in patterns:
|
|||
|
|
m = _re.search(pattern, q)
|
|||
|
|
if m:
|
|||
|
|
return extractor(m)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# Preference extraction (same 16 patterns as v3/palace)
|
|||
|
|
PREF_PATTERNS = [
|
|||
|
|
r"i(?:'ve been| have been) having (?:trouble|issues?|problems?) with ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) feeling ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:struggling|dealing) with ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:worried|concerned) about ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'m| am) (?:worried|concerned) about ([^,\.!?]{5,80})",
|
|||
|
|
r"i prefer ([^,\.!?]{5,60})",
|
|||
|
|
r"i usually ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:trying|attempting) to ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:considering|thinking about) ([^,\.!?]{5,80})",
|
|||
|
|
r"lately[,\s]+(?:i've been|i have been|i'm|i am) ([^,\.!?]{5,80})",
|
|||
|
|
r"recently[,\s]+(?:i've been|i have been|i'm|i am) ([^,\.!?]{5,80})",
|
|||
|
|
r"i(?:'ve been| have been) (?:working on|focused on|interested in) ([^,\.!?]{5,80})",
|
|||
|
|
r"i want to ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'m| am) looking (?:to|for) ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'m| am) thinking (?:about|of) ([^,\.!?]{5,60})",
|
|||
|
|
r"i(?:'ve been| have been) (?:noticing|experiencing) ([^,\.!?]{5,80})",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def extract_preferences(session):
|
|||
|
|
mentions = []
|
|||
|
|
for turn in session:
|
|||
|
|
if turn["role"] != "user":
|
|||
|
|
continue
|
|||
|
|
text = turn["content"].lower()
|
|||
|
|
for pat in PREF_PATTERNS:
|
|||
|
|
for match in _re.findall(pat, text, _re.IGNORECASE):
|
|||
|
|
clean = match.strip().rstrip(".,;!? ")
|
|||
|
|
if 5 <= len(clean) <= 80:
|
|||
|
|
mentions.append(clean)
|
|||
|
|
seen = set()
|
|||
|
|
unique = []
|
|||
|
|
for m in mentions:
|
|||
|
|
if m not in seen:
|
|||
|
|
seen.add(m)
|
|||
|
|
unique.append(m)
|
|||
|
|
return unique[:10]
|
|||
|
|
|
|||
|
|
if diary_cache is None:
|
|||
|
|
diary_cache = {}
|
|||
|
|
|
|||
|
|
sessions = entry["haystack_sessions"]
|
|||
|
|
session_ids = entry["haystack_session_ids"]
|
|||
|
|
dates = entry["haystack_dates"]
|
|||
|
|
question = entry["question"]
|
|||
|
|
question_date = parse_question_date(entry.get("question_date", ""))
|
|||
|
|
|
|||
|
|
corpus_user = []
|
|||
|
|
corpus_ids = []
|
|||
|
|
corpus_timestamps = []
|
|||
|
|
diary_docs = [] # LLM topic layer docs (one per session with diary data)
|
|||
|
|
diary_meta = []
|
|||
|
|
pref_wing_docs = []
|
|||
|
|
pref_wing_meta = []
|
|||
|
|
|
|||
|
|
for session, sess_id, date in zip(sessions, session_ids, dates):
|
|||
|
|
user_turns = [t["content"] for t in session if t["role"] == "user"]
|
|||
|
|
if not user_turns:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
user_doc = "\n".join(user_turns)
|
|||
|
|
corpus_user.append(user_doc)
|
|||
|
|
corpus_ids.append(sess_id)
|
|||
|
|
corpus_timestamps.append(date)
|
|||
|
|
|
|||
|
|
# DIARY LAYER: get or compute LLM topic extraction
|
|||
|
|
if sess_id not in diary_cache:
|
|||
|
|
if api_key:
|
|||
|
|
result = diary_ingest_session(session, sess_id, api_key, model=diary_model)
|
|||
|
|
diary_cache[sess_id] = result # cache even if None
|
|||
|
|
else:
|
|||
|
|
diary_cache[sess_id] = None
|
|||
|
|
|
|||
|
|
diary_data = diary_cache.get(sess_id)
|
|||
|
|
if diary_data:
|
|||
|
|
topics = diary_data.get("topics", [])
|
|||
|
|
summary = diary_data.get("summary", "")
|
|||
|
|
if topics or summary:
|
|||
|
|
topic_str = ", ".join(topics) if topics else ""
|
|||
|
|
diary_doc = f"Session topics: {topic_str}. Summary: {summary}"
|
|||
|
|
diary_docs.append(diary_doc)
|
|||
|
|
diary_meta.append(
|
|||
|
|
{
|
|||
|
|
"corpus_id": sess_id,
|
|||
|
|
"timestamp": date,
|
|||
|
|
"hall": classify_session_hall(session),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# PREFERENCE WING (same as v3/palace)
|
|||
|
|
prefs = extract_preferences(session)
|
|||
|
|
if prefs:
|
|||
|
|
pref_doc = "User has mentioned: " + "; ".join(prefs)
|
|||
|
|
pref_wing_docs.append(pref_doc)
|
|||
|
|
pref_wing_meta.append({"corpus_id": sess_id, "timestamp": date})
|
|||
|
|
|
|||
|
|
if not corpus_user:
|
|||
|
|
return [], corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
# Hall navigation (same as palace)
|
|||
|
|
target_halls = classify_question_hall(question)
|
|||
|
|
primary_hall = target_halls[0]
|
|||
|
|
query_keywords = extract_keywords(question)
|
|||
|
|
|
|||
|
|
def hybrid_score(dist, doc):
|
|||
|
|
overlap = keyword_overlap(query_keywords, doc)
|
|||
|
|
return dist * (1.0 - hybrid_weight * overlap)
|
|||
|
|
|
|||
|
|
time_offset = parse_time_offset_days(question)
|
|||
|
|
target_date = None
|
|||
|
|
if time_offset and question_date:
|
|||
|
|
target_date = question_date - timedelta(days=time_offset[0])
|
|||
|
|
|
|||
|
|
def apply_temporal(fused_dist, timestamp):
|
|||
|
|
if not target_date:
|
|||
|
|
return fused_dist
|
|||
|
|
sess_date = parse_question_date(timestamp)
|
|||
|
|
if not sess_date:
|
|||
|
|
return fused_dist
|
|||
|
|
delta_days = abs((sess_date - target_date).days)
|
|||
|
|
tol = time_offset[1]
|
|||
|
|
if delta_days <= tol:
|
|||
|
|
boost = 0.40
|
|||
|
|
elif delta_days <= tol * 3:
|
|||
|
|
boost = 0.40 * (1.0 - (delta_days - tol) / (tol * 2))
|
|||
|
|
else:
|
|||
|
|
boost = 0.0
|
|||
|
|
return fused_dist * (1.0 - boost)
|
|||
|
|
|
|||
|
|
corpus_id_to_user_idx = {cid: i for i, cid in enumerate(corpus_ids)}
|
|||
|
|
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
# FULL SEARCH: raw user docs + diary topic docs + preference wing
|
|||
|
|
# Diary docs and pref docs share corpus_id with their session — same hit
|
|||
|
|
# -------------------------------------------------------------------------
|
|||
|
|
full_docs = corpus_user + diary_docs + pref_wing_docs
|
|||
|
|
full_meta = (
|
|||
|
|
[
|
|||
|
|
{
|
|||
|
|
"corpus_id": corpus_ids[i],
|
|||
|
|
"timestamp": corpus_timestamps[i],
|
|||
|
|
"hall": classify_session_hall(sessions[i]) if i < len(sessions) else HALL_GENERAL,
|
|||
|
|
"layer": "raw",
|
|||
|
|
}
|
|||
|
|
for i in range(len(corpus_user))
|
|||
|
|
]
|
|||
|
|
+ [dict(m, layer="diary") for m in diary_meta]
|
|||
|
|
+ [dict(m, layer="pref") for m in pref_wing_meta]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
coll = _fresh_collection()
|
|||
|
|
coll.add(
|
|||
|
|
documents=full_docs,
|
|||
|
|
ids=[f"doc_{i}" for i in range(len(full_docs))],
|
|||
|
|
metadatas=full_meta,
|
|||
|
|
)
|
|||
|
|
r = coll.query(
|
|||
|
|
query_texts=[question],
|
|||
|
|
n_results=min(n_results, len(full_docs)),
|
|||
|
|
include=["distances", "metadatas", "documents"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
scored = []
|
|||
|
|
for rid, dist, doc, meta in zip(
|
|||
|
|
r["ids"][0], r["distances"][0], r["documents"][0], r["metadatas"][0]
|
|||
|
|
):
|
|||
|
|
cid = meta["corpus_id"]
|
|||
|
|
fd = hybrid_score(dist, doc)
|
|||
|
|
# Hall bonus
|
|||
|
|
if meta.get("hall") == primary_hall and primary_hall != HALL_GENERAL:
|
|||
|
|
fd *= 0.75
|
|||
|
|
elif meta.get("hall") in target_halls:
|
|||
|
|
fd *= 0.90
|
|||
|
|
# Diary layer bonus: LLM topic doc that matches gets extra 20% boost
|
|||
|
|
# (it's a more precise signal than raw text)
|
|||
|
|
if meta.get("layer") == "diary":
|
|||
|
|
fd *= 0.80
|
|||
|
|
fd = apply_temporal(fd, meta.get("timestamp", ""))
|
|||
|
|
scored.append((cid, fd))
|
|||
|
|
|
|||
|
|
scored.sort(key=lambda x: x[1])
|
|||
|
|
|
|||
|
|
ranked_indices = []
|
|||
|
|
seen_ids = set()
|
|||
|
|
for cid, _ in scored:
|
|||
|
|
if cid not in seen_ids and cid in corpus_id_to_user_idx:
|
|||
|
|
ranked_indices.append(corpus_id_to_user_idx[cid])
|
|||
|
|
seen_ids.add(cid)
|
|||
|
|
|
|||
|
|
for i in range(len(corpus_user)):
|
|||
|
|
if corpus_ids[i] not in seen_ids:
|
|||
|
|
ranked_indices.append(i)
|
|||
|
|
seen_ids.add(corpus_ids[i])
|
|||
|
|
|
|||
|
|
return ranked_indices, corpus_user, corpus_ids, corpus_timestamps
|
|||
|
|
|
|||
|
|
|
|||
|
|
def llm_rerank(
|
|||
|
|
question, rankings, corpus, corpus_ids, api_key, top_k=10, model="claude-haiku-4-5-20251001"
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
Use an LLM to re-rank the top-k retrieved sessions.
|
|||
|
|
|
|||
|
|
Takes the top-k sessions from any retrieval mode and asks the LLM
|
|||
|
|
which single session is most relevant to the question. That session
|
|||
|
|
is promoted to rank 1; the rest stay in their existing order.
|
|||
|
|
|
|||
|
|
This closes the gap for "preference" and jargon-dense "assistant"
|
|||
|
|
failures where the right session is in top-10 semantically but not
|
|||
|
|
top-5 — because the semantic gap (battery life ↔ phone hardware) is
|
|||
|
|
too large for embeddings to bridge.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
question: The benchmark question string
|
|||
|
|
rankings: Current ranked list of corpus indices (from any mode)
|
|||
|
|
corpus: List of document strings
|
|||
|
|
corpus_ids: List of corpus IDs (parallel to corpus)
|
|||
|
|
api_key: Anthropic API key string
|
|||
|
|
top_k: How many top sessions to send to LLM (default: 10)
|
|||
|
|
model: Claude model ID for reranking (default: haiku)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Reordered rankings list with LLM's best pick promoted to rank 1.
|
|||
|
|
"""
|
|||
|
|
import urllib.request
|
|||
|
|
import urllib.error
|
|||
|
|
|
|||
|
|
candidates = rankings[:top_k]
|
|||
|
|
if not candidates:
|
|||
|
|
return rankings
|
|||
|
|
|
|||
|
|
# Format sessions for the prompt — first 500 chars each, labelled 1..N
|
|||
|
|
session_blocks = []
|
|||
|
|
for rank, idx in enumerate(candidates):
|
|||
|
|
text = corpus[idx][:500].replace("\n", " ").strip()
|
|||
|
|
session_blocks.append(f"Session {rank + 1}:\n{text}")
|
|||
|
|
|
|||
|
|
sessions_text = "\n\n".join(session_blocks)
|
|||
|
|
|
|||
|
|
prompt = (
|
|||
|
|
f"Question: {question}\n\n"
|
|||
|
|
f"Below are {len(candidates)} conversation sessions from someone's memory. "
|
|||
|
|
f"Which single session is most likely to contain the answer to the question above? "
|
|||
|
|
f"Reply with ONLY a number between 1 and {len(candidates)}. Nothing else.\n\n"
|
|||
|
|
f"{sessions_text}\n\n"
|
|||
|
|
f"Most relevant session number:"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
payload = json.dumps(
|
|||
|
|
{
|
|||
|
|
"model": model,
|
|||
|
|
"max_tokens": 8,
|
|||
|
|
"messages": [{"role": "user", "content": prompt}],
|
|||
|
|
}
|
|||
|
|
).encode("utf-8")
|
|||
|
|
|
|||
|
|
req = urllib.request.Request(
|
|||
|
|
"https://api.anthropic.com/v1/messages",
|
|||
|
|
data=payload,
|
|||
|
|
headers={
|
|||
|
|
"x-api-key": api_key,
|
|||
|
|
"anthropic-version": "2023-06-01",
|
|||
|
|
"content-type": "application/json",
|
|||
|
|
},
|
|||
|
|
method="POST",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
import socket as _socket
|
|||
|
|
|
|||
|
|
for _attempt in range(3):
|
|||
|
|
try:
|
|||
|
|
with urllib.request.urlopen(req, timeout=20) as resp:
|
|||
|
|
result = json.loads(resp.read())
|
|||
|
|
raw = result["content"][0]["text"].strip()
|
|||
|
|
# Parse just the first integer from Haiku's response
|
|||
|
|
m = re.search(r"\b(\d+)\b", raw)
|
|||
|
|
if m:
|
|||
|
|
pick = int(m.group(1))
|
|||
|
|
if 1 <= pick <= len(candidates):
|
|||
|
|
chosen_idx = candidates[pick - 1]
|
|||
|
|
reordered = [chosen_idx] + [i for i in rankings if i != chosen_idx]
|
|||
|
|
return reordered
|
|||
|
|
break # Got a response, even if unparseable — don't retry
|
|||
|
|
except (_socket.timeout, TimeoutError):
|
|||
|
|
if _attempt < 2:
|
|||
|
|
import time as _time
|
|||
|
|
|
|||
|
|
_time.sleep(3) # brief pause then retry
|
|||
|
|
# else fall through to return rankings
|
|||
|
|
except (urllib.error.URLError, KeyError, ValueError, IndexError, OSError):
|
|||
|
|
break # Non-timeout error — fall back immediately
|
|||
|
|
|
|||
|
|
return rankings
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _load_api_key(key_arg):
|
|||
|
|
"""Load API key from --llm-key arg, env var, or ~/.config/lu/keys.json."""
|
|||
|
|
if key_arg:
|
|||
|
|
return key_arg
|
|||
|
|
env_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
|||
|
|
if env_key:
|
|||
|
|
return env_key
|
|||
|
|
keys_path = os.path.expanduser("~/.config/lu/keys.json")
|
|||
|
|
if os.path.exists(keys_path):
|
|||
|
|
try:
|
|||
|
|
with open(keys_path) as f:
|
|||
|
|
keys = json.load(f)
|
|||
|
|
# Flat string keys
|
|||
|
|
for name in ("lu_key", "anthropic_milla", "anthropic_claude_code_main"):
|
|||
|
|
val = keys.get(name, "")
|
|||
|
|
if isinstance(val, str) and val.startswith("sk-ant-"):
|
|||
|
|
return val
|
|||
|
|
# Nested dict: keys["anthropic"]["lu_key"]
|
|||
|
|
for section in ("anthropic", "anthropic_milla", "anthropic_claude_code_main"):
|
|||
|
|
sec = keys.get(section, {})
|
|||
|
|
if isinstance(sec, dict):
|
|||
|
|
for subkey in ("lu_key", "key", "api_key"):
|
|||
|
|
val = sec.get(subkey, "")
|
|||
|
|
if isinstance(val, str) and val.startswith("sk-ant-"):
|
|||
|
|
return val
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# BENCHMARK RUNNER
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _load_or_create_split(split_file: str, data: list, dev_size: int = 50, seed: int = 42) -> dict:
|
|||
|
|
"""
|
|||
|
|
Load an existing train/test split or create a new one.
|
|||
|
|
|
|||
|
|
Returns {"dev": [question_id, ...], "held_out": [question_id, ...]}
|
|||
|
|
|
|||
|
|
The split is stable: same split_file + same seed = same result.
|
|||
|
|
Creating a split is a one-time operation. After that, always load.
|
|||
|
|
"""
|
|||
|
|
import random
|
|||
|
|
|
|||
|
|
split_path = Path(split_file)
|
|||
|
|
if split_path.exists():
|
|||
|
|
with open(split_path) as f:
|
|||
|
|
return json.load(f)
|
|||
|
|
|
|||
|
|
# Create new split
|
|||
|
|
all_ids = [entry["question_id"] for entry in data]
|
|||
|
|
rng = random.Random(seed)
|
|||
|
|
rng.shuffle(all_ids)
|
|||
|
|
dev_ids = all_ids[:dev_size]
|
|||
|
|
held_out_ids = all_ids[dev_size:]
|
|||
|
|
split = {"dev": dev_ids, "held_out": held_out_ids, "seed": seed, "dev_size": dev_size}
|
|||
|
|
with open(split_path, "w") as f:
|
|||
|
|
json.dump(split, f, indent=2)
|
|||
|
|
print(f" Created new split: {len(dev_ids)} dev / {len(held_out_ids)} held-out → {split_path}")
|
|||
|
|
return split
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_benchmark(
|
|||
|
|
data_file,
|
|||
|
|
granularity="session",
|
|||
|
|
limit=0,
|
|||
|
|
out_file=None,
|
|||
|
|
mode="raw",
|
|||
|
|
skip=0,
|
|||
|
|
hybrid_weight=0.30,
|
|||
|
|
llm_rerank_enabled=False,
|
|||
|
|
llm_key="",
|
|||
|
|
llm_model="claude-haiku-4-5-20251001",
|
|||
|
|
diary_cache_file=None,
|
|||
|
|
skip_precompute=False,
|
|||
|
|
split_file=None,
|
|||
|
|
split_subset=None,
|
|||
|
|
):
|
|||
|
|
"""Run the full benchmark.
|
|||
|
|
|
|||
|
|
split_file: path to a JSON split file. If provided, filters questions by subset.
|
|||
|
|
split_subset: "dev" (50 questions for tuning) or "held_out" (450 for final evaluation).
|
|||
|
|
None = run all questions.
|
|||
|
|
"""
|
|||
|
|
with open(data_file) as f:
|
|||
|
|
data = json.load(f)
|
|||
|
|
|
|||
|
|
# Apply train/test split filter before limit/skip
|
|||
|
|
if split_file and split_subset:
|
|||
|
|
split = _load_or_create_split(split_file, data)
|
|||
|
|
subset_ids = set(split[split_subset])
|
|||
|
|
before = len(data)
|
|||
|
|
data = [entry for entry in data if entry["question_id"] in subset_ids]
|
|||
|
|
print(f" Split filter ({split_subset}): {before} → {len(data)} questions")
|
|||
|
|
|
|||
|
|
if limit > 0:
|
|||
|
|
data = data[:limit]
|
|||
|
|
|
|||
|
|
if skip > 0:
|
|||
|
|
print(f" Skipping first {skip} questions (resume mode)")
|
|||
|
|
data = data[skip:]
|
|||
|
|
|
|||
|
|
api_key = ""
|
|||
|
|
if llm_rerank_enabled or mode == "diary":
|
|||
|
|
api_key = _load_api_key(llm_key)
|
|||
|
|
if not api_key:
|
|||
|
|
print(
|
|||
|
|
"ERROR: --llm-rerank / --mode diary requires an API key. "
|
|||
|
|
"Set ANTHROPIC_API_KEY, use --llm-key, "
|
|||
|
|
"or store in ~/.config/lu/keys.json as 'lu_key'."
|
|||
|
|
)
|
|||
|
|
sys.exit(1)
|
|||
|
|
|
|||
|
|
# Diary mode: pre-compute LLM topic extraction for ALL unique sessions upfront
|
|||
|
|
# This means the main benchmark loop reads from cache only — no API calls mid-loop
|
|||
|
|
diary_cache = {}
|
|||
|
|
if mode == "diary":
|
|||
|
|
# Load existing cache first
|
|||
|
|
if diary_cache_file:
|
|||
|
|
cache_path = Path(diary_cache_file)
|
|||
|
|
if cache_path.exists():
|
|||
|
|
try:
|
|||
|
|
with open(cache_path) as f:
|
|||
|
|
diary_cache = json.load(f)
|
|||
|
|
print(
|
|||
|
|
f" Diary cache: loaded {len(diary_cache)} sessions from {cache_path.name}"
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# Collect all unique sessions not yet in cache
|
|||
|
|
unique_sessions = {} # sess_id → session turns
|
|||
|
|
for entry in data:
|
|||
|
|
for session, sess_id in zip(entry["haystack_sessions"], entry["haystack_session_ids"]):
|
|||
|
|
if sess_id not in diary_cache and sess_id not in unique_sessions:
|
|||
|
|
unique_sessions[sess_id] = session
|
|||
|
|
|
|||
|
|
if unique_sessions and api_key and not skip_precompute:
|
|||
|
|
print(
|
|||
|
|
f" Diary ingest: pre-computing {len(unique_sessions)} sessions with {llm_model.split('-')[1]}..."
|
|||
|
|
)
|
|||
|
|
done = 0
|
|||
|
|
cache_path = Path(diary_cache_file) if diary_cache_file else None
|
|||
|
|
for sess_id, session in unique_sessions.items():
|
|||
|
|
try:
|
|||
|
|
result = diary_ingest_session(session, sess_id, api_key, model=llm_model)
|
|||
|
|
except Exception:
|
|||
|
|
result = None
|
|||
|
|
diary_cache[sess_id] = result
|
|||
|
|
done += 1
|
|||
|
|
if done % 50 == 0:
|
|||
|
|
print(f" {done}/{len(unique_sessions)} sessions ingested...")
|
|||
|
|
# Save progress in case of interruption
|
|||
|
|
if cache_path:
|
|||
|
|
try:
|
|||
|
|
with open(cache_path, "w") as f:
|
|||
|
|
json.dump(diary_cache, f)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
print(f" Diary ingest complete: {done} sessions processed")
|
|||
|
|
# Final cache save
|
|||
|
|
if cache_path:
|
|||
|
|
try:
|
|||
|
|
with open(cache_path, "w") as f:
|
|||
|
|
json.dump(diary_cache, f)
|
|||
|
|
print(f" Diary cache saved → {cache_path.name}")
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
print(f"\n{'=' * 60}")
|
|||
|
|
print(" MemPal × LongMemEval Benchmark")
|
|||
|
|
print(f"{'=' * 60}")
|
|||
|
|
print(f" Data: {Path(data_file).name}")
|
|||
|
|
print(f" Questions: {len(data)}")
|
|||
|
|
print(f" Granularity: {granularity}")
|
|||
|
|
model_short = llm_model.split("-")[1] if "-" in llm_model else llm_model
|
|||
|
|
rerank_label = f" + LLM re-rank ({model_short})" if llm_rerank_enabled else ""
|
|||
|
|
diary_label = f" [diary ingest: {model_short}]" if mode == "diary" else ""
|
|||
|
|
print(f" Mode: {mode}{diary_label}{rerank_label}")
|
|||
|
|
print(f"{'─' * 60}\n")
|
|||
|
|
|
|||
|
|
# Collect metrics
|
|||
|
|
ks = [1, 3, 5, 10, 30, 50]
|
|||
|
|
metrics_session = {f"recall_any@{k}": [] for k in ks}
|
|||
|
|
metrics_session.update({f"recall_all@{k}": [] for k in ks})
|
|||
|
|
metrics_session.update({f"ndcg_any@{k}": [] for k in ks})
|
|||
|
|
|
|||
|
|
metrics_turn = {f"recall_any@{k}": [] for k in ks}
|
|||
|
|
metrics_turn.update({f"recall_all@{k}": [] for k in ks})
|
|||
|
|
metrics_turn.update({f"ndcg_any@{k}": [] for k in ks})
|
|||
|
|
|
|||
|
|
per_type = defaultdict(lambda: defaultdict(list))
|
|||
|
|
|
|||
|
|
results_log = []
|
|||
|
|
start_time = datetime.now()
|
|||
|
|
|
|||
|
|
for i, entry in enumerate(data):
|
|||
|
|
qid = entry["question_id"]
|
|||
|
|
qtype = entry["question_type"]
|
|||
|
|
question = entry["question"]
|
|||
|
|
answer_sids = set(entry["answer_session_ids"])
|
|||
|
|
|
|||
|
|
# Run retrieval with selected mode
|
|||
|
|
if mode == "aaak":
|
|||
|
|
rankings, corpus, corpus_ids, corpus_timestamps = build_palace_and_retrieve_aaak(
|
|||
|
|
entry, granularity=granularity
|
|||
|
|
)
|
|||
|
|
elif mode == "rooms":
|
|||
|
|
rankings, corpus, corpus_ids, corpus_timestamps = build_palace_and_retrieve_rooms(
|
|||
|
|
entry, granularity=granularity
|
|||
|
|
)
|
|||
|
|
elif mode == "hybrid":
|
|||
|
|
rankings, corpus, corpus_ids, corpus_timestamps = build_palace_and_retrieve_hybrid(
|
|||
|
|
entry, granularity=granularity, hybrid_weight=hybrid_weight
|
|||
|
|
)
|
|||
|
|
elif mode == "hybrid_v2":
|
|||
|
|
rankings, corpus, corpus_ids, corpus_timestamps = build_palace_and_retrieve_hybrid_v2(
|
|||
|
|
entry, granularity=granularity, hybrid_weight=hybrid_weight
|
|||
|
|
)
|
|||
|
|
elif mode == "hybrid_v3":
|
|||
|
|
rankings, corpus, corpus_ids, corpus_timestamps = build_palace_and_retrieve_hybrid_v3(
|
|||
|
|
entry, granularity=granularity, hybrid_weight=hybrid_weight
|
|||
|
|
)
|
|||
|
|
elif mode == "hybrid_v4":
|
|||
|
|
rankings, corpus, corpus_ids, corpus_timestamps = build_palace_and_retrieve_hybrid_v4(
|
|||
|
|
entry, granularity=granularity, hybrid_weight=hybrid_weight
|
|||
|
|
)
|
|||
|
|
elif mode == "palace":
|
|||
|
|
rankings, corpus, corpus_ids, corpus_timestamps = build_palace_and_retrieve_palace(
|
|||
|
|
entry, granularity=granularity, hybrid_weight=hybrid_weight
|
|||
|
|
)
|
|||
|
|
elif mode == "diary":
|
|||
|
|
# If skip_precompute, pass empty api_key to prevent inline Haiku calls
|
|||
|
|
_diary_api_key = "" if skip_precompute else api_key
|
|||
|
|
rankings, corpus, corpus_ids, corpus_timestamps = build_palace_and_retrieve_diary(
|
|||
|
|
entry,
|
|||
|
|
granularity=granularity,
|
|||
|
|
hybrid_weight=hybrid_weight,
|
|||
|
|
diary_cache=diary_cache,
|
|||
|
|
api_key=_diary_api_key,
|
|||
|
|
diary_model=llm_model,
|
|||
|
|
)
|
|||
|
|
elif mode == "full":
|
|||
|
|
rankings, corpus, corpus_ids, corpus_timestamps = build_palace_and_retrieve_full(
|
|||
|
|
entry, granularity=granularity
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
rankings, corpus, corpus_ids, corpus_timestamps = build_palace_and_retrieve(
|
|||
|
|
entry, granularity=granularity
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not rankings:
|
|||
|
|
print(f" [{i + 1:4}/{len(data)}] {qid[:30]:30} SKIP (empty corpus)")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Optional LLM re-ranking pass (larger pool for v3/palace to catch rank-11-12 misses)
|
|||
|
|
if llm_rerank_enabled:
|
|||
|
|
rerank_pool = 20 if mode in ("hybrid_v3", "hybrid_v4", "palace") else 10
|
|||
|
|
rankings = llm_rerank(
|
|||
|
|
question, rankings, corpus, corpus_ids, api_key, top_k=rerank_pool, model=llm_model
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Evaluate at session level
|
|||
|
|
# Map corpus_ids to session-level IDs for session metrics
|
|||
|
|
session_level_ids = [session_id_from_corpus_id(cid) for cid in corpus_ids]
|
|||
|
|
session_correct = answer_sids
|
|||
|
|
|
|||
|
|
# Turn-level correct: any corpus_id whose session part is in answer_sids
|
|||
|
|
turn_correct = set()
|
|||
|
|
for cid in corpus_ids:
|
|||
|
|
sid = session_id_from_corpus_id(cid)
|
|||
|
|
if sid in answer_sids:
|
|||
|
|
turn_correct.add(cid)
|
|||
|
|
|
|||
|
|
entry_metrics = {"session": {}, "turn": {}}
|
|||
|
|
|
|||
|
|
for k in ks:
|
|||
|
|
# Session-level metrics
|
|||
|
|
ra, rl, nd = evaluate_retrieval(rankings, session_correct, session_level_ids, k)
|
|||
|
|
metrics_session[f"recall_any@{k}"].append(ra)
|
|||
|
|
metrics_session[f"recall_all@{k}"].append(rl)
|
|||
|
|
metrics_session[f"ndcg_any@{k}"].append(nd)
|
|||
|
|
entry_metrics["session"][f"recall_any@{k}"] = ra
|
|||
|
|
entry_metrics["session"][f"ndcg_any@{k}"] = nd
|
|||
|
|
|
|||
|
|
# Turn-level metrics
|
|||
|
|
ra_t, rl_t, nd_t = evaluate_retrieval(rankings, turn_correct, corpus_ids, k)
|
|||
|
|
metrics_turn[f"recall_any@{k}"].append(ra_t)
|
|||
|
|
metrics_turn[f"recall_all@{k}"].append(rl_t)
|
|||
|
|
metrics_turn[f"ndcg_any@{k}"].append(nd_t)
|
|||
|
|
entry_metrics["turn"][f"recall_any@{k}"] = ra_t
|
|||
|
|
|
|||
|
|
# Per-type tracking
|
|||
|
|
per_type[qtype]["recall_any@5"].append(metrics_session["recall_any@5"][-1])
|
|||
|
|
per_type[qtype]["recall_any@10"].append(metrics_session["recall_any@10"][-1])
|
|||
|
|
per_type[qtype]["ndcg_any@10"].append(metrics_session["ndcg_any@10"][-1])
|
|||
|
|
|
|||
|
|
# Log entry
|
|||
|
|
ranked_items = []
|
|||
|
|
for idx in rankings[:50]:
|
|||
|
|
ranked_items.append(
|
|||
|
|
{
|
|||
|
|
"corpus_id": corpus_ids[idx],
|
|||
|
|
"text": corpus[idx][:500],
|
|||
|
|
"timestamp": corpus_timestamps[idx],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results_log.append(
|
|||
|
|
{
|
|||
|
|
"question_id": qid,
|
|||
|
|
"question_type": qtype,
|
|||
|
|
"question": question,
|
|||
|
|
"answer": entry["answer"],
|
|||
|
|
"retrieval_results": {
|
|||
|
|
"query": question,
|
|||
|
|
"ranked_items": ranked_items,
|
|||
|
|
"metrics": entry_metrics,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Progress
|
|||
|
|
r5 = metrics_session["recall_any@5"][-1]
|
|||
|
|
r10 = metrics_session["recall_any@10"][-1]
|
|||
|
|
status = "HIT" if r5 > 0 else "miss"
|
|||
|
|
print(f" [{i + 1:4}/{len(data)}] {qid[:30]:30} R@5={r5:.0f} R@10={r10:.0f} {status}")
|
|||
|
|
|
|||
|
|
elapsed = (datetime.now() - start_time).total_seconds()
|
|||
|
|
|
|||
|
|
# Print results
|
|||
|
|
print(f"\n{'=' * 60}")
|
|||
|
|
print(f" RESULTS — MemPal ({mode} mode, {granularity} granularity)")
|
|||
|
|
print(f"{'=' * 60}")
|
|||
|
|
print(f" Time: {elapsed:.1f}s ({elapsed / len(data):.2f}s per question)\n")
|
|||
|
|
|
|||
|
|
print(" SESSION-LEVEL METRICS:")
|
|||
|
|
for k in ks:
|
|||
|
|
ra = sum(metrics_session[f"recall_any@{k}"]) / len(metrics_session[f"recall_any@{k}"])
|
|||
|
|
nd = sum(metrics_session[f"ndcg_any@{k}"]) / len(metrics_session[f"ndcg_any@{k}"])
|
|||
|
|
print(f" Recall@{k:2}: {ra:.3f} NDCG@{k:2}: {nd:.3f}")
|
|||
|
|
|
|||
|
|
print("\n TURN-LEVEL METRICS:")
|
|||
|
|
for k in ks:
|
|||
|
|
ra = sum(metrics_turn[f"recall_any@{k}"]) / len(metrics_turn[f"recall_any@{k}"])
|
|||
|
|
nd = sum(metrics_turn[f"ndcg_any@{k}"]) / len(metrics_turn[f"ndcg_any@{k}"])
|
|||
|
|
print(f" Recall@{k:2}: {ra:.3f} NDCG@{k:2}: {nd:.3f}")
|
|||
|
|
|
|||
|
|
print("\n PER-TYPE BREAKDOWN (session recall_any@10):")
|
|||
|
|
for qtype, vals in sorted(per_type.items()):
|
|||
|
|
r10 = sum(vals["recall_any@10"]) / len(vals["recall_any@10"])
|
|||
|
|
n = len(vals["recall_any@10"])
|
|||
|
|
print(f" {qtype:35} R@10={r10:.3f} (n={n})")
|
|||
|
|
|
|||
|
|
print(f"\n{'=' * 60}\n")
|
|||
|
|
|
|||
|
|
# Save diary cache for reuse (Sonnet run tomorrow can skip re-ingesting)
|
|||
|
|
# Only save sessions with real data (None = skipped inline call, not worth persisting)
|
|||
|
|
if mode == "diary" and diary_cache and diary_cache_file:
|
|||
|
|
try:
|
|||
|
|
real_cache = {k: v for k, v in diary_cache.items() if v is not None}
|
|||
|
|
with open(diary_cache_file, "w") as f:
|
|||
|
|
json.dump(real_cache, f)
|
|||
|
|
print(f" Diary cache saved: {len(real_cache)} sessions → {diary_cache_file}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f" Warning: could not save diary cache: {e}")
|
|||
|
|
|
|||
|
|
# Save results
|
|||
|
|
if out_file:
|
|||
|
|
with open(out_file, "w") as f:
|
|||
|
|
for entry in results_log:
|
|||
|
|
f.write(json.dumps(entry) + "\n")
|
|||
|
|
print(f" Results saved to: {out_file}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# CLI
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
parser = argparse.ArgumentParser(description="MemPal × LongMemEval Benchmark")
|
|||
|
|
parser.add_argument("data_file", help="Path to longmemeval_s_cleaned.json")
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--granularity",
|
|||
|
|
choices=["session", "turn"],
|
|||
|
|
default="session",
|
|||
|
|
help="Retrieval granularity (default: session)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument("--limit", type=int, default=0, help="Limit to N questions (0 = all)")
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--mode",
|
|||
|
|
choices=[
|
|||
|
|
"raw",
|
|||
|
|
"aaak",
|
|||
|
|
"rooms",
|
|||
|
|
"hybrid",
|
|||
|
|
"hybrid_v2",
|
|||
|
|
"hybrid_v3",
|
|||
|
|
"hybrid_v4",
|
|||
|
|
"palace",
|
|||
|
|
"diary",
|
|||
|
|
"full",
|
|||
|
|
],
|
|||
|
|
default="raw",
|
|||
|
|
help="Retrieval mode: raw, hybrid, hybrid_v2, hybrid_v3, palace, diary (palace + LLM topic layer)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument("--out", default=None, help="Output JSONL file path")
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--skip", type=int, default=0, help="Skip first N questions (resume after hang)"
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--hybrid-weight",
|
|||
|
|
type=float,
|
|||
|
|
default=0.30,
|
|||
|
|
help="Keyword overlap boost weight for hybrid mode (default: 0.30). "
|
|||
|
|
"Full 500q tuning: 0.30 and 0.40 are equivalent (within noise). Try 0.10–0.60.",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--llm-rerank",
|
|||
|
|
action="store_true",
|
|||
|
|
default=False,
|
|||
|
|
help="Enable LLM re-ranking pass using Claude Haiku (requires API key). "
|
|||
|
|
"Promotes the best session from top-10 to rank 1. Targets preference "
|
|||
|
|
"and jargon-dense failures that embeddings can't bridge semantically.",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--llm-key",
|
|||
|
|
default="",
|
|||
|
|
help="Anthropic API key for LLM re-ranking. Falls back to ANTHROPIC_API_KEY "
|
|||
|
|
"env var or ~/.config/lu/keys.json 'lu_key' field if not provided.",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--llm-model",
|
|||
|
|
default="claude-haiku-4-5-20251001",
|
|||
|
|
help="Model for LLM re-ranking and diary ingest "
|
|||
|
|
"(default: claude-haiku-4-5-20251001). "
|
|||
|
|
"Use 'claude-sonnet-4-6' for Sonnet comparison.",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--diary-cache",
|
|||
|
|
default=None,
|
|||
|
|
help="Path to save/load diary ingest cache (JSON). "
|
|||
|
|
"Saves Haiku calls on re-runs. Sonnet run can reuse Haiku cache.",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--skip-precompute",
|
|||
|
|
action="store_true",
|
|||
|
|
default=False,
|
|||
|
|
help="Skip diary pre-computation for sessions not in cache. "
|
|||
|
|
"Uses cache as-is; uncached sessions fall back to palace-only retrieval.",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--embed-model",
|
|||
|
|
choices=["default", "bge-base", "bge-large", "nomic", "mxbai"],
|
|||
|
|
default="default",
|
|||
|
|
help="Embedding model. 'default'=all-MiniLM-L6-v2 (ChromaDB built-in, baseline). "
|
|||
|
|
"'bge-large'=BAAI/bge-large-en-v1.5 (best local, 1024-dim, ~1.3GB via fastembed). "
|
|||
|
|
"'nomic'=nomic-embed-text-v1.5 (768-dim, fast, ~274MB). "
|
|||
|
|
"'bge-base'=BAAI/bge-base-en-v1.5 (768-dim, balanced). "
|
|||
|
|
"'mxbai'=mxbai-embed-large-v1 (1024-dim). Requires: pip install fastembed.",
|
|||
|
|
)
|
|||
|
|
# ── Train / test split ──────────────────────────────────────────────────
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--split-file",
|
|||
|
|
default=None,
|
|||
|
|
help="Path to a JSON split file. "
|
|||
|
|
"Use --create-split to generate one (50 dev / 450 held-out). "
|
|||
|
|
"Required when using --dev-only or --held-out.",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--create-split",
|
|||
|
|
action="store_true",
|
|||
|
|
default=False,
|
|||
|
|
help="Create a new random 50/450 dev/held-out split and exit. "
|
|||
|
|
"Pass --split-file to specify where to save it.",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--dev-only",
|
|||
|
|
action="store_true",
|
|||
|
|
default=False,
|
|||
|
|
help="Run only the 50 dev questions (safe for iterative tuning). Requires --split-file.",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--held-out",
|
|||
|
|
action="store_true",
|
|||
|
|
default=False,
|
|||
|
|
help="Run only the 450 held-out questions (publishable final score). "
|
|||
|
|
"Use sparingly — looking at results contaminates the held-out set. "
|
|||
|
|
"Requires --split-file.",
|
|||
|
|
)
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# ── Handle --create-split ───────────────────────────────────────────────
|
|||
|
|
if args.create_split:
|
|||
|
|
if not args.split_file:
|
|||
|
|
args.split_file = "benchmarks/lme_split_50_450.json"
|
|||
|
|
with open(args.data_file) as f:
|
|||
|
|
_all_data = json.load(f)
|
|||
|
|
_load_or_create_split(args.split_file, _all_data)
|
|||
|
|
sys.exit(0)
|
|||
|
|
|
|||
|
|
# ── Validate split flags ────────────────────────────────────────────────
|
|||
|
|
if (args.dev_only or args.held_out) and not args.split_file:
|
|||
|
|
parser.error(
|
|||
|
|
"--dev-only / --held-out require --split-file. "
|
|||
|
|
"Run with --create-split first to generate a split."
|
|||
|
|
)
|
|||
|
|
if args.dev_only and args.held_out:
|
|||
|
|
parser.error("--dev-only and --held-out are mutually exclusive.")
|
|||
|
|
|
|||
|
|
split_subset = "dev" if args.dev_only else ("held_out" if args.held_out else None)
|
|||
|
|
|
|||
|
|
if not args.out:
|
|||
|
|
embed_tag = f"_{args.embed_model}" if args.embed_model != "default" else ""
|
|||
|
|
suffix = "_llmrerank" if args.llm_rerank else ""
|
|||
|
|
subset_tag = f"_{split_subset}" if split_subset else ""
|
|||
|
|
args.out = f"benchmarks/results_mempal_{args.mode}{embed_tag}{suffix}{subset_tag}_{args.granularity}_{datetime.now().strftime('%Y%m%d_%H%M')}.jsonl"
|
|||
|
|
|
|||
|
|
# Set global embedding function before running
|
|||
|
|
if args.embed_model != "default":
|
|||
|
|
import sys as _sys
|
|||
|
|
|
|||
|
|
_mod = _sys.modules[__name__]
|
|||
|
|
_mod._bench_embed_fn = _make_embed_fn(args.embed_model)
|
|||
|
|
|
|||
|
|
run_benchmark(
|
|||
|
|
args.data_file,
|
|||
|
|
args.granularity,
|
|||
|
|
args.limit,
|
|||
|
|
args.out,
|
|||
|
|
args.mode,
|
|||
|
|
args.skip,
|
|||
|
|
args.hybrid_weight,
|
|||
|
|
args.llm_rerank,
|
|||
|
|
args.llm_key,
|
|||
|
|
args.llm_model,
|
|||
|
|
args.diary_cache,
|
|||
|
|
args.skip_precompute,
|
|||
|
|
split_file=args.split_file,
|
|||
|
|
split_subset=split_subset,
|
|||
|
|
)
|