Fix: set cosine distance metadata on all collection creation sites
ChromaDB defaults HNSW index to L2 (Euclidean) distance, but
MemPalace scoring uses 1-distance which requires cosine (range 0-2).
Add metadata={"hnsw:space": "cosine"} to the 4 production and 3 test
call sites that were missing it.
Closes #218
This commit is contained in:
@@ -35,8 +35,13 @@ def _fix_blob_seq_ids(palace_path: str):
|
||||
continue
|
||||
if not rows:
|
||||
continue
|
||||
updates = [(int.from_bytes(blob, byteorder="big"), rowid) for rowid, blob in rows]
|
||||
conn.executemany(f"UPDATE {table} SET seq_id = ? WHERE rowid = ?", updates)
|
||||
updates = [
|
||||
(int.from_bytes(blob, byteorder="big"), rowid)
|
||||
for rowid, blob in rows
|
||||
]
|
||||
conn.executemany(
|
||||
f"UPDATE {table} SET seq_id = ? WHERE rowid = ?", updates
|
||||
)
|
||||
logger.info("Fixed %d BLOB seq_ids in %s", len(updates), table)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
@@ -71,7 +76,9 @@ class ChromaCollection(BaseCollection):
|
||||
class ChromaBackend:
|
||||
"""Factory for MemPalace's default ChromaDB backend."""
|
||||
|
||||
def get_collection(self, palace_path: str, collection_name: str, create: bool = False):
|
||||
def get_collection(
|
||||
self, palace_path: str, collection_name: str, create: bool = False
|
||||
):
|
||||
if not create and not os.path.isdir(palace_path):
|
||||
raise FileNotFoundError(palace_path)
|
||||
|
||||
@@ -85,7 +92,9 @@ class ChromaBackend:
|
||||
_fix_blob_seq_ids(palace_path)
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
if create:
|
||||
collection = client.get_or_create_collection(collection_name)
|
||||
collection = client.get_or_create_collection(
|
||||
collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
else:
|
||||
collection = client.get_collection(collection_name)
|
||||
return ChromaCollection(collection)
|
||||
|
||||
Reference in New Issue
Block a user