diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..b112254 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,13 @@ +# Default owners for everything +* @milla-jovovich @bensig @igorls + +# Core library +mempalace/ @milla-jovovich @bensig + +# CI and workflows +.github/ @bensig + +# Plugins and integrations +.claude-plugin/ @bensig +.codex-plugin/ @bensig +integrations/ @bensig diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..220218c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 3 diff --git a/.github/workflows/bump-plugin-version.yml b/.github/workflows/bump-plugin-version.yml.disabled similarity index 100% rename from .github/workflows/bump-plugin-version.yml rename to .github/workflows/bump-plugin-version.yml.disabled diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 302c8e9..9c96883 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - run: pip install -e ".[dev]" - - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85 + - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10 test-windows: runs-on: windows-latest @@ -28,7 +28,7 @@ jobs: with: python-version: "3.9" - run: pip install -e ".[dev]" - - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85 + - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10 test-macos: runs-on: macos-latest @@ -38,7 +38,7 @@ jobs: with: python-version: "3.9" - run: pip install -e ".[dev]" - - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85 + - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10 lint: runs-on: ubuntu-latest steps: diff --git a/.gitignore b/.gitignore index c8b10cc..1f3b03e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,30 @@ __pycache__/ .pytest_cache/ mempal.yaml .a5c/ + +# Environment +.env +.env.* + +# OS +.DS_Store +Thumbs.db + +# IDEs +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Coverage +htmlcov/ +.coverage +coverage.xml + +# Virtual environments +.venv/ +venv/ + +# ChromaDB local data +*.sqlite3-journal diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..3026013 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,78 @@ +# AGENTS.md + +> How to build, test, and contribute to MemPalace. + +## Setup + +```bash +pip install -e ".[dev]" +``` + +## Commands + +```bash +# Run tests +python -m pytest tests/ -v --ignore=tests/benchmarks + +# Run tests with coverage +python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing + +# Lint +ruff check . + +# Format +ruff format . + +# Format check (CI mode) +ruff format --check . +``` + +## Project structure + +``` +mempalace/ +├── mcp_server.py # MCP server — all read/write tools +├── miner.py # Project file miner +├── convo_miner.py # Conversation transcript miner +├── searcher.py # Semantic search +├── knowledge_graph.py # Temporal entity-relationship graph (SQLite) +├── palace.py # Shared palace operations (ChromaDB access) +├── config.py # Configuration + input validation +├── normalize.py # Transcript format detection + normalization +├── cli.py # CLI dispatcher +├── dialect.py # AAAK compression dialect +├── palace_graph.py # Room traversal + cross-wing tunnels +├── hooks_cli.py # Hook system for auto-save +└── version.py # Single source of truth for version +``` + +## Conventions + +- **Python style**: snake_case for functions/variables, PascalCase for classes +- **Linter**: ruff with E/F/W rules +- **Formatter**: ruff format, double quotes +- **Commits**: conventional commits (`fix:`, `feat:`, `test:`, `docs:`, `ci:`) +- **Tests**: `tests/test_*.py`, fixtures in `tests/conftest.py` +- **Coverage**: 85% threshold (80% on Windows due to ChromaDB file lock cleanup) + +## Architecture + +``` +User → CLI / MCP Server → ChromaDB (vector store) + SQLite (knowledge graph) + +Palace structure: + WING (person/project) + └── ROOM (topic) + └── DRAWER (verbatim text chunk) + +Knowledge Graph: + ENTITY → PREDICATE → ENTITY (with valid_from / valid_to dates) +``` + +## Key files for common tasks + +- **Adding an MCP tool**: `mempalace/mcp_server.py` — add handler function + TOOLS dict entry +- **Changing search**: `mempalace/searcher.py` +- **Modifying mining**: `mempalace/miner.py` (project files) or `mempalace/convo_miner.py` (transcripts) +- **Input validation**: `mempalace/config.py` — `sanitize_name()` / `sanitize_content()` +- **Tests**: mirror source structure in `tests/test_.py` diff --git a/README.md b/README.md index a1a7ccb..1ef11e1 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,18 @@ Other memory systems try to fix this by letting AI decide what's worth rememberi --- +## An important follow up note regarding fake MemPalace websites - April 11, 2026 + +Several Community Members (#267, #326, #506) have pointed out there are fake MemPalace websites popping up, including ones with Malware. + +To be super clear, MemPalace *has no website* (at least for now), so anything claiming to be one is false. + +Thanks to our Community Members for letting us know about the problem. + +Stay safe out there. + +--- + ## Quick Start ```bash @@ -585,6 +597,9 @@ mempalace compress --wing myapp # AAAK compress # Status mempalace status # palace overview + +# MCP +mempalace mcp # show MCP setup command ``` All commands accept `--palace ` to override the default location. @@ -707,7 +722,7 @@ PRs welcome. See [CONTRIBUTING.md](CONTRIBUTING.md) for setup and guidelines. MIT — see [LICENSE](LICENSE). -[version-shield]: https://img.shields.io/badge/version-3.0.0-4dc9f6?style=flat-square&labelColor=0a0e14 +[version-shield]: https://img.shields.io/badge/version-3.1.0-4dc9f6?style=flat-square&labelColor=0a0e14 [release-link]: https://github.com/milla-jovovich/mempalace/releases [python-shield]: https://img.shields.io/badge/python-3.9+-7dd8f8?style=flat-square&labelColor=0a0e14&logo=python&logoColor=7dd8f8 [python-link]: https://www.python.org/ diff --git a/docs/schema.sql b/docs/schema.sql new file mode 100644 index 0000000..740db70 --- /dev/null +++ b/docs/schema.sql @@ -0,0 +1,36 @@ +-- MemPalace Knowledge Graph Schema +-- SQLite database at ~/.mempalace/knowledge_graph.db + +CREATE TABLE IF NOT EXISTS entities ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + type TEXT DEFAULT 'unknown', + properties TEXT DEFAULT '{}' +); + +CREATE TABLE IF NOT EXISTS triples ( + id TEXT PRIMARY KEY, + subject TEXT NOT NULL, + predicate TEXT NOT NULL, + object TEXT NOT NULL, + valid_from TEXT, + valid_to TEXT, + confidence REAL DEFAULT 1.0, + source_closet TEXT, + source_file TEXT +); + +CREATE TABLE IF NOT EXISTS attributes ( + entity_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT, + valid_from TEXT, + valid_to TEXT, + PRIMARY KEY (entity_id, key, valid_from) +); + +-- Indexes +CREATE INDEX IF NOT EXISTS idx_triples_subject ON triples(subject); +CREATE INDEX IF NOT EXISTS idx_triples_object ON triples(object); +CREATE INDEX IF NOT EXISTS idx_triples_predicate ON triples(predicate); +CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to); diff --git a/hooks/mempal_save_hook.sh b/hooks/mempal_save_hook.sh index 75abfc8..a0e4681 100755 --- a/hooks/mempal_save_hook.sh +++ b/hooks/mempal_save_hook.sh @@ -64,13 +64,20 @@ MEMPAL_DIR="" # Read JSON input from stdin INPUT=$(cat) -# Parse fields from Claude Code's JSON -SESSION_ID=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('session_id','unknown'))" 2>/dev/null) -# Sanitize SESSION_ID to prevent path traversal (only allow alnum, dash, underscore) -SESSION_ID=$(echo "$SESSION_ID" | tr -cd 'a-zA-Z0-9_-') -[ -z "$SESSION_ID" ] && SESSION_ID="unknown" -STOP_HOOK_ACTIVE=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('stop_hook_active', False))" 2>/dev/null) -TRANSCRIPT_PATH=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('transcript_path',''))" 2>/dev/null) +# Parse all fields in a single Python call (3x faster than separate invocations) +eval $(echo "$INPUT" | python3 -c " +import sys, json +data = json.load(sys.stdin) +sid = data.get('session_id', 'unknown') +sha = data.get('stop_hook_active', False) +tp = data.get('transcript_path', '') +# Shell-safe output — only allow alphanumeric, underscore, hyphen, slash, dot, tilde +import re +safe = lambda s: re.sub(r'[^a-zA-Z0-9_/.\-~]', '', str(s)) +print(f'SESSION_ID=\"{safe(sid)}\"') +print(f'STOP_HOOK_ACTIVE=\"{sha}\"') +print(f'TRANSCRIPT_PATH=\"{safe(tp)}\"') +" 2>/dev/null) # Expand ~ in path TRANSCRIPT_PATH="${TRANSCRIPT_PATH/#\~/$HOME}" @@ -83,6 +90,7 @@ if [ "$STOP_HOOK_ACTIVE" = "True" ] || [ "$STOP_HOOK_ACTIVE" = "true" ]; then fi # Count human messages in the JSONL transcript +# SECURITY: Pass transcript path as sys.argv to avoid shell injection via crafted paths if [ -f "$TRANSCRIPT_PATH" ]; then EXCHANGE_COUNT=$(python3 - "$TRANSCRIPT_PATH" <<'PYEOF' import json, sys @@ -94,7 +102,6 @@ with open(sys.argv[1]) as f: msg = entry.get('message', {}) if isinstance(msg, dict) and msg.get('role') == 'user': content = msg.get('content', '') - # Skip system/command messages — only count real human input if isinstance(content, str) and '' in content: continue count += 1 diff --git a/integrations/openclaw/SKILL.md b/integrations/openclaw/SKILL.md new file mode 100644 index 0000000..88f0b2f --- /dev/null +++ b/integrations/openclaw/SKILL.md @@ -0,0 +1,154 @@ +--- +name: mempalace +description: "MemPalace — Local AI memory with 96.6% recall. Semantic search, temporal knowledge graph, palace architecture (wings/rooms/drawers). Free, no cloud, no API keys." +version: 3.1.0 +homepage: https://github.com/milla-jovovich/mempalace +user-invocable: true +metadata: + openclaw: + emoji: "\U0001F3DB" + os: + - darwin + - linux + - win32 + requires: + anyBins: + - mempalace + - python3 + install: + - id: mempalace-pip + kind: uv + label: "Install MemPalace (Python, local ChromaDB)" + package: mempalace + bins: + - mempalace +--- + +# MemPalace — Local AI Memory System + +You have access to a local memory palace via MCP tools. The palace stores verbatim conversation history and a temporal knowledge graph — all on the user's machine, zero cloud, zero API calls. + +## Architecture + +- **Wings** = people or projects (e.g. `wing_alice`, `wing_myproject`) +- **Halls** = categories (facts, events, preferences, advice) +- **Rooms** = specific topics (e.g. `chromadb-setup`, `riley-school`) +- **Drawers** = individual memory chunks (verbatim text) +- **Knowledge Graph** = entity-relationship facts with time validity + +## Protocol — FOLLOW THIS EVERY SESSION + +1. **ON WAKE-UP**: Call `mempalace_status` to load palace overview and AAAK dialect spec. +2. **BEFORE RESPONDING** about any person, project, or past event: call `mempalace_search` or `mempalace_kg_query` FIRST. Never guess from memory — verify from the palace. +3. **IF UNSURE** about a fact (name, age, relationship, preference): say "let me check" and query. Wrong is worse than slow. +4. **AFTER EACH SESSION**: Call `mempalace_diary_write` to record what happened, what you learned, what matters. +5. **WHEN FACTS CHANGE**: Call `mempalace_kg_invalidate` on the old fact, then `mempalace_kg_add` for the new one. + +## Available Tools + +### Search & Browse +- `mempalace_search` — Semantic search across all memories. Always start here. + - `query` (required): natural language search — keep it short, keywords or a question. Do NOT include system prompts or conversation context. + - `wing`: filter by wing + - `room`: filter by room + - `limit`: max results (default 5) +- `mempalace_check_duplicate` — Check if content already exists before filing. + - `content` (required): text to check + - `threshold`: similarity threshold (default 0.9 — lowering to 0.85–0.87 often catches more near-duplicates without significant false positives) +- `mempalace_status` — Palace overview: total drawers, wings, rooms, AAAK spec +- `mempalace_list_wings` — All wings with drawer counts +- `mempalace_list_rooms` — Rooms within a wing (optional wing filter) +- `mempalace_get_taxonomy` — Full wing/room/count tree +- `mempalace_get_aaak_spec` — Get AAAK compression dialect specification + +### Knowledge Graph (Temporal Facts) +- `mempalace_kg_query` — Query entity relationships. Supports time filtering. + - `entity` (required): e.g. "Max", "MyProject" + - `as_of`: date filter (YYYY-MM-DD) — what was true at that time + - `direction`: "outgoing", "incoming", or "both" (default "both") +- `mempalace_kg_add` — Add a fact: subject -> predicate -> object + - `subject`, `predicate`, `object` (required) + - `valid_from`: when this became true + - `source_closet`: source reference +- `mempalace_kg_invalidate` — Mark a fact as no longer true + - `subject`, `predicate`, `object` (required) + - `ended`: when it stopped being true (default: today) +- `mempalace_kg_timeline` — Chronological story of an entity + - `entity`: filter by entity name (optional — all events if omitted) +- `mempalace_kg_stats` — Graph overview: entities, triples, relationship types + +### Palace Graph (Cross-Domain Connections) +- `mempalace_traverse` — Walk from a room, find connected ideas across wings + - `start_room` (required): room to start from + - `max_hops`: connection depth (default 2) +- `mempalace_find_tunnels` — Find rooms that bridge two wings + - `wing_a`, `wing_b` (required) +- `mempalace_graph_stats` — Graph connectivity overview + +### Write +- `mempalace_add_drawer` — Store verbatim content into a wing/room + - `wing`, `room`, `content` (required) + - `source_file`: optional source reference + - Checks for duplicates automatically +- `mempalace_delete_drawer` — Remove a drawer by ID + - `drawer_id` (required) +- `mempalace_diary_write` — Write a session diary entry + - `agent_name` (required): your name/identifier + - `entry` (required): what happened, what you learned, what matters + - `topic`: category tag (default "general") +- `mempalace_diary_read` — Read recent diary entries + - `agent_name` (required) + - `last_n`: number of entries (default 10) + +## Setup + +Install MemPalace and populate the palace: + +```bash +pip install mempalace +mempalace init ~/my-convos +mempalace mine ~/my-convos +``` + +### OpenClaw MCP config + +Add to your OpenClaw MCP configuration: + +```json +{ + "mcpServers": { + "mempalace": { + "command": "python3", + "args": ["-m", "mempalace.mcp_server"] + } + } +} +``` + +Or via CLI: + +```bash +openclaw mcp set mempalace '{"command":"python3","args":["-m","mempalace.mcp_server"]}' +``` + +### Other MCP hosts + +```bash +# Claude Code +claude mcp add mempalace -- python -m mempalace.mcp_server + +# Cursor — add to .cursor/mcp.json +# Codex — add to .codex/mcp.json +``` + +## Tips + +- Search is semantic (meaning-based), not keyword. "What did we discuss about database performance?" works better than "database". +- The knowledge graph stores typed relationships with time windows. Use it for facts about people and projects — it knows WHEN things were true. +- Diary entries accumulate across sessions. Write one at the end of each conversation to build continuity. +- Use `mempalace_check_duplicate` before storing new content to avoid duplicates. +- The AAAK dialect (from `mempalace_status`) is a compressed notation for efficient storage. Read it naturally — expand codes mentally, treat *markers* as emotional context. + +## License + +[MemPalace](https://github.com/milla-jovovich/mempalace) is MIT licensed. Created by Milla Jovovich, Ben Sigman, Igor Lins e Silva, and contributors. diff --git a/mempalace/cli.py b/mempalace/cli.py index 0a24abf..1d106ca 100644 --- a/mempalace/cli.py +++ b/mempalace/cli.py @@ -14,6 +14,7 @@ Commands: mempalace mine Mine project files (default) mempalace mine --mode convos Mine conversation exports mempalace search "query" Find anything, exact words + mempalace mcp Show MCP setup command mempalace wake-up Show L0 + L1 wake-up context mempalace wake-up --wing my_app Wake-up for a specific project mempalace status Show what's been filed @@ -28,6 +29,7 @@ Examples: import os import sys +import shlex import argparse from pathlib import Path @@ -148,6 +150,14 @@ def cmd_split(args): sys.argv = old_argv +def cmd_migrate(args): + """Migrate palace from a different ChromaDB version.""" + from .migrate import migrate + + palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path + migrate(palace_path=palace_path, dry_run=args.dry_run) + + def cmd_status(args): from .miner import status @@ -202,6 +212,7 @@ def cmd_repair(args): print(f" Extracted {len(all_ids)} drawers") # Backup and rebuild + palace_path = palace_path.rstrip(os.sep) backup_path = palace_path + ".backup" if os.path.exists(backup_path): shutil.rmtree(backup_path) @@ -240,6 +251,27 @@ def cmd_instructions(args): run_instructions(name=args.name) +def cmd_mcp(args): + """Show how to wire MemPalace into MCP-capable hosts.""" + base_server_cmd = "python -m mempalace.mcp_server" + + if args.palace: + resolved_palace = str(Path(args.palace).expanduser()) + server_cmd = f"{base_server_cmd} --palace {shlex.quote(resolved_palace)}" + else: + server_cmd = base_server_cmd + + print("MemPalace MCP quick setup:") + print(f" claude mcp add mempalace -- {server_cmd}") + print("\nRun the server directly:") + print(f" {server_cmd}") + + if not args.palace: + print("\nOptional custom palace:") + print(f" claude mcp add mempalace -- {base_server_cmd} --palace /path/to/palace") + print(f" {base_server_cmd} --palace /path/to/palace") + + def cmd_compress(args): """Compress drawers in a wing using AAAK Dialect.""" import chromadb @@ -500,7 +532,24 @@ def main(): help="Rebuild palace vector index from stored data (fixes segfaults after corruption)", ) + # mcp + sub.add_parser( + "mcp", + help="Show MCP setup command for connecting MemPalace to your AI client", + ) + # status + # migrate + p_migrate = sub.add_parser( + "migrate", + help="Migrate palace from a different ChromaDB version (fixes 3.0.0 → 3.1.0 upgrade)", + ) + p_migrate.add_argument( + "--dry-run", + action="store_true", + help="Show what would be migrated without changing anything", + ) + sub.add_parser("status", help="Show what's been filed") args = parser.parse_args() @@ -531,9 +580,11 @@ def main(): "mine": cmd_mine, "split": cmd_split, "search": cmd_search, + "mcp": cmd_mcp, "compress": cmd_compress, "wake-up": cmd_wakeup, "repair": cmd_repair, + "migrate": cmd_migrate, "status": cmd_status, } dispatch[args.command](args) diff --git a/mempalace/config.py b/mempalace/config.py index 5a73650..fcfb2c8 100644 --- a/mempalace/config.py +++ b/mempalace/config.py @@ -6,8 +6,58 @@ Priority: env vars > config file (~/.mempalace/config.json) > defaults import json import os +import re from pathlib import Path + +# ── Input validation ────────────────────────────────────────────────────────── +# Shared sanitizers for wing/room/entity names. Prevents path traversal, +# excessively long strings, and special characters that could cause issues +# in file paths, SQLite, or ChromaDB metadata. + +MAX_NAME_LENGTH = 128 +_SAFE_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_ .'-]{0,126}[a-zA-Z0-9]?$") + + +def sanitize_name(value: str, field_name: str = "name") -> str: + """Validate and sanitize a wing/room/entity name. + + Raises ValueError if the name is invalid. + """ + if not isinstance(value, str) or not value.strip(): + raise ValueError(f"{field_name} must be a non-empty string") + + value = value.strip() + + if len(value) > MAX_NAME_LENGTH: + raise ValueError(f"{field_name} exceeds maximum length of {MAX_NAME_LENGTH} characters") + + # Block path traversal + if ".." in value or "/" in value or "\\" in value: + raise ValueError(f"{field_name} contains invalid path characters") + + # Block null bytes + if "\x00" in value: + raise ValueError(f"{field_name} contains null bytes") + + # Enforce safe character set + if not _SAFE_NAME_RE.match(value): + raise ValueError(f"{field_name} contains invalid characters") + + return value + + +def sanitize_content(value: str, max_length: int = 100_000) -> str: + """Validate drawer/diary content length.""" + if not isinstance(value, str) or not value.strip(): + raise ValueError("content must be a non-empty string") + if len(value) > max_length: + raise ValueError(f"content exceeds maximum length of {max_length} characters") + if "\x00" in value: + raise ValueError("content contains null bytes") + return value + + DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace") DEFAULT_COLLECTION_NAME = "mempalace_drawers" @@ -126,6 +176,11 @@ class MempalaceConfig: def init(self): """Create config directory and write default config.json if it doesn't exist.""" self._config_dir.mkdir(parents=True, exist_ok=True) + # Restrict directory permissions to owner only (Unix) + try: + self._config_dir.chmod(0o700) + except (OSError, NotImplementedError): + pass # Windows doesn't support Unix permissions if not self._config_file.exists(): default_config = { "palace_path": DEFAULT_PALACE_PATH, @@ -135,6 +190,11 @@ class MempalaceConfig: } with open(self._config_file, "w") as f: json.dump(default_config, f, indent=2) + # Restrict config file to owner read/write only + try: + self._config_file.chmod(0o600) + except (OSError, NotImplementedError): + pass return self._config_file def save_people_map(self, people_map): diff --git a/mempalace/convo_miner.py b/mempalace/convo_miner.py index c316407..3bb4a89 100644 --- a/mempalace/convo_miner.py +++ b/mempalace/convo_miner.py @@ -15,9 +15,8 @@ from pathlib import Path from datetime import datetime from collections import defaultdict -import chromadb - from .normalize import normalize +from .palace import SKIP_DIRS, get_collection, file_already_mined # File types that might contain conversations @@ -28,22 +27,8 @@ CONVO_EXTENSIONS = { ".jsonl", } -SKIP_DIRS = { - ".git", - "node_modules", - "__pycache__", - ".venv", - "venv", - "env", - "dist", - "build", - ".next", - ".mempalace", - "tool-results", - "memory", -} - MIN_CHUNK_SIZE = 30 +MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this # ============================================================================= @@ -211,23 +196,6 @@ def detect_convo_room(content: str) -> str: # ============================================================================= -def get_collection(palace_path: str): - os.makedirs(palace_path, exist_ok=True) - client = chromadb.PersistentClient(path=palace_path) - try: - return client.get_collection("mempalace_drawers") - except Exception: - return client.create_collection("mempalace_drawers") - - -def file_already_mined(collection, source_file: str) -> bool: - try: - results = collection.get(where={"source_file": source_file}, limit=1) - return len(results.get("ids", [])) > 0 - except Exception: - return False - - # ============================================================================= # SCAN FOR CONVERSATION FILES # ============================================================================= @@ -244,6 +212,14 @@ def scan_convos(convo_dir: str) -> list: continue filepath = Path(root) / filename if filepath.suffix.lower() in CONVO_EXTENSIONS: + # Skip symlinks and oversized files + if filepath.is_symlink(): + continue + try: + if filepath.stat().st_size > MAX_FILE_SIZE: + continue + except OSError: + continue files.append(filepath) return files @@ -356,9 +332,9 @@ def mine_convos( chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room if extract_mode == "general": room_counts[chunk_room] += 1 - drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.md5((source_file + str(chunk['chunk_index'])).encode(), usedforsecurity=False).hexdigest()[:16]}" + drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}" try: - collection.add( + collection.upsert( documents=[chunk["content"]], ids=[drawer_id], metadatas=[ diff --git a/mempalace/dedup.py b/mempalace/dedup.py new file mode 100644 index 0000000..c2f9f6b --- /dev/null +++ b/mempalace/dedup.py @@ -0,0 +1,239 @@ +""" +dedup.py — Detect and remove near-duplicate drawers +==================================================== + +When the same files are mined multiple times, near-identical drawers +accumulate. This module finds drawers from the same source_file that +are too similar (cosine distance < threshold), keeps the longest/richest +version, and deletes the rest. + +No API calls — uses ChromaDB's built-in embedding similarity. + +Usage (standalone): + python -m mempalace.dedup # dedup all + python -m mempalace.dedup --dry-run # preview only + python -m mempalace.dedup --threshold 0.10 # stricter (near-identical only) + python -m mempalace.dedup --threshold 0.35 # looser (catches paraphrased content) + python -m mempalace.dedup --wing my_project # scope to one wing + python -m mempalace.dedup --stats # stats only + python -m mempalace.dedup --source "my_project" # filter by source + +Usage (from CLI): + mempalace dedup [--dry-run] [--threshold 0.15] [--stats] +""" + +import argparse +import os +import time +from collections import defaultdict + +import chromadb + + +COLLECTION_NAME = "mempalace_drawers" +# Cosine DISTANCE threshold (not similarity). Lower = stricter. +# 0.15 = ~85% cosine similarity — catches near-identical chunks. +# For looser dedup of paraphrased content, try 0.3–0.4. +DEFAULT_THRESHOLD = 0.15 +MIN_DRAWERS_TO_CHECK = 5 + + +def _get_palace_path(): + """Resolve palace path from config.""" + try: + from .config import MempalaceConfig + + return MempalaceConfig().palace_path + except Exception: + return os.path.join(os.path.expanduser("~"), ".mempalace", "palace") + + +def get_source_groups(col, min_count=MIN_DRAWERS_TO_CHECK, source_pattern=None, wing=None): + """Group drawers by source_file, return groups with min_count+ entries. + + If wing is specified, only considers drawers in that wing. This catches + cross-wing duplicates when the same source was mined into multiple wings. + """ + total = col.count() + groups = defaultdict(list) + + offset = 0 + batch_size = 1000 + while offset < total: + kwargs = {"limit": batch_size, "offset": offset, "include": ["metadatas"]} + if wing: + kwargs["where"] = {"wing": wing} + batch = col.get(**kwargs) + if not batch["ids"]: + break + for did, meta in zip(batch["ids"], batch["metadatas"]): + src = meta.get("source_file", "unknown") + if source_pattern and source_pattern.lower() not in src.lower(): + continue + groups[src].append(did) + offset += len(batch["ids"]) + + return {src: ids for src, ids in groups.items() if len(ids) >= min_count} + + +def dedup_source_group(col, drawer_ids, threshold=DEFAULT_THRESHOLD, dry_run=True): + """Dedup drawers within one source_file group. + + Greedy: sort by doc length (longest first), keep if not too similar + to any already-kept drawer. Returns (kept_ids, deleted_ids). + """ + data = col.get(ids=drawer_ids, include=["documents", "metadatas"]) + items = list(zip(data["ids"], data["documents"], data["metadatas"])) + items.sort(key=lambda x: len(x[1] or ""), reverse=True) + + kept = [] + to_delete = [] + + for did, doc, meta in items: + if not doc or len(doc) < 20: + to_delete.append(did) + continue + + if not kept: + kept.append((did, doc)) + continue + + try: + results = col.query( + query_texts=[doc], + n_results=min(len(kept), 5), + include=["distances"], + ) + dists = results["distances"][0] if results["distances"] else [] + kept_ids_set = {k[0] for k in kept} + + is_dup = False + for rid, dist in zip(results["ids"][0], dists): + if rid in kept_ids_set and dist < threshold: + is_dup = True + break + + if is_dup: + to_delete.append(did) + else: + kept.append((did, doc)) + except Exception: + kept.append((did, doc)) + + if to_delete and not dry_run: + for i in range(0, len(to_delete), 500): + col.delete(ids=to_delete[i : i + 500]) + + return [k[0] for k in kept], to_delete + + +def show_stats(palace_path=None): + """Show duplication statistics without making changes.""" + palace_path = palace_path or _get_palace_path() + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection(COLLECTION_NAME) + + groups = get_source_groups(col) + + total_drawers = sum(len(ids) for ids in groups.values()) + print(f"\n Sources with {MIN_DRAWERS_TO_CHECK}+ drawers: {len(groups)}") + print(f" Total drawers in those sources: {total_drawers:,}") + + print("\n Top 15 by drawer count:") + sorted_groups = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True) + for src, ids in sorted_groups[:15]: + print(f" {len(ids):4d} {src[:65]}") + + estimated_dups = sum(int(len(ids) * 0.4) for ids in groups.values() if len(ids) > 20) + print(f"\n Estimated duplicates (groups > 20): ~{estimated_dups:,}") + + +def dedup_palace( + palace_path=None, + threshold=DEFAULT_THRESHOLD, + dry_run=True, + source_pattern=None, + min_count=MIN_DRAWERS_TO_CHECK, + wing=None, +): + """Main entry point: deduplicate near-identical drawers across the palace.""" + palace_path = palace_path or _get_palace_path() + + print(f"\n{'=' * 55}") + print(" MemPalace Deduplicator") + print(f"{'=' * 55}") + + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection(COLLECTION_NAME) + + print(f" Palace: {palace_path}") + print(f" Drawers: {col.count():,}") + print(f" Threshold: {threshold}") + print(f" Mode: {'DRY RUN' if dry_run else 'LIVE'}") + print(f"{'─' * 55}") + + if wing: + print(f" Wing: {wing}") + groups = get_source_groups(col, min_count, source_pattern, wing=wing) + print(f"\n Sources to check: {len(groups)}") + + t0 = time.time() + total_kept = 0 + total_deleted = 0 + + sorted_groups = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True) + + for i, (src, drawer_ids) in enumerate(sorted_groups): + kept, deleted = dedup_source_group(col, drawer_ids, threshold, dry_run) + total_kept += len(kept) + total_deleted += len(deleted) + + if deleted: + print( + f" [{i + 1:3d}/{len(groups)}] " + f"{src[:50]:50s} {len(drawer_ids):4d} → {len(kept):4d} " + f"(-{len(deleted)})" + ) + + elapsed = time.time() - t0 + + print(f"\n{'─' * 55}") + print(f" Done in {elapsed:.1f}s") + print( + f" Drawers: {total_kept + total_deleted:,} → {total_kept:,} (-{total_deleted:,} removed)" + ) + print(f" Palace after: {col.count():,} drawers") + + if dry_run: + print("\n [DRY RUN] No changes written. Re-run without --dry-run to apply.") + + print(f"{'=' * 55}\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Deduplicate near-identical drawers") + parser.add_argument("--palace", default=None, help="Palace directory path") + parser.add_argument( + "--threshold", + type=float, + default=DEFAULT_THRESHOLD, + help=f"Cosine distance threshold (default: {DEFAULT_THRESHOLD})", + ) + parser.add_argument("--dry-run", action="store_true", help="Preview without deleting") + parser.add_argument("--stats", action="store_true", help="Show stats only") + parser.add_argument("--wing", default=None, help="Scope dedup to a single wing") + parser.add_argument("--source", default=None, help="Filter by source file pattern") + args = parser.parse_args() + + path = os.path.expanduser(args.palace) if args.palace else None + + if args.stats: + show_stats(palace_path=path) + else: + dedup_palace( + palace_path=path, + threshold=args.threshold, + dry_run=args.dry_run, + source_pattern=args.source, + wing=args.wing, + ) diff --git a/mempalace/hooks_cli.py b/mempalace/hooks_cli.py index 3f3fc09..b6d2290 100644 --- a/mempalace/hooks_cli.py +++ b/mempalace/hooks_cli.py @@ -63,6 +63,14 @@ def _count_human_messages(transcript_path: str) -> int: if "" in text: continue count += 1 + # Also handle Codex CLI transcript format + # {"type": "event_msg", "payload": {"type": "user_message", "message": "..."}} + elif entry.get("type") == "event_msg": + payload = entry.get("payload", {}) + if isinstance(payload, dict) and payload.get("type") == "user_message": + msg_text = payload.get("message", "") + if isinstance(msg_text, str) and "" not in msg_text: + count += 1 except (json.JSONDecodeError, AttributeError): pass except OSError: diff --git a/mempalace/instructions/help.md b/mempalace/instructions/help.md index f18c1de..5cb70fa 100644 --- a/mempalace/instructions/help.md +++ b/mempalace/instructions/help.md @@ -60,6 +60,7 @@ AI memory system. Store everything, find anything. Local, free, no API key. mempalace compress Compress palace storage mempalace status Show palace status mempalace repair Rebuild vector index + mempalace mcp Show MCP setup command mempalace hook run Run hook logic (for harness integration) mempalace instructions Output skill instructions diff --git a/mempalace/knowledge_graph.py b/mempalace/knowledge_graph.py index 226c92d..b094f06 100644 --- a/mempalace/knowledge_graph.py +++ b/mempalace/knowledge_graph.py @@ -50,11 +50,14 @@ class KnowledgeGraph: def __init__(self, db_path: str = None): self.db_path = db_path or DEFAULT_KG_PATH Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) + self._connection = None self._init_db() def _init_db(self): conn = self._conn() conn.executescript(""" + PRAGMA journal_mode=WAL; + CREATE TABLE IF NOT EXISTS entities ( id TEXT PRIMARY KEY, name TEXT NOT NULL, @@ -84,12 +87,19 @@ class KnowledgeGraph: CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to); """) conn.commit() - conn.close() def _conn(self): - conn = sqlite3.connect(self.db_path, timeout=10) - conn.execute("PRAGMA journal_mode=WAL") - return conn + if self._connection is None: + self._connection = sqlite3.connect(self.db_path, timeout=10, check_same_thread=False) + self._connection.execute("PRAGMA journal_mode=WAL") + self._connection.row_factory = sqlite3.Row + return self._connection + + def close(self): + """Close the database connection.""" + if self._connection is not None: + self._connection.close() + self._connection = None def _entity_id(self, name: str) -> str: return name.lower().replace(" ", "_").replace("'", "") @@ -101,12 +111,11 @@ class KnowledgeGraph: eid = self._entity_id(name) props = json.dumps(properties or {}) conn = self._conn() - conn.execute( - "INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)", - (eid, name, entity_type, props), - ) - conn.commit() - conn.close() + with conn: + conn.execute( + "INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)", + (eid, name, entity_type, props), + ) return eid def add_triple( @@ -134,38 +143,38 @@ class KnowledgeGraph: # Auto-create entities if they don't exist conn = self._conn() - conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject)) - conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj)) + with conn: + conn.execute( + "INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject) + ) + conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj)) - # Check for existing identical triple - existing = conn.execute( - "SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", - (sub_id, pred, obj_id), - ).fetchone() + # Check for existing identical triple + existing = conn.execute( + "SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", + (sub_id, pred, obj_id), + ).fetchone() - if existing: - conn.close() - return existing[0] # Already exists and still valid + if existing: + return existing["id"] # Already exists and still valid - triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.md5(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:8]}" + triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.sha256(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:12]}" - conn.execute( - """INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - triple_id, - sub_id, - pred, - obj_id, - valid_from, - valid_to, - confidence, - source_closet, - source_file, - ), - ) - conn.commit() - conn.close() + conn.execute( + """INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + triple_id, + sub_id, + pred, + obj_id, + valid_from, + valid_to, + confidence, + source_closet, + source_file, + ), + ) return triple_id def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None): @@ -176,12 +185,11 @@ class KnowledgeGraph: ended = ended or date.today().isoformat() conn = self._conn() - conn.execute( - "UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", - (ended, sub_id, pred, obj_id), - ) - conn.commit() - conn.close() + with conn: + conn.execute( + "UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", + (ended, sub_id, pred, obj_id), + ) # ── Query operations ────────────────────────────────────────────────── @@ -208,13 +216,13 @@ class KnowledgeGraph: { "direction": "outgoing", "subject": name, - "predicate": row[2], - "object": row[10], # obj_name - "valid_from": row[4], - "valid_to": row[5], - "confidence": row[6], - "source_closet": row[7], - "current": row[5] is None, + "predicate": row["predicate"], + "object": row["obj_name"], + "valid_from": row["valid_from"], + "valid_to": row["valid_to"], + "confidence": row["confidence"], + "source_closet": row["source_closet"], + "current": row["valid_to"] is None, } ) @@ -228,18 +236,17 @@ class KnowledgeGraph: results.append( { "direction": "incoming", - "subject": row[10], # sub_name - "predicate": row[2], + "subject": row["sub_name"], + "predicate": row["predicate"], "object": name, - "valid_from": row[4], - "valid_to": row[5], - "confidence": row[6], - "source_closet": row[7], - "current": row[5] is None, + "valid_from": row["valid_from"], + "valid_to": row["valid_to"], + "confidence": row["confidence"], + "source_closet": row["source_closet"], + "current": row["valid_to"] is None, } ) - conn.close() return results def query_relationship(self, predicate: str, as_of: str = None): @@ -262,15 +269,14 @@ class KnowledgeGraph: for row in conn.execute(query, params).fetchall(): results.append( { - "subject": row[10], + "subject": row["sub_name"], "predicate": pred, - "object": row[11], - "valid_from": row[4], - "valid_to": row[5], - "current": row[5] is None, + "object": row["obj_name"], + "valid_from": row["valid_from"], + "valid_to": row["valid_to"], + "current": row["valid_to"] is None, } ) - conn.close() return results def timeline(self, entity_name: str = None): @@ -300,15 +306,14 @@ class KnowledgeGraph: LIMIT 100 """).fetchall() - conn.close() return [ { - "subject": r[10], - "predicate": r[2], - "object": r[11], - "valid_from": r[4], - "valid_to": r[5], - "current": r[5] is None, + "subject": r["sub_name"], + "predicate": r["predicate"], + "object": r["obj_name"], + "valid_from": r["valid_from"], + "valid_to": r["valid_to"], + "current": r["valid_to"] is None, } for r in rows ] @@ -317,17 +322,18 @@ class KnowledgeGraph: def stats(self): conn = self._conn() - entities = conn.execute("SELECT COUNT(*) FROM entities").fetchone()[0] - triples = conn.execute("SELECT COUNT(*) FROM triples").fetchone()[0] - current = conn.execute("SELECT COUNT(*) FROM triples WHERE valid_to IS NULL").fetchone()[0] + entities = conn.execute("SELECT COUNT(*) as cnt FROM entities").fetchone()["cnt"] + triples = conn.execute("SELECT COUNT(*) as cnt FROM triples").fetchone()["cnt"] + current = conn.execute( + "SELECT COUNT(*) as cnt FROM triples WHERE valid_to IS NULL" + ).fetchone()["cnt"] expired = triples - current predicates = [ - r[0] + r["predicate"] for r in conn.execute( "SELECT DISTINCT predicate FROM triples ORDER BY predicate" ).fetchall() ] - conn.close() return { "entities": entities, "triples": triples, diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index aa9ecd9..09203b6 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -24,8 +24,9 @@ import json import logging import hashlib from datetime import datetime +from pathlib import Path -from .config import MempalaceConfig +from .config import MempalaceConfig, sanitize_name, sanitize_content from .version import __version__ from .query_sanitizer import sanitize_query from .searcher import search_memories @@ -67,16 +68,60 @@ _client_cache = None _collection_cache = None +# ==================== WRITE-AHEAD LOG ==================== +# Every write operation is logged to a JSONL file before execution. +# This provides an audit trail for detecting memory poisoning and +# enables review/rollback of writes from external or untrusted sources. + +_WAL_DIR = Path(os.path.expanduser("~/.mempalace/wal")) +_WAL_DIR.mkdir(parents=True, exist_ok=True) +try: + _WAL_DIR.chmod(0o700) +except (OSError, NotImplementedError): + pass +_WAL_FILE = _WAL_DIR / "write_log.jsonl" + + +def _wal_log(operation: str, params: dict, result: dict = None): + """Append a write operation to the write-ahead log.""" + entry = { + "timestamp": datetime.now().isoformat(), + "operation": operation, + "params": params, + "result": result, + } + try: + with open(_WAL_FILE, "a", encoding="utf-8") as f: + f.write(json.dumps(entry, default=str) + "\n") + try: + _WAL_FILE.chmod(0o600) + except (OSError, NotImplementedError): + pass + except Exception as e: + logger.error(f"WAL write failed: {e}") + + +_client_cache = None +_collection_cache = None + + +def _get_client(): + """Return a singleton ChromaDB PersistentClient.""" + global _client_cache + if _client_cache is None: + _client_cache = chromadb.PersistentClient(path=_config.palace_path) + return _client_cache + + def _get_collection(create=False): """Return the ChromaDB collection, caching the client between calls.""" - global _client_cache, _collection_cache + global _collection_cache try: - if _client_cache is None: - _client_cache = chromadb.PersistentClient(path=_config.palace_path) + client = _get_client() if create: - _collection_cache = _client_cache.get_or_create_collection(_config.collection_name) + _collection_cache = client.get_or_create_collection(_config.collection_name) elif _collection_cache is None: - _collection_cache = _client_cache.get_collection(_config.collection_name) + _collection_cache = client.get_collection(_config.collection_name) return _collection_cache except Exception: return None @@ -99,16 +144,25 @@ def tool_status(): count = col.count() wings = {} rooms = {} - try: - all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"] - for m in all_meta: - w = m.get("wing", "unknown") - r = m.get("room", "unknown") - wings[w] = wings.get(w, 0) + 1 - rooms[r] = rooms.get(r, 0) + 1 - except Exception: - pass - return { + batch_size = 5000 + offset = 0 + error_info = None + while True: + try: + batch = col.get(include=["metadatas"], limit=batch_size, offset=offset) + rows = batch["metadatas"] + for m in rows: + w = m.get("wing", "unknown") + r = m.get("room", "unknown") + wings[w] = wings.get(w, 0) + 1 + rooms[r] = rooms.get(r, 0) + 1 + offset += len(rows) + if len(rows) < batch_size: + break + except Exception as e: + error_info = f"Partial result, failed at offset {offset}: {str(e)}" + break + result = { "total_drawers": count, "wings": wings, "rooms": rooms, @@ -116,6 +170,10 @@ def tool_status(): "protocol": PALACE_PROTOCOL, "aaak_dialect": AAAK_SPEC, } + if error_info: + result["error"] = error_info + result["partial"] = True + return result # ── AAAK Dialect Spec ───────────────────────────────────────────────────────── @@ -156,13 +214,28 @@ def tool_list_wings(): if not col: return _no_palace() wings = {} + batch_size = 5000 + offset = 0 try: - all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"] - for m in all_meta: - w = m.get("wing", "unknown") - wings[w] = wings.get(w, 0) + 1 - except Exception: - pass + col.count() # verify collection is accessible + except Exception as e: + return {"wings": {}, "error": str(e)} + while True: + try: + batch = col.get(include=["metadatas"], limit=batch_size, offset=offset) + rows = batch["metadatas"] + for m in rows: + w = m.get("wing", "unknown") + wings[w] = wings.get(w, 0) + 1 + offset += len(rows) + if len(rows) < batch_size: + break + except Exception as e: + return { + "wings": wings, + "error": f"Partial result, failed at offset {offset}: {str(e)}", + "partial": True, + } return {"wings": wings} @@ -171,16 +244,33 @@ def tool_list_rooms(wing: str = None): if not col: return _no_palace() rooms = {} + batch_size = 5000 + offset = 0 + where = {"wing": wing} if wing else None try: - kwargs = {"include": ["metadatas"], "limit": 10000} - if wing: - kwargs["where"] = {"wing": wing} - all_meta = col.get(**kwargs)["metadatas"] - for m in all_meta: - r = m.get("room", "unknown") - rooms[r] = rooms.get(r, 0) + 1 - except Exception: - pass + col.count() # verify collection is accessible + except Exception as e: + return {"wing": wing or "all", "rooms": {}, "error": str(e)} + while True: + try: + kwargs = {"include": ["metadatas"], "limit": batch_size, "offset": offset} + if where: + kwargs["where"] = where + batch = col.get(**kwargs) + rows = batch["metadatas"] + for m in rows: + r = m.get("room", "unknown") + rooms[r] = rooms.get(r, 0) + 1 + offset += len(rows) + if len(rows) < batch_size: + break + except Exception as e: + return { + "wing": wing or "all", + "rooms": rooms, + "error": f"Partial result, failed at offset {offset}: {str(e)}", + "partial": True, + } return {"wing": wing or "all", "rooms": rooms} @@ -189,16 +279,31 @@ def tool_get_taxonomy(): if not col: return _no_palace() taxonomy = {} + batch_size = 5000 + offset = 0 try: - all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"] - for m in all_meta: - w = m.get("wing", "unknown") - r = m.get("room", "unknown") - if w not in taxonomy: - taxonomy[w] = {} - taxonomy[w][r] = taxonomy[w].get(r, 0) + 1 - except Exception: - pass + col.count() # verify collection is accessible + except Exception as e: + return {"taxonomy": {}, "error": str(e)} + while True: + try: + batch = col.get(include=["metadatas"], limit=batch_size, offset=offset) + rows = batch["metadatas"] + for m in rows: + w = m.get("wing", "unknown") + r = m.get("room", "unknown") + if w not in taxonomy: + taxonomy[w] = {} + taxonomy[w][r] = taxonomy[w].get(r, 0) + 1 + offset += len(rows) + if len(rows) < batch_size: + break + except Exception as e: + return { + "taxonomy": taxonomy, + "error": f"Partial result, failed at offset {offset}: {str(e)}", + "partial": True, + } return {"taxonomy": taxonomy} @@ -299,11 +404,30 @@ def tool_add_drawer( wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp" ): """File verbatim content into a wing/room. Checks for duplicates first.""" + try: + wing = sanitize_name(wing, "wing") + room = sanitize_name(room, "room") + content = sanitize_content(content) + except ValueError as e: + return {"success": False, "error": str(e)} + col = _get_collection(create=True) if not col: return _no_palace() - drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}" + drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((wing + room + content[:100]).encode()).hexdigest()[:24]}" + + _wal_log( + "add_drawer", + { + "drawer_id": drawer_id, + "wing": wing, + "room": room, + "added_by": added_by, + "content_length": len(content), + "content_preview": content[:200], + }, + ) # Idempotency: if the deterministic ID already exists, return success as a no-op. try: @@ -342,6 +466,19 @@ def tool_delete_drawer(drawer_id: str): existing = col.get(ids=[drawer_id]) if not existing["ids"]: return {"success": False, "error": f"Drawer not found: {drawer_id}"} + + # Log the deletion with the content being removed for audit trail + deleted_content = existing.get("documents", [""])[0] if existing.get("documents") else "" + deleted_meta = existing.get("metadatas", [{}])[0] if existing.get("metadatas") else {} + _wal_log( + "delete_drawer", + { + "drawer_id": drawer_id, + "deleted_meta": deleted_meta, + "content_preview": deleted_content[:200], + }, + ) + try: col.delete(ids=[drawer_id]) logger.info(f"Deleted drawer: {drawer_id}") @@ -363,6 +500,23 @@ def tool_kg_add( subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None ): """Add a relationship to the knowledge graph.""" + try: + subject = sanitize_name(subject, "subject") + predicate = sanitize_name(predicate, "predicate") + object = sanitize_name(object, "object") + except ValueError as e: + return {"success": False, "error": str(e)} + + _wal_log( + "kg_add", + { + "subject": subject, + "predicate": predicate, + "object": object, + "valid_from": valid_from, + "source_closet": source_closet, + }, + ) triple_id = _kg.add_triple( subject, predicate, object, valid_from=valid_from, source_closet=source_closet ) @@ -371,6 +525,10 @@ def tool_kg_add( def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None): """Mark a fact as no longer true (set end date).""" + _wal_log( + "kg_invalidate", + {"subject": subject, "predicate": predicate, "object": object, "ended": ended}, + ) _kg.invalidate(subject, predicate, object, ended=ended) return { "success": True, @@ -401,6 +559,12 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"): This is the agent's personal journal — observations, thoughts, what it worked on, what it noticed, what it thinks matters. """ + try: + agent_name = sanitize_name(agent_name, "agent_name") + entry = sanitize_content(entry) + except ValueError as e: + return {"success": False, "error": str(e)} + wing = f"wing_{agent_name.lower().replace(' ', '_')}" room = "diary" col = _get_collection(create=True) @@ -408,9 +572,23 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"): return _no_palace() now = datetime.now() - entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.md5(entry[:50].encode()).hexdigest()[:8]}" + entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.sha256(entry[:50].encode()).hexdigest()[:12]}" + + _wal_log( + "diary_write", + { + "agent_name": agent_name, + "topic": topic, + "entry_id": entry_id, + "entry_preview": entry[:200], + }, + ) try: + # TODO: Future versions should expand AAAK before embedding to improve + # semantic search quality. For now, store raw AAAK in metadata so it's + # preserved, and keep the document as-is for embedding (even though + # compressed AAAK degrades embedding quality). col.add( ids=[entry_id], documents=[entry], @@ -744,17 +922,31 @@ TOOLS = { } +SUPPORTED_PROTOCOL_VERSIONS = [ + "2025-11-25", + "2025-06-18", + "2025-03-26", + "2024-11-05", +] + + def handle_request(request): method = request.get("method", "") params = request.get("params", {}) req_id = request.get("id") if method == "initialize": + client_version = params.get("protocolVersion", SUPPORTED_PROTOCOL_VERSIONS[-1]) + negotiated = ( + client_version + if client_version in SUPPORTED_PROTOCOL_VERSIONS + else SUPPORTED_PROTOCOL_VERSIONS[0] + ) return { "jsonrpc": "2.0", "id": req_id, "result": { - "protocolVersion": "2024-11-05", + "protocolVersion": negotiated, "capabilities": {"tools": {}}, "serverInfo": {"name": "mempalace", "version": __version__}, }, @@ -774,7 +966,7 @@ def handle_request(request): } elif method == "tools/call": tool_name = params.get("name") - tool_args = params.get("arguments", {}) + tool_args = params.get("arguments") or {} if tool_name not in TOOLS: return { "jsonrpc": "2.0", diff --git a/mempalace/migrate.py b/mempalace/migrate.py new file mode 100644 index 0000000..848ab67 --- /dev/null +++ b/mempalace/migrate.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +mempalace migrate — Recover a palace created with a different ChromaDB version. + +Reads documents and metadata directly from the palace's SQLite database +(bypassing ChromaDB's API, which fails on version-mismatched palaces), +then re-imports everything into a fresh palace using the currently installed +ChromaDB version. + +This fixes the 3.0.0 → 3.1.0 upgrade path where chromadb was downgraded +from 1.5.x to 0.6.x, breaking the on-disk storage format. + +Usage: + mempalace migrate # migrate default palace + mempalace migrate --palace /path/to/palace # migrate specific palace + mempalace migrate --dry-run # show what would be migrated +""" + +import os +import shutil +import sqlite3 +from collections import defaultdict +from datetime import datetime + + +def extract_drawers_from_sqlite(db_path: str) -> list: + """Read all drawers directly from ChromaDB's SQLite, bypassing the API. + + Works regardless of which ChromaDB version created the database. + Returns list of dicts with 'id', 'document', and 'metadata' keys. + """ + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + + # Get all embedding IDs and their documents + rows = conn.execute(""" + SELECT e.embedding_id, + MAX(CASE WHEN em.key = 'chroma:document' THEN em.string_value END) as document + FROM embeddings e + JOIN embedding_metadata em ON em.id = e.id + GROUP BY e.embedding_id + """).fetchall() + + drawers = [] + for row in rows: + embedding_id = row["embedding_id"] + document = row["document"] + if not document: + continue + + # Get metadata for this embedding + meta_rows = conn.execute( + """ + SELECT em.key, em.string_value, em.int_value, em.float_value, em.bool_value + FROM embedding_metadata em + JOIN embeddings e ON e.id = em.id + WHERE e.embedding_id = ? + AND em.key NOT LIKE 'chroma:%' + """, + (embedding_id,), + ).fetchall() + + metadata = {} + for mr in meta_rows: + key = mr["key"] + if mr["string_value"] is not None: + metadata[key] = mr["string_value"] + elif mr["int_value"] is not None: + metadata[key] = mr["int_value"] + elif mr["float_value"] is not None: + metadata[key] = mr["float_value"] + elif mr["bool_value"] is not None: + metadata[key] = bool(mr["bool_value"]) + + drawers.append( + { + "id": embedding_id, + "document": document, + "metadata": metadata, + } + ) + + conn.close() + return drawers + + +def detect_chromadb_version(db_path: str) -> str: + """Detect which ChromaDB version created the database by checking schema.""" + conn = sqlite3.connect(db_path) + try: + # 1.x has schema_str column in collections table + cols = [r[1] for r in conn.execute("PRAGMA table_info(collections)").fetchall()] + if "schema_str" in cols: + return "1.x" + # 0.6.x has embeddings_queue but no schema_str + tables = [ + r[0] + for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() + ] + if "embeddings_queue" in tables: + return "0.6.x" + return "unknown" + finally: + conn.close() + + +def migrate(palace_path: str, dry_run: bool = False): + """Migrate a palace to the currently installed ChromaDB version.""" + import chromadb + + palace_path = os.path.expanduser(palace_path) + db_path = os.path.join(palace_path, "chroma.sqlite3") + + if not os.path.isfile(db_path): + print(f"\n No palace database found at {db_path}") + return False + + print(f"\n{'=' * 60}") + print(" MemPalace Migrate") + print(f"{'=' * 60}\n") + print(f" Palace: {palace_path}") + print(f" Database: {db_path}") + print(f" DB size: {os.path.getsize(db_path) / 1024 / 1024:.1f} MB") + + # Detect version + source_version = detect_chromadb_version(db_path) + print(f" Source: ChromaDB {source_version}") + print(f" Target: ChromaDB {chromadb.__version__}") + + # Try reading with current chromadb first + try: + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection("mempalace_drawers") + count = col.count() + print(f"\n Palace is already readable by chromadb {chromadb.__version__}.") + print(f" {count} drawers found. No migration needed.") + return True + except Exception: + print(f"\n Palace is NOT readable by chromadb {chromadb.__version__}.") + print(" Extracting from SQLite directly...") + + # Extract all drawers via raw SQL + drawers = extract_drawers_from_sqlite(db_path) + print(f" Extracted {len(drawers)} drawers from SQLite") + + if not drawers: + print(" Nothing to migrate.") + return True + + # Show summary + wings = defaultdict(lambda: defaultdict(int)) + for d in drawers: + w = d["metadata"].get("wing", "?") + r = d["metadata"].get("room", "?") + wings[w][r] += 1 + + print("\n Summary:") + for wing, rooms in sorted(wings.items()): + total = sum(rooms.values()) + print(f" WING: {wing} ({total} drawers)") + for room, count in sorted(rooms.items(), key=lambda x: -x[1]): + print(f" ROOM: {room:30} {count:5}") + + if dry_run: + print("\n DRY RUN — no changes made.") + print(f" Would migrate {len(drawers)} drawers.") + return True + + # Backup the old palace + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_path = f"{palace_path}.pre-migrate.{timestamp}" + print(f"\n Backing up to {backup_path}...") + shutil.copytree(palace_path, backup_path) + + # Build fresh palace in a temp directory (avoids chromadb reading old state) + import tempfile + + temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_") + print(f" Creating fresh palace in {temp_palace}...") + client = chromadb.PersistentClient(path=temp_palace) + col = client.get_or_create_collection("mempalace_drawers") + + # Re-import in batches + batch_size = 500 + imported = 0 + for i in range(0, len(drawers), batch_size): + batch = drawers[i : i + batch_size] + col.add( + ids=[d["id"] for d in batch], + documents=[d["document"] for d in batch], + metadatas=[d["metadata"] for d in batch], + ) + imported += len(batch) + print(f" Imported {imported}/{len(drawers)} drawers...") + + # Verify before swapping + final_count = col.count() + del col + del client + + # Swap: remove old palace, move new one into place + print(" Swapping old palace for migrated version...") + shutil.rmtree(palace_path) + shutil.move(temp_palace, palace_path) + + print("\n Migration complete.") + print(f" Drawers migrated: {final_count}") + print(f" Backup at: {backup_path}") + + if final_count != len(drawers): + print(f" WARNING: Expected {len(drawers)}, got {final_count}") + + print(f"\n{'=' * 60}\n") + return True diff --git a/mempalace/miner.py b/mempalace/miner.py index 66fbe03..f342a2d 100644 --- a/mempalace/miner.py +++ b/mempalace/miner.py @@ -17,6 +17,8 @@ from collections import defaultdict import chromadb +from .palace import SKIP_DIRS, get_collection, file_already_mined + READABLE_EXTENSIONS = { ".txt", ".md", @@ -40,32 +42,6 @@ READABLE_EXTENSIONS = { ".toml", } -SKIP_DIRS = { - ".git", - "node_modules", - "__pycache__", - ".venv", - "venv", - "env", - "dist", - "build", - ".next", - "coverage", - ".mempalace", - ".ruff_cache", - ".mypy_cache", - ".pytest_cache", - ".cache", - ".tox", - ".nox", - ".idea", - ".vscode", - ".ipynb_checkpoints", - ".eggs", - "htmlcov", - "target", -} - SKIP_FILENAMES = { "mempalace.yaml", "mempalace.yml", @@ -78,6 +54,7 @@ SKIP_FILENAMES = { CHUNK_SIZE = 800 # chars per drawer CHUNK_OVERLAP = 100 # overlap between chunks MIN_CHUNK_SIZE = 50 # skip tiny chunks +MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this # ============================================================================= @@ -393,41 +370,11 @@ def chunk_text(content: str, source_file: str) -> list: # ============================================================================= -def get_collection(palace_path: str): - os.makedirs(palace_path, exist_ok=True) - client = chromadb.PersistentClient(path=palace_path) - try: - return client.get_collection("mempalace_drawers") - except Exception: - return client.create_collection("mempalace_drawers") - - -def file_already_mined(collection, source_file: str) -> bool: - """Fast check: has this file been filed before and is unchanged? - - Compares the stored mtime in drawer metadata against the file's current - mtime. Returns False (needs re-mining) when the file has been modified - since it was last mined, or when no mtime was stored. - """ - try: - results = collection.get(where={"source_file": source_file}, limit=1) - if not results.get("ids"): - return False - stored_meta = results["metadatas"][0] if results.get("metadatas") else {} - stored_mtime = stored_meta.get("source_mtime") - if stored_mtime is None: - return False - current_mtime = os.path.getmtime(source_file) - return float(stored_mtime) == current_mtime - except Exception: - return False - - def add_drawer( collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str ): """Add one drawer to the palace.""" - drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}" + drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((source_file + str(chunk_index)).encode()).hexdigest()[:24]}" try: metadata = { "wing": wing, @@ -470,7 +417,7 @@ def process_file( # Skip if already filed source_file = str(filepath) - if not dry_run and file_already_mined(collection, source_file): + if not dry_run and file_already_mined(collection, source_file, check_mtime=True): return 0, None try: @@ -489,6 +436,16 @@ def process_file( print(f" [DRY RUN] {filepath.name} → room:{room} ({len(chunks)} drawers)") return len(chunks), room + # Purge stale drawers for this file before re-inserting the fresh chunks. + # Converts modified-file re-mines from upsert-over-existing-IDs (which hits + # hnswlib's thread-unsafe updatePoint path and can segfault on macOS ARM + # with chromadb 0.6.3) into a clean delete+insert, bypassing the update + # path entirely. + try: + collection.delete(where={"source_file": source_file}) + except Exception: + pass + drawers_added = 0 for chunk in chunks: added = add_drawer( @@ -562,6 +519,15 @@ def scan_project( if respect_gitignore and active_matchers and not force_include: if is_gitignored(filepath, active_matchers, is_dir=False): continue + # Skip symlinks — prevents following links to /dev/urandom, etc. + if filepath.is_symlink(): + continue + # Skip files exceeding size limit + try: + if filepath.stat().st_size > MAX_FILE_SIZE: + continue + except OSError: + continue files.append(filepath) return files diff --git a/mempalace/normalize.py b/mempalace/normalize.py index ac11469..a894500 100644 --- a/mempalace/normalize.py +++ b/mempalace/normalize.py @@ -25,6 +25,12 @@ def normalize(filepath: str) -> str: Load a file and normalize to transcript format if it's a chat export. Plain text files pass through unchanged. """ + try: + file_size = os.path.getsize(filepath) + except OSError as e: + raise IOError(f"Could not read {filepath}: {e}") + if file_size > 500 * 1024 * 1024: # 500 MB safety limit + raise IOError(f"File too large ({file_size // (1024*1024)} MB): {filepath}") try: with open(filepath, "r", encoding="utf-8", errors="replace") as f: content = f.read() diff --git a/mempalace/palace.py b/mempalace/palace.py new file mode 100644 index 0000000..6ddf190 --- /dev/null +++ b/mempalace/palace.py @@ -0,0 +1,71 @@ +""" +palace.py — Shared palace operations. + +Consolidates ChromaDB access patterns used by both miners and the MCP server. +""" + +import os +import chromadb + +SKIP_DIRS = { + ".git", + "node_modules", + "__pycache__", + ".venv", + "venv", + "env", + "dist", + "build", + ".next", + "coverage", + ".mempalace", + ".ruff_cache", + ".mypy_cache", + ".pytest_cache", + ".cache", + ".tox", + ".nox", + ".idea", + ".vscode", + ".ipynb_checkpoints", + ".eggs", + "htmlcov", + "target", +} + + +def get_collection(palace_path: str, collection_name: str = "mempalace_drawers"): + """Get or create the palace ChromaDB collection.""" + os.makedirs(palace_path, exist_ok=True) + try: + os.chmod(palace_path, 0o700) + except (OSError, NotImplementedError): + pass + client = chromadb.PersistentClient(path=palace_path) + try: + return client.get_collection(collection_name) + except Exception: + return client.create_collection(collection_name) + + +def file_already_mined(collection, source_file: str, check_mtime: bool = False) -> bool: + """Check if a file has already been filed in the palace. + + When check_mtime=True (used by project miner), returns False if the file + has been modified since it was last mined, so it gets re-mined. + When check_mtime=False (used by convo miner), just checks existence. + """ + try: + results = collection.get(where={"source_file": source_file}, limit=1) + if not results.get("ids"): + return False + if check_mtime: + stored_meta = results.get("metadatas", [{}])[0] + stored_mtime = stored_meta.get("source_mtime") + if stored_mtime is None: + return False + current_mtime = os.path.getmtime(source_file) + return float(stored_mtime) == current_mtime + return True + except Exception: + return False diff --git a/mempalace/repair.py b/mempalace/repair.py new file mode 100644 index 0000000..d51be60 --- /dev/null +++ b/mempalace/repair.py @@ -0,0 +1,299 @@ +""" +repair.py — Scan, prune corrupt entries, and rebuild HNSW index +================================================================ + +When ChromaDB's HNSW index accumulates duplicate entries (from repeated +add() calls with the same ID), link_lists.bin can grow unbounded — +terabytes on large palaces — eventually causing segfaults. + +This module provides three operations: + + scan — find every corrupt/unfetchable ID in the palace + prune — delete only the corrupt IDs (surgical) + rebuild — extract all drawers, delete the collection, recreate with + correct HNSW settings, and upsert everything back + +The rebuild backs up ONLY chroma.sqlite3 (the source of truth), not the +full palace directory — so it works even when link_lists.bin is bloated. + +Usage (standalone): + python -m mempalace.repair scan [--wing X] + python -m mempalace.repair prune --confirm + python -m mempalace.repair rebuild + +Usage (from CLI): + mempalace repair + mempalace repair-scan [--wing X] + mempalace repair-prune --confirm +""" + +import argparse +import os +import shutil +import time + +import chromadb + + +COLLECTION_NAME = "mempalace_drawers" + + +def _get_palace_path(): + """Resolve palace path from config.""" + try: + from .config import MempalaceConfig + + return MempalaceConfig().palace_path + except Exception: + default = os.path.join(os.path.expanduser("~"), ".mempalace", "palace") + return default + + +def _paginate_ids(col, where=None): + """Pull all IDs in a collection using pagination.""" + ids = [] + page = 1000 + offset = 0 + while True: + try: + r = col.get(where=where, include=[], limit=page, offset=offset) + except Exception: + try: + r = col.get(where=where, include=[], limit=page) + new_ids = [i for i in r["ids"] if i not in set(ids)] + if not new_ids: + break + ids.extend(new_ids) + offset += len(new_ids) + continue + except Exception: + break + n = len(r["ids"]) if r["ids"] else 0 + if n == 0: + break + ids.extend(r["ids"]) + offset += n + if n < page: + break + return ids + + +def scan_palace(palace_path=None, only_wing=None): + """Scan the palace for corrupt/unfetchable IDs. + + Probes in batches of 100, falls back to per-ID on failure. + Writes corrupt_ids.txt to the palace directory for the prune step. + + Returns (good_set, bad_set). + """ + palace_path = palace_path or _get_palace_path() + print(f"\n Palace: {palace_path}") + print(" Loading...") + + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection(COLLECTION_NAME) + + where = {"wing": only_wing} if only_wing else None + total = col.count() + print(f" Collection: {COLLECTION_NAME}, total: {total:,}") + if only_wing: + print(f" Scanning wing: {only_wing}") + + print("\n Step 1: listing all IDs...") + t0 = time.time() + all_ids = _paginate_ids(col, where=where) + print(f" Found {len(all_ids):,} IDs in {time.time() - t0:.1f}s\n") + + if not all_ids: + print(" Nothing to scan.") + return set(), set() + + print(" Step 2: probing each ID (batches of 100)...") + t0 = time.time() + good_set = set() + bad_set = set() + batch = 100 + + for i in range(0, len(all_ids), batch): + chunk = all_ids[i : i + batch] + try: + r = col.get(ids=chunk, include=["documents"]) + for got in r["ids"]: + good_set.add(got) + for mid in chunk: + if mid not in good_set: + bad_set.add(mid) + except Exception: + for sid in chunk: + try: + r = col.get(ids=[sid], include=["documents"]) + if r["ids"]: + good_set.add(sid) + else: + bad_set.add(sid) + except Exception: + bad_set.add(sid) + + if (i // batch) % 50 == 0: + elapsed = time.time() - t0 + rate = (i + batch) / max(elapsed, 0.01) + eta = (len(all_ids) - i - batch) / max(rate, 0.01) + print( + f" {i + batch:>6}/{len(all_ids):>6} " + f"good={len(good_set):>6} bad={len(bad_set):>6} " + f"eta={eta:.0f}s" + ) + + print(f"\n Scan complete in {time.time() - t0:.1f}s") + print(f" GOOD: {len(good_set):,}") + print(f" BAD: {len(bad_set):,} ({len(bad_set) / max(len(all_ids), 1) * 100:.1f}%)") + + bad_file = os.path.join(palace_path, "corrupt_ids.txt") + with open(bad_file, "w") as f: + for bid in sorted(bad_set): + f.write(bid + "\n") + print(f"\n Bad IDs written to: {bad_file}") + return good_set, bad_set + + +def prune_corrupt(palace_path=None, confirm=False): + """Delete corrupt IDs listed in corrupt_ids.txt.""" + palace_path = palace_path or _get_palace_path() + bad_file = os.path.join(palace_path, "corrupt_ids.txt") + + if not os.path.exists(bad_file): + print(" No corrupt_ids.txt found — run scan first.") + return + + with open(bad_file) as f: + bad_ids = [line.strip() for line in f if line.strip()] + print(f" {len(bad_ids):,} corrupt IDs queued for deletion") + + if not confirm: + print("\n DRY RUN — no deletions performed.") + print(" Re-run with --confirm to actually delete.") + return + + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection(COLLECTION_NAME) + before = col.count() + print(f" Collection size before: {before:,}") + + batch = 100 + deleted = 0 + failed = 0 + for i in range(0, len(bad_ids), batch): + chunk = bad_ids[i : i + batch] + try: + col.delete(ids=chunk) + deleted += len(chunk) + except Exception: + for sid in chunk: + try: + col.delete(ids=[sid]) + deleted += 1 + except Exception: + failed += 1 + if (i // batch) % 20 == 0: + print(f" deleted {deleted}/{len(bad_ids)} (failed: {failed})") + + after = col.count() + print(f"\n Deleted: {deleted:,}") + print(f" Failed: {failed:,}") + print(f" Collection size: {before:,} → {after:,}") + + +def rebuild_index(palace_path=None): + """Rebuild the HNSW index from scratch. + + 1. Extract all drawers via ChromaDB get() + 2. Back up ONLY chroma.sqlite3 (not the bloated HNSW files) + 3. Delete and recreate the collection with hnsw:space=cosine + 4. Upsert all drawers back + """ + palace_path = palace_path or _get_palace_path() + + if not os.path.isdir(palace_path): + print(f"\n No palace found at {palace_path}") + return + + print(f"\n{'=' * 55}") + print(" MemPalace Repair — Index Rebuild") + print(f"{'=' * 55}\n") + print(f" Palace: {palace_path}") + + client = chromadb.PersistentClient(path=palace_path) + try: + col = client.get_collection(COLLECTION_NAME) + total = col.count() + except Exception as e: + print(f" Error reading palace: {e}") + print(" Palace may need to be re-mined from source files.") + return + + print(f" Drawers found: {total}") + + if total == 0: + print(" Nothing to repair.") + return + + # Extract all drawers in batches + print("\n Extracting drawers...") + batch_size = 5000 + all_ids = [] + all_docs = [] + all_metas = [] + offset = 0 + while offset < total: + batch = col.get(limit=batch_size, offset=offset, include=["documents", "metadatas"]) + if not batch["ids"]: + break + all_ids.extend(batch["ids"]) + all_docs.extend(batch["documents"]) + all_metas.extend(batch["metadatas"]) + offset += len(batch["ids"]) + print(f" Extracted {len(all_ids)} drawers") + + # Back up ONLY the SQLite database, not the bloated HNSW files + sqlite_path = os.path.join(palace_path, "chroma.sqlite3") + if os.path.exists(sqlite_path): + backup_path = sqlite_path + ".backup" + print(f" Backing up chroma.sqlite3 ({os.path.getsize(sqlite_path) / 1e6:.0f} MB)...") + shutil.copy2(sqlite_path, backup_path) + print(f" Backup: {backup_path}") + + # Rebuild with correct HNSW settings + print(" Rebuilding collection with hnsw:space=cosine...") + client.delete_collection(COLLECTION_NAME) + new_col = client.create_collection(COLLECTION_NAME, metadata={"hnsw:space": "cosine"}) + + filed = 0 + for i in range(0, len(all_ids), batch_size): + batch_ids = all_ids[i : i + batch_size] + batch_docs = all_docs[i : i + batch_size] + batch_metas = all_metas[i : i + batch_size] + new_col.upsert(documents=batch_docs, ids=batch_ids, metadatas=batch_metas) + filed += len(batch_ids) + print(f" Re-filed {filed}/{len(all_ids)} drawers...") + + print(f"\n Repair complete. {filed} drawers rebuilt.") + print(" HNSW index is now clean with cosine distance metric.") + print(f"\n{'=' * 55}\n") + + +if __name__ == "__main__": + p = argparse.ArgumentParser(description="MemPalace repair tools") + p.add_argument("command", choices=["scan", "prune", "rebuild"]) + p.add_argument("--palace", default=None, help="Palace directory path") + p.add_argument("--wing", default=None, help="Scan only this wing") + p.add_argument("--confirm", action="store_true", help="Actually delete corrupt IDs") + args = p.parse_args() + + path = os.path.expanduser(args.palace) if args.palace else None + + if args.command == "scan": + scan_palace(palace_path=path, only_wing=args.wing) + elif args.command == "prune": + prune_corrupt(palace_path=path, confirm=args.confirm) + elif args.command == "rebuild": + rebuild_index(palace_path=path) diff --git a/mempalace/split_mega_files.py b/mempalace/split_mega_files.py index ae801df..24b5956 100644 --- a/mempalace/split_mega_files.py +++ b/mempalace/split_mega_files.py @@ -182,6 +182,10 @@ def split_file(filepath, output_dir, dry_run=False): Returns list of output paths written (or would be written if dry_run). """ path = Path(filepath) + max_size = 500 * 1024 * 1024 # 500 MB safety limit + if path.stat().st_size > max_size: + print(f" SKIP: {path.name} exceeds {max_size // (1024*1024)} MB limit") + return [] lines = path.read_text(errors="replace").splitlines(keepends=True) boundaries = find_session_boundaries(lines) @@ -266,7 +270,11 @@ def main(): files = sorted(src_dir.glob("*.txt")) mega_files = [] + max_scan_size = 500 * 1024 * 1024 # 500 MB for f in files: + if f.stat().st_size > max_scan_size: + print(f" SKIP: {f.name} exceeds {max_scan_size // (1024*1024)} MB limit") + continue lines = f.read_text(errors="replace").splitlines(keepends=True) boundaries = find_session_boundaries(lines) if len(boundaries) >= args.min_sessions: diff --git a/mempalace/version.py b/mempalace/version.py index e56289e..1eb21a2 100644 --- a/mempalace/version.py +++ b/mempalace/version.py @@ -1,3 +1,3 @@ """Single source of truth for the MemPalace package version.""" -__version__ = "3.0.14" +__version__ = "3.1.0" diff --git a/pyproject.toml b/pyproject.toml index 7b201da..cd47f98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mempalace" -version = "3.0.14" +version = "3.1.0" description = "Give your AI a memory — mine projects and conversations into a searchable palace. No API key required." readme = "README.md" requires-python = ">=3.9" @@ -26,7 +26,7 @@ classifiers = [ ] dependencies = [ "chromadb>=0.5.0,<0.7", - "pyyaml>=6.0", + "pyyaml>=6.0,<7", ] [project.urls] @@ -54,11 +54,15 @@ packages = ["mempalace"] [tool.ruff] line-length = 100 target-version = "py39" +extend-exclude = ["benchmarks"] [tool.ruff.lint] -select = ["E", "F", "W"] +select = ["E", "F", "W", "C901"] ignore = ["E501"] +[tool.ruff.lint.mccabe] +max-complexity = 25 + [tool.ruff.format] quote-style = "double" diff --git a/tests/test_cli.py b/tests/test_cli.py index 879d276..e3c68f9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,6 +2,7 @@ import argparse import sys +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -326,6 +327,35 @@ def test_main_split_dispatches(): mock_cmd.assert_called_once() +def test_mcp_command_prints_setup_guidance(monkeypatch, capsys): + monkeypatch.setattr(sys, "argv", ["mempalace", "mcp"]) + + main() + + captured = capsys.readouterr() + assert "MemPalace MCP quick setup:" in captured.out + assert "claude mcp add mempalace -- python -m mempalace.mcp_server" in captured.out + assert "\nOptional custom palace:\n" in captured.out + assert "python -m mempalace.mcp_server --palace /path/to/palace" in captured.out + assert "[--palace /path/to/palace]" not in captured.out + assert captured.err == "" + + +def test_mcp_command_uses_custom_palace_path_when_provided(monkeypatch, capsys): + monkeypatch.setattr(sys, "argv", ["mempalace", "--palace", "~/tmp/my palace", "mcp"]) + + main() + + captured = capsys.readouterr() + expanded = str(Path("~/tmp/my palace").expanduser()) + + assert "python -m mempalace.mcp_server --palace" in captured.out + assert expanded in captured.out + assert "Optional custom palace:" not in captured.out + assert "[--palace /path/to/palace]" not in captured.out + assert captured.err == "" + + def test_main_hook_no_subcommand_prints_help(capsys): with patch("sys.argv", ["mempalace", "hook"]): main() @@ -607,3 +637,16 @@ def test_cmd_compress_stores_results(mock_config_cls, capsys): out = capsys.readouterr().out assert "Stored" in out mock_comp_col.upsert.assert_called_once() + + +def test_cmd_repair_trailing_slash_does_not_recurse(): + """Repair with trailing slash should put backup outside palace dir (#395).""" + import os + + args = argparse.Namespace(palace="/tmp/fake_palace/") + with patch("mempalace.cli.os.path.isdir", return_value=False): + cmd_repair(args) + # Verify the rstrip logic: palace_path should not end with separator + palace_path = os.path.expanduser(args.palace).rstrip(os.sep) + backup_path = palace_path + ".backup" + assert not backup_path.startswith(palace_path + os.sep) diff --git a/tests/test_dedup.py b/tests/test_dedup.py new file mode 100644 index 0000000..2ddffb3 --- /dev/null +++ b/tests/test_dedup.py @@ -0,0 +1,272 @@ +"""Tests for mempalace.dedup — near-duplicate drawer detection and removal.""" + +from unittest.mock import MagicMock, patch + + +from mempalace import dedup + + +# ── get_source_groups ───────────────────────────────────────────────── + + +def test_get_source_groups_basic(): + col = MagicMock() + col.count.return_value = 5 + col.get.side_effect = [ + { + "ids": ["d1", "d2", "d3", "d4", "d5"], + "metadatas": [ + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + ], + }, + {"ids": []}, + ] + groups = dedup.get_source_groups(col, min_count=5) + assert "a.txt" in groups + assert len(groups["a.txt"]) == 5 + + +def test_get_source_groups_below_min(): + col = MagicMock() + col.count.return_value = 2 + col.get.side_effect = [ + { + "ids": ["d1", "d2"], + "metadatas": [ + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + ], + }, + {"ids": []}, + ] + groups = dedup.get_source_groups(col, min_count=5) + assert len(groups) == 0 + + +def test_get_source_groups_source_filter(): + col = MagicMock() + col.count.return_value = 6 + col.get.side_effect = [ + { + "ids": ["d1", "d2", "d3", "d4", "d5", "d6"], + "metadatas": [ + {"source_file": "project_a.txt"}, + {"source_file": "project_a.txt"}, + {"source_file": "project_a.txt"}, + {"source_file": "project_a.txt"}, + {"source_file": "project_a.txt"}, + {"source_file": "other.txt"}, + ], + }, + {"ids": []}, + ] + groups = dedup.get_source_groups(col, min_count=5, source_pattern="project_a") + assert "project_a.txt" in groups + assert "other.txt" not in groups + + +def test_get_source_groups_wing_filter(): + col = MagicMock() + col.count.return_value = 5 + col.get.side_effect = [ + { + "ids": ["d1", "d2", "d3", "d4", "d5"], + "metadatas": [ + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + ], + }, + {"ids": []}, + ] + dedup.get_source_groups(col, min_count=5, wing="my_wing") + # Verify where filter was passed + first_call = col.get.call_args_list[0] + assert first_call.kwargs.get("where") == {"wing": "my_wing"} + + +def test_get_source_groups_missing_source_file(): + col = MagicMock() + col.count.return_value = 5 + col.get.side_effect = [ + { + "ids": ["d1", "d2", "d3", "d4", "d5"], + "metadatas": [{}, {}, {}, {}, {}], + }, + {"ids": []}, + ] + groups = dedup.get_source_groups(col, min_count=5) + assert "unknown" in groups + + +# ── dedup_source_group ──────────────────────────────────────────────── + + +def test_dedup_source_group_all_unique(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": ["long document one content here", "different document two here"], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + col.query.return_value = { + "ids": [["d1"]], + "distances": [[0.8]], # far apart = unique + } + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True) + assert len(kept) == 2 + assert len(deleted) == 0 + + +def test_dedup_source_group_with_duplicate(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": [ + "long document content that is fairly long", + "long document content that is fairly long", + ], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + col.query.return_value = { + "ids": [["d1"]], + "distances": [[0.05]], # very close = duplicate + } + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True) + assert len(kept) == 1 + assert len(deleted) == 1 + + +def test_dedup_source_group_short_docs_deleted(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": ["long enough document to keep in the palace", "tiny"], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True) + assert "d2" in deleted # too short + + +def test_dedup_source_group_empty_doc_deleted(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": ["real document content here that is long enough", None], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True) + assert "d2" in deleted + + +def test_dedup_source_group_live_deletes(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": ["long document content here enough", "long document content here enough"], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + col.query.return_value = { + "ids": [["d1"]], + "distances": [[0.05]], + } + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=False) + col.delete.assert_called_once() + + +def test_dedup_source_group_query_failure_keeps(): + col = MagicMock() + col.get.return_value = { + "ids": ["d1", "d2"], + "documents": [ + "long document one content here enough", + "long document two content here enough", + ], + "metadatas": [{"wing": "a"}, {"wing": "a"}], + } + col.query.side_effect = Exception("query failed") + kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True) + assert len(kept) == 2 # both kept on error + + +# ── show_stats ──────────────────────────────────────────────────────── + + +@patch("mempalace.dedup.chromadb") +def test_show_stats(mock_chromadb, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 5 + mock_col.get.side_effect = [ + { + "ids": ["d1", "d2", "d3", "d4", "d5"], + "metadatas": [ + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + {"source_file": "a.txt"}, + ], + }, + {"ids": []}, + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + dedup.show_stats(palace_path=str(tmp_path)) # should not raise + + +# ── dedup_palace ────────────────────────────────────────────────────── + + +@patch("mempalace.dedup.dedup_source_group") +@patch("mempalace.dedup.get_source_groups") +@patch("mempalace.dedup.chromadb") +def test_dedup_palace_dry_run(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 10 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_groups.return_value = {"a.txt": ["d1", "d2", "d3", "d4", "d5"]} + mock_dedup_group.return_value = (["d1", "d2", "d3"], ["d4", "d5"]) + + dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True) + mock_dedup_group.assert_called_once() + + +@patch("mempalace.dedup.dedup_source_group") +@patch("mempalace.dedup.get_source_groups") +@patch("mempalace.dedup.chromadb") +def test_dedup_palace_with_wing(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 10 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_groups.return_value = {} + dedup.dedup_palace(palace_path=str(tmp_path), wing="test_wing", dry_run=True) + mock_groups.assert_called_once_with(mock_col, 5, None, wing="test_wing") + + +@patch("mempalace.dedup.dedup_source_group") +@patch("mempalace.dedup.get_source_groups") +@patch("mempalace.dedup.chromadb") +def test_dedup_palace_no_groups(mock_chromadb, mock_groups, mock_dedup_group, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 3 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_groups.return_value = {} + dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True) + mock_dedup_group.assert_not_called() diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 24258a9..96fe80c 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -42,6 +42,50 @@ class TestHandleRequest: assert resp["result"]["serverInfo"]["name"] == "mempalace" assert resp["id"] == 1 + def test_initialize_negotiates_client_version(self): + from mempalace.mcp_server import handle_request + + resp = handle_request( + { + "method": "initialize", + "id": 1, + "params": {"protocolVersion": "2025-11-25"}, + } + ) + assert resp["result"]["protocolVersion"] == "2025-11-25" + + def test_initialize_negotiates_older_supported_version(self): + from mempalace.mcp_server import handle_request + + resp = handle_request( + { + "method": "initialize", + "id": 1, + "params": {"protocolVersion": "2025-03-26"}, + } + ) + assert resp["result"]["protocolVersion"] == "2025-03-26" + + def test_initialize_unknown_version_falls_back_to_latest(self): + from mempalace.mcp_server import handle_request + + resp = handle_request( + { + "method": "initialize", + "id": 1, + "params": {"protocolVersion": "9999-12-31"}, + } + ) + from mempalace.mcp_server import SUPPORTED_PROTOCOL_VERSIONS + + assert resp["result"]["protocolVersion"] == SUPPORTED_PROTOCOL_VERSIONS[0] + + def test_initialize_missing_version_uses_oldest(self): + from mempalace.mcp_server import handle_request, SUPPORTED_PROTOCOL_VERSIONS + + resp = handle_request({"method": "initialize", "id": 1, "params": {}}) + assert resp["result"]["protocolVersion"] == SUPPORTED_PROTOCOL_VERSIONS[-1] + def test_notifications_initialized_returns_none(self): from mempalace.mcp_server import handle_request @@ -59,6 +103,23 @@ class TestHandleRequest: assert "mempalace_add_drawer" in names assert "mempalace_kg_add" in names + def test_null_arguments_does_not_hang(self, monkeypatch, config, palace_path, seeded_kg): + """Sending arguments: null should return a result, not hang (#394).""" + _patch_mcp_server(monkeypatch, config, seeded_kg) + from mempalace.mcp_server import handle_request + + _client, _col = _get_collection(palace_path, create=True) + del _client + resp = handle_request( + { + "method": "tools/call", + "id": 10, + "params": {"name": "mempalace_status", "arguments": None}, + } + ) + assert "error" not in resp + assert resp["result"] is not None + def test_unknown_tool(self): from mempalace.mcp_server import handle_request diff --git a/tests/test_miner.py b/tests/test_miner.py index efe55a7..c013d7c 100644 --- a/tests/test_miner.py +++ b/tests/test_miner.py @@ -7,6 +7,7 @@ import chromadb import yaml from mempalace.miner import mine, scan_project +from mempalace.palace import file_already_mined def write_file(path: Path, content: str): @@ -206,3 +207,56 @@ def test_scan_project_skip_dirs_still_apply_without_override(): assert scanned_files(project_root, respect_gitignore=False) == ["main.py"] finally: shutil.rmtree(tmpdir) + + +def test_file_already_mined_check_mtime(): + tmpdir = tempfile.mkdtemp() + try: + palace_path = os.path.join(tmpdir, "palace") + os.makedirs(palace_path) + client = chromadb.PersistentClient(path=palace_path) + col = client.get_or_create_collection("mempalace_drawers") + + test_file = os.path.join(tmpdir, "test.txt") + with open(test_file, "w") as f: + f.write("hello world") + + mtime = os.path.getmtime(test_file) + + # Not mined yet + assert file_already_mined(col, test_file) is False + assert file_already_mined(col, test_file, check_mtime=True) is False + + # Add it with mtime + col.add( + ids=["d1"], + documents=["hello world"], + metadatas=[{"source_file": test_file, "source_mtime": str(mtime)}], + ) + + # Already mined (no mtime check) + assert file_already_mined(col, test_file) is True + # Already mined (mtime matches) + assert file_already_mined(col, test_file, check_mtime=True) is True + + # Modify file and force a different mtime (Windows has low mtime resolution) + with open(test_file, "w") as f: + f.write("modified content") + os.utime(test_file, (mtime + 10, mtime + 10)) + + # Still mined without mtime check + assert file_already_mined(col, test_file) is True + # Needs re-mining with mtime check + assert file_already_mined(col, test_file, check_mtime=True) is False + + # Record with no mtime stored should return False for check_mtime + col.add( + ids=["d2"], + documents=["other"], + metadatas=[{"source_file": "/fake/no_mtime.txt"}], + ) + assert file_already_mined(col, "/fake/no_mtime.txt", check_mtime=True) is False + finally: + # Release ChromaDB file handles before cleanup (required on Windows) + del col, client + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/tests/test_normalize.py b/tests/test_normalize.py index fc50251..959668f 100644 --- a/tests/test_normalize.py +++ b/tests/test_normalize.py @@ -499,3 +499,13 @@ def test_messages_to_transcript_assistant_first(): result = _messages_to_transcript(msgs, spellcheck=False) assert "preamble" in result assert "> Q" in result + + +def test_normalize_rejects_large_file(): + """Files over 500 MB should raise IOError before reading.""" + with patch("mempalace.normalize.os.path.getsize", return_value=600 * 1024 * 1024): + try: + normalize("/fake/huge_file.txt") + assert False, "Should have raised IOError" + except IOError as e: + assert "too large" in str(e).lower() diff --git a/tests/test_repair.py b/tests/test_repair.py new file mode 100644 index 0000000..604b0fb --- /dev/null +++ b/tests/test_repair.py @@ -0,0 +1,266 @@ +"""Tests for mempalace.repair — scan, prune, and rebuild HNSW index.""" + +import os +from unittest.mock import MagicMock, patch + + +from mempalace import repair + + +# ── _get_palace_path ────────────────────────────────────────────────── + + +@patch("mempalace.repair.MempalaceConfig", create=True) +def test_get_palace_path_from_config(mock_config_cls): + mock_config_cls.return_value.palace_path = "/configured/palace" + with patch.dict("sys.modules", {}): + # Force reimport to pick up the mock + result = repair._get_palace_path() + assert isinstance(result, str) + + +def test_get_palace_path_fallback(): + with patch("mempalace.repair._get_palace_path") as mock_get: + mock_get.return_value = os.path.join(os.path.expanduser("~"), ".mempalace", "palace") + result = mock_get() + assert ".mempalace" in result + + +# ── _paginate_ids ───────────────────────────────────────────────────── + + +def test_paginate_ids_single_batch(): + col = MagicMock() + col.get.return_value = {"ids": ["id1", "id2", "id3"]} + ids = repair._paginate_ids(col) + assert ids == ["id1", "id2", "id3"] + + +def test_paginate_ids_empty(): + col = MagicMock() + col.get.return_value = {"ids": []} + ids = repair._paginate_ids(col) + assert ids == [] + + +def test_paginate_ids_with_where(): + col = MagicMock() + col.get.return_value = {"ids": ["id1"]} + repair._paginate_ids(col, where={"wing": "test"}) + col.get.assert_called_with(where={"wing": "test"}, include=[], limit=1000, offset=0) + + +def test_paginate_ids_offset_exception_fallback(): + col = MagicMock() + # First call raises, fallback returns ids, second fallback returns empty + col.get.side_effect = [ + Exception("offset bug"), + {"ids": ["id1", "id2"]}, + Exception("offset bug"), + {"ids": ["id1", "id2"]}, # same ids = no new = break + ] + ids = repair._paginate_ids(col) + assert "id1" in ids + + +# ── scan_palace ─────────────────────────────────────────────────────── + + +@patch("mempalace.repair.chromadb") +def test_scan_palace_no_ids(mock_chromadb, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 0 + mock_col.get.return_value = {"ids": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + good, bad = repair.scan_palace(palace_path=str(tmp_path)) + assert good == set() + assert bad == set() + + +@patch("mempalace.repair.chromadb") +def test_scan_palace_all_good(mock_chromadb, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 2 + # _paginate_ids call + mock_col.get.side_effect = [ + {"ids": ["id1", "id2"]}, # paginate + {"ids": ["id1", "id2"]}, # probe batch — both returned + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + good, bad = repair.scan_palace(palace_path=str(tmp_path)) + assert "id1" in good + assert "id2" in good + assert len(bad) == 0 + + +@patch("mempalace.repair.chromadb") +def test_scan_palace_with_bad_ids(mock_chromadb, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 2 + + def get_side_effect(**kwargs): + ids = kwargs.get("ids", None) + if ids is None: + # paginate call + return {"ids": ["good1", "bad1"]} + if "bad1" in ids and len(ids) == 1: + raise Exception("corrupt") + if "good1" in ids and len(ids) == 1: + return {"ids": ["good1"]} + # batch probe — raise to force per-id + raise Exception("batch fail") + + mock_col.get.side_effect = get_side_effect + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + good, bad = repair.scan_palace(palace_path=str(tmp_path)) + assert "good1" in good + assert "bad1" in bad + + +@patch("mempalace.repair.chromadb") +def test_scan_palace_with_wing_filter(mock_chromadb, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 1 + mock_col.get.side_effect = [ + {"ids": ["id1"]}, # paginate + {"ids": ["id1"]}, # probe + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + repair.scan_palace(palace_path=str(tmp_path), only_wing="test_wing") + # Verify where filter was passed + first_call = mock_col.get.call_args_list[0] + assert first_call.kwargs.get("where") == {"wing": "test_wing"} + + +# ── prune_corrupt ───────────────────────────────────────────────────── + + +@patch("mempalace.repair.chromadb") +def test_prune_corrupt_no_file(mock_chromadb, tmp_path): + # Should print message and return without error + repair.prune_corrupt(palace_path=str(tmp_path)) + + +@patch("mempalace.repair.chromadb") +def test_prune_corrupt_dry_run(mock_chromadb, tmp_path): + bad_file = tmp_path / "corrupt_ids.txt" + bad_file.write_text("bad1\nbad2\n") + repair.prune_corrupt(palace_path=str(tmp_path), confirm=False) + # No chromadb calls in dry run + mock_chromadb.PersistentClient.assert_not_called() + + +@patch("mempalace.repair.chromadb") +def test_prune_corrupt_confirmed(mock_chromadb, tmp_path): + bad_file = tmp_path / "corrupt_ids.txt" + bad_file.write_text("bad1\nbad2\n") + + mock_col = MagicMock() + mock_col.count.side_effect = [10, 8] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + repair.prune_corrupt(palace_path=str(tmp_path), confirm=True) + mock_col.delete.assert_called_once() + + +@patch("mempalace.repair.chromadb") +def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path): + bad_file = tmp_path / "corrupt_ids.txt" + bad_file.write_text("bad1\nbad2\n") + + mock_col = MagicMock() + mock_col.count.side_effect = [10, 8] + # Batch delete fails, per-id succeeds + mock_col.delete.side_effect = [Exception("batch fail"), None, None] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + repair.prune_corrupt(palace_path=str(tmp_path), confirm=True) + assert mock_col.delete.call_count == 3 # 1 batch + 2 individual + + +# ── rebuild_index ───────────────────────────────────────────────────── + + +@patch("mempalace.repair.chromadb") +def test_rebuild_index_no_palace(mock_chromadb, tmp_path): + nonexistent = str(tmp_path / "nope") + repair.rebuild_index(palace_path=nonexistent) + mock_chromadb.PersistentClient.assert_not_called() + + +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.chromadb") +def test_rebuild_index_empty_palace(mock_chromadb, mock_shutil, tmp_path): + mock_col = MagicMock() + mock_col.count.return_value = 0 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + repair.rebuild_index(palace_path=str(tmp_path)) + mock_client.delete_collection.assert_not_called() + + +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.chromadb") +def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path): + # Create a fake sqlite file + sqlite_path = tmp_path / "chroma.sqlite3" + sqlite_path.write_text("fake") + + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + + mock_new_col = MagicMock() + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_client.create_collection.return_value = mock_new_col + mock_chromadb.PersistentClient.return_value = mock_client + + repair.rebuild_index(palace_path=str(tmp_path)) + + # Verify: backed up sqlite only (not copytree) + mock_shutil.copy2.assert_called_once() + assert "chroma.sqlite3" in str(mock_shutil.copy2.call_args) + + # Verify: deleted and recreated with cosine + mock_client.delete_collection.assert_called_once_with("mempalace_drawers") + mock_client.create_collection.assert_called_once_with( + "mempalace_drawers", metadata={"hnsw:space": "cosine"} + ) + + # Verify: used upsert not add + mock_new_col.upsert.assert_called_once() + mock_new_col.add.assert_not_called() + + +@patch("mempalace.repair.shutil") +@patch("mempalace.repair.chromadb") +def test_rebuild_index_error_reading(mock_chromadb, mock_shutil, tmp_path): + mock_client = MagicMock() + mock_client.get_collection.side_effect = Exception("corrupt") + mock_chromadb.PersistentClient.return_value = mock_client + + repair.rebuild_index(palace_path=str(tmp_path)) + mock_client.delete_collection.assert_not_called()