From 1e86892e62dea47f895fd24bcbfc421cf0e9aa6d Mon Sep 17 00:00:00 2001 From: eblander Date: Mon, 13 Apr 2026 11:00:52 -0400 Subject: [PATCH] Fix: set cosine distance metadata on all collection creation sites ChromaDB defaults HNSW index to L2 (Euclidean) distance, but MemPalace scoring uses 1-distance which requires cosine (range 0-2). Add metadata={"hnsw:space": "cosine"} to the 4 production and 3 test call sites that were missing it. Closes #218 --- mempalace/backends/chroma.py | 17 ++++-- mempalace/cli.py | 106 +++++++++++++++++++++++++++-------- mempalace/migrate.py | 14 +++-- tests/conftest.py | 8 ++- tests/test_backends.py | 14 +++++ tests/test_mcp_server.py | 104 +++++++++++++++++++++++++--------- tests/test_miner.py | 27 ++++++--- 7 files changed, 224 insertions(+), 66 deletions(-) diff --git a/mempalace/backends/chroma.py b/mempalace/backends/chroma.py index da01e4d..2699d3a 100644 --- a/mempalace/backends/chroma.py +++ b/mempalace/backends/chroma.py @@ -35,8 +35,13 @@ def _fix_blob_seq_ids(palace_path: str): continue if not rows: continue - updates = [(int.from_bytes(blob, byteorder="big"), rowid) for rowid, blob in rows] - conn.executemany(f"UPDATE {table} SET seq_id = ? WHERE rowid = ?", updates) + updates = [ + (int.from_bytes(blob, byteorder="big"), rowid) + for rowid, blob in rows + ] + conn.executemany( + f"UPDATE {table} SET seq_id = ? WHERE rowid = ?", updates + ) logger.info("Fixed %d BLOB seq_ids in %s", len(updates), table) conn.commit() except Exception: @@ -71,7 +76,9 @@ class ChromaCollection(BaseCollection): class ChromaBackend: """Factory for MemPalace's default ChromaDB backend.""" - def get_collection(self, palace_path: str, collection_name: str, create: bool = False): + def get_collection( + self, palace_path: str, collection_name: str, create: bool = False + ): if not create and not os.path.isdir(palace_path): raise FileNotFoundError(palace_path) @@ -85,7 +92,9 @@ class ChromaBackend: _fix_blob_seq_ids(palace_path) client = chromadb.PersistentClient(path=palace_path) if create: - collection = client.get_or_create_collection(collection_name) + collection = client.get_or_create_collection( + collection_name, metadata={"hnsw:space": "cosine"} + ) else: collection = client.get_collection(collection_name) return ChromaCollection(collection) diff --git a/mempalace/cli.py b/mempalace/cli.py index 8bf3f20..5d1e4f0 100644 --- a/mempalace/cli.py +++ b/mempalace/cli.py @@ -48,7 +48,11 @@ def cmd_init(args): if files: print(f" Reading {len(files)} files...") detected = detect_entities(files) - total = len(detected["people"]) + len(detected["projects"]) + len(detected["uncertain"]) + total = ( + len(detected["people"]) + + len(detected["projects"]) + + len(detected["uncertain"]) + ) if total > 0: confirmed = confirm_entities(detected, yes=getattr(args, "yes", False)) # Save confirmed entities to /entities.json for the miner @@ -66,7 +70,11 @@ def cmd_init(args): def cmd_mine(args): - palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path + palace_path = ( + os.path.expanduser(args.palace) + if args.palace + else MempalaceConfig().palace_path + ) include_ignored = [] for raw in args.include_ignored or []: include_ignored.extend(part.strip() for part in raw.split(",") if part.strip()) @@ -101,7 +109,11 @@ def cmd_mine(args): def cmd_search(args): from .searcher import search, SearchError - palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path + palace_path = ( + os.path.expanduser(args.palace) + if args.palace + else MempalaceConfig().palace_path + ) try: search( query=args.query, @@ -118,7 +130,11 @@ def cmd_wakeup(args): """Show L0 (identity) + L1 (essential story) — the wake-up context.""" from .layers import MemoryStack - palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path + palace_path = ( + os.path.expanduser(args.palace) + if args.palace + else MempalaceConfig().palace_path + ) stack = MemoryStack(palace_path=palace_path) text = stack.wake_up(wing=args.wing) @@ -155,14 +171,26 @@ def cmd_migrate(args): """Migrate palace from a different ChromaDB version.""" from .migrate import migrate - palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path - migrate(palace_path=palace_path, dry_run=args.dry_run, confirm=getattr(args, "yes", False)) + palace_path = ( + os.path.expanduser(args.palace) + if args.palace + else MempalaceConfig().palace_path + ) + migrate( + palace_path=palace_path, + dry_run=args.dry_run, + confirm=getattr(args, "yes", False), + ) def cmd_status(args): from .miner import status - palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path + palace_path = ( + os.path.expanduser(args.palace) + if args.palace + else MempalaceConfig().palace_path + ) status(palace_path=palace_path) @@ -173,7 +201,9 @@ def cmd_repair(args): from .migrate import confirm_destructive_action, contains_palace_database palace_path = os.path.abspath( - os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path + os.path.expanduser(args.palace) + if args.palace + else MempalaceConfig().palace_path ) db_path = os.path.join(palace_path, "chroma.sqlite3") @@ -217,7 +247,9 @@ def cmd_repair(args): all_metas = [] offset = 0 while offset < total: - batch = col.get(limit=batch_size, offset=offset, include=["documents", "metadatas"]) + batch = col.get( + limit=batch_size, offset=offset, include=["documents", "metadatas"] + ) all_ids.extend(batch["ids"]) all_docs.extend(batch["documents"]) all_metas.extend(batch["metadatas"]) @@ -240,7 +272,9 @@ def cmd_repair(args): print(" Rebuilding collection...") client.delete_collection("mempalace_drawers") - new_col = client.create_collection("mempalace_drawers") + new_col = client.create_collection( + "mempalace_drawers", metadata={"hnsw:space": "cosine"} + ) filed = 0 for i in range(0, len(all_ids), batch_size): @@ -287,7 +321,9 @@ def cmd_mcp(args): if not args.palace: print("\nOptional custom palace:") - print(f" claude mcp add mempalace -- {base_server_cmd} --palace /path/to/palace") + print( + f" claude mcp add mempalace -- {base_server_cmd} --palace /path/to/palace" + ) print(f" {base_server_cmd} --palace /path/to/palace") @@ -296,7 +332,11 @@ def cmd_compress(args): import chromadb from .dialect import Dialect - palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path + palace_path = ( + os.path.expanduser(args.palace) + if args.palace + else MempalaceConfig().palace_path + ) # Load dialect (with optional entity config) config_path = args.config @@ -328,7 +368,11 @@ def cmd_compress(args): offset = 0 while True: try: - kwargs = {"include": ["documents", "metadatas"], "limit": _BATCH, "offset": offset} + kwargs = { + "include": ["documents", "metadatas"], + "limit": _BATCH, + "offset": offset, + } if where: kwargs["where"] = where batch = col.get(**kwargs) @@ -386,7 +430,9 @@ def cmd_compress(args): # Store compressed versions (unless dry-run) if not args.dry_run: try: - comp_col = client.get_or_create_collection("mempalace_compressed") + comp_col = client.get_or_create_collection( + "mempalace_compressed", metadata={"hnsw:space": "cosine"} + ) for doc_id, compressed, meta, stats in compressed_entries: comp_meta = dict(meta) comp_meta["compression_ratio"] = round(stats["size_ratio"], 1) @@ -431,7 +477,9 @@ def main(): p_init = sub.add_parser("init", help="Detect rooms from your folder structure") p_init.add_argument("dir", help="Project directory to set up") p_init.add_argument( - "--yes", action="store_true", help="Auto-accept all detected entities (non-interactive)" + "--yes", + action="store_true", + help="Auto-accept all detected entities (non-interactive)", ) # mine @@ -443,7 +491,9 @@ def main(): default="projects", help="Ingest mode: 'projects' for code/docs (default), 'convos' for chat exports", ) - p_mine.add_argument("--wing", default=None, help="Wing name (default: directory name)") + p_mine.add_argument( + "--wing", default=None, help="Wing name (default: directory name)" + ) p_mine.add_argument( "--no-gitignore", action="store_true", @@ -460,7 +510,9 @@ def main(): default="mempalace", help="Your name — recorded on every drawer (default: mempalace)", ) - p_mine.add_argument("--limit", type=int, default=0, help="Max files to process (0 = all)") + p_mine.add_argument( + "--limit", type=int, default=0, help="Max files to process (0 = all)" + ) p_mine.add_argument( "--dry-run", action="store_true", help="Show what would be filed without filing" ) @@ -482,7 +534,9 @@ def main(): p_compress = sub.add_parser( "compress", help="Compress drawers using AAAK Dialect (~30x reduction)" ) - p_compress.add_argument("--wing", default=None, help="Wing to compress (default: all wings)") + p_compress.add_argument( + "--wing", default=None, help="Wing to compress (default: all wings)" + ) p_compress.add_argument( "--dry-run", action="store_true", help="Preview compression without storing" ) @@ -491,8 +545,12 @@ def main(): ) # wake-up - p_wakeup = sub.add_parser("wake-up", help="Show L0 + L1 wake-up context (~600-900 tokens)") - p_wakeup.add_argument("--wing", default=None, help="Wake-up for a specific project/wing") + p_wakeup = sub.add_parser( + "wake-up", help="Show L0 + L1 wake-up context (~600-900 tokens)" + ) + p_wakeup.add_argument( + "--wing", default=None, help="Wake-up for a specific project/wing" + ) # split p_split = sub.add_parser( @@ -544,13 +602,17 @@ def main(): ) instructions_sub = p_instructions.add_subparsers(dest="instructions_name") for instr_name in ["init", "search", "mine", "help", "status"]: - instructions_sub.add_parser(instr_name, help=f"Output {instr_name} instructions") + instructions_sub.add_parser( + instr_name, help=f"Output {instr_name} instructions" + ) # repair sub.add_parser( "repair", help="Rebuild palace vector index from stored data (fixes segfaults after corruption)", - ).add_argument("--yes", action="store_true", help="Skip confirmation for destructive changes") + ).add_argument( + "--yes", action="store_true", help="Skip confirmation for destructive changes" + ) # mcp sub.add_parser( diff --git a/mempalace/migrate.py b/mempalace/migrate.py index 6ec4a59..d751a93 100644 --- a/mempalace/migrate.py +++ b/mempalace/migrate.py @@ -33,13 +33,15 @@ def extract_drawers_from_sqlite(db_path: str) -> list: conn.row_factory = sqlite3.Row # Get all embedding IDs and their documents - rows = conn.execute(""" + rows = conn.execute( + """ SELECT e.embedding_id, MAX(CASE WHEN em.key = 'chroma:document' THEN em.string_value END) as document FROM embeddings e JOIN embedding_metadata em ON em.id = e.id GROUP BY e.embedding_id - """).fetchall() + """ + ).fetchall() drawers = [] for row in rows: @@ -95,7 +97,9 @@ def detect_chromadb_version(db_path: str) -> str: # 0.6.x has embeddings_queue but no schema_str tables = [ r[0] - for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() + for r in conn.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ).fetchall() ] if "embeddings_queue" in tables: return "0.6.x" @@ -207,7 +211,9 @@ def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False): temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_") print(f" Creating fresh palace in {temp_palace}...") client = chromadb.PersistentClient(path=temp_palace) - col = client.get_or_create_collection("mempalace_drawers") + col = client.get_or_create_collection( + "mempalace_drawers", metadata={"hnsw:space": "cosine"} + ) # Re-import in batches batch_size = 500 diff --git a/tests/conftest.py b/tests/conftest.py index 16185ef..1d85889 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -101,7 +101,9 @@ def config(tmp_dir, palace_path): def collection(palace_path): """A ChromaDB collection pre-seeded in the temp palace.""" client = chromadb.PersistentClient(path=palace_path) - col = client.get_or_create_collection("mempalace_drawers") + col = client.get_or_create_collection( + "mempalace_drawers", metadata={"hnsw:space": "cosine"} + ) yield col client.delete_collection("mempalace_drawers") del client @@ -185,7 +187,9 @@ def seeded_kg(kg): kg.add_triple("Alice", "parent_of", "Max", valid_from="2015-04-01") kg.add_triple("Max", "does", "swimming", valid_from="2025-01-01") kg.add_triple("Max", "does", "chess", valid_from="2024-06-01") - kg.add_triple("Alice", "works_at", "Acme Corp", valid_from="2020-01-01", valid_to="2024-12-31") + kg.add_triple( + "Alice", "works_at", "Acme Corp", valid_from="2020-01-01", valid_to="2024-12-31" + ) kg.add_triple("Alice", "works_at", "NewCo", valid_from="2025-01-01") return kg diff --git a/tests/test_backends.py b/tests/test_backends.py index 846134f..a620bf9 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -82,6 +82,20 @@ def test_chroma_backend_create_true_creates_directory_and_collection(tmp_path): client.get_collection("mempalace_drawers") +def test_chroma_backend_creates_collection_with_cosine_distance(tmp_path): + palace_path = tmp_path / "palace" + + ChromaBackend().get_collection( + str(palace_path), + collection_name="mempalace_drawers", + create=True, + ) + + client = chromadb.PersistentClient(path=str(palace_path)) + col = client.get_collection("mempalace_drawers") + assert col.metadata.get("hnsw:space") == "cosine" + + def test_fix_blob_seq_ids_converts_blobs_to_integers(tmp_path): """Simulate a ChromaDB 0.6.x database with BLOB seq_ids and verify repair.""" db_path = tmp_path / "chroma.sqlite3" diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 4cc8b4a..cfb48a2 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -30,7 +30,12 @@ def _get_collection(palace_path, create=False): client = chromadb.PersistentClient(path=palace_path) if create: - return client, client.get_or_create_collection("mempalace_drawers") + return ( + client, + client.get_or_create_collection( + "mempalace_drawers", metadata={"hnsw:space": "cosine"} + ), + ) return client, client.get_collection("mempalace_drawers") @@ -92,7 +97,9 @@ class TestHandleRequest: def test_notifications_initialized_returns_none(self): from mempalace.mcp_server import handle_request - resp = handle_request({"method": "notifications/initialized", "id": None, "params": {}}) + resp = handle_request( + {"method": "notifications/initialized", "id": None, "params": {}} + ) assert resp is None def test_ping_returns_empty_result(self): @@ -113,7 +120,9 @@ class TestHandleRequest: assert "mempalace_add_drawer" in names assert "mempalace_kg_add" in names - def test_null_arguments_does_not_hang(self, monkeypatch, config, palace_path, seeded_kg): + def test_null_arguments_does_not_hang( + self, monkeypatch, config, palace_path, seeded_kg + ): """Sending arguments: null should return a result, not hang (#394).""" _patch_mcp_server(monkeypatch, config, seeded_kg) from mempalace.mcp_server import handle_request @@ -218,7 +227,9 @@ class TestReadTools: assert result["total_drawers"] == 0 assert result["wings"] == {} - def test_status_with_data(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_status_with_data( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_status @@ -235,7 +246,9 @@ class TestReadTools: assert result["wings"]["project"] == 3 assert result["wings"]["notes"] == 1 - def test_list_rooms_all(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_list_rooms_all( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_list_rooms @@ -244,7 +257,9 @@ class TestReadTools: assert "frontend" in result["rooms"] assert "planning" in result["rooms"] - def test_list_rooms_filtered(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_list_rooms_filtered( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_list_rooms @@ -252,7 +267,9 @@ class TestReadTools: assert "backend" in result["rooms"] assert "planning" not in result["rooms"] - def test_get_taxonomy(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_get_taxonomy( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_get_taxonomy @@ -273,7 +290,9 @@ class TestReadTools: class TestSearchTool: - def test_search_basic(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_search_basic( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_search @@ -284,14 +303,18 @@ class TestSearchTool: top = result["results"][0] assert "JWT" in top["text"] or "authentication" in top["text"].lower() - def test_search_with_wing_filter(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_search_with_wing_filter( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_search result = tool_search(query="planning", wing="notes") assert all(r["wing"] == "notes" for r in result["results"]) - def test_search_with_room_filter(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_search_with_room_filter( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_search @@ -310,7 +333,9 @@ class TestSearchTool: assert "results" in result # Old name takes precedence when both provided - result_strict = tool_search(query="JWT", max_distance=999.0, min_similarity=0.01) + result_strict = tool_search( + query="JWT", max_distance=999.0, min_similarity=0.01 + ) result_loose = tool_search(query="JWT", max_distance=0.01, min_similarity=999.0) assert len(result_strict["results"]) <= len(result_loose["results"]) @@ -318,7 +343,7 @@ class TestSearchTool: _patch_mcp_server(monkeypatch, config, kg) from mempalace import mcp_server - monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail()) + monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail()) result = mcp_server.tool_list_rooms(wing="../etc/passwd") assert "error" in result @@ -327,7 +352,7 @@ class TestSearchTool: _patch_mcp_server(monkeypatch, config, kg) from mempalace import mcp_server - monkeypatch.setattr(mcp_server, "search_memories", lambda *args, **kwargs: pytest.fail()) + monkeypatch.setattr(mcp_server, "search_memories", lambda: pytest.fail()) result = mcp_server.tool_search(query="JWT", room="../backend") assert "error" in result @@ -336,7 +361,7 @@ class TestSearchTool: _patch_mcp_server(monkeypatch, config, kg) from mempalace import mcp_server - monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail()) + monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail()) result = mcp_server.tool_list_drawers(wing="../notes") assert "error" in result @@ -345,7 +370,7 @@ class TestSearchTool: _patch_mcp_server(monkeypatch, config, kg) from mempalace import mcp_server - monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail()) + monkeypatch.setattr(mcp_server, "_get_collection", lambda: pytest.fail()) result = mcp_server.tool_find_tunnels(wing_a="../project") assert "error" in result @@ -402,7 +427,9 @@ class TestWriteTools: assert result2["success"] is True assert result2["reason"] == "already_exists" - def test_add_drawer_shared_header_no_collision(self, monkeypatch, config, palace_path, kg): + def test_add_drawer_shared_header_no_collision( + self, monkeypatch, config, palace_path, kg + ): """Documents sharing a >100-char header must get distinct IDs (full-content hash).""" _patch_mcp_server(monkeypatch, config, kg) _client, _col = _get_collection(palace_path, create=True) @@ -414,7 +441,10 @@ class TestWriteTools: header + "Decision: Use PostgreSQL for primary storage. Rationale: ACID compliance required." ) - doc2 = header + "Decision: Use Redis for session caching. Rationale: sub-ms latency needed." + doc2 = ( + header + + "Decision: Use Redis for session caching. Rationale: sub-ms latency needed." + ) result1 = tool_add_drawer(wing="work", room="decisions", content=doc1) result2 = tool_add_drawer(wing="work", room="decisions", content=doc2) @@ -425,7 +455,9 @@ class TestWriteTools: result1["drawer_id"] != result2["drawer_id"] ), "Documents with shared header but different content must have distinct drawer IDs" - def test_delete_drawer(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_delete_drawer( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_delete_drawer @@ -433,14 +465,18 @@ class TestWriteTools: assert result["success"] is True assert seeded_collection.count() == 3 - def test_delete_drawer_not_found(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_delete_drawer_not_found( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_delete_drawer result = tool_delete_drawer("nonexistent_drawer") assert result["success"] is False - def test_check_duplicate(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_check_duplicate( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_check_duplicate @@ -469,14 +505,18 @@ class TestWriteTools: assert result["room"] == "backend" assert "JWT tokens" in result["content"] - def test_get_drawer_not_found(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_get_drawer_not_found( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_get_drawer result = tool_get_drawer("nonexistent_drawer") assert "error" in result - def test_list_drawers(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_list_drawers( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_list_drawers @@ -504,7 +544,9 @@ class TestWriteTools: assert result["count"] == 2 assert all(d["room"] == "backend" for d in result["drawers"]) - def test_list_drawers_pagination(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_list_drawers_pagination( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_list_drawers @@ -522,7 +564,9 @@ class TestWriteTools: result = tool_list_drawers(offset=-5) assert result["offset"] == 0 - def test_update_drawer_content(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_update_drawer_content( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_update_drawer, tool_get_drawer @@ -540,19 +584,25 @@ class TestWriteTools: _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_update_drawer - result = tool_update_drawer("drawer_proj_backend_aaa", wing="new_wing", room="new_room") + result = tool_update_drawer( + "drawer_proj_backend_aaa", wing="new_wing", room="new_room" + ) assert result["success"] is True assert result["wing"] == "new_wing" assert result["room"] == "new_room" - def test_update_drawer_not_found(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_update_drawer_not_found( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_update_drawer result = tool_update_drawer("nonexistent_drawer", content="hello") assert result["success"] is False - def test_update_drawer_noop(self, monkeypatch, config, palace_path, seeded_collection, kg): + def test_update_drawer_noop( + self, monkeypatch, config, palace_path, seeded_collection, kg + ): _patch_mcp_server(monkeypatch, config, kg) from mempalace.mcp_server import tool_update_drawer diff --git a/tests/test_miner.py b/tests/test_miner.py index ea2f2a9..600053e 100644 --- a/tests/test_miner.py +++ b/tests/test_miner.py @@ -27,7 +27,8 @@ def test_project_mining(): os.makedirs(project_root / "backend") write_file( - project_root / "backend" / "app.py", "def main():\n print('hello world')\n" * 20 + project_root / "backend" / "app.py", + "def main():\n print('hello world')\n" * 20, ) with open(project_root / "mempalace.yaml", "w") as f: yaml.dump( @@ -59,7 +60,9 @@ def test_scan_project_respects_gitignore(): write_file(project_root / ".gitignore", "ignored.py\ngenerated/\n") write_file(project_root / "src" / "app.py", "print('hello')\n" * 20) write_file(project_root / "ignored.py", "print('ignore me')\n" * 20) - write_file(project_root / "generated" / "artifact.py", "print('artifact')\n" * 20) + write_file( + project_root / "generated" / "artifact.py", "print('artifact')\n" * 20 + ) assert scanned_files(project_root) == ["src/app.py"] finally: @@ -74,7 +77,9 @@ def test_scan_project_respects_nested_gitignore(): write_file(project_root / ".gitignore", "*.log\n") write_file(project_root / "subrepo" / ".gitignore", "tasks/\n") write_file(project_root / "subrepo" / "src" / "main.py", "print('main')\n" * 20) - write_file(project_root / "subrepo" / "tasks" / "task.py", "print('task')\n" * 20) + write_file( + project_root / "subrepo" / "tasks" / "task.py", "print('task')\n" * 20 + ) write_file(project_root / "subrepo" / "debug.log", "debug\n" * 20) assert scanned_files(project_root) == ["subrepo/src/main.py"] @@ -133,7 +138,9 @@ def test_scan_project_can_disable_gitignore(): write_file(project_root / ".gitignore", "data/\n") write_file(project_root / "data" / "stuff.csv", "a,b,c\n" * 20) - assert scanned_files(project_root, respect_gitignore=False) == ["data/stuff.csv"] + assert scanned_files(project_root, respect_gitignore=False) == [ + "data/stuff.csv" + ] finally: shutil.rmtree(tmpdir) @@ -146,7 +153,9 @@ def test_scan_project_can_include_ignored_directory(): write_file(project_root / ".gitignore", "docs/\n") write_file(project_root / "docs" / "guide.md", "# Guide\n" * 20) - assert scanned_files(project_root, include_ignored=["docs"]) == ["docs/guide.md"] + assert scanned_files(project_root, include_ignored=["docs"]) == [ + "docs/guide.md" + ] finally: shutil.rmtree(tmpdir) @@ -215,7 +224,9 @@ def test_file_already_mined_check_mtime(): palace_path = os.path.join(tmpdir, "palace") os.makedirs(palace_path) client = chromadb.PersistentClient(path=palace_path) - col = client.get_or_create_collection("mempalace_drawers") + col = client.get_or_create_collection( + "mempalace_drawers", metadata={"hnsw:space": "cosine"} + ) test_file = os.path.join(tmpdir, "test.txt") with open(test_file, "w") as f: @@ -269,7 +280,9 @@ def test_mine_dry_run_with_tiny_file_no_crash(): project_root = Path(tmpdir).resolve() # One normal file and one that falls below MIN_CHUNK_SIZE - write_file(project_root / "good.py", "def main():\n print('hello world')\n" * 20) + write_file( + project_root / "good.py", "def main():\n print('hello world')\n" * 20 + ) write_file(project_root / "tiny.txt", "x") with open(project_root / "mempalace.yaml", "w") as f: