fix(chroma): harden HNSW startup preflight

This commit is contained in:
Mika Cohen
2026-04-30 09:31:32 -06:00
parent 5e540da06b
commit c3e1104e75
2 changed files with 395 additions and 19 deletions
+279 -4
View File
@@ -1,4 +1,5 @@
import os
import pickle
import sqlite3
from pathlib import Path
@@ -18,6 +19,7 @@ from mempalace.backends.chroma import (
ChromaCollection,
_fix_blob_seq_ids,
_pin_hnsw_threads,
quarantine_invalid_hnsw_metadata,
quarantine_stale_hnsw,
)
@@ -708,7 +710,10 @@ def test_make_client_quarantines_only_on_first_call_per_palace(tmp_path, monkeyp
"""Quarantine fires on first ``make_client()`` for a palace, then is
skipped on subsequent calls — prevents runtime thrash where a daemon's
own steady writes bump ``chroma.sqlite3`` faster than HNSW flushes,
making the mtime heuristic falsely trigger every reconnect."""
making the mtime heuristic falsely trigger every reconnect.
Invalid metadata quarantine shares the same cold-start gate here; the
more aggressive refresh path lives in ``_client()``."""
from mempalace.backends.chroma import ChromaBackend
palace_path = str(tmp_path / "palace")
@@ -730,9 +735,37 @@ def test_make_client_quarantines_only_on_first_call_per_palace(tmp_path, monkeyp
ChromaBackend.make_client(palace_path)
ChromaBackend.make_client(palace_path)
assert calls == [
palace_path
], "quarantine_stale_hnsw should fire once per palace per process, not on every reconnect"
assert calls == [palace_path], (
"quarantine_stale_hnsw should fire once per palace per process, not on every reconnect"
)
def test_make_client_gates_invalid_metadata_on_first_call(tmp_path, monkeypatch):
"""Invalid metadata quarantine is gated on the first make_client() call."""
from mempalace.backends.chroma import ChromaBackend
palace_path = str(tmp_path / "palace")
os.makedirs(palace_path, exist_ok=True)
(Path(palace_path) / "chroma.sqlite3").write_text("")
monkeypatch.setattr(ChromaBackend, "_quarantined_paths", set())
calls: list[str] = []
def _invalid(path, *args, **kwargs):
calls.append(path)
return []
def _stale(path, stale_seconds=300.0):
return []
monkeypatch.setattr("mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _invalid)
monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _stale)
ChromaBackend.make_client(palace_path)
ChromaBackend.make_client(palace_path)
assert calls == [palace_path]
def test_make_client_quarantines_each_palace_independently(tmp_path, monkeypatch):
@@ -811,3 +844,245 @@ def test_get_collection_applies_retrofit_on_existing_palace(tmp_path):
)
assert wrapper._collection.configuration_json["hnsw"]["num_threads"] == 1
def test_quarantine_invalid_hnsw_metadata_renames_missing_dimensionality(tmp_path):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
with open(seg / "index_metadata.pickle", "wb") as f:
pickle.dump({"dimensionality": None, "id_to_label": {"a": 1}}, f)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert len(moved) == 1
assert ".corrupt-" in moved[0]
assert not seg.exists()
def test_quarantine_invalid_hnsw_metadata_allows_uninitialized_segment(tmp_path):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
with open(seg / "index_metadata.pickle", "wb") as f:
pickle.dump({"dimensionality": None, "id_to_label": {}}, f)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert moved == []
assert seg.exists()
def test_quarantine_invalid_hnsw_metadata_rejects_non_dict_id_to_label(tmp_path):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
with open(seg / "index_metadata.pickle", "wb") as f:
pickle.dump({"dimensionality": 8, "id_to_label": ["a", "b"]}, f)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert len(moved) == 1
assert ".corrupt-" in moved[0]
assert not seg.exists()
def test_quarantine_invalid_hnsw_metadata_rejects_non_schema_payload(tmp_path):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
with open(seg / "index_metadata.pickle", "wb") as f:
pickle.dump(["not", "a", "metadata", "object"], f)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert len(moved) == 1
assert ".corrupt-" in moved[0]
assert not seg.exists()
def _dangerous_pickle_payload_executed():
raise AssertionError("unsafe pickle payload executed")
class _DangerousPickle:
def __reduce__(self):
return (_dangerous_pickle_payload_executed, ())
def test_quarantine_invalid_hnsw_metadata_rejects_unsafe_pickle(tmp_path):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
with open(seg / "index_metadata.pickle", "wb") as f:
pickle.dump(_DangerousPickle(), f)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert len(moved) == 1
assert ".corrupt-" in moved[0]
assert not seg.exists()
def test_quarantine_invalid_hnsw_metadata_skips_transient_read_errors(tmp_path, monkeypatch):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
meta = seg / "index_metadata.pickle"
meta.write_bytes(b"partial")
monkeypatch.setattr(
"mempalace.backends.chroma._SafePersistentDataUnpickler.load",
lambda path: (_ for _ in ()).throw(EOFError("flush in progress")),
)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert moved == []
assert seg.exists()
def test_quarantine_invalid_hnsw_metadata_skips_truncated_pickle(tmp_path, monkeypatch):
palace = tmp_path / "palace"
palace.mkdir()
seg = palace / "abcd-1234-5678"
seg.mkdir()
meta = seg / "index_metadata.pickle"
meta.write_bytes(b"partial")
monkeypatch.setattr(
"mempalace.backends.chroma._SafePersistentDataUnpickler.load",
lambda path: (_ for _ in ()).throw(pickle.UnpicklingError("pickle data was truncated")),
)
moved = quarantine_invalid_hnsw_metadata(str(palace))
assert moved == []
assert seg.exists()
def test_chroma_backend_preflights_metadata_before_persistent_client(tmp_path, monkeypatch):
palace = tmp_path / "palace"
palace.mkdir()
calls = []
def _record(name):
def inner(path, *args, **kwargs):
calls.append((name, path))
return [] if name != "blob" else None
return inner
monkeypatch.setattr("mempalace.backends.chroma._fix_blob_seq_ids", _record("blob"))
monkeypatch.setattr(
"mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _record("invalid")
)
monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _record("stale"))
class DummyClient:
pass
monkeypatch.setattr(
"mempalace.backends.chroma.chromadb.PersistentClient", lambda path: DummyClient()
)
backend = ChromaBackend()
backend._client(str(palace))
assert calls == [
("blob", str(palace)),
("invalid", str(palace)),
("stale", str(palace)),
]
def test_chroma_backend_stale_quarantine_is_cold_start_only_on_refresh(tmp_path, monkeypatch):
palace = tmp_path / "palace"
palace.mkdir()
(palace / "chroma.sqlite3").write_text("")
calls = []
def _record(name):
def inner(path, *args, **kwargs):
calls.append((name, path))
return [] if name != "blob" else None
return inner
monkeypatch.setattr(ChromaBackend, "_quarantined_paths", set())
monkeypatch.setattr("mempalace.backends.chroma._fix_blob_seq_ids", _record("blob"))
monkeypatch.setattr(
"mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _record("invalid")
)
monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _record("stale"))
class DummyClient:
pass
monkeypatch.setattr(
"mempalace.backends.chroma.chromadb.PersistentClient", lambda path: DummyClient()
)
backend = ChromaBackend()
stats = iter([(1, 1.0), (1, 1.0), (1, 2.0), (1, 2.0)])
monkeypatch.setattr(backend, "_db_stat", lambda path: next(stats))
backend._client(str(palace))
backend._client(str(palace))
assert calls == [
("blob", str(palace)),
("invalid", str(palace)),
("stale", str(palace)),
("blob", str(palace)),
]
def test_chroma_backend_requarantines_after_inode_replacement(tmp_path, monkeypatch):
palace = tmp_path / "palace"
palace.mkdir()
(palace / "chroma.sqlite3").write_text("")
calls = []
def _record(name):
def inner(path, *args, **kwargs):
calls.append((name, path))
return [] if name != "blob" else None
return inner
monkeypatch.setattr(ChromaBackend, "_quarantined_paths", set())
monkeypatch.setattr("mempalace.backends.chroma._fix_blob_seq_ids", _record("blob"))
monkeypatch.setattr(
"mempalace.backends.chroma.quarantine_invalid_hnsw_metadata", _record("invalid")
)
monkeypatch.setattr("mempalace.backends.chroma.quarantine_stale_hnsw", _record("stale"))
class DummyClient:
pass
monkeypatch.setattr(
"mempalace.backends.chroma.chromadb.PersistentClient", lambda path: DummyClient()
)
backend = ChromaBackend()
stats = iter([(1, 1.0), (1, 1.0), (2, 2.0), (2, 2.0)])
monkeypatch.setattr(backend, "_db_stat", lambda path: next(stats))
backend._client(str(palace))
backend._client(str(palace))
assert calls == [
("blob", str(palace)),
("invalid", str(palace)),
("stale", str(palace)),
("blob", str(palace)),
("invalid", str(palace)),
("stale", str(palace)),
]