benchmarks: add --llm-backend ollama for non-Anthropic rerank

The rerank pipeline was hardcoded to Anthropic's /v1/messages.
Add a backend flag so the same code path can be exercised with
any OpenAI-compatible endpoint — local Ollama, Ollama Cloud,
or any gateway that speaks /v1/chat/completions.

Enables independent verification of the "100% with Haiku rerank"
claim by running the full benchmark with a different LLM family
(e.g. minimax-m2.7:cloud) and zero Anthropic dependency.

Both longmemeval_bench.py and locomo_bench.py:
 - llm_rerank*() gain backend= / base_url= kwargs
 - CLI: --llm-backend {anthropic,ollama}, --llm-base-url
 - API key required only when backend=anthropic (diary/palace modes still require it)
 - Parse last integer in response (reasoning models emit multi-int output)
 - Fallback to message.reasoning when content is empty
 - Raise max_tokens to 1024 for reasoning models
This commit is contained in:
Igor Lins e Silva
2026-04-14 21:20:14 -03:00
parent 4aa7e1eebd
commit 8df7b9bf2c
3 changed files with 169 additions and 66 deletions
+61 -15
View File
@@ -510,11 +510,20 @@ def palace_assign_rooms(sessions, sample_id, api_key, cache, model="claude-haiku
def llm_rerank_locomo( def llm_rerank_locomo(
question, retrieved_ids, retrieved_docs, api_key, top_k=10, model="claude-sonnet-4-6" question,
retrieved_ids,
retrieved_docs,
api_key,
top_k=10,
model="claude-sonnet-4-6",
backend="anthropic",
base_url="",
): ):
""" """
Ask LLM to pick the single most relevant document for this question. Ask LLM to pick the single most relevant document for this question.
Returns reordered retrieved_ids with the best candidate first. Returns reordered retrieved_ids with the best candidate first.
Supports backend="anthropic" (default) or "ollama" (OpenAI-compat endpoint).
""" """
candidates = retrieved_ids[:top_k] candidates = retrieved_ids[:top_k]
candidate_docs = retrieved_docs[:top_k] candidate_docs = retrieved_docs[:top_k]
@@ -522,7 +531,6 @@ def llm_rerank_locomo(
if len(candidates) <= 1: if len(candidates) <= 1:
return retrieved_ids return retrieved_ids
# Build numbered list of candidates
lines = [] lines = []
for i, (cid, doc) in enumerate(zip(candidates, candidate_docs), 1): for i, (cid, doc) in enumerate(zip(candidates, candidate_docs), 1):
snippet = doc[:300].replace("\n", " ") snippet = doc[:300].replace("\n", " ")
@@ -534,6 +542,21 @@ def llm_rerank_locomo(
f"Reply with just the number (1-{len(candidates)}).\n\n" + "\n".join(lines) f"Reply with just the number (1-{len(candidates)}).\n\n" + "\n".join(lines)
) )
if backend == "ollama":
url = (base_url or "http://localhost:11434").rstrip("/") + "/v1/chat/completions"
payload = json.dumps(
{
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1024,
"temperature": 0.0,
}
).encode("utf-8")
headers = {"content-type": "application/json"}
if api_key:
headers["authorization"] = f"Bearer {api_key}"
else:
url = "https://api.anthropic.com/v1/messages"
payload = json.dumps( payload = json.dumps(
{ {
"model": model, "model": model,
@@ -541,28 +564,29 @@ def llm_rerank_locomo(
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
} }
).encode("utf-8") ).encode("utf-8")
headers = {
req = urllib.request.Request(
"https://api.anthropic.com/v1/messages",
data=payload,
headers={
"x-api-key": api_key, "x-api-key": api_key,
"anthropic-version": "2023-06-01", "anthropic-version": "2023-06-01",
"content-type": "application/json", "content-type": "application/json",
}, }
method="POST",
) req = urllib.request.Request(url, data=payload, headers=headers, method="POST")
import socket as _socket import socket as _socket
for _attempt in range(3): for _attempt in range(3):
try: try:
with urllib.request.urlopen(req, timeout=30) as resp: with urllib.request.urlopen(req, timeout=120 if backend == "ollama" else 30) as resp:
result = json.loads(resp.read()) result = json.loads(resp.read())
if backend == "ollama":
msg = result["choices"][0]["message"]
raw = (msg.get("content") or "").strip() or (msg.get("reasoning") or "").strip()
else:
raw = result["content"][0]["text"].strip() raw = result["content"][0]["text"].strip()
m = re.search(r"\b(\d+)\b", raw) # Take LAST integer — reasoning models often count candidates first
m = re.search(r"\b(\d+)\b", raw[::-1])
if m: if m:
pick = int(m.group(1)) pick = int(m.group(1)[::-1])
if 1 <= pick <= len(candidates): if 1 <= pick <= len(candidates):
chosen_id = candidates[pick - 1] chosen_id = candidates[pick - 1]
reordered = [chosen_id] + [cid for cid in retrieved_ids if cid != chosen_id] reordered = [chosen_id] + [cid for cid in retrieved_ids if cid != chosen_id]
@@ -608,6 +632,8 @@ def run_benchmark(
palace_cache_file=None, palace_cache_file=None,
palace_model="claude-haiku-4-5-20251001", palace_model="claude-haiku-4-5-20251001",
embed_model="default", embed_model="default",
llm_backend="anthropic",
llm_base_url="",
): ):
"""Run LoCoMo retrieval benchmark.""" """Run LoCoMo retrieval benchmark."""
with open(data_file) as f: with open(data_file) as f:
@@ -619,8 +645,12 @@ def run_benchmark(
api_key = "" api_key = ""
if llm_rerank_enabled or mode == "palace": if llm_rerank_enabled or mode == "palace":
api_key = _load_api_key(llm_key) api_key = _load_api_key(llm_key)
if not api_key: # Ollama backend doesn't require an Anthropic key. Palace mode still does
print(f"ERROR: --mode {mode} requires an API key (--llm-key or ANTHROPIC_API_KEY).") # (it uses Anthropic for room-assignment indexing) — so only relax the
# requirement when rerank is the ONLY llm use and backend is ollama.
needs_key = mode == "palace" or (llm_rerank_enabled and llm_backend == "anthropic")
if needs_key and not api_key:
print(f"ERROR: --mode {mode} / --llm-rerank (anthropic) requires an API key.")
sys.exit(1) sys.exit(1)
# Palace mode: load or create room assignment cache # Palace mode: load or create room assignment cache
@@ -888,6 +918,8 @@ def run_benchmark(
api_key, api_key,
top_k=rerank_pool, top_k=rerank_pool,
model=llm_model, model=llm_model,
backend=llm_backend,
base_url=llm_base_url,
) )
# Compute recall # Compute recall
@@ -1013,6 +1045,18 @@ if __name__ == "__main__":
help="Model for LLM rerank (default: claude-sonnet-4-6)", help="Model for LLM rerank (default: claude-sonnet-4-6)",
) )
parser.add_argument("--llm-key", default="", help="API key (or set ANTHROPIC_API_KEY env var)") parser.add_argument("--llm-key", default="", help="API key (or set ANTHROPIC_API_KEY env var)")
parser.add_argument(
"--llm-backend",
choices=["anthropic", "ollama"],
default="anthropic",
help="Which API for --llm-rerank. 'anthropic' (default) or 'ollama' "
"(OpenAI-compat /v1/chat/completions — works for local + Ollama Cloud).",
)
parser.add_argument(
"--llm-base-url",
default="",
help="Override base URL for --llm-backend ollama. Default: http://localhost:11434.",
)
parser.add_argument( parser.add_argument(
"--hybrid-weight", "--hybrid-weight",
type=float, type=float,
@@ -1049,4 +1093,6 @@ if __name__ == "__main__":
palace_cache_file=args.palace_cache, palace_cache_file=args.palace_cache,
palace_model=args.palace_model, palace_model=args.palace_model,
embed_model=args.embed_model, embed_model=args.embed_model,
llm_backend=args.llm_backend,
llm_base_url=args.llm_base_url,
) )
+86 -29
View File
@@ -2763,7 +2763,15 @@ def build_palace_and_retrieve_diary(
def llm_rerank( def llm_rerank(
question, rankings, corpus, corpus_ids, api_key, top_k=10, model="claude-haiku-4-5-20251001" question,
rankings,
corpus,
corpus_ids,
api_key,
top_k=10,
model="claude-haiku-4-5-20251001",
backend="anthropic",
base_url="",
): ):
""" """
Use an LLM to re-rank the top-k retrieved sessions. Use an LLM to re-rank the top-k retrieved sessions.
@@ -2772,19 +2780,22 @@ def llm_rerank(
which single session is most relevant to the question. That session which single session is most relevant to the question. That session
is promoted to rank 1; the rest stay in their existing order. is promoted to rank 1; the rest stay in their existing order.
This closes the gap for "preference" and jargon-dense "assistant" Supports two backends:
failures where the right session is in top-10 semantically but not - "anthropic": hits https://api.anthropic.com/v1/messages with x-api-key.
top-5 — because the semantic gap (battery life ↔ phone hardware) is - "ollama": hits {base_url}/v1/chat/completions (OpenAI-compat) —
too large for embeddings to bridge. works for local Ollama (default http://localhost:11434)
and Ollama Cloud (:cloud model tags).
Args: Args:
question: The benchmark question string question: The benchmark question string
rankings: Current ranked list of corpus indices (from any mode) rankings: Current ranked list of corpus indices (from any mode)
corpus: List of document strings corpus: List of document strings
corpus_ids: List of corpus IDs (parallel to corpus) corpus_ids: List of corpus IDs (parallel to corpus)
api_key: Anthropic API key string api_key: Anthropic API key (only required for backend="anthropic")
top_k: How many top sessions to send to LLM (default: 10) top_k: How many top sessions to send to LLM (default: 10)
model: Claude model ID for reranking (default: haiku) model: Model id (Claude model for anthropic, e.g. "minimax-m2.7:cloud" for ollama)
backend: "anthropic" or "ollama"
base_url: Override base URL (ollama default: http://localhost:11434)
Returns: Returns:
Reordered rankings list with LLM's best pick promoted to rank 1. Reordered rankings list with LLM's best pick promoted to rank 1.
@@ -2796,7 +2807,6 @@ def llm_rerank(
if not candidates: if not candidates:
return rankings return rankings
# Format sessions for the prompt — first 500 chars each, labelled 1..N
session_blocks = [] session_blocks = []
for rank, idx in enumerate(candidates): for rank, idx in enumerate(candidates):
text = corpus[idx][:500].replace("\n", " ").strip() text = corpus[idx][:500].replace("\n", " ").strip()
@@ -2813,6 +2823,21 @@ def llm_rerank(
f"Most relevant session number:" f"Most relevant session number:"
) )
if backend == "ollama":
url = (base_url or "http://localhost:11434").rstrip("/") + "/v1/chat/completions"
payload = json.dumps(
{
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1024,
"temperature": 0.0,
}
).encode("utf-8")
headers = {"content-type": "application/json"}
if api_key:
headers["authorization"] = f"Bearer {api_key}"
else:
url = "https://api.anthropic.com/v1/messages"
payload = json.dumps( payload = json.dumps(
{ {
"model": model, "model": model,
@@ -2820,42 +2845,44 @@ def llm_rerank(
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
} }
).encode("utf-8") ).encode("utf-8")
headers = {
req = urllib.request.Request(
"https://api.anthropic.com/v1/messages",
data=payload,
headers={
"x-api-key": api_key, "x-api-key": api_key,
"anthropic-version": "2023-06-01", "anthropic-version": "2023-06-01",
"content-type": "application/json", "content-type": "application/json",
}, }
method="POST",
) req = urllib.request.Request(url, data=payload, headers=headers, method="POST")
import socket as _socket import socket as _socket
for _attempt in range(3): for _attempt in range(3):
try: try:
with urllib.request.urlopen(req, timeout=20) as resp: with urllib.request.urlopen(req, timeout=120 if backend == "ollama" else 20) as resp:
result = json.loads(resp.read()) result = json.loads(resp.read())
if backend == "ollama":
msg = result["choices"][0]["message"]
# Reasoning models (e.g. minimax-m2.7) may emit final answer in "content"
# or embed it in "reasoning". Try content first, fall back to reasoning.
raw = (msg.get("content") or "").strip()
if not raw:
raw = (msg.get("reasoning") or "").strip()
else:
raw = result["content"][0]["text"].strip() raw = result["content"][0]["text"].strip()
# Parse just the first integer from Haiku's response m = re.search(r"\b(\d+)\b", raw[::-1]) # take LAST integer (rerank models often reason first)
m = re.search(r"\b(\d+)\b", raw)
if m: if m:
pick = int(m.group(1)) pick = int(m.group(1)[::-1])
if 1 <= pick <= len(candidates): if 1 <= pick <= len(candidates):
chosen_idx = candidates[pick - 1] chosen_idx = candidates[pick - 1]
reordered = [chosen_idx] + [i for i in rankings if i != chosen_idx] reordered = [chosen_idx] + [i for i in rankings if i != chosen_idx]
return reordered return reordered
break # Got a response, even if unparseable — don't retry break
except (_socket.timeout, TimeoutError): except (_socket.timeout, TimeoutError):
if _attempt < 2: if _attempt < 2:
import time as _time import time as _time
_time.sleep(3) # brief pause then retry _time.sleep(3)
# else fall through to return rankings
except (urllib.error.URLError, KeyError, ValueError, IndexError, OSError): except (urllib.error.URLError, KeyError, ValueError, IndexError, OSError):
break # Non-timeout error — fall back immediately break
return rankings return rankings
@@ -2919,6 +2946,8 @@ def run_benchmark(
skip_precompute=False, skip_precompute=False,
split_file=None, split_file=None,
split_subset=None, split_subset=None,
llm_backend="anthropic",
llm_base_url="",
): ):
"""Run the full benchmark. """Run the full benchmark.
@@ -2947,10 +2976,14 @@ def run_benchmark(
api_key = "" api_key = ""
if llm_rerank_enabled or mode == "diary": if llm_rerank_enabled or mode == "diary":
api_key = _load_api_key(llm_key) api_key = _load_api_key(llm_key)
if not api_key: # Ollama backend doesn't require an Anthropic API key; a local/cloud Ollama
# daemon with the requested model pulled is enough. Diary mode is always anthropic.
needs_key = (llm_backend == "anthropic") or (mode == "diary")
if needs_key and not api_key:
print( print(
"ERROR: --llm-rerank / --mode diary requires an API key. " "ERROR: --llm-rerank (anthropic backend) / --mode diary requires an API key. "
"Set ANTHROPIC_API_KEY or use --llm-key." "Set ANTHROPIC_API_KEY or use --llm-key. For ollama backend, pass "
"--llm-backend ollama."
) )
sys.exit(1) sys.exit(1)
@@ -3100,7 +3133,15 @@ def run_benchmark(
if llm_rerank_enabled: if llm_rerank_enabled:
rerank_pool = 20 if mode in ("hybrid_v3", "hybrid_v4", "palace") else 10 rerank_pool = 20 if mode in ("hybrid_v3", "hybrid_v4", "palace") else 10
rankings = llm_rerank( rankings = llm_rerank(
question, rankings, corpus, corpus_ids, api_key, top_k=rerank_pool, model=llm_model question,
rankings,
corpus,
corpus_ids,
api_key,
top_k=rerank_pool,
model=llm_model,
backend=llm_backend,
base_url=llm_base_url,
) )
# Evaluate at session level # Evaluate at session level
@@ -3276,7 +3317,21 @@ if __name__ == "__main__":
default="claude-haiku-4-5-20251001", default="claude-haiku-4-5-20251001",
help="Model for LLM re-ranking and diary ingest " help="Model for LLM re-ranking and diary ingest "
"(default: claude-haiku-4-5-20251001). " "(default: claude-haiku-4-5-20251001). "
"Use 'claude-sonnet-4-6' for Sonnet comparison.", "Use 'claude-sonnet-4-6' for Sonnet comparison. "
"With --llm-backend ollama, use an Ollama model tag like 'minimax-m2.7:cloud'.",
)
parser.add_argument(
"--llm-backend",
choices=["anthropic", "ollama"],
default="anthropic",
help="Which API to hit for --llm-rerank. 'anthropic' (default) uses Anthropic's "
"/v1/messages endpoint. 'ollama' uses Ollama's OpenAI-compatible "
"/v1/chat/completions endpoint (works with local Ollama and Ollama Cloud).",
)
parser.add_argument(
"--llm-base-url",
default="",
help="Override base URL for --llm-backend ollama. Defaults to http://localhost:11434.",
) )
parser.add_argument( parser.add_argument(
"--diary-cache", "--diary-cache",
@@ -3380,4 +3435,6 @@ if __name__ == "__main__":
args.skip_precompute, args.skip_precompute,
split_file=args.split_file, split_file=args.split_file,
split_subset=split_subset, split_subset=split_subset,
llm_backend=args.llm_backend,
llm_base_url=args.llm_base_url,
) )
Generated
+1 -1
View File
@@ -1239,7 +1239,7 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "autocorrect", marker = "extra == 'spellcheck'", specifier = ">=2.0" }, { name = "autocorrect", marker = "extra == 'spellcheck'", specifier = ">=2.0" },
{ name = "chromadb", specifier = ">=0.5.0,<0.7" }, { name = "chromadb", specifier = ">=0.5.0" },
{ name = "psutil", marker = "extra == 'dev'", specifier = ">=5.9" }, { name = "psutil", marker = "extra == 'dev'", specifier = ">=5.9" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },