Merge pull request #807 from sha2fiddy/fix/218-cosine-distance-metadata
Fix: set cosine distance metadata on all collection creation sites
This commit is contained in:
@@ -85,7 +85,9 @@ class ChromaBackend:
|
||||
_fix_blob_seq_ids(palace_path)
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
if create:
|
||||
collection = client.get_or_create_collection(collection_name)
|
||||
collection = client.get_or_create_collection(
|
||||
collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
else:
|
||||
collection = client.get_collection(collection_name)
|
||||
return ChromaCollection(collection)
|
||||
|
||||
+17
-5
@@ -156,7 +156,11 @@ def cmd_migrate(args):
|
||||
from .migrate import migrate
|
||||
|
||||
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
|
||||
migrate(palace_path=palace_path, dry_run=args.dry_run, confirm=getattr(args, "yes", False))
|
||||
migrate(
|
||||
palace_path=palace_path,
|
||||
dry_run=args.dry_run,
|
||||
confirm=getattr(args, "yes", False),
|
||||
)
|
||||
|
||||
|
||||
def cmd_status(args):
|
||||
@@ -240,7 +244,7 @@ def cmd_repair(args):
|
||||
|
||||
print(" Rebuilding collection...")
|
||||
client.delete_collection("mempalace_drawers")
|
||||
new_col = client.create_collection("mempalace_drawers")
|
||||
new_col = client.create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"})
|
||||
|
||||
filed = 0
|
||||
for i in range(0, len(all_ids), batch_size):
|
||||
@@ -328,7 +332,11 @@ def cmd_compress(args):
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
kwargs = {"include": ["documents", "metadatas"], "limit": _BATCH, "offset": offset}
|
||||
kwargs = {
|
||||
"include": ["documents", "metadatas"],
|
||||
"limit": _BATCH,
|
||||
"offset": offset,
|
||||
}
|
||||
if where:
|
||||
kwargs["where"] = where
|
||||
batch = col.get(**kwargs)
|
||||
@@ -386,7 +394,9 @@ def cmd_compress(args):
|
||||
# Store compressed versions (unless dry-run)
|
||||
if not args.dry_run:
|
||||
try:
|
||||
comp_col = client.get_or_create_collection("mempalace_compressed")
|
||||
comp_col = client.get_or_create_collection(
|
||||
"mempalace_compressed", metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
for doc_id, compressed, meta, stats in compressed_entries:
|
||||
comp_meta = dict(meta)
|
||||
comp_meta["compression_ratio"] = round(stats["size_ratio"], 1)
|
||||
@@ -431,7 +441,9 @@ def main():
|
||||
p_init = sub.add_parser("init", help="Detect rooms from your folder structure")
|
||||
p_init.add_argument("dir", help="Project directory to set up")
|
||||
p_init.add_argument(
|
||||
"--yes", action="store_true", help="Auto-accept all detected entities (non-interactive)"
|
||||
"--yes",
|
||||
action="store_true",
|
||||
help="Auto-accept all detected entities (non-interactive)",
|
||||
)
|
||||
|
||||
# mine
|
||||
|
||||
@@ -33,13 +33,15 @@ def extract_drawers_from_sqlite(db_path: str) -> list:
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
# Get all embedding IDs and their documents
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT e.embedding_id,
|
||||
MAX(CASE WHEN em.key = 'chroma:document' THEN em.string_value END) as document
|
||||
FROM embeddings e
|
||||
JOIN embedding_metadata em ON em.id = e.id
|
||||
GROUP BY e.embedding_id
|
||||
""").fetchall()
|
||||
"""
|
||||
).fetchall()
|
||||
|
||||
drawers = []
|
||||
for row in rows:
|
||||
@@ -207,7 +209,7 @@ def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False):
|
||||
temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_")
|
||||
print(f" Creating fresh palace in {temp_palace}...")
|
||||
client = chromadb.PersistentClient(path=temp_palace)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
col = client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"})
|
||||
|
||||
# Re-import in batches
|
||||
batch_size = 500
|
||||
|
||||
+1
-1
@@ -101,7 +101,7 @@ def config(tmp_dir, palace_path):
|
||||
def collection(palace_path):
|
||||
"""A ChromaDB collection pre-seeded in the temp palace."""
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
col = client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"})
|
||||
yield col
|
||||
client.delete_collection("mempalace_drawers")
|
||||
del client
|
||||
|
||||
@@ -82,6 +82,20 @@ def test_chroma_backend_create_true_creates_directory_and_collection(tmp_path):
|
||||
client.get_collection("mempalace_drawers")
|
||||
|
||||
|
||||
def test_chroma_backend_creates_collection_with_cosine_distance(tmp_path):
|
||||
palace_path = tmp_path / "palace"
|
||||
|
||||
ChromaBackend().get_collection(
|
||||
str(palace_path),
|
||||
collection_name="mempalace_drawers",
|
||||
create=True,
|
||||
)
|
||||
|
||||
client = chromadb.PersistentClient(path=str(palace_path))
|
||||
col = client.get_collection("mempalace_drawers")
|
||||
assert col.metadata.get("hnsw:space") == "cosine"
|
||||
|
||||
|
||||
def test_fix_blob_seq_ids_converts_blobs_to_integers(tmp_path):
|
||||
"""Simulate a ChromaDB 0.6.x database with BLOB seq_ids and verify repair."""
|
||||
db_path = tmp_path / "chroma.sqlite3"
|
||||
|
||||
@@ -31,7 +31,10 @@ def _get_collection(palace_path, create=False):
|
||||
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
if create:
|
||||
return client, client.get_or_create_collection("mempalace_drawers")
|
||||
return (
|
||||
client,
|
||||
client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"}),
|
||||
)
|
||||
return client, client.get_collection("mempalace_drawers")
|
||||
|
||||
|
||||
@@ -319,7 +322,7 @@ class TestSearchTool:
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace import mcp_server
|
||||
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
|
||||
|
||||
result = mcp_server.tool_list_rooms(wing="../etc/passwd")
|
||||
assert "error" in result
|
||||
@@ -328,7 +331,7 @@ class TestSearchTool:
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace import mcp_server
|
||||
|
||||
monkeypatch.setattr(mcp_server, "search_memories", lambda *args, **kwargs: pytest.fail())
|
||||
monkeypatch.setattr(mcp_server, "search_memories", lambda: pytest.fail())
|
||||
|
||||
result = mcp_server.tool_search(query="JWT", room="../backend")
|
||||
assert "error" in result
|
||||
@@ -337,7 +340,7 @@ class TestSearchTool:
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace import mcp_server
|
||||
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
|
||||
|
||||
result = mcp_server.tool_list_drawers(wing="../notes")
|
||||
assert "error" in result
|
||||
@@ -346,7 +349,7 @@ class TestSearchTool:
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace import mcp_server
|
||||
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
|
||||
monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail())
|
||||
|
||||
result = mcp_server.tool_find_tunnels(wing_a="../project")
|
||||
assert "error" in result
|
||||
|
||||
+5
-2
@@ -27,7 +27,8 @@ def test_project_mining():
|
||||
os.makedirs(project_root / "backend")
|
||||
|
||||
write_file(
|
||||
project_root / "backend" / "app.py", "def main():\n print('hello world')\n" * 20
|
||||
project_root / "backend" / "app.py",
|
||||
"def main():\n print('hello world')\n" * 20,
|
||||
)
|
||||
with open(project_root / "mempalace.yaml", "w") as f:
|
||||
yaml.dump(
|
||||
@@ -215,7 +216,9 @@ def test_file_already_mined_check_mtime():
|
||||
palace_path = os.path.join(tmpdir, "palace")
|
||||
os.makedirs(palace_path)
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
col = client.get_or_create_collection(
|
||||
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
test_file = os.path.join(tmpdir, "test.txt")
|
||||
with open(test_file, "w") as f:
|
||||
|
||||
Reference in New Issue
Block a user