diff --git a/mempalace/backends/chroma.py b/mempalace/backends/chroma.py index da01e4d..28fe55f 100644 --- a/mempalace/backends/chroma.py +++ b/mempalace/backends/chroma.py @@ -85,7 +85,9 @@ class ChromaBackend: _fix_blob_seq_ids(palace_path) client = chromadb.PersistentClient(path=palace_path) if create: - collection = client.get_or_create_collection(collection_name) + collection = client.get_or_create_collection( + collection_name, metadata={"hnsw:space": "cosine"} + ) else: collection = client.get_collection(collection_name) return ChromaCollection(collection) diff --git a/mempalace/cli.py b/mempalace/cli.py index 8bf3f20..fa92ed6 100644 --- a/mempalace/cli.py +++ b/mempalace/cli.py @@ -156,7 +156,11 @@ def cmd_migrate(args): from .migrate import migrate palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path - migrate(palace_path=palace_path, dry_run=args.dry_run, confirm=getattr(args, "yes", False)) + migrate( + palace_path=palace_path, + dry_run=args.dry_run, + confirm=getattr(args, "yes", False), + ) def cmd_status(args): @@ -240,7 +244,7 @@ def cmd_repair(args): print(" Rebuilding collection...") client.delete_collection("mempalace_drawers") - new_col = client.create_collection("mempalace_drawers") + new_col = client.create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"}) filed = 0 for i in range(0, len(all_ids), batch_size): @@ -328,7 +332,11 @@ def cmd_compress(args): offset = 0 while True: try: - kwargs = {"include": ["documents", "metadatas"], "limit": _BATCH, "offset": offset} + kwargs = { + "include": ["documents", "metadatas"], + "limit": _BATCH, + "offset": offset, + } if where: kwargs["where"] = where batch = col.get(**kwargs) @@ -386,7 +394,9 @@ def cmd_compress(args): # Store compressed versions (unless dry-run) if not args.dry_run: try: - comp_col = client.get_or_create_collection("mempalace_compressed") + comp_col = client.get_or_create_collection( + "mempalace_compressed", metadata={"hnsw:space": "cosine"} + ) for doc_id, compressed, meta, stats in compressed_entries: comp_meta = dict(meta) comp_meta["compression_ratio"] = round(stats["size_ratio"], 1) @@ -431,7 +441,9 @@ def main(): p_init = sub.add_parser("init", help="Detect rooms from your folder structure") p_init.add_argument("dir", help="Project directory to set up") p_init.add_argument( - "--yes", action="store_true", help="Auto-accept all detected entities (non-interactive)" + "--yes", + action="store_true", + help="Auto-accept all detected entities (non-interactive)", ) # mine diff --git a/mempalace/migrate.py b/mempalace/migrate.py index 6ec4a59..319c670 100644 --- a/mempalace/migrate.py +++ b/mempalace/migrate.py @@ -33,13 +33,15 @@ def extract_drawers_from_sqlite(db_path: str) -> list: conn.row_factory = sqlite3.Row # Get all embedding IDs and their documents - rows = conn.execute(""" + rows = conn.execute( + """ SELECT e.embedding_id, MAX(CASE WHEN em.key = 'chroma:document' THEN em.string_value END) as document FROM embeddings e JOIN embedding_metadata em ON em.id = e.id GROUP BY e.embedding_id - """).fetchall() + """ + ).fetchall() drawers = [] for row in rows: @@ -207,7 +209,7 @@ def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False): temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_") print(f" Creating fresh palace in {temp_palace}...") client = chromadb.PersistentClient(path=temp_palace) - col = client.get_or_create_collection("mempalace_drawers") + col = client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"}) # Re-import in batches batch_size = 500 diff --git a/tests/conftest.py b/tests/conftest.py index 16185ef..7b2bb77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -101,7 +101,7 @@ def config(tmp_dir, palace_path): def collection(palace_path): """A ChromaDB collection pre-seeded in the temp palace.""" client = chromadb.PersistentClient(path=palace_path) - col = client.get_or_create_collection("mempalace_drawers") + col = client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"}) yield col client.delete_collection("mempalace_drawers") del client diff --git a/tests/test_backends.py b/tests/test_backends.py index 846134f..a620bf9 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -82,6 +82,20 @@ def test_chroma_backend_create_true_creates_directory_and_collection(tmp_path): client.get_collection("mempalace_drawers") +def test_chroma_backend_creates_collection_with_cosine_distance(tmp_path): + palace_path = tmp_path / "palace" + + ChromaBackend().get_collection( + str(palace_path), + collection_name="mempalace_drawers", + create=True, + ) + + client = chromadb.PersistentClient(path=str(palace_path)) + col = client.get_collection("mempalace_drawers") + assert col.metadata.get("hnsw:space") == "cosine" + + def test_fix_blob_seq_ids_converts_blobs_to_integers(tmp_path): """Simulate a ChromaDB 0.6.x database with BLOB seq_ids and verify repair.""" db_path = tmp_path / "chroma.sqlite3" diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index a8189ae..9584f36 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -31,7 +31,10 @@ def _get_collection(palace_path, create=False): client = chromadb.PersistentClient(path=palace_path) if create: - return client, client.get_or_create_collection("mempalace_drawers") + return ( + client, + client.get_or_create_collection("mempalace_drawers", metadata={"hnsw:space": "cosine"}), + ) return client, client.get_collection("mempalace_drawers") @@ -319,7 +322,7 @@ class TestSearchTool: _patch_mcp_server(monkeypatch, config, kg) from mempalace import mcp_server - monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail()) + monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail()) result = mcp_server.tool_list_rooms(wing="../etc/passwd") assert "error" in result @@ -328,7 +331,7 @@ class TestSearchTool: _patch_mcp_server(monkeypatch, config, kg) from mempalace import mcp_server - monkeypatch.setattr(mcp_server, "search_memories", lambda *args, **kwargs: pytest.fail()) + monkeypatch.setattr(mcp_server, "search_memories", lambda: pytest.fail()) result = mcp_server.tool_search(query="JWT", room="../backend") assert "error" in result @@ -337,7 +340,7 @@ class TestSearchTool: _patch_mcp_server(monkeypatch, config, kg) from mempalace import mcp_server - monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail()) + monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail()) result = mcp_server.tool_list_drawers(wing="../notes") assert "error" in result @@ -346,7 +349,7 @@ class TestSearchTool: _patch_mcp_server(monkeypatch, config, kg) from mempalace import mcp_server - monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail()) + monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail()) result = mcp_server.tool_find_tunnels(wing_a="../project") assert "error" in result diff --git a/tests/test_miner.py b/tests/test_miner.py index 020d5bd..d16c97c 100644 --- a/tests/test_miner.py +++ b/tests/test_miner.py @@ -27,7 +27,8 @@ def test_project_mining(): os.makedirs(project_root / "backend") write_file( - project_root / "backend" / "app.py", "def main():\n print('hello world')\n" * 20 + project_root / "backend" / "app.py", + "def main():\n print('hello world')\n" * 20, ) with open(project_root / "mempalace.yaml", "w") as f: yaml.dump( @@ -215,7 +216,9 @@ def test_file_already_mined_check_mtime(): palace_path = os.path.join(tmpdir, "palace") os.makedirs(palace_path) client = chromadb.PersistentClient(path=palace_path) - col = client.get_or_create_collection("mempalace_drawers") + col = client.get_or_create_collection( + "mempalace_drawers", metadata={"hnsw:space": "cosine"} + ) test_file = os.path.join(tmpdir, "test.txt") with open(test_file, "w") as f: