diff --git a/mempalace/cli.py b/mempalace/cli.py index 27c81d3..9a1e8e4 100644 --- a/mempalace/cli.py +++ b/mempalace/cli.py @@ -760,7 +760,7 @@ def cmd_repair(args): if getattr(e, "live_replaced", False): print(" Live collection was already replaced; restoring from backup...") try: - _close_chroma_handles(palace_path) + _close_chroma_handles(palace_path, backend=backend) if os.path.exists(palace_path): shutil.rmtree(palace_path) shutil.copytree(backup_path, palace_path) diff --git a/mempalace/repair.py b/mempalace/repair.py index 49d6abe..0585405 100644 --- a/mempalace/repair.py +++ b/mempalace/repair.py @@ -517,7 +517,7 @@ def rebuild_index(palace_path=None, confirm_truncation_ok: bool = False): if e.live_replaced and os.path.exists(backup_path): print(f" Restoring from backup: {backup_path}") try: - _close_chroma_handles(palace_path) + _close_chroma_handles(palace_path, backend=backend) _delete_collection_if_exists(backend, palace_path, COLLECTION_NAME) shutil.copy2(backup_path, sqlite_path) print(" Backup restored. Palace is back to pre-repair state.") @@ -593,12 +593,18 @@ def status(palace_path=None) -> dict: # --------------------------------------------------------------------------- -def _close_chroma_handles(palace_path: str) -> None: - """Drop ChromaBackend + chromadb singleton caches so OS mmap handles release.""" +def _close_chroma_handles(palace_path: str, backend: ChromaBackend | None = None) -> None: + """Drop ChromaBackend + chromadb singleton caches so OS mmap handles release. + + When ``backend`` is provided, close the live instance so rollback/restore + releases the handles it was already using. Otherwise fall back to a + transient backend instance for the max-seq-id repair path. + """ import gc try: - ChromaBackend().close_palace(palace_path) + closer = backend if backend is not None else ChromaBackend() + closer.close_palace(palace_path) except Exception: pass try: diff --git a/tests/test_cli.py b/tests/test_cli.py index 11845fe..6572f1d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -799,17 +799,14 @@ def test_cmd_repair_restores_backup_on_live_rebuild_failure(mock_config_cls, tmp mock_temp_col.count.return_value = 2 mock_backend = _mock_backend_for(col=mock_col) mock_backend.create_collection.side_effect = [mock_temp_col, RuntimeError("live build failed")] - with ( - patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend), - patch("mempalace.repair._close_chroma_handles") as mock_close_handles, - ): + with patch("mempalace.backends.chroma.ChromaBackend", return_value=mock_backend): with pytest.raises(SystemExit) as excinfo: cmd_repair(args) out = capsys.readouterr().out assert excinfo.value.code == 1 assert "Repair failed" in out assert "restoring from backup" in out - mock_close_handles.assert_called_once_with(str(palace_dir)) + mock_backend.close_palace.assert_called_once_with(str(palace_dir)) assert mock_backend.delete_collection.call_args_list == [ call(str(palace_dir), "mempalace_drawers__repair_tmp"), call(str(palace_dir), "mempalace_drawers"), diff --git a/tests/test_repair.py b/tests/test_repair.py index 33daad9..9cd12dd 100644 --- a/tests/test_repair.py +++ b/tests/test_repair.py @@ -499,20 +499,25 @@ def test_rebuild_index_live_failure_restores_backup(mock_backend_cls, mock_shuti mock_temp_col.count.return_value = 2 mock_new_col = MagicMock() mock_new_col.upsert.side_effect = RuntimeError("live upsert failed") - mock_backend = _install_mock_backend(mock_backend_cls, mock_col) - mock_backend.create_collection.side_effect = [mock_temp_col, mock_new_col] + active_backend = MagicMock() + active_backend.get_collection.return_value = mock_col + active_backend.create_collection.side_effect = [mock_temp_col, mock_new_col] + helper_backend = MagicMock() + mock_backend_cls.side_effect = [active_backend, helper_backend] with pytest.raises(repair.RebuildCollectionError) as excinfo: repair.rebuild_index(palace_path=str(tmp_path)) assert excinfo.value.live_replaced is True assert mock_shutil.copy2.call_count == 2 - assert mock_backend.delete_collection.call_args_list == [ + assert active_backend.delete_collection.call_args_list == [ call(str(tmp_path), "mempalace_drawers__repair_tmp"), call(str(tmp_path), "mempalace_drawers"), call(str(tmp_path), "mempalace_drawers__repair_tmp"), call(str(tmp_path), "mempalace_drawers"), ] + active_backend.close_palace.assert_called_once_with(str(tmp_path)) + helper_backend.close_palace.assert_not_called() @patch("mempalace.repair.shutil")