diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..b112254 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,13 @@ +# Default owners for everything +* @milla-jovovich @bensig @igorls + +# Core library +mempalace/ @milla-jovovich @bensig + +# CI and workflows +.github/ @bensig + +# Plugins and integrations +.claude-plugin/ @bensig +.codex-plugin/ @bensig +integrations/ @bensig diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..220218c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 3 diff --git a/.github/workflows/bump-plugin-version.yml b/.github/workflows/bump-plugin-version.yml.disabled similarity index 100% rename from .github/workflows/bump-plugin-version.yml rename to .github/workflows/bump-plugin-version.yml.disabled diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 88f14cf..9c96883 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: branches: [main] jobs: - test: + test-linux: runs-on: ubuntu-latest strategy: matrix: @@ -18,8 +18,27 @@ jobs: with: python-version: ${{ matrix.python-version }} - run: pip install -e ".[dev]" - - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=30 + - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10 + test-windows: + runs-on: windows-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.9" + - run: pip install -e ".[dev]" + - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10 + + test-macos: + runs-on: macos-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.9" + - run: pip install -e ".[dev]" + - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10 lint: runs-on: ubuntu-latest steps: @@ -27,6 +46,6 @@ jobs: - uses: actions/setup-python@v6 with: python-version: "3.11" - - run: pip install ruff + - run: pip install "ruff>=0.4.0,<0.5" - run: ruff check . - run: ruff format --check . diff --git a/.gitignore b/.gitignore index c8b10cc..1f3b03e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,30 @@ __pycache__/ .pytest_cache/ mempal.yaml .a5c/ + +# Environment +.env +.env.* + +# OS +.DS_Store +Thumbs.db + +# IDEs +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Coverage +htmlcov/ +.coverage +coverage.xml + +# Virtual environments +.venv/ +venv/ + +# ChromaDB local data +*.sqlite3-journal diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..3026013 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,78 @@ +# AGENTS.md + +> How to build, test, and contribute to MemPalace. + +## Setup + +```bash +pip install -e ".[dev]" +``` + +## Commands + +```bash +# Run tests +python -m pytest tests/ -v --ignore=tests/benchmarks + +# Run tests with coverage +python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing + +# Lint +ruff check . + +# Format +ruff format . + +# Format check (CI mode) +ruff format --check . +``` + +## Project structure + +``` +mempalace/ +├── mcp_server.py # MCP server — all read/write tools +├── miner.py # Project file miner +├── convo_miner.py # Conversation transcript miner +├── searcher.py # Semantic search +├── knowledge_graph.py # Temporal entity-relationship graph (SQLite) +├── palace.py # Shared palace operations (ChromaDB access) +├── config.py # Configuration + input validation +├── normalize.py # Transcript format detection + normalization +├── cli.py # CLI dispatcher +├── dialect.py # AAAK compression dialect +├── palace_graph.py # Room traversal + cross-wing tunnels +├── hooks_cli.py # Hook system for auto-save +└── version.py # Single source of truth for version +``` + +## Conventions + +- **Python style**: snake_case for functions/variables, PascalCase for classes +- **Linter**: ruff with E/F/W rules +- **Formatter**: ruff format, double quotes +- **Commits**: conventional commits (`fix:`, `feat:`, `test:`, `docs:`, `ci:`) +- **Tests**: `tests/test_*.py`, fixtures in `tests/conftest.py` +- **Coverage**: 85% threshold (80% on Windows due to ChromaDB file lock cleanup) + +## Architecture + +``` +User → CLI / MCP Server → ChromaDB (vector store) + SQLite (knowledge graph) + +Palace structure: + WING (person/project) + └── ROOM (topic) + └── DRAWER (verbatim text chunk) + +Knowledge Graph: + ENTITY → PREDICATE → ENTITY (with valid_from / valid_to dates) +``` + +## Key files for common tasks + +- **Adding an MCP tool**: `mempalace/mcp_server.py` — add handler function + TOOLS dict entry +- **Changing search**: `mempalace/searcher.py` +- **Modifying mining**: `mempalace/miner.py` (project files) or `mempalace/convo_miner.py` (transcripts) +- **Input validation**: `mempalace/config.py` — `sanitize_name()` / `sanitize_content()` +- **Tests**: mirror source structure in `tests/test_.py` diff --git a/README.md b/README.md index a1a7ccb..c3540e5 100644 --- a/README.md +++ b/README.md @@ -585,6 +585,9 @@ mempalace compress --wing myapp # AAAK compress # Status mempalace status # palace overview + +# MCP +mempalace mcp # show MCP setup command ``` All commands accept `--palace ` to override the default location. @@ -707,7 +710,7 @@ PRs welcome. See [CONTRIBUTING.md](CONTRIBUTING.md) for setup and guidelines. MIT — see [LICENSE](LICENSE). -[version-shield]: https://img.shields.io/badge/version-3.0.0-4dc9f6?style=flat-square&labelColor=0a0e14 +[version-shield]: https://img.shields.io/badge/version-3.1.0-4dc9f6?style=flat-square&labelColor=0a0e14 [release-link]: https://github.com/milla-jovovich/mempalace/releases [python-shield]: https://img.shields.io/badge/python-3.9+-7dd8f8?style=flat-square&labelColor=0a0e14&logo=python&logoColor=7dd8f8 [python-link]: https://www.python.org/ diff --git a/docs/schema.sql b/docs/schema.sql new file mode 100644 index 0000000..740db70 --- /dev/null +++ b/docs/schema.sql @@ -0,0 +1,36 @@ +-- MemPalace Knowledge Graph Schema +-- SQLite database at ~/.mempalace/knowledge_graph.db + +CREATE TABLE IF NOT EXISTS entities ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + type TEXT DEFAULT 'unknown', + properties TEXT DEFAULT '{}' +); + +CREATE TABLE IF NOT EXISTS triples ( + id TEXT PRIMARY KEY, + subject TEXT NOT NULL, + predicate TEXT NOT NULL, + object TEXT NOT NULL, + valid_from TEXT, + valid_to TEXT, + confidence REAL DEFAULT 1.0, + source_closet TEXT, + source_file TEXT +); + +CREATE TABLE IF NOT EXISTS attributes ( + entity_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT, + valid_from TEXT, + valid_to TEXT, + PRIMARY KEY (entity_id, key, valid_from) +); + +-- Indexes +CREATE INDEX IF NOT EXISTS idx_triples_subject ON triples(subject); +CREATE INDEX IF NOT EXISTS idx_triples_object ON triples(object); +CREATE INDEX IF NOT EXISTS idx_triples_predicate ON triples(predicate); +CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to); diff --git a/hooks/mempal_save_hook.sh b/hooks/mempal_save_hook.sh index 75abfc8..a0e4681 100755 --- a/hooks/mempal_save_hook.sh +++ b/hooks/mempal_save_hook.sh @@ -64,13 +64,20 @@ MEMPAL_DIR="" # Read JSON input from stdin INPUT=$(cat) -# Parse fields from Claude Code's JSON -SESSION_ID=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('session_id','unknown'))" 2>/dev/null) -# Sanitize SESSION_ID to prevent path traversal (only allow alnum, dash, underscore) -SESSION_ID=$(echo "$SESSION_ID" | tr -cd 'a-zA-Z0-9_-') -[ -z "$SESSION_ID" ] && SESSION_ID="unknown" -STOP_HOOK_ACTIVE=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('stop_hook_active', False))" 2>/dev/null) -TRANSCRIPT_PATH=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('transcript_path',''))" 2>/dev/null) +# Parse all fields in a single Python call (3x faster than separate invocations) +eval $(echo "$INPUT" | python3 -c " +import sys, json +data = json.load(sys.stdin) +sid = data.get('session_id', 'unknown') +sha = data.get('stop_hook_active', False) +tp = data.get('transcript_path', '') +# Shell-safe output — only allow alphanumeric, underscore, hyphen, slash, dot, tilde +import re +safe = lambda s: re.sub(r'[^a-zA-Z0-9_/.\-~]', '', str(s)) +print(f'SESSION_ID=\"{safe(sid)}\"') +print(f'STOP_HOOK_ACTIVE=\"{sha}\"') +print(f'TRANSCRIPT_PATH=\"{safe(tp)}\"') +" 2>/dev/null) # Expand ~ in path TRANSCRIPT_PATH="${TRANSCRIPT_PATH/#\~/$HOME}" @@ -83,6 +90,7 @@ if [ "$STOP_HOOK_ACTIVE" = "True" ] || [ "$STOP_HOOK_ACTIVE" = "true" ]; then fi # Count human messages in the JSONL transcript +# SECURITY: Pass transcript path as sys.argv to avoid shell injection via crafted paths if [ -f "$TRANSCRIPT_PATH" ]; then EXCHANGE_COUNT=$(python3 - "$TRANSCRIPT_PATH" <<'PYEOF' import json, sys @@ -94,7 +102,6 @@ with open(sys.argv[1]) as f: msg = entry.get('message', {}) if isinstance(msg, dict) and msg.get('role') == 'user': content = msg.get('content', '') - # Skip system/command messages — only count real human input if isinstance(content, str) and '' in content: continue count += 1 diff --git a/integrations/openclaw/SKILL.md b/integrations/openclaw/SKILL.md new file mode 100644 index 0000000..88f0b2f --- /dev/null +++ b/integrations/openclaw/SKILL.md @@ -0,0 +1,154 @@ +--- +name: mempalace +description: "MemPalace — Local AI memory with 96.6% recall. Semantic search, temporal knowledge graph, palace architecture (wings/rooms/drawers). Free, no cloud, no API keys." +version: 3.1.0 +homepage: https://github.com/milla-jovovich/mempalace +user-invocable: true +metadata: + openclaw: + emoji: "\U0001F3DB" + os: + - darwin + - linux + - win32 + requires: + anyBins: + - mempalace + - python3 + install: + - id: mempalace-pip + kind: uv + label: "Install MemPalace (Python, local ChromaDB)" + package: mempalace + bins: + - mempalace +--- + +# MemPalace — Local AI Memory System + +You have access to a local memory palace via MCP tools. The palace stores verbatim conversation history and a temporal knowledge graph — all on the user's machine, zero cloud, zero API calls. + +## Architecture + +- **Wings** = people or projects (e.g. `wing_alice`, `wing_myproject`) +- **Halls** = categories (facts, events, preferences, advice) +- **Rooms** = specific topics (e.g. `chromadb-setup`, `riley-school`) +- **Drawers** = individual memory chunks (verbatim text) +- **Knowledge Graph** = entity-relationship facts with time validity + +## Protocol — FOLLOW THIS EVERY SESSION + +1. **ON WAKE-UP**: Call `mempalace_status` to load palace overview and AAAK dialect spec. +2. **BEFORE RESPONDING** about any person, project, or past event: call `mempalace_search` or `mempalace_kg_query` FIRST. Never guess from memory — verify from the palace. +3. **IF UNSURE** about a fact (name, age, relationship, preference): say "let me check" and query. Wrong is worse than slow. +4. **AFTER EACH SESSION**: Call `mempalace_diary_write` to record what happened, what you learned, what matters. +5. **WHEN FACTS CHANGE**: Call `mempalace_kg_invalidate` on the old fact, then `mempalace_kg_add` for the new one. + +## Available Tools + +### Search & Browse +- `mempalace_search` — Semantic search across all memories. Always start here. + - `query` (required): natural language search — keep it short, keywords or a question. Do NOT include system prompts or conversation context. + - `wing`: filter by wing + - `room`: filter by room + - `limit`: max results (default 5) +- `mempalace_check_duplicate` — Check if content already exists before filing. + - `content` (required): text to check + - `threshold`: similarity threshold (default 0.9 — lowering to 0.85–0.87 often catches more near-duplicates without significant false positives) +- `mempalace_status` — Palace overview: total drawers, wings, rooms, AAAK spec +- `mempalace_list_wings` — All wings with drawer counts +- `mempalace_list_rooms` — Rooms within a wing (optional wing filter) +- `mempalace_get_taxonomy` — Full wing/room/count tree +- `mempalace_get_aaak_spec` — Get AAAK compression dialect specification + +### Knowledge Graph (Temporal Facts) +- `mempalace_kg_query` — Query entity relationships. Supports time filtering. + - `entity` (required): e.g. "Max", "MyProject" + - `as_of`: date filter (YYYY-MM-DD) — what was true at that time + - `direction`: "outgoing", "incoming", or "both" (default "both") +- `mempalace_kg_add` — Add a fact: subject -> predicate -> object + - `subject`, `predicate`, `object` (required) + - `valid_from`: when this became true + - `source_closet`: source reference +- `mempalace_kg_invalidate` — Mark a fact as no longer true + - `subject`, `predicate`, `object` (required) + - `ended`: when it stopped being true (default: today) +- `mempalace_kg_timeline` — Chronological story of an entity + - `entity`: filter by entity name (optional — all events if omitted) +- `mempalace_kg_stats` — Graph overview: entities, triples, relationship types + +### Palace Graph (Cross-Domain Connections) +- `mempalace_traverse` — Walk from a room, find connected ideas across wings + - `start_room` (required): room to start from + - `max_hops`: connection depth (default 2) +- `mempalace_find_tunnels` — Find rooms that bridge two wings + - `wing_a`, `wing_b` (required) +- `mempalace_graph_stats` — Graph connectivity overview + +### Write +- `mempalace_add_drawer` — Store verbatim content into a wing/room + - `wing`, `room`, `content` (required) + - `source_file`: optional source reference + - Checks for duplicates automatically +- `mempalace_delete_drawer` — Remove a drawer by ID + - `drawer_id` (required) +- `mempalace_diary_write` — Write a session diary entry + - `agent_name` (required): your name/identifier + - `entry` (required): what happened, what you learned, what matters + - `topic`: category tag (default "general") +- `mempalace_diary_read` — Read recent diary entries + - `agent_name` (required) + - `last_n`: number of entries (default 10) + +## Setup + +Install MemPalace and populate the palace: + +```bash +pip install mempalace +mempalace init ~/my-convos +mempalace mine ~/my-convos +``` + +### OpenClaw MCP config + +Add to your OpenClaw MCP configuration: + +```json +{ + "mcpServers": { + "mempalace": { + "command": "python3", + "args": ["-m", "mempalace.mcp_server"] + } + } +} +``` + +Or via CLI: + +```bash +openclaw mcp set mempalace '{"command":"python3","args":["-m","mempalace.mcp_server"]}' +``` + +### Other MCP hosts + +```bash +# Claude Code +claude mcp add mempalace -- python -m mempalace.mcp_server + +# Cursor — add to .cursor/mcp.json +# Codex — add to .codex/mcp.json +``` + +## Tips + +- Search is semantic (meaning-based), not keyword. "What did we discuss about database performance?" works better than "database". +- The knowledge graph stores typed relationships with time windows. Use it for facts about people and projects — it knows WHEN things were true. +- Diary entries accumulate across sessions. Write one at the end of each conversation to build continuity. +- Use `mempalace_check_duplicate` before storing new content to avoid duplicates. +- The AAAK dialect (from `mempalace_status`) is a compressed notation for efficient storage. Read it naturally — expand codes mentally, treat *markers* as emotional context. + +## License + +[MemPalace](https://github.com/milla-jovovich/mempalace) is MIT licensed. Created by Milla Jovovich, Ben Sigman, Igor Lins e Silva, and contributors. diff --git a/mempalace/cli.py b/mempalace/cli.py index 0a24abf..1d106ca 100644 --- a/mempalace/cli.py +++ b/mempalace/cli.py @@ -14,6 +14,7 @@ Commands: mempalace mine Mine project files (default) mempalace mine --mode convos Mine conversation exports mempalace search "query" Find anything, exact words + mempalace mcp Show MCP setup command mempalace wake-up Show L0 + L1 wake-up context mempalace wake-up --wing my_app Wake-up for a specific project mempalace status Show what's been filed @@ -28,6 +29,7 @@ Examples: import os import sys +import shlex import argparse from pathlib import Path @@ -148,6 +150,14 @@ def cmd_split(args): sys.argv = old_argv +def cmd_migrate(args): + """Migrate palace from a different ChromaDB version.""" + from .migrate import migrate + + palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path + migrate(palace_path=palace_path, dry_run=args.dry_run) + + def cmd_status(args): from .miner import status @@ -202,6 +212,7 @@ def cmd_repair(args): print(f" Extracted {len(all_ids)} drawers") # Backup and rebuild + palace_path = palace_path.rstrip(os.sep) backup_path = palace_path + ".backup" if os.path.exists(backup_path): shutil.rmtree(backup_path) @@ -240,6 +251,27 @@ def cmd_instructions(args): run_instructions(name=args.name) +def cmd_mcp(args): + """Show how to wire MemPalace into MCP-capable hosts.""" + base_server_cmd = "python -m mempalace.mcp_server" + + if args.palace: + resolved_palace = str(Path(args.palace).expanduser()) + server_cmd = f"{base_server_cmd} --palace {shlex.quote(resolved_palace)}" + else: + server_cmd = base_server_cmd + + print("MemPalace MCP quick setup:") + print(f" claude mcp add mempalace -- {server_cmd}") + print("\nRun the server directly:") + print(f" {server_cmd}") + + if not args.palace: + print("\nOptional custom palace:") + print(f" claude mcp add mempalace -- {base_server_cmd} --palace /path/to/palace") + print(f" {base_server_cmd} --palace /path/to/palace") + + def cmd_compress(args): """Compress drawers in a wing using AAAK Dialect.""" import chromadb @@ -500,7 +532,24 @@ def main(): help="Rebuild palace vector index from stored data (fixes segfaults after corruption)", ) + # mcp + sub.add_parser( + "mcp", + help="Show MCP setup command for connecting MemPalace to your AI client", + ) + # status + # migrate + p_migrate = sub.add_parser( + "migrate", + help="Migrate palace from a different ChromaDB version (fixes 3.0.0 → 3.1.0 upgrade)", + ) + p_migrate.add_argument( + "--dry-run", + action="store_true", + help="Show what would be migrated without changing anything", + ) + sub.add_parser("status", help="Show what's been filed") args = parser.parse_args() @@ -531,9 +580,11 @@ def main(): "mine": cmd_mine, "split": cmd_split, "search": cmd_search, + "mcp": cmd_mcp, "compress": cmd_compress, "wake-up": cmd_wakeup, "repair": cmd_repair, + "migrate": cmd_migrate, "status": cmd_status, } dispatch[args.command](args) diff --git a/mempalace/config.py b/mempalace/config.py index 5a73650..fcfb2c8 100644 --- a/mempalace/config.py +++ b/mempalace/config.py @@ -6,8 +6,58 @@ Priority: env vars > config file (~/.mempalace/config.json) > defaults import json import os +import re from pathlib import Path + +# ── Input validation ────────────────────────────────────────────────────────── +# Shared sanitizers for wing/room/entity names. Prevents path traversal, +# excessively long strings, and special characters that could cause issues +# in file paths, SQLite, or ChromaDB metadata. + +MAX_NAME_LENGTH = 128 +_SAFE_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_ .'-]{0,126}[a-zA-Z0-9]?$") + + +def sanitize_name(value: str, field_name: str = "name") -> str: + """Validate and sanitize a wing/room/entity name. + + Raises ValueError if the name is invalid. + """ + if not isinstance(value, str) or not value.strip(): + raise ValueError(f"{field_name} must be a non-empty string") + + value = value.strip() + + if len(value) > MAX_NAME_LENGTH: + raise ValueError(f"{field_name} exceeds maximum length of {MAX_NAME_LENGTH} characters") + + # Block path traversal + if ".." in value or "/" in value or "\\" in value: + raise ValueError(f"{field_name} contains invalid path characters") + + # Block null bytes + if "\x00" in value: + raise ValueError(f"{field_name} contains null bytes") + + # Enforce safe character set + if not _SAFE_NAME_RE.match(value): + raise ValueError(f"{field_name} contains invalid characters") + + return value + + +def sanitize_content(value: str, max_length: int = 100_000) -> str: + """Validate drawer/diary content length.""" + if not isinstance(value, str) or not value.strip(): + raise ValueError("content must be a non-empty string") + if len(value) > max_length: + raise ValueError(f"content exceeds maximum length of {max_length} characters") + if "\x00" in value: + raise ValueError("content contains null bytes") + return value + + DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace") DEFAULT_COLLECTION_NAME = "mempalace_drawers" @@ -126,6 +176,11 @@ class MempalaceConfig: def init(self): """Create config directory and write default config.json if it doesn't exist.""" self._config_dir.mkdir(parents=True, exist_ok=True) + # Restrict directory permissions to owner only (Unix) + try: + self._config_dir.chmod(0o700) + except (OSError, NotImplementedError): + pass # Windows doesn't support Unix permissions if not self._config_file.exists(): default_config = { "palace_path": DEFAULT_PALACE_PATH, @@ -135,6 +190,11 @@ class MempalaceConfig: } with open(self._config_file, "w") as f: json.dump(default_config, f, indent=2) + # Restrict config file to owner read/write only + try: + self._config_file.chmod(0o600) + except (OSError, NotImplementedError): + pass return self._config_file def save_people_map(self, people_map): diff --git a/mempalace/convo_miner.py b/mempalace/convo_miner.py index c316407..7879f96 100644 --- a/mempalace/convo_miner.py +++ b/mempalace/convo_miner.py @@ -15,9 +15,8 @@ from pathlib import Path from datetime import datetime from collections import defaultdict -import chromadb - from .normalize import normalize +from .palace import SKIP_DIRS, get_collection, file_already_mined # File types that might contain conversations @@ -28,22 +27,8 @@ CONVO_EXTENSIONS = { ".jsonl", } -SKIP_DIRS = { - ".git", - "node_modules", - "__pycache__", - ".venv", - "venv", - "env", - "dist", - "build", - ".next", - ".mempalace", - "tool-results", - "memory", -} - MIN_CHUNK_SIZE = 30 +MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this # ============================================================================= @@ -211,23 +196,6 @@ def detect_convo_room(content: str) -> str: # ============================================================================= -def get_collection(palace_path: str): - os.makedirs(palace_path, exist_ok=True) - client = chromadb.PersistentClient(path=palace_path) - try: - return client.get_collection("mempalace_drawers") - except Exception: - return client.create_collection("mempalace_drawers") - - -def file_already_mined(collection, source_file: str) -> bool: - try: - results = collection.get(where={"source_file": source_file}, limit=1) - return len(results.get("ids", [])) > 0 - except Exception: - return False - - # ============================================================================= # SCAN FOR CONVERSATION FILES # ============================================================================= @@ -244,6 +212,14 @@ def scan_convos(convo_dir: str) -> list: continue filepath = Path(root) / filename if filepath.suffix.lower() in CONVO_EXTENSIONS: + # Skip symlinks and oversized files + if filepath.is_symlink(): + continue + try: + if filepath.stat().st_size > MAX_FILE_SIZE: + continue + except OSError: + continue files.append(filepath) return files @@ -356,7 +332,7 @@ def mine_convos( chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room if extract_mode == "general": room_counts[chunk_room] += 1 - drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.md5((source_file + str(chunk['chunk_index'])).encode(), usedforsecurity=False).hexdigest()[:16]}" + drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}" try: collection.add( documents=[chunk["content"]], diff --git a/mempalace/entity_registry.py b/mempalace/entity_registry.py index 24fef0a..2a4ad8d 100644 --- a/mempalace/entity_registry.py +++ b/mempalace/entity_registry.py @@ -309,7 +309,7 @@ class EntityRegistry: def save(self): self._path.parent.mkdir(parents=True, exist_ok=True) - self._path.write_text(json.dumps(self._data, indent=2)) + self._path.write_text(json.dumps(self._data, indent=2), encoding="utf-8") @staticmethod def _empty() -> dict: diff --git a/mempalace/hooks_cli.py b/mempalace/hooks_cli.py index d9408ac..b6d2290 100644 --- a/mempalace/hooks_cli.py +++ b/mempalace/hooks_cli.py @@ -158,7 +158,7 @@ def hook_stop(data: dict, harness: str): if since_last >= SAVE_INTERVAL and exchange_count > 0: # Update last save point try: - last_save_file.write_text(str(exchange_count)) + last_save_file.write_text(str(exchange_count), encoding="utf-8") except OSError: pass diff --git a/mempalace/instructions/help.md b/mempalace/instructions/help.md index f18c1de..5cb70fa 100644 --- a/mempalace/instructions/help.md +++ b/mempalace/instructions/help.md @@ -60,6 +60,7 @@ AI memory system. Store everything, find anything. Local, free, no API key. mempalace compress Compress palace storage mempalace status Show palace status mempalace repair Rebuild vector index + mempalace mcp Show MCP setup command mempalace hook run Run hook logic (for harness integration) mempalace instructions Output skill instructions diff --git a/mempalace/knowledge_graph.py b/mempalace/knowledge_graph.py index 226c92d..b094f06 100644 --- a/mempalace/knowledge_graph.py +++ b/mempalace/knowledge_graph.py @@ -50,11 +50,14 @@ class KnowledgeGraph: def __init__(self, db_path: str = None): self.db_path = db_path or DEFAULT_KG_PATH Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) + self._connection = None self._init_db() def _init_db(self): conn = self._conn() conn.executescript(""" + PRAGMA journal_mode=WAL; + CREATE TABLE IF NOT EXISTS entities ( id TEXT PRIMARY KEY, name TEXT NOT NULL, @@ -84,12 +87,19 @@ class KnowledgeGraph: CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to); """) conn.commit() - conn.close() def _conn(self): - conn = sqlite3.connect(self.db_path, timeout=10) - conn.execute("PRAGMA journal_mode=WAL") - return conn + if self._connection is None: + self._connection = sqlite3.connect(self.db_path, timeout=10, check_same_thread=False) + self._connection.execute("PRAGMA journal_mode=WAL") + self._connection.row_factory = sqlite3.Row + return self._connection + + def close(self): + """Close the database connection.""" + if self._connection is not None: + self._connection.close() + self._connection = None def _entity_id(self, name: str) -> str: return name.lower().replace(" ", "_").replace("'", "") @@ -101,12 +111,11 @@ class KnowledgeGraph: eid = self._entity_id(name) props = json.dumps(properties or {}) conn = self._conn() - conn.execute( - "INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)", - (eid, name, entity_type, props), - ) - conn.commit() - conn.close() + with conn: + conn.execute( + "INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)", + (eid, name, entity_type, props), + ) return eid def add_triple( @@ -134,38 +143,38 @@ class KnowledgeGraph: # Auto-create entities if they don't exist conn = self._conn() - conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject)) - conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj)) + with conn: + conn.execute( + "INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject) + ) + conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj)) - # Check for existing identical triple - existing = conn.execute( - "SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", - (sub_id, pred, obj_id), - ).fetchone() + # Check for existing identical triple + existing = conn.execute( + "SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", + (sub_id, pred, obj_id), + ).fetchone() - if existing: - conn.close() - return existing[0] # Already exists and still valid + if existing: + return existing["id"] # Already exists and still valid - triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.md5(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:8]}" + triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.sha256(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:12]}" - conn.execute( - """INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - triple_id, - sub_id, - pred, - obj_id, - valid_from, - valid_to, - confidence, - source_closet, - source_file, - ), - ) - conn.commit() - conn.close() + conn.execute( + """INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + triple_id, + sub_id, + pred, + obj_id, + valid_from, + valid_to, + confidence, + source_closet, + source_file, + ), + ) return triple_id def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None): @@ -176,12 +185,11 @@ class KnowledgeGraph: ended = ended or date.today().isoformat() conn = self._conn() - conn.execute( - "UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", - (ended, sub_id, pred, obj_id), - ) - conn.commit() - conn.close() + with conn: + conn.execute( + "UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", + (ended, sub_id, pred, obj_id), + ) # ── Query operations ────────────────────────────────────────────────── @@ -208,13 +216,13 @@ class KnowledgeGraph: { "direction": "outgoing", "subject": name, - "predicate": row[2], - "object": row[10], # obj_name - "valid_from": row[4], - "valid_to": row[5], - "confidence": row[6], - "source_closet": row[7], - "current": row[5] is None, + "predicate": row["predicate"], + "object": row["obj_name"], + "valid_from": row["valid_from"], + "valid_to": row["valid_to"], + "confidence": row["confidence"], + "source_closet": row["source_closet"], + "current": row["valid_to"] is None, } ) @@ -228,18 +236,17 @@ class KnowledgeGraph: results.append( { "direction": "incoming", - "subject": row[10], # sub_name - "predicate": row[2], + "subject": row["sub_name"], + "predicate": row["predicate"], "object": name, - "valid_from": row[4], - "valid_to": row[5], - "confidence": row[6], - "source_closet": row[7], - "current": row[5] is None, + "valid_from": row["valid_from"], + "valid_to": row["valid_to"], + "confidence": row["confidence"], + "source_closet": row["source_closet"], + "current": row["valid_to"] is None, } ) - conn.close() return results def query_relationship(self, predicate: str, as_of: str = None): @@ -262,15 +269,14 @@ class KnowledgeGraph: for row in conn.execute(query, params).fetchall(): results.append( { - "subject": row[10], + "subject": row["sub_name"], "predicate": pred, - "object": row[11], - "valid_from": row[4], - "valid_to": row[5], - "current": row[5] is None, + "object": row["obj_name"], + "valid_from": row["valid_from"], + "valid_to": row["valid_to"], + "current": row["valid_to"] is None, } ) - conn.close() return results def timeline(self, entity_name: str = None): @@ -300,15 +306,14 @@ class KnowledgeGraph: LIMIT 100 """).fetchall() - conn.close() return [ { - "subject": r[10], - "predicate": r[2], - "object": r[11], - "valid_from": r[4], - "valid_to": r[5], - "current": r[5] is None, + "subject": r["sub_name"], + "predicate": r["predicate"], + "object": r["obj_name"], + "valid_from": r["valid_from"], + "valid_to": r["valid_to"], + "current": r["valid_to"] is None, } for r in rows ] @@ -317,17 +322,18 @@ class KnowledgeGraph: def stats(self): conn = self._conn() - entities = conn.execute("SELECT COUNT(*) FROM entities").fetchone()[0] - triples = conn.execute("SELECT COUNT(*) FROM triples").fetchone()[0] - current = conn.execute("SELECT COUNT(*) FROM triples WHERE valid_to IS NULL").fetchone()[0] + entities = conn.execute("SELECT COUNT(*) as cnt FROM entities").fetchone()["cnt"] + triples = conn.execute("SELECT COUNT(*) as cnt FROM triples").fetchone()["cnt"] + current = conn.execute( + "SELECT COUNT(*) as cnt FROM triples WHERE valid_to IS NULL" + ).fetchone()["cnt"] expired = triples - current predicates = [ - r[0] + r["predicate"] for r in conn.execute( "SELECT DISTINCT predicate FROM triples ORDER BY predicate" ).fetchall() ] - conn.close() return { "entities": entities, "triples": triples, diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index 7d263a6..bffd3b2 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -24,8 +24,9 @@ import json import logging import hashlib from datetime import datetime +from pathlib import Path -from .config import MempalaceConfig +from .config import MempalaceConfig, sanitize_name, sanitize_content from .version import __version__ from .searcher import search_memories from .palace_graph import traverse, find_tunnels, graph_stats @@ -44,7 +45,9 @@ def _parse_args(): metavar="PATH", help="Path to the palace directory (overrides config file and env var)", ) - args, _ = parser.parse_known_args() + args, unknown = parser.parse_known_args() + if unknown: + logger.debug("Ignoring unknown args: %s", unknown) return args @@ -64,16 +67,60 @@ _client_cache = None _collection_cache = None +# ==================== WRITE-AHEAD LOG ==================== +# Every write operation is logged to a JSONL file before execution. +# This provides an audit trail for detecting memory poisoning and +# enables review/rollback of writes from external or untrusted sources. + +_WAL_DIR = Path(os.path.expanduser("~/.mempalace/wal")) +_WAL_DIR.mkdir(parents=True, exist_ok=True) +try: + _WAL_DIR.chmod(0o700) +except (OSError, NotImplementedError): + pass +_WAL_FILE = _WAL_DIR / "write_log.jsonl" + + +def _wal_log(operation: str, params: dict, result: dict = None): + """Append a write operation to the write-ahead log.""" + entry = { + "timestamp": datetime.now().isoformat(), + "operation": operation, + "params": params, + "result": result, + } + try: + with open(_WAL_FILE, "a", encoding="utf-8") as f: + f.write(json.dumps(entry, default=str) + "\n") + try: + _WAL_FILE.chmod(0o600) + except (OSError, NotImplementedError): + pass + except Exception as e: + logger.error(f"WAL write failed: {e}") + + +_client_cache = None +_collection_cache = None + + +def _get_client(): + """Return a singleton ChromaDB PersistentClient.""" + global _client_cache + if _client_cache is None: + _client_cache = chromadb.PersistentClient(path=_config.palace_path) + return _client_cache + + def _get_collection(create=False): """Return the ChromaDB collection, caching the client between calls.""" - global _client_cache, _collection_cache + global _collection_cache try: - if _client_cache is None: - _client_cache = chromadb.PersistentClient(path=_config.palace_path) + client = _get_client() if create: - _collection_cache = _client_cache.get_or_create_collection(_config.collection_name) + _collection_cache = client.get_or_create_collection(_config.collection_name) elif _collection_cache is None: - _collection_cache = _client_cache.get_collection(_config.collection_name) + _collection_cache = client.get_collection(_config.collection_name) return _collection_cache except Exception: return None @@ -280,11 +327,30 @@ def tool_add_drawer( wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp" ): """File verbatim content into a wing/room. Checks for duplicates first.""" + try: + wing = sanitize_name(wing, "wing") + room = sanitize_name(room, "room") + content = sanitize_content(content) + except ValueError as e: + return {"success": False, "error": str(e)} + col = _get_collection(create=True) if not col: return _no_palace() - drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}" + drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((wing + room + content[:100]).encode()).hexdigest()[:24]}" + + _wal_log( + "add_drawer", + { + "drawer_id": drawer_id, + "wing": wing, + "room": room, + "added_by": added_by, + "content_length": len(content), + "content_preview": content[:200], + }, + ) # Idempotency: if the deterministic ID already exists, return success as a no-op. try: @@ -323,6 +389,19 @@ def tool_delete_drawer(drawer_id: str): existing = col.get(ids=[drawer_id]) if not existing["ids"]: return {"success": False, "error": f"Drawer not found: {drawer_id}"} + + # Log the deletion with the content being removed for audit trail + deleted_content = existing.get("documents", [""])[0] if existing.get("documents") else "" + deleted_meta = existing.get("metadatas", [{}])[0] if existing.get("metadatas") else {} + _wal_log( + "delete_drawer", + { + "drawer_id": drawer_id, + "deleted_meta": deleted_meta, + "content_preview": deleted_content[:200], + }, + ) + try: col.delete(ids=[drawer_id]) logger.info(f"Deleted drawer: {drawer_id}") @@ -344,6 +423,23 @@ def tool_kg_add( subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None ): """Add a relationship to the knowledge graph.""" + try: + subject = sanitize_name(subject, "subject") + predicate = sanitize_name(predicate, "predicate") + object = sanitize_name(object, "object") + except ValueError as e: + return {"success": False, "error": str(e)} + + _wal_log( + "kg_add", + { + "subject": subject, + "predicate": predicate, + "object": object, + "valid_from": valid_from, + "source_closet": source_closet, + }, + ) triple_id = _kg.add_triple( subject, predicate, object, valid_from=valid_from, source_closet=source_closet ) @@ -352,6 +448,10 @@ def tool_kg_add( def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None): """Mark a fact as no longer true (set end date).""" + _wal_log( + "kg_invalidate", + {"subject": subject, "predicate": predicate, "object": object, "ended": ended}, + ) _kg.invalidate(subject, predicate, object, ended=ended) return { "success": True, @@ -382,6 +482,12 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"): This is the agent's personal journal — observations, thoughts, what it worked on, what it noticed, what it thinks matters. """ + try: + agent_name = sanitize_name(agent_name, "agent_name") + entry = sanitize_content(entry) + except ValueError as e: + return {"success": False, "error": str(e)} + wing = f"wing_{agent_name.lower().replace(' ', '_')}" room = "diary" col = _get_collection(create=True) @@ -389,9 +495,23 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"): return _no_palace() now = datetime.now() - entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.md5(entry[:50].encode()).hexdigest()[:8]}" + entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.sha256(entry[:50].encode()).hexdigest()[:12]}" + + _wal_log( + "diary_write", + { + "agent_name": agent_name, + "topic": topic, + "entry_id": entry_id, + "entry_preview": entry[:200], + }, + ) try: + # TODO: Future versions should expand AAAK before embedding to improve + # semantic search quality. For now, store raw AAAK in metadata so it's + # preserved, and keep the document as-is for embedding (even though + # compressed AAAK degrades embedding quality). col.add( ids=[entry_id], documents=[entry], @@ -717,17 +837,31 @@ TOOLS = { } +SUPPORTED_PROTOCOL_VERSIONS = [ + "2025-11-25", + "2025-06-18", + "2025-03-26", + "2024-11-05", +] + + def handle_request(request): method = request.get("method", "") params = request.get("params", {}) req_id = request.get("id") if method == "initialize": + client_version = params.get("protocolVersion", SUPPORTED_PROTOCOL_VERSIONS[-1]) + negotiated = ( + client_version + if client_version in SUPPORTED_PROTOCOL_VERSIONS + else SUPPORTED_PROTOCOL_VERSIONS[0] + ) return { "jsonrpc": "2.0", "id": req_id, "result": { - "protocolVersion": "2024-11-05", + "protocolVersion": negotiated, "capabilities": {"tools": {}}, "serverInfo": {"name": "mempalace", "version": __version__}, }, @@ -747,7 +881,7 @@ def handle_request(request): } elif method == "tools/call": tool_name = params.get("name") - tool_args = params.get("arguments", {}) + tool_args = params.get("arguments") or {} if tool_name not in TOOLS: return { "jsonrpc": "2.0", diff --git a/mempalace/migrate.py b/mempalace/migrate.py new file mode 100644 index 0000000..848ab67 --- /dev/null +++ b/mempalace/migrate.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +mempalace migrate — Recover a palace created with a different ChromaDB version. + +Reads documents and metadata directly from the palace's SQLite database +(bypassing ChromaDB's API, which fails on version-mismatched palaces), +then re-imports everything into a fresh palace using the currently installed +ChromaDB version. + +This fixes the 3.0.0 → 3.1.0 upgrade path where chromadb was downgraded +from 1.5.x to 0.6.x, breaking the on-disk storage format. + +Usage: + mempalace migrate # migrate default palace + mempalace migrate --palace /path/to/palace # migrate specific palace + mempalace migrate --dry-run # show what would be migrated +""" + +import os +import shutil +import sqlite3 +from collections import defaultdict +from datetime import datetime + + +def extract_drawers_from_sqlite(db_path: str) -> list: + """Read all drawers directly from ChromaDB's SQLite, bypassing the API. + + Works regardless of which ChromaDB version created the database. + Returns list of dicts with 'id', 'document', and 'metadata' keys. + """ + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + + # Get all embedding IDs and their documents + rows = conn.execute(""" + SELECT e.embedding_id, + MAX(CASE WHEN em.key = 'chroma:document' THEN em.string_value END) as document + FROM embeddings e + JOIN embedding_metadata em ON em.id = e.id + GROUP BY e.embedding_id + """).fetchall() + + drawers = [] + for row in rows: + embedding_id = row["embedding_id"] + document = row["document"] + if not document: + continue + + # Get metadata for this embedding + meta_rows = conn.execute( + """ + SELECT em.key, em.string_value, em.int_value, em.float_value, em.bool_value + FROM embedding_metadata em + JOIN embeddings e ON e.id = em.id + WHERE e.embedding_id = ? + AND em.key NOT LIKE 'chroma:%' + """, + (embedding_id,), + ).fetchall() + + metadata = {} + for mr in meta_rows: + key = mr["key"] + if mr["string_value"] is not None: + metadata[key] = mr["string_value"] + elif mr["int_value"] is not None: + metadata[key] = mr["int_value"] + elif mr["float_value"] is not None: + metadata[key] = mr["float_value"] + elif mr["bool_value"] is not None: + metadata[key] = bool(mr["bool_value"]) + + drawers.append( + { + "id": embedding_id, + "document": document, + "metadata": metadata, + } + ) + + conn.close() + return drawers + + +def detect_chromadb_version(db_path: str) -> str: + """Detect which ChromaDB version created the database by checking schema.""" + conn = sqlite3.connect(db_path) + try: + # 1.x has schema_str column in collections table + cols = [r[1] for r in conn.execute("PRAGMA table_info(collections)").fetchall()] + if "schema_str" in cols: + return "1.x" + # 0.6.x has embeddings_queue but no schema_str + tables = [ + r[0] + for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() + ] + if "embeddings_queue" in tables: + return "0.6.x" + return "unknown" + finally: + conn.close() + + +def migrate(palace_path: str, dry_run: bool = False): + """Migrate a palace to the currently installed ChromaDB version.""" + import chromadb + + palace_path = os.path.expanduser(palace_path) + db_path = os.path.join(palace_path, "chroma.sqlite3") + + if not os.path.isfile(db_path): + print(f"\n No palace database found at {db_path}") + return False + + print(f"\n{'=' * 60}") + print(" MemPalace Migrate") + print(f"{'=' * 60}\n") + print(f" Palace: {palace_path}") + print(f" Database: {db_path}") + print(f" DB size: {os.path.getsize(db_path) / 1024 / 1024:.1f} MB") + + # Detect version + source_version = detect_chromadb_version(db_path) + print(f" Source: ChromaDB {source_version}") + print(f" Target: ChromaDB {chromadb.__version__}") + + # Try reading with current chromadb first + try: + client = chromadb.PersistentClient(path=palace_path) + col = client.get_collection("mempalace_drawers") + count = col.count() + print(f"\n Palace is already readable by chromadb {chromadb.__version__}.") + print(f" {count} drawers found. No migration needed.") + return True + except Exception: + print(f"\n Palace is NOT readable by chromadb {chromadb.__version__}.") + print(" Extracting from SQLite directly...") + + # Extract all drawers via raw SQL + drawers = extract_drawers_from_sqlite(db_path) + print(f" Extracted {len(drawers)} drawers from SQLite") + + if not drawers: + print(" Nothing to migrate.") + return True + + # Show summary + wings = defaultdict(lambda: defaultdict(int)) + for d in drawers: + w = d["metadata"].get("wing", "?") + r = d["metadata"].get("room", "?") + wings[w][r] += 1 + + print("\n Summary:") + for wing, rooms in sorted(wings.items()): + total = sum(rooms.values()) + print(f" WING: {wing} ({total} drawers)") + for room, count in sorted(rooms.items(), key=lambda x: -x[1]): + print(f" ROOM: {room:30} {count:5}") + + if dry_run: + print("\n DRY RUN — no changes made.") + print(f" Would migrate {len(drawers)} drawers.") + return True + + # Backup the old palace + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_path = f"{palace_path}.pre-migrate.{timestamp}" + print(f"\n Backing up to {backup_path}...") + shutil.copytree(palace_path, backup_path) + + # Build fresh palace in a temp directory (avoids chromadb reading old state) + import tempfile + + temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_") + print(f" Creating fresh palace in {temp_palace}...") + client = chromadb.PersistentClient(path=temp_palace) + col = client.get_or_create_collection("mempalace_drawers") + + # Re-import in batches + batch_size = 500 + imported = 0 + for i in range(0, len(drawers), batch_size): + batch = drawers[i : i + batch_size] + col.add( + ids=[d["id"] for d in batch], + documents=[d["document"] for d in batch], + metadatas=[d["metadata"] for d in batch], + ) + imported += len(batch) + print(f" Imported {imported}/{len(drawers)} drawers...") + + # Verify before swapping + final_count = col.count() + del col + del client + + # Swap: remove old palace, move new one into place + print(" Swapping old palace for migrated version...") + shutil.rmtree(palace_path) + shutil.move(temp_palace, palace_path) + + print("\n Migration complete.") + print(f" Drawers migrated: {final_count}") + print(f" Backup at: {backup_path}") + + if final_count != len(drawers): + print(f" WARNING: Expected {len(drawers)}, got {final_count}") + + print(f"\n{'=' * 60}\n") + return True diff --git a/mempalace/miner.py b/mempalace/miner.py index 66fbe03..b52e6f7 100644 --- a/mempalace/miner.py +++ b/mempalace/miner.py @@ -17,6 +17,8 @@ from collections import defaultdict import chromadb +from .palace import SKIP_DIRS, get_collection, file_already_mined + READABLE_EXTENSIONS = { ".txt", ".md", @@ -40,32 +42,6 @@ READABLE_EXTENSIONS = { ".toml", } -SKIP_DIRS = { - ".git", - "node_modules", - "__pycache__", - ".venv", - "venv", - "env", - "dist", - "build", - ".next", - "coverage", - ".mempalace", - ".ruff_cache", - ".mypy_cache", - ".pytest_cache", - ".cache", - ".tox", - ".nox", - ".idea", - ".vscode", - ".ipynb_checkpoints", - ".eggs", - "htmlcov", - "target", -} - SKIP_FILENAMES = { "mempalace.yaml", "mempalace.yml", @@ -78,6 +54,7 @@ SKIP_FILENAMES = { CHUNK_SIZE = 800 # chars per drawer CHUNK_OVERLAP = 100 # overlap between chunks MIN_CHUNK_SIZE = 50 # skip tiny chunks +MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this # ============================================================================= @@ -393,41 +370,11 @@ def chunk_text(content: str, source_file: str) -> list: # ============================================================================= -def get_collection(palace_path: str): - os.makedirs(palace_path, exist_ok=True) - client = chromadb.PersistentClient(path=palace_path) - try: - return client.get_collection("mempalace_drawers") - except Exception: - return client.create_collection("mempalace_drawers") - - -def file_already_mined(collection, source_file: str) -> bool: - """Fast check: has this file been filed before and is unchanged? - - Compares the stored mtime in drawer metadata against the file's current - mtime. Returns False (needs re-mining) when the file has been modified - since it was last mined, or when no mtime was stored. - """ - try: - results = collection.get(where={"source_file": source_file}, limit=1) - if not results.get("ids"): - return False - stored_meta = results["metadatas"][0] if results.get("metadatas") else {} - stored_mtime = stored_meta.get("source_mtime") - if stored_mtime is None: - return False - current_mtime = os.path.getmtime(source_file) - return float(stored_mtime) == current_mtime - except Exception: - return False - - def add_drawer( collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str ): """Add one drawer to the palace.""" - drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}" + drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((source_file + str(chunk_index)).encode()).hexdigest()[:24]}" try: metadata = { "wing": wing, @@ -470,7 +417,7 @@ def process_file( # Skip if already filed source_file = str(filepath) - if not dry_run and file_already_mined(collection, source_file): + if not dry_run and file_already_mined(collection, source_file, check_mtime=True): return 0, None try: @@ -562,6 +509,15 @@ def scan_project( if respect_gitignore and active_matchers and not force_include: if is_gitignored(filepath, active_matchers, is_dir=False): continue + # Skip symlinks — prevents following links to /dev/urandom, etc. + if filepath.is_symlink(): + continue + # Skip files exceeding size limit + try: + if filepath.stat().st_size > MAX_FILE_SIZE: + continue + except OSError: + continue files.append(filepath) return files diff --git a/mempalace/normalize.py b/mempalace/normalize.py index ac11469..a894500 100644 --- a/mempalace/normalize.py +++ b/mempalace/normalize.py @@ -25,6 +25,12 @@ def normalize(filepath: str) -> str: Load a file and normalize to transcript format if it's a chat export. Plain text files pass through unchanged. """ + try: + file_size = os.path.getsize(filepath) + except OSError as e: + raise IOError(f"Could not read {filepath}: {e}") + if file_size > 500 * 1024 * 1024: # 500 MB safety limit + raise IOError(f"File too large ({file_size // (1024*1024)} MB): {filepath}") try: with open(filepath, "r", encoding="utf-8", errors="replace") as f: content = f.read() diff --git a/mempalace/onboarding.py b/mempalace/onboarding.py index f578d91..70f7b54 100644 --- a/mempalace/onboarding.py +++ b/mempalace/onboarding.py @@ -312,7 +312,7 @@ def _generate_aaak_bootstrap( ] ) - (mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines)) + (mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines), encoding="utf-8") # Critical facts bootstrap (pre-palace — before any mining) facts_lines = [ @@ -359,7 +359,7 @@ def _generate_aaak_bootstrap( ] ) - (mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines)) + (mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines), encoding="utf-8") def run_onboarding( diff --git a/mempalace/palace.py b/mempalace/palace.py new file mode 100644 index 0000000..6ddf190 --- /dev/null +++ b/mempalace/palace.py @@ -0,0 +1,71 @@ +""" +palace.py — Shared palace operations. + +Consolidates ChromaDB access patterns used by both miners and the MCP server. +""" + +import os +import chromadb + +SKIP_DIRS = { + ".git", + "node_modules", + "__pycache__", + ".venv", + "venv", + "env", + "dist", + "build", + ".next", + "coverage", + ".mempalace", + ".ruff_cache", + ".mypy_cache", + ".pytest_cache", + ".cache", + ".tox", + ".nox", + ".idea", + ".vscode", + ".ipynb_checkpoints", + ".eggs", + "htmlcov", + "target", +} + + +def get_collection(palace_path: str, collection_name: str = "mempalace_drawers"): + """Get or create the palace ChromaDB collection.""" + os.makedirs(palace_path, exist_ok=True) + try: + os.chmod(palace_path, 0o700) + except (OSError, NotImplementedError): + pass + client = chromadb.PersistentClient(path=palace_path) + try: + return client.get_collection(collection_name) + except Exception: + return client.create_collection(collection_name) + + +def file_already_mined(collection, source_file: str, check_mtime: bool = False) -> bool: + """Check if a file has already been filed in the palace. + + When check_mtime=True (used by project miner), returns False if the file + has been modified since it was last mined, so it gets re-mined. + When check_mtime=False (used by convo miner), just checks existence. + """ + try: + results = collection.get(where={"source_file": source_file}, limit=1) + if not results.get("ids"): + return False + if check_mtime: + stored_meta = results.get("metadatas", [{}])[0] + stored_mtime = stored_meta.get("source_mtime") + if stored_mtime is None: + return False + current_mtime = os.path.getmtime(source_file) + return float(stored_mtime) == current_mtime + return True + except Exception: + return False diff --git a/mempalace/split_mega_files.py b/mempalace/split_mega_files.py index 80bbae4..24b5956 100644 --- a/mempalace/split_mega_files.py +++ b/mempalace/split_mega_files.py @@ -182,6 +182,10 @@ def split_file(filepath, output_dir, dry_run=False): Returns list of output paths written (or would be written if dry_run). """ path = Path(filepath) + max_size = 500 * 1024 * 1024 # 500 MB safety limit + if path.stat().st_size > max_size: + print(f" SKIP: {path.name} exceeds {max_size // (1024*1024)} MB limit") + return [] lines = path.read_text(errors="replace").splitlines(keepends=True) boundaries = find_session_boundaries(lines) @@ -219,7 +223,7 @@ def split_file(filepath, output_dir, dry_run=False): if dry_run: print(f" [{i + 1}/{len(boundaries) - 1}] {name} ({len(chunk)} lines)") else: - out_path.write_text("".join(chunk)) + out_path.write_text("".join(chunk), encoding="utf-8") print(f" ✓ {name} ({len(chunk)} lines)") written.append(out_path) @@ -266,7 +270,11 @@ def main(): files = sorted(src_dir.glob("*.txt")) mega_files = [] + max_scan_size = 500 * 1024 * 1024 # 500 MB for f in files: + if f.stat().st_size > max_scan_size: + print(f" SKIP: {f.name} exceeds {max_scan_size // (1024*1024)} MB limit") + continue lines = f.read_text(errors="replace").splitlines(keepends=True) boundaries = find_session_boundaries(lines) if len(boundaries) >= args.min_sessions: diff --git a/mempalace/version.py b/mempalace/version.py index e56289e..1eb21a2 100644 --- a/mempalace/version.py +++ b/mempalace/version.py @@ -1,3 +1,3 @@ """Single source of truth for the MemPalace package version.""" -__version__ = "3.0.14" +__version__ = "3.1.0" diff --git a/pyproject.toml b/pyproject.toml index 3aaa765..cd47f98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mempalace" -version = "3.0.14" +version = "3.1.0" description = "Give your AI a memory — mine projects and conversations into a searchable palace. No API key required." readme = "README.md" requires-python = ">=3.9" @@ -26,7 +26,7 @@ classifiers = [ ] dependencies = [ "chromadb>=0.5.0,<0.7", - "pyyaml>=6.0", + "pyyaml>=6.0,<7", ] [project.urls] @@ -54,11 +54,15 @@ packages = ["mempalace"] [tool.ruff] line-length = 100 target-version = "py39" +extend-exclude = ["benchmarks"] [tool.ruff.lint] -select = ["E", "F", "W"] +select = ["E", "F", "W", "C901"] ignore = ["E501"] +[tool.ruff.lint.mccabe] +max-complexity = 25 + [tool.ruff.format] quote-style = "double" @@ -76,7 +80,7 @@ markers = [ source = ["mempalace"] [tool.coverage.report] -fail_under = 30 +fail_under = 85 show_missing = true exclude_lines = [ "if __name__", diff --git a/tests/benchmarks/test_layers_bench.py b/tests/benchmarks/test_layers_bench.py index b9604d7..7496588 100644 --- a/tests/benchmarks/test_layers_bench.py +++ b/tests/benchmarks/test_layers_bench.py @@ -148,9 +148,9 @@ class TestWakeUpTokenBudget: record_metric("wakeup_budget", f"tokens_at_{n_drawers}", token_estimate) record_metric("wakeup_budget", f"chars_at_{n_drawers}", len(text)) - assert token_estimate < 1200, ( - f"Wake-up exceeded budget: ~{token_estimate} tokens at {n_drawers} drawers" - ) + assert ( + token_estimate < 1200 + ), f"Wake-up exceeded budget: ~{token_estimate} tokens at {n_drawers} drawers" @pytest.mark.benchmark diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..e3c68f9 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,652 @@ +"""Tests for mempalace.cli — the main CLI dispatcher.""" + +import argparse +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from mempalace.cli import ( + cmd_compress, + cmd_hook, + cmd_init, + cmd_instructions, + cmd_mine, + cmd_repair, + cmd_search, + cmd_split, + cmd_status, + cmd_wakeup, + main, +) + + +# ── cmd_status ───────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_status_default_palace(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None) + mock_miner = MagicMock() + with patch.dict("sys.modules", {"mempalace.miner": mock_miner}): + cmd_status(args) + mock_miner.status.assert_called_once_with(palace_path="/fake/palace") + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_status_custom_palace(mock_config_cls): + args = argparse.Namespace(palace="~/my_palace") + mock_miner = MagicMock() + with patch.dict("sys.modules", {"mempalace.miner": mock_miner}): + cmd_status(args) + import os + + expected = os.path.expanduser("~/my_palace") + mock_miner.status.assert_called_once_with(palace_path=expected) + + +# ── cmd_search ───────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_search_calls_search(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace( + palace=None, query="test query", wing="mywing", room="myroom", results=3 + ) + with patch("mempalace.searcher.search") as mock_search: + cmd_search(args) + mock_search.assert_called_once_with( + query="test query", + palace_path="/fake/palace", + wing="mywing", + room="myroom", + n_results=3, + ) + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_search_error_exits(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, query="q", wing=None, room=None, results=5) + from mempalace.searcher import SearchError + + with patch("mempalace.searcher.search", side_effect=SearchError("fail")): + with pytest.raises(SystemExit) as exc_info: + cmd_search(args) + assert exc_info.value.code == 1 + + +# ── cmd_instructions ─────────────────────────────────────────────────── + + +def test_cmd_instructions_calls_run_instructions(): + args = argparse.Namespace(name="help") + with patch("mempalace.instructions_cli.run_instructions") as mock_run: + cmd_instructions(args) + mock_run.assert_called_once_with(name="help") + + +# ── cmd_hook ─────────────────────────────────────────────────────────── + + +def test_cmd_hook_calls_run_hook(): + args = argparse.Namespace(hook="session-start", harness="claude-code") + with patch("mempalace.hooks_cli.run_hook") as mock_run: + cmd_hook(args) + mock_run.assert_called_once_with(hook_name="session-start", harness="claude-code") + + +# ── cmd_init ─────────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_init_no_entities(mock_config_cls, tmp_path): + args = argparse.Namespace(dir=str(tmp_path), yes=True) + with ( + patch("mempalace.entity_detector.scan_for_detection", return_value=[]), + patch("mempalace.room_detector_local.detect_rooms_local") as mock_rooms, + ): + cmd_init(args) + mock_rooms.assert_called_once_with(project_dir=str(tmp_path), yes=True) + mock_config_cls.return_value.init.assert_called_once() + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_init_with_entities(mock_config_cls, tmp_path): + fake_files = [tmp_path / "a.txt"] + detected = {"people": [{"name": "Alice"}], "projects": [], "uncertain": []} + confirmed = {"people": ["Alice"], "projects": []} + args = argparse.Namespace(dir=str(tmp_path), yes=True) + with ( + patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files), + patch("mempalace.entity_detector.detect_entities", return_value=detected), + patch("mempalace.entity_detector.confirm_entities", return_value=confirmed), + patch("mempalace.room_detector_local.detect_rooms_local"), + patch("builtins.open", MagicMock()), + ): + cmd_init(args) + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_init_with_entities_zero_total(mock_config_cls, tmp_path, capsys): + """When entities detected but total is 0, prints 'No entities' message.""" + fake_files = [tmp_path / "a.txt"] + detected = {"people": [], "projects": [], "uncertain": []} + args = argparse.Namespace(dir=str(tmp_path), yes=False) + with ( + patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files), + patch("mempalace.entity_detector.detect_entities", return_value=detected), + patch("mempalace.room_detector_local.detect_rooms_local"), + ): + cmd_init(args) + out = capsys.readouterr().out + assert "No entities detected" in out + + +# ── cmd_mine ─────────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_mine_projects_mode(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace( + dir="/src", + palace=None, + mode="projects", + wing=None, + agent="mempalace", + limit=0, + dry_run=False, + no_gitignore=False, + include_ignored=[], + extract="exchange", + ) + with patch("mempalace.miner.mine") as mock_mine: + cmd_mine(args) + mock_mine.assert_called_once_with( + project_dir="/src", + palace_path="/fake/palace", + wing_override=None, + agent="mempalace", + limit=0, + dry_run=False, + respect_gitignore=True, + include_ignored=[], + ) + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_mine_convos_mode(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace( + dir="/chats", + palace=None, + mode="convos", + wing="mywing", + agent="me", + limit=10, + dry_run=True, + no_gitignore=False, + include_ignored=[], + extract="general", + ) + with patch("mempalace.convo_miner.mine_convos") as mock_mine: + cmd_mine(args) + mock_mine.assert_called_once_with( + convo_dir="/chats", + palace_path="/fake/palace", + wing="mywing", + agent="me", + limit=10, + dry_run=True, + extract_mode="general", + ) + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_mine_include_ignored_comma_split(mock_config_cls): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace( + dir="/src", + palace=None, + mode="projects", + wing=None, + agent="mempalace", + limit=0, + dry_run=False, + no_gitignore=False, + include_ignored=["a.txt,b.txt", "c.txt"], + extract="exchange", + ) + with patch("mempalace.miner.mine") as mock_mine: + cmd_mine(args) + mock_mine.assert_called_once() + call_kwargs = mock_mine.call_args[1] + assert call_kwargs["include_ignored"] == ["a.txt", "b.txt", "c.txt"] + + +# ── cmd_wakeup ───────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_wakeup(mock_config_cls, capsys): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, wing=None) + mock_stack = MagicMock() + mock_stack.wake_up.return_value = "Hello world context" + with patch("mempalace.layers.MemoryStack", return_value=mock_stack): + cmd_wakeup(args) + out = capsys.readouterr().out + assert "Hello world context" in out + assert "tokens" in out + + +# ── cmd_split ────────────────────────────────────────────────────────── + + +def test_cmd_split_basic(): + args = argparse.Namespace(dir="/chats", output_dir=None, dry_run=False, min_sessions=2) + with patch("mempalace.split_mega_files.main") as mock_main: + cmd_split(args) + mock_main.assert_called_once() + + +def test_cmd_split_all_options(): + args = argparse.Namespace(dir="/chats", output_dir="/out", dry_run=True, min_sessions=5) + with patch("mempalace.split_mega_files.main") as mock_main: + cmd_split(args) + mock_main.assert_called_once() + # sys.argv should be restored + assert sys.argv[0] != "mempalace split" + + +# ── main() argparse dispatch ────────────────────────────────────────── + + +def test_main_no_args_prints_help(capsys): + with patch("sys.argv", ["mempalace"]): + main() + out = capsys.readouterr().out + assert "MemPalace" in out + + +def test_main_status_dispatches(): + with ( + patch("sys.argv", ["mempalace", "status"]), + patch("mempalace.cli.cmd_status") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_search_dispatches(): + with ( + patch("sys.argv", ["mempalace", "search", "my query"]), + patch("mempalace.cli.cmd_search") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_init_dispatches(): + with ( + patch("sys.argv", ["mempalace", "init", "/some/dir"]), + patch("mempalace.cli.cmd_init") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_mine_dispatches(): + with ( + patch("sys.argv", ["mempalace", "mine", "/some/dir"]), + patch("mempalace.cli.cmd_mine") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_wakeup_dispatches(): + with ( + patch("sys.argv", ["mempalace", "wake-up"]), + patch("mempalace.cli.cmd_wakeup") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_split_dispatches(): + with ( + patch("sys.argv", ["mempalace", "split", "/chats"]), + patch("mempalace.cli.cmd_split") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_mcp_command_prints_setup_guidance(monkeypatch, capsys): + monkeypatch.setattr(sys, "argv", ["mempalace", "mcp"]) + + main() + + captured = capsys.readouterr() + assert "MemPalace MCP quick setup:" in captured.out + assert "claude mcp add mempalace -- python -m mempalace.mcp_server" in captured.out + assert "\nOptional custom palace:\n" in captured.out + assert "python -m mempalace.mcp_server --palace /path/to/palace" in captured.out + assert "[--palace /path/to/palace]" not in captured.out + assert captured.err == "" + + +def test_mcp_command_uses_custom_palace_path_when_provided(monkeypatch, capsys): + monkeypatch.setattr(sys, "argv", ["mempalace", "--palace", "~/tmp/my palace", "mcp"]) + + main() + + captured = capsys.readouterr() + expanded = str(Path("~/tmp/my palace").expanduser()) + + assert "python -m mempalace.mcp_server --palace" in captured.out + assert expanded in captured.out + assert "Optional custom palace:" not in captured.out + assert "[--palace /path/to/palace]" not in captured.out + assert captured.err == "" + + +def test_main_hook_no_subcommand_prints_help(capsys): + with patch("sys.argv", ["mempalace", "hook"]): + main() + out = capsys.readouterr().out + assert "hook" in out.lower() or "run" in out.lower() + + +def test_main_hook_run_dispatches(): + with ( + patch( + "sys.argv", + ["mempalace", "hook", "run", "--hook", "session-start", "--harness", "claude-code"], + ), + patch("mempalace.cli.cmd_hook") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_instructions_no_subcommand_prints_help(capsys): + with patch("sys.argv", ["mempalace", "instructions"]): + main() + out = capsys.readouterr().out + assert "instructions" in out.lower() or "init" in out.lower() + + +def test_main_instructions_dispatches(): + with ( + patch("sys.argv", ["mempalace", "instructions", "help"]), + patch("mempalace.cli.cmd_instructions") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_repair_dispatches(): + with ( + patch("sys.argv", ["mempalace", "repair"]), + patch("mempalace.cli.cmd_repair") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +def test_main_compress_dispatches(): + with ( + patch("sys.argv", ["mempalace", "compress"]), + patch("mempalace.cli.cmd_compress") as mock_cmd, + ): + main() + mock_cmd.assert_called_once() + + +# ── cmd_repair ───────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_repair_no_palace(mock_config_cls, tmp_path, capsys): + mock_config_cls.return_value.palace_path = str(tmp_path / "nonexistent") + args = argparse.Namespace(palace=None) + mock_chromadb = MagicMock() + with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + cmd_repair(args) + out = capsys.readouterr().out + assert "No palace found" in out + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys): + palace_dir = tmp_path / "palace" + palace_dir.mkdir() + mock_config_cls.return_value.palace_path = str(palace_dir) + args = argparse.Namespace(palace=None) + mock_chromadb = MagicMock() + mock_client = MagicMock() + mock_client.get_collection.side_effect = Exception("corrupt db") + mock_chromadb.PersistentClient.return_value = mock_client + with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + cmd_repair(args) + out = capsys.readouterr().out + assert "Error reading palace" in out + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys): + palace_dir = tmp_path / "palace" + palace_dir.mkdir() + mock_config_cls.return_value.palace_path = str(palace_dir) + args = argparse.Namespace(palace=None) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.count.return_value = 0 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + cmd_repair(args) + out = capsys.readouterr().out + assert "Nothing to repair" in out + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_repair_success(mock_config_cls, tmp_path, capsys): + palace_dir = tmp_path / "palace" + palace_dir.mkdir() + mock_config_cls.return_value.palace_path = str(palace_dir) + args = argparse.Namespace(palace=None) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.count.return_value = 2 + mock_col.get.return_value = { + "ids": ["id1", "id2"], + "documents": ["doc1", "doc2"], + "metadatas": [{"wing": "a"}, {"wing": "b"}], + } + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_new_col = MagicMock() + mock_client.create_collection.return_value = mock_new_col + mock_chromadb.PersistentClient.return_value = mock_client + with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + cmd_repair(args) + out = capsys.readouterr().out + assert "Repair complete" in out + assert "2 drawers rebuilt" in out + + +# ── cmd_compress ─────────────────────────────────────────────────────── + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_compress_no_palace(mock_config_cls, capsys): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None) + mock_chromadb = MagicMock() + mock_chromadb.PersistentClient.side_effect = Exception("no palace") + with ( + patch.dict("sys.modules", {"chromadb": mock_chromadb}), + pytest.raises(SystemExit), + ): + cmd_compress(args) + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_compress_no_drawers(mock_config_cls, capsys): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, wing="mywing", dry_run=False, config=None) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + with patch.dict("sys.modules", {"chromadb": mock_chromadb}): + cmd_compress(args) + out = capsys.readouterr().out + assert "No drawers found" in out + + +def _make_mock_dialect_module(dialect_instance): + """Create a mock dialect module with a Dialect class that returns the given instance.""" + mock_mod = MagicMock() + mock_mod.Dialect.return_value = dialect_instance + mock_mod.Dialect.from_config.return_value = dialect_instance + mock_mod.Dialect.count_tokens = MagicMock(side_effect=lambda x: len(x) // 4) + return mock_mod + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_compress_dry_run(mock_config_cls, capsys): + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=None) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.get.side_effect = [ + { + "documents": ["some long text here for testing"], + "metadatas": [{"wing": "test", "room": "general", "source_file": "test.txt"}], + "ids": ["id1"], + }, + {"documents": [], "metadatas": [], "ids": []}, + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_dialect = MagicMock() + mock_dialect.compress.return_value = "compressed" + mock_dialect.compression_stats.return_value = { + "original_chars": 100, + "compressed_chars": 30, + "original_tokens": 25, + "compressed_tokens": 8, + "ratio": 3.3, + } + mock_dialect_mod = _make_mock_dialect_module(mock_dialect) + + with patch.dict( + "sys.modules", + { + "chromadb": mock_chromadb, + "mempalace.dialect": mock_dialect_mod, + }, + ): + cmd_compress(args) + out = capsys.readouterr().out + assert "dry run" in out.lower() + assert "Compressing" in out + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_compress_with_config(mock_config_cls, tmp_path, capsys): + mock_config_cls.return_value.palace_path = "/fake/palace" + config_file = tmp_path / "entities.json" + config_file.write_text('{"people": [], "projects": []}') + args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=str(config_file)) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_dialect = MagicMock() + mock_dialect_mod = _make_mock_dialect_module(mock_dialect) + + with patch.dict( + "sys.modules", + { + "chromadb": mock_chromadb, + "mempalace.dialect": mock_dialect_mod, + }, + ): + cmd_compress(args) + out = capsys.readouterr().out + assert "Loaded entity config" in out + + +@patch("mempalace.cli.MempalaceConfig") +def test_cmd_compress_stores_results(mock_config_cls, capsys): + """Non-dry-run compress stores to mempalace_compressed collection.""" + mock_config_cls.return_value.palace_path = "/fake/palace" + args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None) + mock_chromadb = MagicMock() + mock_col = MagicMock() + mock_col.get.side_effect = [ + { + "documents": ["text"], + "metadatas": [{"wing": "w", "room": "r", "source_file": "f.txt"}], + "ids": ["id1"], + }, + {"documents": [], "metadatas": [], "ids": []}, + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + mock_comp_col = MagicMock() + mock_client.get_or_create_collection.return_value = mock_comp_col + mock_chromadb.PersistentClient.return_value = mock_client + + mock_dialect = MagicMock() + mock_dialect.compress.return_value = "compressed" + mock_dialect.compression_stats.return_value = { + "original_chars": 100, + "compressed_chars": 30, + "original_tokens": 25, + "compressed_tokens": 8, + "ratio": 3.3, + } + mock_dialect_mod = _make_mock_dialect_module(mock_dialect) + + with patch.dict( + "sys.modules", + { + "chromadb": mock_chromadb, + "mempalace.dialect": mock_dialect_mod, + }, + ): + cmd_compress(args) + out = capsys.readouterr().out + assert "Stored" in out + mock_comp_col.upsert.assert_called_once() + + +def test_cmd_repair_trailing_slash_does_not_recurse(): + """Repair with trailing slash should put backup outside palace dir (#395).""" + import os + + args = argparse.Namespace(palace="/tmp/fake_palace/") + with patch("mempalace.cli.os.path.isdir", return_value=False): + cmd_repair(args) + # Verify the rstrip logic: palace_path should not end with separator + palace_path = os.path.expanduser(args.palace).rstrip(os.sep) + backup_path = palace_path + ".backup" + assert not backup_path.startswith(palace_path + os.sep) diff --git a/tests/test_config_extra.py b/tests/test_config_extra.py new file mode 100644 index 0000000..d0d9b5d --- /dev/null +++ b/tests/test_config_extra.py @@ -0,0 +1,79 @@ +"""Extra tests for mempalace.config to cover remaining gaps.""" + +import json +import os + +from mempalace.config import MempalaceConfig + + +def test_config_bad_json(tmp_path): + """Bad JSON in config file falls back to empty.""" + (tmp_path / "config.json").write_text("not json", encoding="utf-8") + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.palace_path # still returns default + + +def test_people_map_from_file(tmp_path): + (tmp_path / "people_map.json").write_text(json.dumps({"bob": "Robert"}), encoding="utf-8") + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.people_map == {"bob": "Robert"} + + +def test_people_map_bad_json(tmp_path): + (tmp_path / "people_map.json").write_text("bad", encoding="utf-8") + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.people_map == {} + + +def test_people_map_missing(tmp_path): + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.people_map == {} + + +def test_topic_wings_default(tmp_path): + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert isinstance(cfg.topic_wings, list) + assert "emotions" in cfg.topic_wings + + +def test_hall_keywords_default(tmp_path): + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert isinstance(cfg.hall_keywords, dict) + assert "technical" in cfg.hall_keywords + + +def test_init_idempotent(tmp_path): + cfg = MempalaceConfig(config_dir=str(tmp_path)) + cfg.init() + cfg.init() # second call should not overwrite + with open(tmp_path / "config.json") as f: + data = json.load(f) + assert "palace_path" in data + + +def test_save_people_map(tmp_path): + cfg = MempalaceConfig(config_dir=str(tmp_path)) + result = cfg.save_people_map({"alice": "Alice Smith"}) + assert result.exists() + with open(result) as f: + data = json.load(f) + assert data["alice"] == "Alice Smith" + + +def test_env_mempal_palace_path(tmp_path): + """MEMPAL_PALACE_PATH (legacy) should also work.""" + os.environ.pop("MEMPALACE_PALACE_PATH", None) + os.environ["MEMPAL_PALACE_PATH"] = "/legacy/path" + try: + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.palace_path == "/legacy/path" + finally: + del os.environ["MEMPAL_PALACE_PATH"] + + +def test_collection_name_from_config(tmp_path): + (tmp_path / "config.json").write_text( + json.dumps({"collection_name": "custom_col"}), encoding="utf-8" + ) + cfg = MempalaceConfig(config_dir=str(tmp_path)) + assert cfg.collection_name == "custom_col" diff --git a/tests/test_convo_miner.py b/tests/test_convo_miner.py index 788c46d..0ac0019 100644 --- a/tests/test_convo_miner.py +++ b/tests/test_convo_miner.py @@ -23,4 +23,4 @@ def test_convo_mining(): results = col.query(query_texts=["memory persistence"], n_results=1) assert len(results["documents"][0]) > 0 - shutil.rmtree(tmpdir) + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/tests/test_convo_miner_unit.py b/tests/test_convo_miner_unit.py new file mode 100644 index 0000000..3c7e8f2 --- /dev/null +++ b/tests/test_convo_miner_unit.py @@ -0,0 +1,102 @@ +"""Unit tests for convo_miner pure functions (no chromadb needed).""" + +from mempalace.convo_miner import ( + chunk_exchanges, + detect_convo_room, + scan_convos, +) + + +class TestChunkExchanges: + def test_exchange_chunking(self): + content = ( + "> What is memory?\n" + "Memory is persistence of information over time.\n\n" + "> Why does it matter?\n" + "It enables continuity across sessions and conversations.\n\n" + "> How do we build it?\n" + "With structured storage and retrieval mechanisms.\n" + ) + chunks = chunk_exchanges(content) + assert len(chunks) >= 2 + assert all("content" in c and "chunk_index" in c for c in chunks) + + def test_paragraph_fallback(self): + """Content without '>' lines falls back to paragraph chunking.""" + content = ( + "This is a long paragraph about memory systems. " * 10 + "\n\n" + "This is another paragraph about storage. " * 10 + "\n\n" + "And a third paragraph about retrieval. " * 10 + ) + chunks = chunk_exchanges(content) + assert len(chunks) >= 2 + + def test_paragraph_line_group_fallback(self): + """Long content with no paragraph breaks chunks by line groups.""" + lines = [f"Line {i}: some content that is meaningful" for i in range(60)] + content = "\n".join(lines) + chunks = chunk_exchanges(content) + assert len(chunks) >= 1 + + def test_empty_content(self): + chunks = chunk_exchanges("") + assert chunks == [] + + def test_short_content_skipped(self): + chunks = chunk_exchanges("> hi\nbye") + # Too short to produce chunks (below MIN_CHUNK_SIZE) + assert isinstance(chunks, list) + + +class TestDetectConvoRoom: + def test_technical_room(self): + content = "Let me debug this python function and fix the code error in the api" + assert detect_convo_room(content) == "technical" + + def test_planning_room(self): + content = "We need to plan the roadmap for the next sprint and set milestone deadlines" + assert detect_convo_room(content) == "planning" + + def test_architecture_room(self): + content = "The architecture uses a service layer with component interface and module design" + assert detect_convo_room(content) == "architecture" + + def test_decisions_room(self): + content = "We decided to switch and migrated to the new framework after we chose it" + assert detect_convo_room(content) == "decisions" + + def test_general_fallback(self): + content = "Hello, how are you doing today? The weather is nice." + assert detect_convo_room(content) == "general" + + +class TestScanConvos: + def test_scan_finds_txt_and_md(self, tmp_path): + (tmp_path / "chat.txt").write_text("hello", encoding="utf-8") + (tmp_path / "notes.md").write_text("world", encoding="utf-8") + (tmp_path / "image.png").write_bytes(b"fake") + files = scan_convos(str(tmp_path)) + extensions = {f.suffix for f in files} + assert ".txt" in extensions + assert ".md" in extensions + assert ".png" not in extensions + + def test_scan_skips_git_dir(self, tmp_path): + git_dir = tmp_path / ".git" + git_dir.mkdir() + (git_dir / "config.txt").write_text("git stuff", encoding="utf-8") + (tmp_path / "chat.txt").write_text("hello", encoding="utf-8") + files = scan_convos(str(tmp_path)) + assert len(files) == 1 + + def test_scan_skips_meta_json(self, tmp_path): + (tmp_path / "chat.meta.json").write_text("{}", encoding="utf-8") + (tmp_path / "chat.json").write_text("{}", encoding="utf-8") + files = scan_convos(str(tmp_path)) + names = [f.name for f in files] + assert "chat.json" in names + assert "chat.meta.json" not in names + + def test_scan_empty_dir(self, tmp_path): + files = scan_convos(str(tmp_path)) + assert files == [] diff --git a/tests/test_entity_detector.py b/tests/test_entity_detector.py new file mode 100644 index 0000000..91f0e29 --- /dev/null +++ b/tests/test_entity_detector.py @@ -0,0 +1,380 @@ +"""Tests for mempalace.entity_detector.""" + +import os +from unittest.mock import patch + +from mempalace.entity_detector import ( + PROSE_EXTENSIONS, + STOPWORDS, + _print_entity_list, + classify_entity, + confirm_entities, + detect_entities, + extract_candidates, + scan_for_detection, + score_entity, +) + + +# ── extract_candidates ────────────────────────────────────────────────── + + +def test_extract_candidates_finds_frequent_names(): + text = "Riley said hello. Riley laughed. Riley smiled. Riley waved." + result = extract_candidates(text) + assert "Riley" in result + assert result["Riley"] >= 3 + + +def test_extract_candidates_ignores_stopwords(): + # "The" appears many times but is a stopword + text = "The The The The The The" + result = extract_candidates(text) + assert "The" not in result + + +def test_extract_candidates_requires_min_frequency(): + text = "Riley said hi. Devon waved." + result = extract_candidates(text) + # Each name appears only once, below the threshold of 3 + assert "Riley" not in result + assert "Devon" not in result + + +def test_extract_candidates_finds_multi_word_names(): + # Multi-word names need 3+ occurrences and no stopwords + text = "Claude Code is great. Claude Code rocks. Claude Code works. Claude Code rules." + result = extract_candidates(text) + assert "Claude Code" in result + + +def test_extract_candidates_empty_text(): + result = extract_candidates("") + assert result == {} + + +# ── score_entity ──────────────────────────────────────────────────────── + + +def test_score_entity_person_verbs(): + text = "Riley said hello. Riley asked why. Riley told me." + lines = text.splitlines() + result = score_entity("Riley", text, lines) + assert result["person_score"] > 0 + assert len(result["person_signals"]) > 0 + + +def test_score_entity_project_verbs(): + text = "We are building ChromaDB. We deployed ChromaDB. Install ChromaDB." + lines = text.splitlines() + result = score_entity("ChromaDB", text, lines) + assert result["project_score"] > 0 + assert len(result["project_signals"]) > 0 + + +def test_score_entity_dialogue_markers(): + text = "Riley: Hey, how are you?\nRiley: I'm fine." + lines = text.splitlines() + result = score_entity("Riley", text, lines) + assert result["person_score"] > 0 + + +def test_score_entity_code_ref(): + text = "Check out ChromaDB.py for details. Also ChromaDB.js is good." + lines = text.splitlines() + result = score_entity("ChromaDB", text, lines) + assert result["project_score"] > 0 + + +def test_score_entity_no_signals(): + text = "Nothing interesting here at all." + lines = text.splitlines() + result = score_entity("Riley", text, lines) + assert result["person_score"] == 0 + assert result["project_score"] == 0 + + +# ── classify_entity ───────────────────────────────────────────────────── + + +def test_classify_entity_no_signals_gives_uncertain(): + scores = { + "person_score": 0, + "project_score": 0, + "person_signals": [], + "project_signals": [], + } + result = classify_entity("Foo", 10, scores) + assert result["type"] == "uncertain" + assert result["name"] == "Foo" + + +def test_classify_entity_strong_project(): + scores = { + "person_score": 0, + "project_score": 10, + "person_signals": [], + "project_signals": ["project verb (5x)", "code file reference (2x)"], + } + result = classify_entity("ChromaDB", 5, scores) + assert result["type"] == "project" + + +def test_classify_entity_strong_person_needs_two_signal_types(): + scores = { + "person_score": 10, + "project_score": 0, + "person_signals": [ + "dialogue marker (3x)", + "'Riley ...' action (4x)", + ], + "project_signals": [], + } + result = classify_entity("Riley", 8, scores) + assert result["type"] == "person" + + +def test_classify_entity_pronoun_only_is_uncertain(): + scores = { + "person_score": 8, + "project_score": 0, + "person_signals": ["pronoun nearby (4x)"], + "project_signals": [], + } + result = classify_entity("Riley", 5, scores) + assert result["type"] == "uncertain" + + +def test_classify_entity_mixed_signals(): + scores = { + "person_score": 5, + "project_score": 5, + "person_signals": ["pronoun nearby (2x)"], + "project_signals": ["project verb (2x)"], + } + result = classify_entity("Lantern", 5, scores) + assert result["type"] == "uncertain" + assert "mixed signals" in result["signals"][-1] + + +# ── detect_entities (integration) ─────────────────────────────────────── + + +def test_detect_entities_with_person_file(tmp_path): + f = tmp_path / "notes.txt" + content = "\n".join( + [ + "Riley said hello today.", + "Riley asked about the project.", + "Riley told me she was happy.", + "Riley: I think we should go.", + "Hey Riley, thanks for the help.", + "Riley laughed and smiled.", + "Riley decided to join.", + "Riley pushed the change.", + ] + ) + f.write_text(content) + result = detect_entities([f]) + all_names = [e["name"] for cat in result.values() for e in cat] + assert "Riley" in all_names + + +def test_detect_entities_with_project_file(tmp_path): + f = tmp_path / "readme.txt" + # "ChromaDB" has uppercase+lowercase mix but extract_candidates looks + # for /[A-Z][a-z]{1,19}/ — so we need a name that matches that regex. + # Use "Lantern" which matches the capitalized-word pattern. + content = "\n".join( + [ + "The Lantern project is great.", + "Building Lantern was fun.", + "We deployed Lantern today.", + "Install Lantern with pip install Lantern.", + "Check Lantern.py for the source.", + "Lantern v2 is faster.", + ] + ) + f.write_text(content) + result = detect_entities([f]) + all_names = [e["name"] for cat in result.values() for e in cat] + assert "Lantern" in all_names + + +def test_detect_entities_empty_files(tmp_path): + f = tmp_path / "empty.txt" + f.write_text("") + result = detect_entities([f]) + assert result == {"people": [], "projects": [], "uncertain": []} + + +def test_detect_entities_handles_missing_file(tmp_path): + missing = tmp_path / "nonexistent.txt" + result = detect_entities([missing]) + assert result == {"people": [], "projects": [], "uncertain": []} + + +def test_detect_entities_respects_max_files(tmp_path): + files = [] + for i in range(5): + f = tmp_path / f"file{i}.txt" + f.write_text("Riley said hello. " * 10) + files.append(f) + # max_files=2 should only read 2 files + result = detect_entities(files, max_files=2) + # Should still work without error + assert isinstance(result, dict) + + +# ── scan_for_detection ────────────────────────────────────────────────── + + +def test_scan_for_detection_finds_prose(tmp_path): + (tmp_path / "notes.md").write_text("hello") + (tmp_path / "data.txt").write_text("world") + (tmp_path / "code.py").write_text("import os") + files = scan_for_detection(str(tmp_path)) + extensions = {os.path.splitext(str(f))[1] for f in files} + # Prose files should be found + assert ".md" in extensions or ".txt" in extensions + + +def test_scan_for_detection_skips_git_dir(tmp_path): + git_dir = tmp_path / ".git" + git_dir.mkdir() + (git_dir / "config.txt").write_text("git config") + (tmp_path / "readme.md").write_text("hello") + files = scan_for_detection(str(tmp_path)) + file_strs = [str(f) for f in files] + assert not any(".git" in f for f in file_strs) + + +# ── module-level constants ────────────────────────────────────────────── + + +def test_stopwords_contains_common_words(): + assert "the" in STOPWORDS + assert "import" in STOPWORDS + assert "class" in STOPWORDS + + +def test_prose_extensions(): + assert ".txt" in PROSE_EXTENSIONS + assert ".md" in PROSE_EXTENSIONS + + +# ── _print_entity_list ───────────────────────────────────────────────── + + +def test_print_entity_list_with_entities(capsys): + entities = [ + {"name": "Alice", "confidence": 0.9, "signals": ["dialogue marker (3x)"]}, + {"name": "Bob", "confidence": 0.5, "signals": []}, + ] + _print_entity_list(entities, "PEOPLE") + out = capsys.readouterr().out + assert "PEOPLE" in out + assert "Alice" in out + assert "Bob" in out + + +def test_print_entity_list_empty(capsys): + _print_entity_list([], "PEOPLE") + out = capsys.readouterr().out + assert "none detected" in out + + +# ── confirm_entities ─────────────────────────────────────────────────── + + +def test_confirm_entities_yes_mode(): + detected = { + "people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}], + "projects": [{"name": "Acme", "confidence": 0.8, "signals": ["test"]}], + "uncertain": [{"name": "Foo", "confidence": 0.4, "signals": ["test"]}], + } + result = confirm_entities(detected, yes=True) + assert result["people"] == ["Alice"] + assert result["projects"] == ["Acme"] + + +def test_confirm_entities_accept_all(): + detected = { + "people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}], + "projects": [], + "uncertain": [], + } + with patch("builtins.input", side_effect=["", "n"]): + result = confirm_entities(detected, yes=False) + assert "Alice" in result["people"] + + +def test_confirm_entities_edit_reclassify_uncertain(): + detected = { + "people": [], + "projects": [], + "uncertain": [ + {"name": "Foo", "confidence": 0.4, "signals": ["test"]}, + {"name": "Bar", "confidence": 0.4, "signals": ["test"]}, + ], + } + with patch( + "builtins.input", + side_effect=[ + "edit", # choice + "p", # Foo -> person + "s", # Bar -> skip + "", # no removals from people + "", # no removals from projects + "n", # don't add missing + ], + ): + result = confirm_entities(detected, yes=False) + assert "Foo" in result["people"] + assert "Bar" not in result["people"] + assert "Bar" not in result["projects"] + + +def test_confirm_entities_add_mode(): + detected = { + "people": [], + "projects": [], + "uncertain": [], + } + with patch( + "builtins.input", + side_effect=[ + "add", # choice = add + "NewPerson", # name + "p", # person + "NewProj", # name + "r", # project + "", # stop adding + ], + ): + result = confirm_entities(detected, yes=False) + assert "NewPerson" in result["people"] + assert "NewProj" in result["projects"] + + +# ── scan_for_detection fallback ──────────────────────────────────────── + + +def test_scan_for_detection_fallback_to_all_readable(tmp_path): + """When fewer than 3 prose files, falls back to include all readable files.""" + (tmp_path / "one.md").write_text("hello") + (tmp_path / "two.txt").write_text("world") + # Only 2 prose files, so it should also include code files + (tmp_path / "code.py").write_text("import os") + (tmp_path / "app.js").write_text("console.log()") + files = scan_for_detection(str(tmp_path)) + extensions = {os.path.splitext(str(f))[1] for f in files} + assert ".py" in extensions or ".js" in extensions + + +def test_scan_for_detection_max_files(tmp_path): + """Caps to max_files.""" + for i in range(20): + (tmp_path / f"note{i}.md").write_text(f"content {i}") + files = scan_for_detection(str(tmp_path), max_files=5) + assert len(files) <= 5 diff --git a/tests/test_entity_registry.py b/tests/test_entity_registry.py new file mode 100644 index 0000000..b92bf84 --- /dev/null +++ b/tests/test_entity_registry.py @@ -0,0 +1,313 @@ +"""Tests for mempalace.entity_registry.""" + +from unittest.mock import patch + +from mempalace.entity_registry import ( + COMMON_ENGLISH_WORDS, + PERSON_CONTEXT_PATTERNS, + EntityRegistry, +) + + +# ── COMMON_ENGLISH_WORDS ──────────────────────────────────────────────── + + +def test_common_english_words_has_expected_entries(): + assert "ever" in COMMON_ENGLISH_WORDS + assert "grace" in COMMON_ENGLISH_WORDS + assert "will" in COMMON_ENGLISH_WORDS + assert "may" in COMMON_ENGLISH_WORDS + assert "monday" in COMMON_ENGLISH_WORDS + + +def test_common_english_words_is_lowercase(): + for word in COMMON_ENGLISH_WORDS: + assert word == word.lower(), f"{word} should be lowercase" + + +# ── PERSON_CONTEXT_PATTERNS ───────────────────────────────────────────── + + +def test_person_context_patterns_is_nonempty(): + assert len(PERSON_CONTEXT_PATTERNS) > 0 + + +# ── EntityRegistry creation and empty state ───────────────────────────── + + +def test_load_from_nonexistent_dir(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + assert registry.people == {} + assert registry.projects == [] + assert registry.mode == "personal" + assert registry.ambiguous_flags == [] + + +def test_save_and_load_roundtrip(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="work", + people=[{"name": "Alice", "relationship": "colleague", "context": "work"}], + projects=["MemPalace"], + ) + # Load again from same dir + loaded = EntityRegistry.load(config_dir=tmp_path) + assert loaded.mode == "work" + assert "Alice" in loaded.people + assert "MemPalace" in loaded.projects + + +def test_save_creates_file(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.save() + assert (tmp_path / "entity_registry.json").exists() + + +# ── seed ──────────────────────────────────────────────────────────────── + + +def test_seed_registers_people(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[ + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + {"name": "Devon", "relationship": "friend", "context": "personal"}, + ], + projects=["MemPalace"], + ) + assert "Riley" in registry.people + assert "Devon" in registry.people + assert registry.people["Riley"]["relationship"] == "daughter" + assert registry.people["Riley"]["source"] == "onboarding" + assert registry.people["Riley"]["confidence"] == 1.0 + + +def test_seed_registers_projects(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="work", people=[], projects=["Acme", "Widget"]) + assert registry.projects == ["Acme", "Widget"] + + +def test_seed_sets_mode(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="combo", people=[], projects=[]) + assert registry.mode == "combo" + + +def test_seed_flags_ambiguous_names(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[ + {"name": "Grace", "relationship": "friend", "context": "personal"}, + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + ], + projects=[], + ) + assert "grace" in registry.ambiguous_flags + # Riley is not a common English word + assert "riley" not in registry.ambiguous_flags + + +def test_seed_with_aliases(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}], + projects=[], + aliases={"Max": "Maxwell"}, + ) + assert "Maxwell" in registry.people + assert "Max" in registry.people + assert registry.people["Max"].get("canonical") == "Maxwell" + + +def test_seed_skips_empty_names(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "", "relationship": "", "context": "personal"}], + projects=[], + ) + assert len(registry.people) == 0 + + +# ── lookup ────────────────────────────────────────────────────────────── + + +def test_lookup_known_person(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=[], + ) + result = registry.lookup("Riley") + assert result["type"] == "person" + assert result["confidence"] == 1.0 + assert result["name"] == "Riley" + + +def test_lookup_known_project(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="work", people=[], projects=["MemPalace"]) + result = registry.lookup("MemPalace") + assert result["type"] == "project" + assert result["confidence"] == 1.0 + + +def test_lookup_unknown_word(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + result = registry.lookup("Xyzzy") + assert result["type"] == "unknown" + assert result["confidence"] == 0.0 + + +def test_lookup_case_insensitive(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=[], + ) + result = registry.lookup("riley") + assert result["type"] == "person" + + +def test_lookup_alias(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}], + projects=[], + aliases={"Max": "Maxwell"}, + ) + result = registry.lookup("Max") + assert result["type"] == "person" + + +# ── disambiguation ────────────────────────────────────────────────────── + + +def test_lookup_ambiguous_word_as_person(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Grace", "relationship": "friend", "context": "personal"}], + projects=[], + ) + result = registry.lookup("Grace", context="I went with Grace today") + assert result["type"] == "person" + + +def test_lookup_ambiguous_word_as_concept(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Ever", "relationship": "friend", "context": "personal"}], + projects=[], + ) + result = registry.lookup("Ever", context="have you ever tried this") + assert result["type"] == "concept" + + +# ── research (Wikipedia) — mocked ────────────────────────────────────── + + +def test_research_caches_result(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + + mock_result = { + "inferred_type": "person", + "confidence": 0.80, + "wiki_summary": "Saoirse is an Irish given name.", + "wiki_title": "Saoirse", + } + + with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result): + result = registry.research("Saoirse", auto_confirm=True) + assert result["inferred_type"] == "person" + + # Second call should use cache, not call Wikipedia again + with patch( + "mempalace.entity_registry._wikipedia_lookup", + side_effect=AssertionError("should not be called"), + ): + cached = registry.research("Saoirse") + assert cached["inferred_type"] == "person" + + +def test_confirm_research_adds_to_people(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + + mock_result = { + "inferred_type": "person", + "confidence": 0.80, + "wiki_summary": "Saoirse is a name", + "wiki_title": "Saoirse", + } + with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result): + registry.research("Saoirse", auto_confirm=False) + + registry.confirm_research("Saoirse", entity_type="person", relationship="friend") + assert "Saoirse" in registry.people + assert registry.people["Saoirse"]["source"] == "wiki" + + +# ── extract_people_from_query ─────────────────────────────────────────── + + +def test_extract_people_from_query(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[ + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + {"name": "Devon", "relationship": "friend", "context": "personal"}, + ], + projects=[], + ) + found = registry.extract_people_from_query("What did Riley say about the weather?") + assert "Riley" in found + assert "Devon" not in found + + +# ── extract_unknown_candidates ────────────────────────────────────────── + + +def test_extract_unknown_candidates(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed(mode="personal", people=[], projects=[]) + unknowns = registry.extract_unknown_candidates("Saoirse went to the store") + assert "Saoirse" in unknowns + + +def test_extract_unknown_candidates_skips_known(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=[], + ) + unknowns = registry.extract_unknown_candidates("Riley went to the store") + assert "Riley" not in unknowns + + +# ── summary ───────────────────────────────────────────────────────────── + + +def test_summary(tmp_path): + registry = EntityRegistry.load(config_dir=tmp_path) + registry.seed( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=["MemPalace"], + ) + s = registry.summary() + assert "personal" in s + assert "Riley" in s + assert "MemPalace" in s diff --git a/tests/test_general_extractor.py b/tests/test_general_extractor.py new file mode 100644 index 0000000..0f5d46c --- /dev/null +++ b/tests/test_general_extractor.py @@ -0,0 +1,248 @@ +"""Tests for mempalace.general_extractor.""" + +from mempalace.general_extractor import ( + ALL_MARKERS, + NEGATIVE_WORDS, + POSITIVE_WORDS, + _extract_prose, + _get_sentiment, + _has_resolution, + _is_code_line, + _score_markers, + _split_into_segments, + extract_memories, +) + + +# ── extract_memories — empty / no markers ─────────────────────────────── + + +def test_extract_memories_empty_text(): + result = extract_memories("") + assert result == [] + + +def test_extract_memories_no_markers(): + result = extract_memories("The quick brown fox jumped over the lazy dog.") + assert result == [] + + +def test_extract_memories_short_text_skipped(): + # Paragraphs shorter than 20 chars are skipped + result = extract_memories("ok sure") + assert result == [] + + +# ── extract_memories — decision markers ───────────────────────────────── + + +def test_extract_memories_decision(): + text = ( + "We decided to go with PostgreSQL instead of MySQL " + "because the performance was better for our use case. " + "The trade-off was more complexity in setup." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "decision" for m in result) + + +# ── extract_memories — preference markers ─────────────────────────────── + + +def test_extract_memories_preference(): + text = ( + "I prefer using snake_case in Python code. " + "Please always use type hints. " + "Never use wildcard imports." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "preference" for m in result) + + +# ── extract_memories — milestone markers ──────────────────────────────── + + +def test_extract_memories_milestone(): + text = ( + "It finally works! After three days of debugging, " + "I figured out the issue. The breakthrough was realizing " + "the config file was cached. Got it working at 2am." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "milestone" for m in result) + + +# ── extract_memories — problem markers ────────────────────────────────── + + +def test_extract_memories_problem(): + text = ( + "There's a critical bug in the auth module. " + "The error keeps crashing the server. " + "The root cause was a missing null check. " + "The problem is that tokens expire silently." + ) + result = extract_memories(text) + assert len(result) >= 1 + types = {m["memory_type"] for m in result} + assert "problem" in types or "milestone" in types # resolved problems become milestones + + +# ── extract_memories — emotional markers ──────────────────────────────── + + +def test_extract_memories_emotional(): + text = ( + "I feel so proud of what we built together. " + "I love working on this project, it makes me happy. " + "I'm grateful for the team and the beautiful code we wrote." + ) + result = extract_memories(text) + assert len(result) >= 1 + assert any(m["memory_type"] == "emotional" for m in result) + + +# ── extract_memories — chunk_index ────────────────────────────────────── + + +def test_extract_memories_chunk_index_increments(): + text = ( + "We decided to use React because it fits our team.\n\n" + "I prefer functional components always.\n\n" + "It works! We finally shipped the v1.0 release." + ) + result = extract_memories(text) + if len(result) >= 2: + indices = [m["chunk_index"] for m in result] + assert indices == list(range(len(result))) + + +# ── _score_markers ────────────────────────────────────────────────────── + + +def test_score_markers_with_matches(): + score, keywords = _score_markers( + "we decided to go with postgres because it is faster", + ALL_MARKERS["decision"], + ) + assert score > 0 + assert len(keywords) > 0 + + +def test_score_markers_no_matches(): + score, keywords = _score_markers("nothing relevant here", ALL_MARKERS["decision"]) + assert score == 0.0 + + +# ── _get_sentiment ────────────────────────────────────────────────────── + + +def test_get_sentiment_positive(): + assert _get_sentiment("I am so happy and proud of this breakthrough") == "positive" + + +def test_get_sentiment_negative(): + assert _get_sentiment("This bug caused a crash and total failure") == "negative" + + +def test_get_sentiment_neutral(): + assert _get_sentiment("The meeting is at three") == "neutral" + + +# ── _has_resolution ───────────────────────────────────────────────────── + + +def test_has_resolution_true(): + assert _has_resolution("I fixed the auth bug and it works now") is True + + +def test_has_resolution_false(): + assert _has_resolution("The server keeps crashing") is False + + +# ── _is_code_line ─────────────────────────────────────────────────────── + + +def test_is_code_line_detects_code(): + assert _is_code_line(" import os") is True + assert _is_code_line(" $ pip install flask") is True + assert _is_code_line(" ```python") is True + + +def test_is_code_line_allows_prose(): + assert _is_code_line("This is a regular sentence about coding.") is False + assert _is_code_line("") is False + + +# ── _extract_prose ────────────────────────────────────────────────────── + + +def test_extract_prose_strips_code_blocks(): + text = "Hello world\n```\nimport os\nprint('hi')\n```\nGoodbye" + result = _extract_prose(text) + assert "import os" not in result + assert "Hello world" in result + assert "Goodbye" in result + + +def test_extract_prose_returns_original_if_all_code(): + text = "import os\nfrom sys import argv" + result = _extract_prose(text) + # Falls back to original text if nothing left + assert len(result) > 0 + + +# ── _split_into_segments ─────────────────────────────────────────────── + + +def test_split_into_segments_by_paragraph(): + text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph." + result = _split_into_segments(text) + assert len(result) == 3 + + +def test_split_into_segments_by_turns(): + lines = [] + for i in range(5): + lines.append(f"Human: Question {i}") + lines.append(f"Assistant: Answer {i}") + text = "\n".join(lines) + result = _split_into_segments(text) + assert len(result) >= 3 # turn-based splitting should fire + + +def test_split_into_segments_single_block(): + # Many lines without double-newline produces chunked segments + lines = [f"Line {i} of the document" for i in range(30)] + text = "\n".join(lines) + result = _split_into_segments(text) + assert len(result) >= 1 + + +# ── ALL_MARKERS constant ─────────────────────────────────────────────── + + +def test_all_markers_has_five_types(): + assert set(ALL_MARKERS.keys()) == { + "decision", + "preference", + "milestone", + "problem", + "emotional", + } + + +# ── POSITIVE_WORDS / NEGATIVE_WORDS ──────────────────────────────────── + + +def test_positive_words(): + assert "happy" in POSITIVE_WORDS + assert "proud" in POSITIVE_WORDS + + +def test_negative_words(): + assert "bug" in NEGATIVE_WORDS + assert "crash" in NEGATIVE_WORDS diff --git a/tests/test_hooks_cli.py b/tests/test_hooks_cli.py index d6951e2..5a1870e 100644 --- a/tests/test_hooks_cli.py +++ b/tests/test_hooks_cli.py @@ -1,17 +1,24 @@ import contextlib +import io import json from pathlib import Path from unittest.mock import patch +import pytest + from mempalace.hooks_cli import ( SAVE_INTERVAL, STOP_BLOCK_REASON, PRECOMPACT_BLOCK_REASON, _count_human_messages, + _log, + _maybe_auto_ingest, + _parse_harness_input, _sanitize_session_id, hook_stop, hook_session_start, hook_precompact, + run_hook, ) @@ -205,3 +212,209 @@ def test_precompact_always_blocks(tmp_path): ) assert result["decision"] == "block" assert result["reason"] == PRECOMPACT_BLOCK_REASON + + +# --- _log --- + + +def test_log_writes_to_hook_log(tmp_path): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + _log("test message") + log_path = tmp_path / "hook.log" + assert log_path.is_file() + content = log_path.read_text() + assert "test message" in content + + +def test_log_oserror_is_silenced(tmp_path): + """_log should not raise if the directory cannot be created.""" + with patch("mempalace.hooks_cli.STATE_DIR", Path("/nonexistent/deeply/nested/dir")): + # Should not raise + _log("this will fail silently") + + +# --- _maybe_auto_ingest --- + + +def test_maybe_auto_ingest_no_env(tmp_path): + """Without MEMPAL_DIR set, does nothing.""" + with patch.dict("os.environ", {}, clear=True): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + _maybe_auto_ingest() # should not raise + + +def test_maybe_auto_ingest_with_env(tmp_path): + """With MEMPAL_DIR set to a valid directory, spawns subprocess.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli.subprocess.Popen") as mock_popen: + _maybe_auto_ingest() + mock_popen.assert_called_once() + + +def test_maybe_auto_ingest_oserror(tmp_path): + """OSError during subprocess spawn is silenced.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli.subprocess.Popen", side_effect=OSError("fail")): + _maybe_auto_ingest() # should not raise + + +# --- _parse_harness_input --- + + +def test_parse_harness_input_unknown(): + """Unknown harness should sys.exit(1).""" + with pytest.raises(SystemExit) as exc_info: + _parse_harness_input({"session_id": "test"}, "unknown-harness") + assert exc_info.value.code == 1 + + +def test_parse_harness_input_valid(): + result = _parse_harness_input( + {"session_id": "abc-123", "stop_hook_active": True, "transcript_path": "/tmp/t.jsonl"}, + "claude-code", + ) + assert result["session_id"] == "abc-123" + assert result["stop_hook_active"] is True + + +# --- hook_stop with OSError on write --- + + +def test_stop_hook_oserror_on_last_save_read(tmp_path): + """When last_save_file has invalid content, falls back to 0.""" + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) + # Write invalid content to last save file + (tmp_path / "test_last_save").write_text("not_a_number") + result = _capture_hook_output( + hook_stop, + {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + + +def test_stop_hook_oserror_on_write(tmp_path): + """When write to last_save_file fails, hook still outputs correctly.""" + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) + + def bad_write_text(*args, **kwargs): + raise OSError("disk full") + + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch.object(Path, "write_text", bad_write_text): + result = _capture_hook_output( + hook_stop, + { + "session_id": "test", + "stop_hook_active": False, + "transcript_path": str(transcript), + }, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + + +# --- hook_precompact with MEMPAL_DIR --- + + +def test_precompact_with_mempal_dir(tmp_path): + """Precompact runs subprocess.run when MEMPAL_DIR is set.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.subprocess.run") as mock_run: + result = _capture_hook_output( + hook_precompact, + {"session_id": "test"}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + mock_run.assert_called_once() + + +def test_precompact_with_mempal_dir_oserror(tmp_path): + """Precompact handles OSError from subprocess gracefully.""" + mempal_dir = tmp_path / "project" + mempal_dir.mkdir() + with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}): + with patch("mempalace.hooks_cli.subprocess.run", side_effect=OSError("fail")): + result = _capture_hook_output( + hook_precompact, + {"session_id": "test"}, + state_dir=tmp_path, + ) + assert result["decision"] == "block" + + +# --- run_hook --- + + +def test_run_hook_dispatches_session_start(tmp_path): + """run_hook reads stdin JSON and dispatches to correct handler.""" + stdin_data = json.dumps({"session_id": "run-test"}) + with patch("sys.stdin", io.StringIO(stdin_data)): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("session-start", "claude-code") + mock_output.assert_called_once_with({}) + + +def test_run_hook_dispatches_stop(tmp_path): + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(3)] + ) + stdin_data = json.dumps( + { + "session_id": "run-test", + "stop_hook_active": False, + "transcript_path": str(transcript), + } + ) + with patch("sys.stdin", io.StringIO(stdin_data)): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("stop", "claude-code") + mock_output.assert_called_once_with({}) + + +def test_run_hook_dispatches_precompact(tmp_path): + stdin_data = json.dumps({"session_id": "run-test"}) + with patch("sys.stdin", io.StringIO(stdin_data)): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("precompact", "claude-code") + mock_output.assert_called_once() + call_args = mock_output.call_args[0][0] + assert call_args["decision"] == "block" + + +def test_run_hook_unknown_hook(): + stdin_data = json.dumps({"session_id": "test"}) + with patch("sys.stdin", io.StringIO(stdin_data)): + with pytest.raises(SystemExit) as exc_info: + run_hook("nonexistent", "claude-code") + assert exc_info.value.code == 1 + + +def test_run_hook_invalid_json(tmp_path): + """Invalid stdin JSON should not crash — falls back to empty dict.""" + with patch("sys.stdin", io.StringIO("not valid json")): + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._output") as mock_output: + run_hook("session-start", "claude-code") + mock_output.assert_called_once_with({}) diff --git a/tests/test_instructions_cli.py b/tests/test_instructions_cli.py new file mode 100644 index 0000000..c99ed14 --- /dev/null +++ b/tests/test_instructions_cli.py @@ -0,0 +1,45 @@ +"""Tests for mempalace.instructions_cli — instruction text output.""" + +from unittest.mock import patch + +import pytest + +from mempalace.instructions_cli import AVAILABLE, INSTRUCTIONS_DIR, run_instructions + + +def test_run_instructions_valid_name(capsys): + """Valid name prints the .md file content.""" + name = "init" + expected = (INSTRUCTIONS_DIR / f"{name}.md").read_text() + run_instructions(name) + captured = capsys.readouterr() + assert captured.out.strip() == expected.strip() + + +def test_run_instructions_all_available(capsys): + """Every name in AVAILABLE should succeed without error.""" + for name in AVAILABLE: + run_instructions(name) + out = capsys.readouterr().out + assert len(out) > 0 + + +def test_run_instructions_invalid_name(capsys): + """Invalid name should sys.exit(1) and print error to stderr.""" + with pytest.raises(SystemExit) as exc_info: + run_instructions("nonexistent") + assert exc_info.value.code == 1 + captured = capsys.readouterr() + assert "Unknown instructions: nonexistent" in captured.err + assert "Available:" in captured.err + + +def test_run_instructions_missing_md_file(capsys, tmp_path): + """If the .md file is missing on disk, should sys.exit(1).""" + with patch("mempalace.instructions_cli.INSTRUCTIONS_DIR", tmp_path): + with patch("mempalace.instructions_cli.AVAILABLE", ["fakecmd"]): + with pytest.raises(SystemExit) as exc_info: + run_instructions("fakecmd") + assert exc_info.value.code == 1 + captured = capsys.readouterr() + assert "Instructions file not found" in captured.err diff --git a/tests/test_knowledge_graph_extra.py b/tests/test_knowledge_graph_extra.py new file mode 100644 index 0000000..29605bb --- /dev/null +++ b/tests/test_knowledge_graph_extra.py @@ -0,0 +1,105 @@ +"""Extra knowledge graph tests for seed_from_entity_facts and query_relationship.""" + +import pytest + +from mempalace.knowledge_graph import KnowledgeGraph + + +@pytest.fixture +def kg(tmp_path): + return KnowledgeGraph(db_path=str(tmp_path / "kg.db")) + + +class TestSeedFromEntityFacts: + def test_seed_person_with_partner(self, kg): + facts = { + "alice": { + "full_name": "Alice Smith", + "type": "person", + "gender": "female", + "partner": "bob", + "relationship": "husband", + } + } + kg.seed_from_entity_facts(facts) + stats = kg.stats() + assert stats["entities"] >= 1 + results = kg.query_entity("Alice Smith", direction="outgoing") + predicates = {r["predicate"] for r in results} + assert "married_to" in predicates + assert "is_partner_of" in predicates + + def test_seed_child(self, kg): + facts = { + "max": { + "full_name": "Max", + "type": "person", + "birthday": "2015-04-01", + "parent": "alice", + "relationship": "daughter", + } + } + kg.seed_from_entity_facts(facts) + results = kg.query_entity("Max", direction="outgoing") + predicates = {r["predicate"] for r in results} + assert "child_of" in predicates + assert "is_child_of" in predicates + + def test_seed_sibling(self, kg): + facts = { + "emma": { + "full_name": "Emma", + "type": "person", + "relationship": "brother", + "sibling": "max", + } + } + kg.seed_from_entity_facts(facts) + results = kg.query_entity("Emma", direction="outgoing") + predicates = {r["predicate"] for r in results} + assert "is_sibling_of" in predicates + + def test_seed_dog(self, kg): + facts = { + "rex": { + "full_name": "Rex", + "type": "animal", + "relationship": "dog", + "owner": "alice", + } + } + kg.seed_from_entity_facts(facts) + results = kg.query_entity("Rex", direction="outgoing") + predicates = {r["predicate"] for r in results} + assert "is_pet_of" in predicates + + def test_seed_with_interests(self, kg): + facts = { + "max": { + "full_name": "Max", + "type": "person", + "interests": ["swimming", "chess"], + } + } + kg.seed_from_entity_facts(facts) + results = kg.query_entity("Max", direction="outgoing") + objects = {r["object"] for r in results if r["predicate"] == "loves"} + assert "Swimming" in objects + assert "Chess" in objects + + def test_seed_minimal_facts(self, kg): + """Facts with no relationships just create entities.""" + facts = {"bob": {"full_name": "Bob"}} + kg.seed_from_entity_facts(facts) + stats = kg.stats() + assert stats["entities"] >= 1 + + +class TestQueryRelationshipWithTime: + def test_query_relationship_with_as_of(self, kg): + kg.add_triple("Alice", "works_at", "Acme", valid_from="2020-01-01", valid_to="2024-12-31") + kg.add_triple("Alice", "works_at", "NewCo", valid_from="2025-01-01") + results = kg.query_relationship("works_at", as_of="2023-06-01") + objects = [r["object"] for r in results] + assert "Acme" in objects + assert "NewCo" not in objects diff --git a/tests/test_layers.py b/tests/test_layers.py new file mode 100644 index 0000000..46b60e9 --- /dev/null +++ b/tests/test_layers.py @@ -0,0 +1,719 @@ +"""Tests for mempalace.layers — Layer0, Layer1, Layer2, Layer3, MemoryStack.""" + +import os +from unittest.mock import MagicMock, patch + +from mempalace.layers import Layer0, Layer1, Layer2, Layer3, MemoryStack + + +# ── Layer0 — with identity file ───────────────────────────────────────── + + +def test_layer0_reads_identity_file(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas, a personal AI assistant for Alice.") + layer = Layer0(identity_path=str(identity_file)) + text = layer.render() + assert "Atlas" in text + assert "Alice" in text + + +def test_layer0_caches_text(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("Hello world") + layer = Layer0(identity_path=str(identity_file)) + first = layer.render() + identity_file.write_text("Changed content") + second = layer.render() + assert first == second + assert second == "Hello world" + + +def test_layer0_missing_file_returns_default(tmp_path): + missing = str(tmp_path / "nonexistent.txt") + layer = Layer0(identity_path=missing) + text = layer.render() + assert "No identity configured" in text + assert "identity.txt" in text + + +def test_layer0_token_estimate(tmp_path): + identity_file = tmp_path / "identity.txt" + content = "A" * 400 + identity_file.write_text(content) + layer = Layer0(identity_path=str(identity_file)) + estimate = layer.token_estimate() + assert estimate == 100 + + +def test_layer0_token_estimate_empty(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("") + layer = Layer0(identity_path=str(identity_file)) + assert layer.token_estimate() == 0 + + +def test_layer0_strips_whitespace(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text(" Hello world \n\n") + layer = Layer0(identity_path=str(identity_file)) + text = layer.render() + assert text == "Hello world" + + +def test_layer0_default_path(): + layer = Layer0() + expected = os.path.expanduser("~/.mempalace/identity.txt") + assert layer.path == expected + + +# ── Layer1 — mocked chromadb ──────────────────────────────────────────── + + +def _mock_chromadb_for_layer(docs, metas, monkeypatch=None): + """Return a mock PersistentClient whose collection.get returns docs/metas.""" + mock_col = MagicMock() + # First batch returns data, second batch returns empty (end of pagination) + mock_col.get.side_effect = [ + {"documents": docs, "metadatas": metas}, + {"documents": [], "metadatas": []}, + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + return mock_client + + +def test_layer1_no_palace(): + """Layer1 returns helpful message when no palace exists.""" + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + layer = Layer1(palace_path="/nonexistent/palace") + result = layer.generate() + assert "No palace found" in result or "No memories" in result + + +def test_layer1_generates_essential_story(): + docs = [ + "Important memory about project decisions", + "Key architectural choice for the backend", + ] + metas = [ + {"room": "decisions", "source_file": "meeting.txt", "importance": 5}, + {"room": "architecture", "source_file": "design.txt", "importance": 4}, + ] + mock_client = _mock_chromadb_for_layer(docs, metas) + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + result = layer.generate() + + assert "ESSENTIAL STORY" in result + assert "project decisions" in result + + +def test_layer1_empty_palace(): + mock_col = MagicMock() + mock_col.get.return_value = {"documents": [], "metadatas": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + result = layer.generate() + + assert "No memories" in result + + +def test_layer1_with_wing_filter(): + docs = ["Memory about project X"] + metas = [{"room": "general", "source_file": "x.txt", "importance": 3}] + mock_client = _mock_chromadb_for_layer(docs, metas) + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake", wing="project_x") + result = layer.generate() + + assert "ESSENTIAL STORY" in result + # Verify wing filter was passed + call_kwargs = mock_client.get_collection.return_value.get.call_args_list[0][1] + assert call_kwargs.get("where") == {"wing": "project_x"} + + +def test_layer1_truncates_long_snippets(): + docs = ["A" * 300] + metas = [{"room": "general", "source_file": "long.txt"}] + mock_client = _mock_chromadb_for_layer(docs, metas) + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + result = layer.generate() + + assert "..." in result + + +def test_layer1_respects_max_chars(): + """L1 stops adding entries once MAX_CHARS is reached.""" + docs = [f"Memory number {i} with substantial content padding here" for i in range(30)] + metas = [{"room": "general", "source_file": f"f{i}.txt", "importance": 5} for i in range(30)] + mock_client = _mock_chromadb_for_layer(docs, metas) + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + layer.MAX_CHARS = 200 # Very low cap to trigger truncation + result = layer.generate() + + assert "more in L3 search" in result + + +def test_layer1_importance_from_various_keys(): + """Layer1 tries importance, emotional_weight, weight keys.""" + docs = ["mem1", "mem2", "mem3"] + metas = [ + {"room": "r", "emotional_weight": 5}, + {"room": "r", "weight": 1}, + {"room": "r"}, # no weight key, defaults to 3 + ] + mock_client = _mock_chromadb_for_layer(docs, metas) + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + result = layer.generate() + + assert "ESSENTIAL STORY" in result + + +def test_layer1_batch_exception_breaks(): + """If col.get raises on a batch, loop breaks gracefully.""" + mock_col = MagicMock() + mock_col.get.side_effect = [ + {"documents": ["doc1"], "metadatas": [{"room": "r"}]}, + RuntimeError("batch error"), + ] + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer1(palace_path="/fake") + result = layer.generate() + + assert "ESSENTIAL STORY" in result + + +# ── Layer2 — mocked chromadb ──────────────────────────────────────────── + + +def test_layer2_no_palace(): + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + layer = Layer2(palace_path="/nonexistent/palace") + result = layer.retrieve(wing="test") + assert "No palace found" in result + + +def test_layer2_retrieve_with_wing(): + mock_col = MagicMock() + mock_col.get.return_value = { + "documents": ["Some memory about the project"], + "metadatas": [{"room": "backend", "source_file": "notes.txt"}], + } + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(wing="project") + + assert "ON-DEMAND" in result + assert "memory about the project" in result + + +def test_layer2_retrieve_with_room(): + mock_col = MagicMock() + mock_col.get.return_value = { + "documents": ["Backend architecture notes"], + "metadatas": [{"room": "architecture", "source_file": "arch.txt"}], + } + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(room="architecture") + + assert "ON-DEMAND" in result + + +def test_layer2_retrieve_wing_and_room(): + mock_col = MagicMock() + mock_col.get.return_value = { + "documents": ["Filtered result"], + "metadatas": [{"room": "backend", "source_file": "x.txt"}], + } + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(wing="proj", room="backend") + + assert "ON-DEMAND" in result + call_kwargs = mock_col.get.call_args[1] + assert "$and" in call_kwargs.get("where", {}) + + +def test_layer2_retrieve_empty(): + mock_col = MagicMock() + mock_col.get.return_value = {"documents": [], "metadatas": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(wing="missing") + + assert "No drawers found" in result + + +def test_layer2_retrieve_no_filter(): + mock_col = MagicMock() + mock_col.get.return_value = {"documents": [], "metadatas": []} + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + layer.retrieve() + + # No where filter should be passed + call_kwargs = mock_col.get.call_args[1] + assert "where" not in call_kwargs + + +def test_layer2_retrieve_error(): + mock_col = MagicMock() + mock_col.get.side_effect = RuntimeError("db error") + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(wing="test") + + assert "Retrieval error" in result + + +def test_layer2_truncates_long_snippets(): + mock_col = MagicMock() + mock_col.get.return_value = { + "documents": ["B" * 400], + "metadatas": [{"room": "r", "source_file": "s.txt"}], + } + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer2(palace_path="/fake") + result = layer.retrieve(wing="test") + + assert "..." in result + + +# ── Layer3 — mocked chromadb ──────────────────────────────────────────── + + +def _mock_query_results(docs, metas, dists): + return { + "documents": [docs], + "metadatas": [metas], + "distances": [dists], + } + + +def test_layer3_no_palace(): + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + layer = Layer3(palace_path="/nonexistent/palace") + result = layer.search("test query") + assert "No palace found" in result + + +def test_layer3_search_raw_no_palace(): + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent/palace" + layer = Layer3(palace_path="/nonexistent/palace") + result = layer.search_raw("test query") + assert result == [] + + +def test_layer3_search_with_results(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["Found this important memory"], + [{"wing": "project", "room": "backend", "source_file": "notes.txt"}], + [0.2], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + result = layer.search("important") + + assert "SEARCH RESULTS" in result + assert "important memory" in result + assert "sim=0.8" in result + + +def test_layer3_search_no_results(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results([], [], []) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + result = layer.search("nothing") + + assert "No results found" in result + + +def test_layer3_search_with_wing_filter(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["result"], + [{"wing": "proj", "room": "r"}], + [0.1], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + layer.search("q", wing="proj") + + call_kwargs = mock_col.query.call_args[1] + assert call_kwargs["where"] == {"wing": "proj"} + + +def test_layer3_search_with_room_filter(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["result"], + [{"wing": "w", "room": "backend"}], + [0.1], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + layer.search("q", room="backend") + + call_kwargs = mock_col.query.call_args[1] + assert call_kwargs["where"] == {"room": "backend"} + + +def test_layer3_search_with_wing_and_room(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["result"], + [{"wing": "proj", "room": "backend"}], + [0.1], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + layer.search("q", wing="proj", room="backend") + + call_kwargs = mock_col.query.call_args[1] + assert "$and" in call_kwargs["where"] + + +def test_layer3_search_error(): + mock_col = MagicMock() + mock_col.query.side_effect = RuntimeError("search failed") + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + result = layer.search("q") + + assert "Search error" in result + + +def test_layer3_search_truncates_long_docs(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["C" * 400], + [{"wing": "w", "room": "r", "source_file": "s.txt"}], + [0.1], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + result = layer.search("q") + + assert "..." in result + + +def test_layer3_search_raw_returns_dicts(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["doc text"], + [{"wing": "proj", "room": "backend", "source_file": "f.txt"}], + [0.3], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + hits = layer.search_raw("q") + + assert len(hits) == 1 + assert hits[0]["text"] == "doc text" + assert hits[0]["wing"] == "proj" + assert hits[0]["similarity"] == 0.7 + assert "metadata" in hits[0] + + +def test_layer3_search_raw_with_filters(): + mock_col = MagicMock() + mock_col.query.return_value = _mock_query_results( + ["doc"], + [{"wing": "w", "room": "r"}], + [0.1], + ) + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + layer.search_raw("q", wing="w", room="r") + + call_kwargs = mock_col.query.call_args[1] + assert "$and" in call_kwargs["where"] + + +def test_layer3_search_raw_error(): + mock_col = MagicMock() + mock_col.query.side_effect = RuntimeError("fail") + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + layer = Layer3(palace_path="/fake") + result = layer.search_raw("q") + + assert result == [] + + +# ── MemoryStack ───────────────────────────────────────────────────────── + + +def test_memory_stack_wake_up(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent" + stack = MemoryStack( + palace_path="/nonexistent", + identity_path=str(identity_file), + ) + result = stack.wake_up() + + assert "Atlas" in result + # L1 will say no palace found + assert "No palace" in result or "No memories" in result + + +def test_memory_stack_wake_up_with_wing(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent" + stack = MemoryStack( + palace_path="/nonexistent", + identity_path=str(identity_file), + ) + result = stack.wake_up(wing="my_project") + + assert stack.l1.wing == "my_project" + assert "Atlas" in result + + +def test_memory_stack_recall(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent" + stack = MemoryStack( + palace_path="/nonexistent", + identity_path=str(identity_file), + ) + result = stack.recall(wing="test") + + assert "No palace found" in result + + +def test_memory_stack_search(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent" + stack = MemoryStack( + palace_path="/nonexistent", + identity_path=str(identity_file), + ) + result = stack.search("test query") + + assert "No palace found" in result + + +def test_memory_stack_status(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + with patch("mempalace.layers.MempalaceConfig") as mock_cfg: + mock_cfg.return_value.palace_path = "/nonexistent" + stack = MemoryStack( + palace_path="/nonexistent", + identity_path=str(identity_file), + ) + result = stack.status() + + assert result["palace_path"] == "/nonexistent" + assert result["total_drawers"] == 0 + assert "L0_identity" in result + assert "L1_essential" in result + assert "L2_on_demand" in result + assert "L3_deep_search" in result + + +def test_memory_stack_status_with_palace(tmp_path): + identity_file = tmp_path / "identity.txt" + identity_file.write_text("I am Atlas.") + + mock_col = MagicMock() + mock_col.count.return_value = 42 + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with ( + patch("mempalace.layers.MempalaceConfig") as mock_cfg, + patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client), + ): + mock_cfg.return_value.palace_path = "/fake" + stack = MemoryStack( + palace_path="/fake", + identity_path=str(identity_file), + ) + result = stack.status() + + assert result["total_drawers"] == 42 + assert result["L0_identity"]["exists"] is True diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 24258a9..96fe80c 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -42,6 +42,50 @@ class TestHandleRequest: assert resp["result"]["serverInfo"]["name"] == "mempalace" assert resp["id"] == 1 + def test_initialize_negotiates_client_version(self): + from mempalace.mcp_server import handle_request + + resp = handle_request( + { + "method": "initialize", + "id": 1, + "params": {"protocolVersion": "2025-11-25"}, + } + ) + assert resp["result"]["protocolVersion"] == "2025-11-25" + + def test_initialize_negotiates_older_supported_version(self): + from mempalace.mcp_server import handle_request + + resp = handle_request( + { + "method": "initialize", + "id": 1, + "params": {"protocolVersion": "2025-03-26"}, + } + ) + assert resp["result"]["protocolVersion"] == "2025-03-26" + + def test_initialize_unknown_version_falls_back_to_latest(self): + from mempalace.mcp_server import handle_request + + resp = handle_request( + { + "method": "initialize", + "id": 1, + "params": {"protocolVersion": "9999-12-31"}, + } + ) + from mempalace.mcp_server import SUPPORTED_PROTOCOL_VERSIONS + + assert resp["result"]["protocolVersion"] == SUPPORTED_PROTOCOL_VERSIONS[0] + + def test_initialize_missing_version_uses_oldest(self): + from mempalace.mcp_server import handle_request, SUPPORTED_PROTOCOL_VERSIONS + + resp = handle_request({"method": "initialize", "id": 1, "params": {}}) + assert resp["result"]["protocolVersion"] == SUPPORTED_PROTOCOL_VERSIONS[-1] + def test_notifications_initialized_returns_none(self): from mempalace.mcp_server import handle_request @@ -59,6 +103,23 @@ class TestHandleRequest: assert "mempalace_add_drawer" in names assert "mempalace_kg_add" in names + def test_null_arguments_does_not_hang(self, monkeypatch, config, palace_path, seeded_kg): + """Sending arguments: null should return a result, not hang (#394).""" + _patch_mcp_server(monkeypatch, config, seeded_kg) + from mempalace.mcp_server import handle_request + + _client, _col = _get_collection(palace_path, create=True) + del _client + resp = handle_request( + { + "method": "tools/call", + "id": 10, + "params": {"name": "mempalace_status", "arguments": None}, + } + ) + assert "error" not in resp + assert resp["result"] is not None + def test_unknown_tool(self): from mempalace.mcp_server import handle_request diff --git a/tests/test_miner.py b/tests/test_miner.py index 337e949..c013d7c 100644 --- a/tests/test_miner.py +++ b/tests/test_miner.py @@ -7,6 +7,7 @@ import chromadb import yaml from mempalace.miner import mine, scan_project +from mempalace.palace import file_already_mined def write_file(path: Path, content: str): @@ -47,7 +48,7 @@ def test_project_mining(): col = client.get_collection("mempalace_drawers") assert col.count() > 0 finally: - shutil.rmtree(tmpdir) + shutil.rmtree(tmpdir, ignore_errors=True) def test_scan_project_respects_gitignore(): @@ -206,3 +207,56 @@ def test_scan_project_skip_dirs_still_apply_without_override(): assert scanned_files(project_root, respect_gitignore=False) == ["main.py"] finally: shutil.rmtree(tmpdir) + + +def test_file_already_mined_check_mtime(): + tmpdir = tempfile.mkdtemp() + try: + palace_path = os.path.join(tmpdir, "palace") + os.makedirs(palace_path) + client = chromadb.PersistentClient(path=palace_path) + col = client.get_or_create_collection("mempalace_drawers") + + test_file = os.path.join(tmpdir, "test.txt") + with open(test_file, "w") as f: + f.write("hello world") + + mtime = os.path.getmtime(test_file) + + # Not mined yet + assert file_already_mined(col, test_file) is False + assert file_already_mined(col, test_file, check_mtime=True) is False + + # Add it with mtime + col.add( + ids=["d1"], + documents=["hello world"], + metadatas=[{"source_file": test_file, "source_mtime": str(mtime)}], + ) + + # Already mined (no mtime check) + assert file_already_mined(col, test_file) is True + # Already mined (mtime matches) + assert file_already_mined(col, test_file, check_mtime=True) is True + + # Modify file and force a different mtime (Windows has low mtime resolution) + with open(test_file, "w") as f: + f.write("modified content") + os.utime(test_file, (mtime + 10, mtime + 10)) + + # Still mined without mtime check + assert file_already_mined(col, test_file) is True + # Needs re-mining with mtime check + assert file_already_mined(col, test_file, check_mtime=True) is False + + # Record with no mtime stored should return False for check_mtime + col.add( + ids=["d2"], + documents=["other"], + metadatas=[{"source_file": "/fake/no_mtime.txt"}], + ) + assert file_already_mined(col, "/fake/no_mtime.txt", check_mtime=True) is False + finally: + # Release ChromaDB file handles before cleanup (required on Windows) + del col, client + shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/tests/test_normalize.py b/tests/test_normalize.py index c304c9d..959668f 100644 --- a/tests/test_normalize.py +++ b/tests/test_normalize.py @@ -1,31 +1,511 @@ -import os import json -import tempfile -from mempalace.normalize import normalize +from unittest.mock import patch + +from mempalace.normalize import ( + _extract_content, + _messages_to_transcript, + _try_chatgpt_json, + _try_claude_ai_json, + _try_claude_code_jsonl, + _try_codex_jsonl, + _try_normalize_json, + _try_slack_json, + normalize, +) -def test_plain_text(): - f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) - f.write("Hello world\nSecond line\n") - f.close() - result = normalize(f.name) +# ── normalize() top-level ────────────────────────────────────────────── + + +def test_plain_text(tmp_path): + f = tmp_path / "plain.txt" + f.write_text("Hello world\nSecond line\n") + result = normalize(str(f)) assert "Hello world" in result - os.unlink(f.name) -def test_claude_json(): +def test_claude_json(tmp_path): data = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}] - f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) - json.dump(data, f) - f.close() - result = normalize(f.name) + f = tmp_path / "claude.json" + f.write_text(json.dumps(data)) + result = normalize(str(f)) assert "Hi" in result - os.unlink(f.name) -def test_empty(): - f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) - f.close() - result = normalize(f.name) +def test_empty(tmp_path): + f = tmp_path / "empty.txt" + f.write_text("") + result = normalize(str(f)) assert result.strip() == "" - os.unlink(f.name) + + +def test_normalize_io_error(): + """normalize raises IOError for unreadable file.""" + try: + normalize("/nonexistent/path/file.txt") + assert False, "Should have raised" + except IOError as e: + assert "Could not read" in str(e) + + +def test_normalize_already_has_markers(tmp_path): + """Files with >= 3 '>' lines pass through unchanged.""" + content = "> question 1\nanswer 1\n> question 2\nanswer 2\n> question 3\nanswer 3\n" + f = tmp_path / "markers.txt" + f.write_text(content) + result = normalize(str(f)) + assert result == content + + +def test_normalize_json_content_detected_by_brace(tmp_path): + """A .txt file starting with [ triggers JSON parsing.""" + data = [{"role": "user", "content": "Hey"}, {"role": "assistant", "content": "Hi there"}] + f = tmp_path / "chat.txt" + f.write_text(json.dumps(data)) + result = normalize(str(f)) + assert "Hey" in result + + +def test_normalize_whitespace_only(tmp_path): + f = tmp_path / "ws.txt" + f.write_text(" \n \n ") + result = normalize(str(f)) + assert result.strip() == "" + + +# ── _extract_content ─────────────────────────────────────────────────── + + +def test_extract_content_string(): + assert _extract_content("hello") == "hello" + + +def test_extract_content_list_of_strings(): + assert _extract_content(["hello", "world"]) == "hello world" + + +def test_extract_content_list_of_blocks(): + blocks = [{"type": "text", "text": "hello"}, {"type": "image", "url": "x"}] + assert _extract_content(blocks) == "hello" + + +def test_extract_content_dict(): + assert _extract_content({"text": "hello"}) == "hello" + + +def test_extract_content_none(): + assert _extract_content(None) == "" + + +def test_extract_content_mixed_list(): + blocks = ["plain", {"type": "text", "text": "block"}] + assert _extract_content(blocks) == "plain block" + + +# ── _try_claude_code_jsonl ───────────────────────────────────────────── + + +def test_claude_code_jsonl_valid(): + lines = [ + json.dumps({"type": "human", "message": {"content": "What is X?"}}), + json.dumps({"type": "assistant", "message": {"content": "X is Y."}}), + ] + result = _try_claude_code_jsonl("\n".join(lines)) + assert result is not None + assert "> What is X?" in result + assert "X is Y." in result + + +def test_claude_code_jsonl_user_type(): + lines = [ + json.dumps({"type": "user", "message": {"content": "Q"}}), + json.dumps({"type": "assistant", "message": {"content": "A"}}), + ] + result = _try_claude_code_jsonl("\n".join(lines)) + assert result is not None + assert "> Q" in result + + +def test_claude_code_jsonl_too_few_messages(): + lines = [json.dumps({"type": "human", "message": {"content": "only one"}})] + result = _try_claude_code_jsonl("\n".join(lines)) + assert result is None + + +def test_claude_code_jsonl_invalid_json_lines(): + lines = [ + "not json", + json.dumps({"type": "human", "message": {"content": "Q"}}), + json.dumps({"type": "assistant", "message": {"content": "A"}}), + ] + result = _try_claude_code_jsonl("\n".join(lines)) + assert result is not None + + +def test_claude_code_jsonl_non_dict_entries(): + lines = [ + json.dumps([1, 2, 3]), + json.dumps({"type": "human", "message": {"content": "Q"}}), + json.dumps({"type": "assistant", "message": {"content": "A"}}), + ] + result = _try_claude_code_jsonl("\n".join(lines)) + assert result is not None + + +# ── _try_codex_jsonl ─────────────────────────────────────────────────── + + +def test_codex_jsonl_valid(): + lines = [ + json.dumps({"type": "session_meta", "payload": {}}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is not None + assert "> Q" in result + + +def test_codex_jsonl_no_session_meta(): + """Without session_meta, codex parser returns None.""" + lines = [ + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is None + + +def test_codex_jsonl_skips_non_event_msg(): + lines = [ + json.dumps({"type": "session_meta"}), + json.dumps({"type": "response_item", "payload": {"type": "user_message", "message": "X"}}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is not None + assert "X" not in result.split("> Q")[0] + + +def test_codex_jsonl_non_string_message(): + lines = [ + json.dumps({"type": "session_meta"}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": 123}}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is not None + + +def test_codex_jsonl_empty_text_skipped(): + lines = [ + json.dumps({"type": "session_meta"}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": " "}}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is not None + + +def test_codex_jsonl_payload_not_dict(): + lines = [ + json.dumps({"type": "session_meta"}), + json.dumps({"type": "event_msg", "payload": "not a dict"}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), + ] + result = _try_codex_jsonl("\n".join(lines)) + assert result is not None + + +# ── _try_claude_ai_json ─────────────────────────────────────────────── + + +def test_claude_ai_flat_messages(): + data = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + result = _try_claude_ai_json(data) + assert result is not None + assert "> Hello" in result + + +def test_claude_ai_dict_with_messages_key(): + data = { + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + } + result = _try_claude_ai_json(data) + assert result is not None + + +def test_claude_ai_privacy_export(): + data = [ + { + "chat_messages": [ + {"role": "human", "content": "Q1"}, + {"role": "ai", "content": "A1"}, + ] + } + ] + result = _try_claude_ai_json(data) + assert result is not None + assert "> Q1" in result + + +def test_claude_ai_not_a_list(): + result = _try_claude_ai_json("not a list") + assert result is None + + +def test_claude_ai_too_few_messages(): + data = [{"role": "user", "content": "Hello"}] + result = _try_claude_ai_json(data) + assert result is None + + +def test_claude_ai_dict_with_chat_messages_key(): + data = { + "chat_messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "World"}, + ] + } + result = _try_claude_ai_json(data) + assert result is not None + + +def test_claude_ai_privacy_export_non_dict_items(): + """Non-dict items in privacy export are skipped.""" + data = [ + { + "chat_messages": [ + "not a dict", + {"role": "user", "content": "Q"}, + {"role": "assistant", "content": "A"}, + ] + }, + "not a convo", + ] + result = _try_claude_ai_json(data) + assert result is not None + + +# ── _try_chatgpt_json ───────────────────────────────────────────────── + + +def test_chatgpt_json_valid(): + data = { + "mapping": { + "root": { + "parent": None, + "message": None, + "children": ["msg1"], + }, + "msg1": { + "parent": "root", + "message": { + "author": {"role": "user"}, + "content": {"parts": ["Hello ChatGPT"]}, + }, + "children": ["msg2"], + }, + "msg2": { + "parent": "msg1", + "message": { + "author": {"role": "assistant"}, + "content": {"parts": ["Hello! How can I help?"]}, + }, + "children": [], + }, + } + } + result = _try_chatgpt_json(data) + assert result is not None + assert "> Hello ChatGPT" in result + + +def test_chatgpt_json_no_mapping(): + result = _try_chatgpt_json({"data": []}) + assert result is None + + +def test_chatgpt_json_not_dict(): + result = _try_chatgpt_json([1, 2, 3]) + assert result is None + + +def test_chatgpt_json_fallback_root(): + """Root node has a message (no synthetic root), uses fallback.""" + data = { + "mapping": { + "root": { + "parent": None, + "message": { + "author": {"role": "system"}, + "content": {"parts": ["system prompt"]}, + }, + "children": ["msg1"], + }, + "msg1": { + "parent": "root", + "message": { + "author": {"role": "user"}, + "content": {"parts": ["Hello"]}, + }, + "children": ["msg2"], + }, + "msg2": { + "parent": "msg1", + "message": { + "author": {"role": "assistant"}, + "content": {"parts": ["Hi there"]}, + }, + "children": [], + }, + } + } + result = _try_chatgpt_json(data) + assert result is not None + + +def test_chatgpt_json_too_few_messages(): + data = { + "mapping": { + "root": { + "parent": None, + "message": None, + "children": ["msg1"], + }, + "msg1": { + "parent": "root", + "message": { + "author": {"role": "user"}, + "content": {"parts": ["Only one"]}, + }, + "children": [], + }, + } + } + result = _try_chatgpt_json(data) + assert result is None + + +# ── _try_slack_json ──────────────────────────────────────────────────── + + +def test_slack_json_valid(): + data = [ + {"type": "message", "user": "U1", "text": "Hello"}, + {"type": "message", "user": "U2", "text": "Hi there"}, + ] + result = _try_slack_json(data) + assert result is not None + assert "Hello" in result + + +def test_slack_json_not_a_list(): + result = _try_slack_json({"type": "message"}) + assert result is None + + +def test_slack_json_too_few_messages(): + data = [{"type": "message", "user": "U1", "text": "Hello"}] + result = _try_slack_json(data) + assert result is None + + +def test_slack_json_skips_non_message_types(): + data = [ + {"type": "channel_join", "user": "U1", "text": "joined"}, + {"type": "message", "user": "U1", "text": "Hello"}, + {"type": "message", "user": "U2", "text": "Hi"}, + ] + result = _try_slack_json(data) + assert result is not None + + +def test_slack_json_three_users(): + """Three speakers get alternating roles.""" + data = [ + {"type": "message", "user": "U1", "text": "Hello"}, + {"type": "message", "user": "U2", "text": "Hi"}, + {"type": "message", "user": "U3", "text": "Hey"}, + ] + result = _try_slack_json(data) + assert result is not None + + +def test_slack_json_empty_text_skipped(): + data = [ + {"type": "message", "user": "U1", "text": ""}, + {"type": "message", "user": "U1", "text": "Hello"}, + {"type": "message", "user": "U2", "text": "Hi"}, + ] + result = _try_slack_json(data) + assert result is not None + + +def test_slack_json_username_fallback(): + data = [ + {"type": "message", "username": "bot1", "text": "Hello"}, + {"type": "message", "username": "bot2", "text": "Hi"}, + ] + result = _try_slack_json(data) + assert result is not None + + +# ── _try_normalize_json ──────────────────────────────────────────────── + + +def test_try_normalize_json_invalid_json(): + result = _try_normalize_json("not json at all {{{") + assert result is None + + +def test_try_normalize_json_valid_but_unknown_schema(): + result = _try_normalize_json(json.dumps({"random": "data"})) + assert result is None + + +# ── _messages_to_transcript ──────────────────────────────────────────── + + +def test_messages_to_transcript_basic(): + msgs = [("user", "Q"), ("assistant", "A")] + with patch("mempalace.normalize.spellcheck_user_text", side_effect=lambda x: x, create=True): + result = _messages_to_transcript(msgs, spellcheck=False) + assert "> Q" in result + assert "A" in result + + +def test_messages_to_transcript_consecutive_users(): + """Two user messages in a row (no assistant between).""" + msgs = [("user", "Q1"), ("user", "Q2"), ("assistant", "A")] + result = _messages_to_transcript(msgs, spellcheck=False) + assert "> Q1" in result + assert "> Q2" in result + + +def test_messages_to_transcript_assistant_first(): + """Leading assistant message (no user before it).""" + msgs = [("assistant", "preamble"), ("user", "Q"), ("assistant", "A")] + result = _messages_to_transcript(msgs, spellcheck=False) + assert "preamble" in result + assert "> Q" in result + + +def test_normalize_rejects_large_file(): + """Files over 500 MB should raise IOError before reading.""" + with patch("mempalace.normalize.os.path.getsize", return_value=600 * 1024 * 1024): + try: + normalize("/fake/huge_file.txt") + assert False, "Should have raised IOError" + except IOError as e: + assert "too large" in str(e).lower() diff --git a/tests/test_onboarding.py b/tests/test_onboarding.py new file mode 100644 index 0000000..ea7a37b --- /dev/null +++ b/tests/test_onboarding.py @@ -0,0 +1,452 @@ +"""Tests for mempalace.onboarding.""" + +import os +from unittest.mock import patch + +from mempalace.onboarding import ( + DEFAULT_WINGS, + _ask, + _ask_mode, + _ask_people, + _ask_projects, + _ask_wings, + _auto_detect, + _generate_aaak_bootstrap, + _header, + _hr, + _warn_ambiguous, + _yn, + quick_setup, + run_onboarding, +) + +# Force UTF-8 for Windows (source file contains Unicode symbols like hearts/stars) +os.environ["PYTHONUTF8"] = "1" + + +# ── DEFAULT_WINGS ─────────────────────────────────────────────────────── + + +def test_default_wings_has_expected_keys(): + assert "work" in DEFAULT_WINGS + assert "personal" in DEFAULT_WINGS + assert "combo" in DEFAULT_WINGS + + +def test_default_wings_work_has_projects(): + assert "projects" in DEFAULT_WINGS["work"] + + +def test_default_wings_personal_has_family(): + assert "family" in DEFAULT_WINGS["personal"] + + +def test_default_wings_combo_has_both(): + wings = DEFAULT_WINGS["combo"] + assert "family" in wings + assert "work" in wings + + +def test_default_wings_values_are_lists(): + for mode, wings in DEFAULT_WINGS.items(): + assert isinstance(wings, list), f"{mode} wings should be a list" + assert len(wings) >= 3, f"{mode} should have at least 3 wings" + + +# ── _warn_ambiguous ───────────────────────────────────────────────────── + + +def test_warn_ambiguous_flags_common_words(): + people = [ + {"name": "Grace", "relationship": "friend"}, + {"name": "Riley", "relationship": "daughter"}, + ] + result = _warn_ambiguous(people) + assert "Grace" in result + # Riley is not a common English word + assert "Riley" not in result + + +def test_warn_ambiguous_empty_list(): + result = _warn_ambiguous([]) + assert result == [] + + +def test_warn_ambiguous_no_ambiguous_names(): + people = [ + {"name": "Riley", "relationship": "daughter"}, + {"name": "Devon", "relationship": "friend"}, + ] + result = _warn_ambiguous(people) + assert result == [] + + +def test_warn_ambiguous_multiple_hits(): + people = [ + {"name": "Grace", "relationship": "friend"}, + {"name": "May", "relationship": "aunt"}, + {"name": "Joy", "relationship": "sister"}, + ] + result = _warn_ambiguous(people) + assert "Grace" in result + assert "May" in result + assert "Joy" in result + + +# ── quick_setup ───────────────────────────────────────────────────────── + + +def test_quick_setup_creates_registry(tmp_path): + registry = quick_setup( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + projects=["MemPalace"], + config_dir=tmp_path, + ) + assert "Riley" in registry.people + assert "MemPalace" in registry.projects + assert registry.mode == "personal" + + +def test_quick_setup_work_mode(tmp_path): + registry = quick_setup( + mode="work", + people=[{"name": "Alice", "relationship": "colleague", "context": "work"}], + projects=["Acme"], + config_dir=tmp_path, + ) + assert registry.mode == "work" + assert "Alice" in registry.people + assert "Acme" in registry.projects + + +def test_quick_setup_empty(tmp_path): + registry = quick_setup(mode="personal", people=[], config_dir=tmp_path) + assert len(registry.people) == 0 + assert len(registry.projects) == 0 + + +def test_quick_setup_saves_to_disk(tmp_path): + quick_setup( + mode="personal", + people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}], + config_dir=tmp_path, + ) + assert (tmp_path / "entity_registry.json").exists() + + +# ── _generate_aaak_bootstrap ─────────────────────────────────────────── + + +def test_generate_aaak_bootstrap_creates_files(tmp_path): + people = [ + {"name": "Riley", "relationship": "daughter", "context": "personal"}, + {"name": "Devon", "relationship": "friend", "context": "personal"}, + ] + projects = ["MemPalace"] + wings = ["family", "creative"] + _generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path) + + assert (tmp_path / "aaak_entities.md").exists() + assert (tmp_path / "critical_facts.md").exists() + + +def test_generate_aaak_bootstrap_entities_content(tmp_path): + people = [{"name": "Riley", "relationship": "daughter", "context": "personal"}] + projects = ["MemPalace"] + wings = ["family"] + _generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path) + + content = (tmp_path / "aaak_entities.md").read_text() + assert "Riley" in content + assert "RIL" in content # entity code + assert "MemPalace" in content + + +def test_generate_aaak_bootstrap_facts_content(tmp_path): + people = [ + {"name": "Alice", "relationship": "colleague", "context": "work"}, + ] + projects = ["Acme"] + wings = ["projects"] + _generate_aaak_bootstrap(people, projects, wings, "work", config_dir=tmp_path) + + content = (tmp_path / "critical_facts.md").read_text() + assert "Alice" in content + assert "Acme" in content + assert "work" in content.lower() + + +def test_generate_aaak_bootstrap_empty_people(tmp_path): + _generate_aaak_bootstrap([], [], ["general"], "personal", config_dir=tmp_path) + assert (tmp_path / "aaak_entities.md").exists() + assert (tmp_path / "critical_facts.md").exists() + + +def test_generate_aaak_bootstrap_collision(tmp_path): + """Two people with same 3-letter code get different codes.""" + people = [ + {"name": "Alice", "relationship": "friend", "context": "work"}, + {"name": "Alison", "relationship": "coworker", "context": "work"}, + ] + _generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path) + content = (tmp_path / "aaak_entities.md").read_text() + assert "ALI" in content + assert "ALIS" in content + + +def test_generate_aaak_bootstrap_no_relationship(tmp_path): + """Person without relationship string still generates entry.""" + people = [{"name": "Bob", "context": "work"}] + _generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path) + content = (tmp_path / "aaak_entities.md").read_text() + assert "BOB=Bob" in content + + +# ── _hr, _header ────────────────────────────────────────────────────── + + +def test_hr_prints_line(capsys): + _hr() + out = capsys.readouterr().out + assert "─" in out + + +def test_header_prints_banner(capsys): + _header("Test Title") + out = capsys.readouterr().out + assert "Test Title" in out + assert "=" in out + + +# ── _ask ────────────────────────────────────────────────────────────── + + +def test_ask_with_default_uses_default(): + with patch("builtins.input", return_value=""): + result = _ask("prompt", default="fallback") + assert result == "fallback" + + +def test_ask_with_default_uses_input(): + with patch("builtins.input", return_value="custom"): + result = _ask("prompt", default="fallback") + assert result == "custom" + + +def test_ask_no_default(): + with patch("builtins.input", return_value="answer"): + result = _ask("prompt") + assert result == "answer" + + +# ── _yn ─────────────────────────────────────────────────────────────── + + +def test_yn_default_yes_empty_input(): + with patch("builtins.input", return_value=""): + assert _yn("continue?") is True + + +def test_yn_default_no_empty_input(): + with patch("builtins.input", return_value=""): + assert _yn("continue?", default="n") is False + + +def test_yn_explicit_yes(): + with patch("builtins.input", return_value="yes"): + assert _yn("continue?", default="n") is True + + +def test_yn_explicit_no(): + with patch("builtins.input", return_value="no"): + assert _yn("continue?") is False + + +# ── _ask_mode ───────────────────────────────────────────────────────── + + +def test_ask_mode_work(): + with patch("builtins.input", return_value="1"): + assert _ask_mode() == "work" + + +def test_ask_mode_personal(): + with patch("builtins.input", return_value="2"): + assert _ask_mode() == "personal" + + +def test_ask_mode_combo(): + with patch("builtins.input", return_value="3"): + assert _ask_mode() == "combo" + + +def test_ask_mode_retries_on_bad_input(): + with patch("builtins.input", side_effect=["x", "bad", "1"]): + assert _ask_mode() == "work" + + +# ── _ask_people ─────────────────────────────────────────────────────── + + +def test_ask_people_personal_mode(): + with patch("builtins.input", side_effect=["Alice, daughter", "", "done"]): + people, aliases = _ask_people("personal") + assert len(people) == 1 + assert people[0]["name"] == "Alice" + assert people[0]["relationship"] == "daughter" + + +def test_ask_people_work_mode(): + with patch("builtins.input", side_effect=["Bob, manager", "", "done"]): + people, aliases = _ask_people("work") + assert len(people) == 1 + assert people[0]["name"] == "Bob" + assert people[0]["context"] == "work" + + +def test_ask_people_combo_mode(): + with patch( + "builtins.input", + side_effect=[ + "Alice, daughter", + "", + "done", # personal + "Bob, boss", + "done", # work + ], + ): + people, aliases = _ask_people("combo") + assert len(people) == 2 + + +def test_ask_people_with_nickname(): + with patch("builtins.input", side_effect=["Alice, daughter", "Ali", "done"]): + people, aliases = _ask_people("personal") + assert aliases == {"Ali": "Alice"} + + +def test_ask_people_empty_name_skipped(): + with patch("builtins.input", side_effect=["", "done"]): + people, aliases = _ask_people("personal") + assert len(people) == 0 + + +# ── _ask_projects ───────────────────────────────────────────────────── + + +def test_ask_projects_personal_returns_empty(): + result = _ask_projects("personal") + assert result == [] + + +def test_ask_projects_work_mode(): + with patch("builtins.input", side_effect=["Acme", "BigCo", "done"]): + result = _ask_projects("work") + assert result == ["Acme", "BigCo"] + + +def test_ask_projects_empty_entry_stops(): + with patch("builtins.input", side_effect=["Acme", ""]): + result = _ask_projects("work") + assert result == ["Acme"] + + +# ── _ask_wings ──────────────────────────────────────────────────────── + + +def test_ask_wings_accept_defaults(): + with patch("builtins.input", return_value=""): + result = _ask_wings("work") + assert result == DEFAULT_WINGS["work"] + + +def test_ask_wings_custom(): + with patch("builtins.input", return_value="alpha, beta, gamma"): + result = _ask_wings("personal") + assert result == ["alpha", "beta", "gamma"] + + +# ── _auto_detect ────────────────────────────────────────────────────── + + +def test_auto_detect_no_files(tmp_path): + result = _auto_detect(str(tmp_path), []) + assert result == [] + + +def test_auto_detect_filters_known(tmp_path): + known = [{"name": "Alice"}] + fake_detected = { + "people": [ + {"name": "Alice", "confidence": 0.9, "signals": ["test"]}, + {"name": "Bob", "confidence": 0.8, "signals": ["test"]}, + ], + "projects": [], + "uncertain": [], + } + with ( + patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]), + patch("mempalace.onboarding.detect_entities", return_value=fake_detected), + ): + result = _auto_detect(str(tmp_path), known) + names = [p["name"] for p in result] + assert "Alice" not in names + assert "Bob" in names + + +def test_auto_detect_filters_low_confidence(tmp_path): + fake_detected = { + "people": [{"name": "Bob", "confidence": 0.5, "signals": ["test"]}], + "projects": [], + "uncertain": [], + } + with ( + patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]), + patch("mempalace.onboarding.detect_entities", return_value=fake_detected), + ): + result = _auto_detect(str(tmp_path), []) + assert len(result) == 0 + + +def test_auto_detect_handles_exception(tmp_path): + with patch("mempalace.onboarding.scan_for_detection", side_effect=Exception("boom")): + result = _auto_detect(str(tmp_path), []) + assert result == [] + + +# ── run_onboarding ──────────────────────────────────────────────────── + + +def test_run_onboarding_basic_flow(tmp_path): + """Test the full onboarding flow with minimal mocking.""" + with ( + patch("mempalace.onboarding._ask_mode", return_value="work"), + patch( + "mempalace.onboarding._ask_people", + return_value=([{"name": "Bob", "relationship": "boss", "context": "work"}], {}), + ), + patch("mempalace.onboarding._ask_projects", return_value=["Acme"]), + patch("mempalace.onboarding._ask_wings", return_value=["projects", "team"]), + patch("mempalace.onboarding._yn", return_value=False), + patch("mempalace.onboarding._warn_ambiguous", return_value=[]), + ): + registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False) + assert "Bob" in registry.people + assert "Acme" in registry.projects + + +def test_run_onboarding_with_ambiguous_names(tmp_path): + """Onboarding prints a warning for ambiguous names.""" + with ( + patch("mempalace.onboarding._ask_mode", return_value="personal"), + patch( + "mempalace.onboarding._ask_people", + return_value=([{"name": "Grace", "relationship": "friend", "context": "personal"}], {}), + ), + patch("mempalace.onboarding._ask_projects", return_value=[]), + patch("mempalace.onboarding._ask_wings", return_value=["family"]), + patch("mempalace.onboarding._yn", return_value=False), + ): + registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False) + assert "Grace" in registry.people diff --git a/tests/test_palace_graph.py b/tests/test_palace_graph.py new file mode 100644 index 0000000..ddda272 --- /dev/null +++ b/tests/test_palace_graph.py @@ -0,0 +1,244 @@ +"""Tests for mempalace.palace_graph — graph traversal layer. + +All ChromaDB access is mocked — no real database needed. +""" + +from unittest.mock import MagicMock, patch + + +def _make_fake_collection(metadatas, ids=None): + """Create a mock collection that returns the given metadata in batches.""" + if ids is None: + ids = [f"id_{i}" for i in range(len(metadatas))] + + col = MagicMock() + col.count.return_value = len(metadatas) + + def fake_get(limit=1000, offset=0, include=None): + batch_meta = metadatas[offset : offset + limit] + batch_ids = ids[offset : offset + limit] + return {"ids": batch_ids, "metadatas": batch_meta} + + col.get.side_effect = fake_get + return col + + +# Patch chromadb at import time so palace_graph can be imported +with patch.dict("sys.modules", {"chromadb": MagicMock()}): + from mempalace.palace_graph import ( + _fuzzy_match, + build_graph, + find_tunnels, + graph_stats, + traverse, + ) + + +# --- build_graph --- + + +class TestBuildGraph: + def test_empty_collection(self): + col = _make_fake_collection([]) + nodes, edges = build_graph(col=col) + assert nodes == {} + assert edges == [] + + def test_falsy_collection(self): + """When col is explicitly falsy, build_graph returns empty.""" + nodes, edges = build_graph(col=0) + assert nodes == {} + assert edges == [] + + def test_single_wing_no_edges(self): + col = _make_fake_collection( + [ + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-02"}, + ] + ) + nodes, edges = build_graph(col=col) + assert "auth" in nodes + assert nodes["auth"]["count"] == 2 + assert edges == [] + + def test_multi_wing_creates_edges(self): + col = _make_fake_collection( + [ + { + "room": "chromadb", + "wing": "wing_code", + "hall": "databases", + "date": "2026-01-01", + }, + { + "room": "chromadb", + "wing": "wing_project", + "hall": "databases", + "date": "2026-01-02", + }, + ] + ) + nodes, edges = build_graph(col=col) + assert "chromadb" in nodes + assert len(edges) == 1 + assert edges[0]["wing_a"] == "wing_code" + assert edges[0]["wing_b"] == "wing_project" + assert edges[0]["hall"] == "databases" + + def test_general_room_excluded(self): + col = _make_fake_collection( + [ + {"room": "general", "wing": "wing_code", "hall": "misc", "date": ""}, + ] + ) + nodes, edges = build_graph(col=col) + assert "general" not in nodes + + def test_missing_wing_excluded(self): + col = _make_fake_collection( + [ + {"room": "orphan", "wing": "", "hall": "misc", "date": ""}, + ] + ) + nodes, edges = build_graph(col=col) + assert "orphan" not in nodes + + def test_dates_capped_at_five(self): + col = _make_fake_collection( + [ + {"room": "busy", "wing": "w", "hall": "h", "date": f"2026-01-{i:02d}"} + for i in range(1, 10) + ] + ) + nodes, _ = build_graph(col=col) + assert len(nodes["busy"]["dates"]) <= 5 + + +# --- traverse --- + + +class TestTraverse: + def _build_col(self): + return _make_fake_collection( + [ + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "login", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "deploy", "wing": "wing_ops", "hall": "infra", "date": "2026-01-01"}, + ] + ) + + def test_traverse_known_room(self): + col = self._build_col() + result = traverse("auth", col=col) + assert isinstance(result, list) + rooms = [r["room"] for r in result] + assert "auth" in rooms + # login shares wing_code with auth + assert "login" in rooms + + def test_traverse_unknown_room(self): + col = self._build_col() + result = traverse("nonexistent", col=col) + assert isinstance(result, dict) + assert "error" in result + assert "suggestions" in result + + def test_traverse_max_hops(self): + col = self._build_col() + result = traverse("auth", col=col, max_hops=0) + # Only the start room itself at hop 0 + assert len(result) == 1 + assert result[0]["room"] == "auth" + + +# --- find_tunnels --- + + +class TestFindTunnels: + def _build_tunnel_col(self): + return _make_fake_collection( + [ + {"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"}, + {"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + ] + ) + + def test_find_all_tunnels(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(col=col) + assert len(tunnels) == 1 + assert tunnels[0]["room"] == "chromadb" + + def test_find_tunnels_with_wing_filter(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(wing_a="wing_code", col=col) + assert len(tunnels) == 1 + + def test_find_tunnels_no_match(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(wing_a="wing_nonexistent", col=col) + assert tunnels == [] + + def test_find_tunnels_both_wings(self): + col = self._build_tunnel_col() + tunnels = find_tunnels(wing_a="wing_code", wing_b="wing_project", col=col) + assert len(tunnels) == 1 + assert tunnels[0]["room"] == "chromadb" + + +# --- graph_stats --- + + +class TestGraphStats: + def test_empty_graph(self): + col = _make_fake_collection([]) + stats = graph_stats(col=col) + assert stats["total_rooms"] == 0 + assert stats["tunnel_rooms"] == 0 + assert stats["total_edges"] == 0 + + def test_stats_with_data(self): + col = _make_fake_collection( + [ + {"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"}, + {"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + ] + ) + stats = graph_stats(col=col) + assert stats["total_rooms"] == 2 + assert stats["tunnel_rooms"] == 1 + assert stats["total_edges"] == 1 + assert "wing_code" in stats["rooms_per_wing"] + + +# --- _fuzzy_match --- + + +class TestFuzzyMatch: + def test_exact_substring(self): + nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}} + result = _fuzzy_match("chromadb", nodes) + assert "chromadb-setup" in result + + def test_partial_word_match(self): + nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}} + result = _fuzzy_match("auth", nodes) + assert "auth-module" in result + + def test_no_match(self): + nodes = {"chromadb-setup": {}, "auth-module": {}} + result = _fuzzy_match("zzzzz", nodes) + assert result == [] + + def test_hyphenated_query(self): + nodes = {"riley-college-apps": {}, "college-prep": {}} + result = _fuzzy_match("riley-college", nodes) + assert "riley-college-apps" in result + + def test_max_results(self): + nodes = {f"room-{i}": {} for i in range(20)} + result = _fuzzy_match("room", nodes, n=3) + assert len(result) <= 3 diff --git a/tests/test_room_detector_local.py b/tests/test_room_detector_local.py new file mode 100644 index 0000000..11963e4 --- /dev/null +++ b/tests/test_room_detector_local.py @@ -0,0 +1,264 @@ +"""Tests for mempalace.room_detector_local.""" + +from unittest.mock import MagicMock, patch + +from mempalace.room_detector_local import ( + FOLDER_ROOM_MAP, + detect_rooms_from_files, + detect_rooms_from_folders, + detect_rooms_local, + get_user_approval, + print_proposed_structure, + save_config, +) + + +# ── FOLDER_ROOM_MAP ──────────────────────────────────────────────────── + + +def test_folder_room_map_has_expected_mappings(): + assert FOLDER_ROOM_MAP["frontend"] == "frontend" + assert FOLDER_ROOM_MAP["backend"] == "backend" + assert FOLDER_ROOM_MAP["docs"] == "documentation" + assert FOLDER_ROOM_MAP["tests"] == "testing" + assert FOLDER_ROOM_MAP["config"] == "configuration" + + +def test_folder_room_map_alternative_names(): + assert FOLDER_ROOM_MAP["front-end"] == "frontend" + assert FOLDER_ROOM_MAP["back-end"] == "backend" + assert FOLDER_ROOM_MAP["server"] == "backend" + assert FOLDER_ROOM_MAP["client"] == "frontend" + assert FOLDER_ROOM_MAP["api"] == "backend" + + +# ── detect_rooms_from_folders ─────────────────────────────────────────── + + +def test_detect_rooms_from_folders_standard_layout(tmp_path): + (tmp_path / "frontend").mkdir() + (tmp_path / "backend").mkdir() + (tmp_path / "docs").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert "frontend" in room_names + assert "backend" in room_names + assert "documentation" in room_names + + +def test_detect_rooms_from_folders_always_has_general(tmp_path): + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert "general" in room_names + + +def test_detect_rooms_from_folders_empty_dir(tmp_path): + rooms = detect_rooms_from_folders(str(tmp_path)) + # Should at least have "general" + assert len(rooms) >= 1 + assert any(r["name"] == "general" for r in rooms) + + +def test_detect_rooms_from_folders_skips_git(tmp_path): + (tmp_path / ".git").mkdir() + (tmp_path / "node_modules").mkdir() + (tmp_path / "frontend").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert ".git" not in room_names + assert "node_modules" not in room_names + + +def test_detect_rooms_from_folders_nested_dirs(tmp_path): + src = tmp_path / "src" + src.mkdir() + (src / "components").mkdir() + (src / "routes").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + # Nested dirs should be detected at one level deep + assert "frontend" in room_names or "backend" in room_names + + +def test_detect_rooms_from_folders_room_has_description(tmp_path): + (tmp_path / "docs").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + doc_room = next((r for r in rooms if r["name"] == "documentation"), None) + assert doc_room is not None + assert "description" in doc_room + assert "docs" in doc_room["description"] + + +def test_detect_rooms_from_folders_room_has_keywords(tmp_path): + (tmp_path / "frontend").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + fe_room = next((r for r in rooms if r["name"] == "frontend"), None) + assert fe_room is not None + assert "keywords" in fe_room + assert len(fe_room["keywords"]) > 0 + + +def test_detect_rooms_from_folders_custom_named_dirs(tmp_path): + (tmp_path / "mylib").mkdir() + rooms = detect_rooms_from_folders(str(tmp_path)) + room_names = {r["name"] for r in rooms} + # Custom dir names that don't match FOLDER_ROOM_MAP get added as-is + assert "mylib" in room_names or "general" in room_names + + +# ── detect_rooms_from_files ───────────────────────────────────────────── + + +def test_detect_rooms_from_files_with_matching_filenames(tmp_path): + # Create files whose names contain room keywords + for name in ["test_auth.py", "test_login.py", "test_api.py"]: + (tmp_path / name).write_text("content") + rooms = detect_rooms_from_files(str(tmp_path)) + room_names = {r["name"] for r in rooms} + assert "testing" in room_names or "general" in room_names + + +def test_detect_rooms_from_files_empty_dir(tmp_path): + rooms = detect_rooms_from_files(str(tmp_path)) + assert len(rooms) >= 1 + assert any(r["name"] == "general" for r in rooms) + + +def test_detect_rooms_from_files_caps_at_six(tmp_path): + # Create many files with different keywords to hit the cap + for keyword in ["test", "doc", "api", "config", "frontend", "backend", "design", "meeting"]: + for i in range(3): + (tmp_path / f"{keyword}_file_{i}.txt").write_text("content") + rooms = detect_rooms_from_files(str(tmp_path)) + assert len(rooms) <= 6 + + +# ── save_config ───────────────────────────────────────────────────────── + + +def test_save_config_creates_yaml(tmp_path): + rooms = [ + {"name": "frontend", "description": "UI files", "keywords": ["frontend"]}, + {"name": "backend", "description": "Server files", "keywords": ["backend"]}, + ] + save_config(str(tmp_path), "myproject", rooms) + config_file = tmp_path / "mempalace.yaml" + assert config_file.exists() + content = config_file.read_text() + assert "myproject" in content + assert "frontend" in content + assert "backend" in content + + +def test_save_config_valid_yaml(tmp_path): + import yaml + + rooms = [{"name": "general", "description": "All files", "keywords": []}] + save_config(str(tmp_path), "test_proj", rooms) + config_file = tmp_path / "mempalace.yaml" + data = yaml.safe_load(config_file.read_text()) + assert data["wing"] == "test_proj" + assert len(data["rooms"]) == 1 + assert data["rooms"][0]["name"] == "general" + + +# ── print_proposed_structure ────────────────────────────────────────── + + +def test_print_proposed_structure(capsys): + rooms = [ + {"name": "frontend", "description": "UI files"}, + {"name": "general", "description": "Everything else"}, + ] + print_proposed_structure("myapp", rooms, 42, "folder structure") + out = capsys.readouterr().out + assert "myapp" in out + assert "frontend" in out + assert "42 files" in out + assert "folder structure" in out + + +# ── get_user_approval ───────────────────────────────────────────────── + + +def test_get_user_approval_accept_all(): + rooms = [{"name": "frontend", "description": "UI"}] + with patch("builtins.input", return_value=""): + result = get_user_approval(rooms) + assert result == rooms + + +def test_get_user_approval_edit_remove(): + rooms = [ + {"name": "frontend", "description": "UI"}, + {"name": "backend", "description": "Server"}, + ] + with patch("builtins.input", side_effect=["edit", "1", "n"]): + result = get_user_approval(rooms) + # Room 1 (frontend) removed + assert len(result) == 1 + assert result[0]["name"] == "backend" + + +def test_get_user_approval_add_room(): + rooms = [{"name": "general", "description": "All files"}] + with patch( + "builtins.input", + side_effect=[ + "add", + "custom_room", + "My custom room", + "", + ], + ): + result = get_user_approval(rooms) + names = [r["name"] for r in result] + assert "custom_room" in names + + +# ── detect_rooms_local ──────────────────────────────────────────────── + + +def test_detect_rooms_local_yes_mode(tmp_path): + (tmp_path / "docs").mkdir() + (tmp_path / "docs" / "readme.md").write_text("hello") + mock_miner = MagicMock() + mock_miner.scan_project.return_value = ["file1.py"] + with patch.dict("sys.modules", {"mempalace.miner": mock_miner}): + detect_rooms_local(str(tmp_path), yes=True) + assert (tmp_path / "mempalace.yaml").exists() + + +def test_detect_rooms_local_fallback_to_files(tmp_path): + """When folder detection gives only 'general', falls back to file patterns.""" + for i in range(3): + (tmp_path / f"test_file_{i}.py").write_text("content") + mock_miner = MagicMock() + mock_miner.scan_project.return_value = ["f1", "f2"] + with patch.dict("sys.modules", {"mempalace.miner": mock_miner}): + detect_rooms_local(str(tmp_path), yes=True) + assert (tmp_path / "mempalace.yaml").exists() + + +def test_detect_rooms_local_missing_dir(): + """Non-existent directory causes sys.exit.""" + import pytest + + with pytest.raises(SystemExit): + detect_rooms_local("/nonexistent/path/that/does/not/exist", yes=True) + + +def test_detect_rooms_local_interactive(tmp_path): + (tmp_path / "src").mkdir() + (tmp_path / "src" / "main.py").write_text("code") + mock_miner = MagicMock() + mock_miner.scan_project.return_value = ["f1"] + with ( + patch.dict("sys.modules", {"mempalace.miner": mock_miner}), + patch( + "mempalace.room_detector_local.get_user_approval", + return_value=[{"name": "general", "description": "All files", "keywords": []}], + ), + ): + detect_rooms_local(str(tmp_path), yes=False) + assert (tmp_path / "mempalace.yaml").exists() diff --git a/tests/test_searcher.py b/tests/test_searcher.py index 1c2687d..94f22b4 100644 --- a/tests/test_searcher.py +++ b/tests/test_searcher.py @@ -1,10 +1,18 @@ """ -test_searcher.py — Tests for the programmatic search_memories API. +test_searcher.py -- Tests for both search() (CLI) and search_memories() (API). -Tests the library-facing search interface (not the CLI print variant). +Uses the real ChromaDB fixtures from conftest.py for integration tests, +plus mock-based tests for error paths. """ -from mempalace.searcher import search_memories +from unittest.mock import MagicMock, patch + +import pytest + +from mempalace.searcher import SearchError, search, search_memories + + +# ── search_memories (API) ────────────────────────────────────────────── class TestSearchMemories: @@ -43,3 +51,75 @@ class TestSearchMemories: assert "source_file" in hit assert "similarity" in hit assert isinstance(hit["similarity"], float) + + def test_search_memories_query_error(self): + """search_memories returns error dict when query raises.""" + mock_col = MagicMock() + mock_col.query.side_effect = RuntimeError("query failed") + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client): + result = search_memories("test", "/fake/path") + assert "error" in result + assert "query failed" in result["error"] + + def test_search_memories_filters_in_result(self, palace_path, seeded_collection): + result = search_memories("test", palace_path, wing="project", room="backend") + assert result["filters"]["wing"] == "project" + assert result["filters"]["room"] == "backend" + + +# ── search() (CLI print function) ───────────────────────────────────── + + +class TestSearchCLI: + def test_search_prints_results(self, palace_path, seeded_collection, capsys): + search("JWT authentication", palace_path) + captured = capsys.readouterr() + assert "JWT" in captured.out or "authentication" in captured.out + + def test_search_with_wing_filter(self, palace_path, seeded_collection, capsys): + search("planning", palace_path, wing="notes") + captured = capsys.readouterr() + assert "Results for" in captured.out + + def test_search_with_room_filter(self, palace_path, seeded_collection, capsys): + search("database", palace_path, room="backend") + captured = capsys.readouterr() + assert "Room:" in captured.out + + def test_search_with_wing_and_room(self, palace_path, seeded_collection, capsys): + search("code", palace_path, wing="project", room="frontend") + captured = capsys.readouterr() + assert "Wing:" in captured.out + assert "Room:" in captured.out + + def test_search_no_palace_raises(self, tmp_path): + with pytest.raises(SearchError, match="No palace found"): + search("anything", str(tmp_path / "missing")) + + def test_search_no_results(self, palace_path, collection, capsys): + """Empty collection returns no results message.""" + # collection is empty (no seeded data) + result = search("xyzzy_nonexistent_query", palace_path, n_results=1) + captured = capsys.readouterr() + # Either prints "No results" or returns None + assert result is None or "No results" in captured.out + + def test_search_query_error_raises(self): + """search raises SearchError when query fails.""" + mock_col = MagicMock() + mock_col.query.side_effect = RuntimeError("boom") + mock_client = MagicMock() + mock_client.get_collection.return_value = mock_col + + with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client): + with pytest.raises(SearchError, match="Search error"): + search("test", "/fake/path") + + def test_search_n_results(self, palace_path, seeded_collection, capsys): + search("code", palace_path, n_results=1) + captured = capsys.readouterr() + # Should have output with at least one result block + assert "[1]" in captured.out diff --git a/tests/test_spellcheck.py b/tests/test_spellcheck.py new file mode 100644 index 0000000..f2c7484 --- /dev/null +++ b/tests/test_spellcheck.py @@ -0,0 +1,160 @@ +"""Tests for mempalace.spellcheck — spell-correction utilities.""" + +from unittest.mock import patch + +from mempalace.spellcheck import ( + _edit_distance, + _get_system_words, + _should_skip, + spellcheck_transcript, + spellcheck_transcript_line, + spellcheck_user_text, +) + + +# --- _should_skip --- + + +class TestShouldSkip: + """Token-level skip logic.""" + + def test_short_tokens_skipped(self): + assert _should_skip("hi", set()) is True + assert _should_skip("ok", set()) is True + assert _should_skip("I", set()) is True + + def test_digits_skipped(self): + assert _should_skip("3am", set()) is True + assert _should_skip("top10", set()) is True + assert _should_skip("bge-large-v1.5", set()) is True + + def test_camelcase_skipped(self): + assert _should_skip("ChromaDB", set()) is True + assert _should_skip("MemPalace", set()) is True + + def test_allcaps_skipped(self): + assert _should_skip("NDCG", set()) is True + assert _should_skip("MAX_RESULTS", set()) is True + + def test_technical_skipped(self): + assert _should_skip("bge-large", set()) is True + assert _should_skip("train_test", set()) is True + + def test_url_skipped(self): + assert _should_skip("https://example.com", set()) is True + assert _should_skip("www.google.com", set()) is True + + def test_code_or_emoji_skipped(self): + assert _should_skip("`code`", set()) is True + assert _should_skip("**bold**", set()) is True + + def test_known_name_skipped(self): + assert _should_skip("mempalace", {"mempalace"}) is True + + def test_normal_word_not_skipped(self): + assert _should_skip("hello", set()) is False + assert _should_skip("question", set()) is False + + +# --- _edit_distance --- + + +class TestEditDistance: + def test_identical(self): + assert _edit_distance("hello", "hello") == 0 + + def test_empty_strings(self): + assert _edit_distance("", "abc") == 3 + assert _edit_distance("abc", "") == 3 + assert _edit_distance("", "") == 0 + + def test_single_edit(self): + assert _edit_distance("cat", "bat") == 1 # substitution + assert _edit_distance("cat", "cats") == 1 # insertion + assert _edit_distance("cats", "cat") == 1 # deletion + + def test_known_distance(self): + assert _edit_distance("kitten", "sitting") == 3 + + +# --- _get_system_words --- + + +def test_get_system_words_returns_set(): + result = _get_system_words() + assert isinstance(result, set) + + +# --- spellcheck_user_text --- + + +def test_spellcheck_user_text_passthrough_no_autocorrect(): + """When autocorrect is not installed, text passes through unchanged.""" + with patch("mempalace.spellcheck._get_speller", return_value=None): + text = "somee misspeledd textt" + assert spellcheck_user_text(text) == text + + +def test_spellcheck_user_text_with_speller(): + """When a speller is available, it corrects words.""" + + def fake_speller(word): + corrections = {"knoe": "know", "befor": "before"} + return corrections.get(word, word) + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value=set()): + with patch("mempalace.spellcheck._load_known_names", return_value=set()): + result = spellcheck_user_text("knoe the question befor") + assert "know" in result + assert "before" in result + + +def test_spellcheck_preserves_technical_terms(): + """Technical terms should never be touched even with a speller.""" + + def fake_speller(word): + return "WRONG" + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value=set()): + result = spellcheck_user_text("ChromaDB bge-large", known_names=set()) + assert "ChromaDB" in result + assert "bge-large" in result + assert "WRONG" not in result + + +# --- spellcheck_transcript_line --- + + +def test_transcript_line_user_turn(): + """Lines starting with '>' should be processed.""" + with patch("mempalace.spellcheck.spellcheck_user_text", return_value="corrected"): + result = spellcheck_transcript_line("> hello world") + assert "corrected" in result + + +def test_transcript_line_assistant_turn(): + """Lines not starting with '>' should pass through unchanged.""" + line = "This is an assistant response" + assert spellcheck_transcript_line(line) == line + + +def test_transcript_line_empty_user_turn(): + """A '> ' line with no message content should pass through.""" + line = "> " + assert spellcheck_transcript_line(line) == line + + +# --- spellcheck_transcript --- + + +def test_spellcheck_transcript_processes_content(): + """Full transcript: only '>' lines are touched.""" + content = "Assistant line\n> user line\nAnother assistant line" + with patch("mempalace.spellcheck.spellcheck_user_text", return_value="fixed"): + result = spellcheck_transcript(content) + lines = result.split("\n") + assert lines[0] == "Assistant line" + assert "fixed" in lines[1] + assert lines[2] == "Another assistant line" diff --git a/tests/test_spellcheck_extra.py b/tests/test_spellcheck_extra.py new file mode 100644 index 0000000..567cb01 --- /dev/null +++ b/tests/test_spellcheck_extra.py @@ -0,0 +1,72 @@ +"""Extra spellcheck tests covering _load_known_names and speller edge cases.""" + +from unittest.mock import patch, MagicMock + +from mempalace.spellcheck import ( + _load_known_names, + spellcheck_user_text, +) + + +class TestLoadKnownNames: + def test_returns_names_from_registry(self): + mock_reg = MagicMock() + mock_reg._data = { + "entities": { + "e1": {"canonical": "Alice", "aliases": ["ali"]}, + "e2": {"canonical": "Bob", "aliases": []}, + } + } + with patch("mempalace.entity_registry.EntityRegistry") as MockER: + MockER.load.return_value = mock_reg + names = _load_known_names() + assert "alice" in names + assert "ali" in names + assert "bob" in names + + def test_returns_empty_on_exception(self): + with patch( + "mempalace.entity_registry.EntityRegistry.load", + side_effect=Exception("no registry"), + ): + names = _load_known_names() + assert names == set() + + +class TestSpellerEdgeCases: + def test_capitalized_word_skipped(self): + """Capitalized words (likely proper nouns) are not corrected.""" + + def fake_speller(word): + return "WRONG" + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value=set()): + with patch("mempalace.spellcheck._load_known_names", return_value=set()): + result = spellcheck_user_text("Alice went home") + assert "Alice" in result + assert "WRONG" not in result + + def test_system_word_not_corrected(self): + """Words in system dict should not be corrected.""" + + def fake_speller(word): + return "WRONG" + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value={"coherently"}): + with patch("mempalace.spellcheck._load_known_names", return_value=set()): + result = spellcheck_user_text("coherently") + assert "coherently" in result + + def test_high_edit_distance_rejected(self): + """Corrections with too many edits are rejected.""" + + def fake_speller(word): + return "completely_different_word" + + with patch("mempalace.spellcheck._get_speller", return_value=fake_speller): + with patch("mempalace.spellcheck._get_system_words", return_value=set()): + with patch("mempalace.spellcheck._load_known_names", return_value=set()): + result = spellcheck_user_text("hello") + assert "hello" in result diff --git a/tests/test_split_mega_files.py b/tests/test_split_mega_files.py index 70c7f84..c1db02b 100644 --- a/tests/test_split_mega_files.py +++ b/tests/test_split_mega_files.py @@ -3,6 +3,9 @@ import json from mempalace import split_mega_files as smf +# ── Config loading ───────────────────────────────────────────────────── + + def test_load_known_people_falls_back_when_config_missing(monkeypatch, tmp_path): monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", tmp_path / "missing.json") smf._KNOWN_NAMES_CACHE = None @@ -46,3 +49,244 @@ def test_extract_people_detects_names_from_content(monkeypatch): monkeypatch.setattr(smf, "KNOWN_PEOPLE", ["Alice", "Ben"]) people = smf.extract_people(["> Alice reviewed the change with Ben\n"]) assert people == ["Alice", "Ben"] + + +# ── Config: force_reload and invalid JSON ────────────────────────────── + + +def test_load_known_names_force_reload(monkeypatch, tmp_path): + config_path = tmp_path / "known_names.json" + config_path.write_text(json.dumps(["Alice"])) + monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path) + smf._KNOWN_NAMES_CACHE = None + + smf._load_known_names_config() + assert smf._KNOWN_NAMES_CACHE == ["Alice"] + + config_path.write_text(json.dumps(["Bob"])) + smf._load_known_names_config(force_reload=True) + assert smf._KNOWN_NAMES_CACHE == ["Bob"] + + +def test_load_known_names_invalid_json(monkeypatch, tmp_path): + config_path = tmp_path / "known_names.json" + config_path.write_text("not json {{{") + monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path) + smf._KNOWN_NAMES_CACHE = None + + result = smf._load_known_names_config() + assert result is None + + +def test_load_known_names_caching(monkeypatch, tmp_path): + config_path = tmp_path / "known_names.json" + config_path.write_text(json.dumps(["Alice"])) + monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path) + smf._KNOWN_NAMES_CACHE = None + + smf._load_known_names_config() + # Second call returns cached value without re-reading + config_path.write_text(json.dumps(["Changed"])) + result = smf._load_known_names_config() + assert result == ["Alice"] + + +# ── is_true_session_start ────────────────────────────────────────────── + + +def test_is_true_session_start_yes(): + lines = ["Claude Code v1.0", "Some content", "More content", "", "", ""] + assert smf.is_true_session_start(lines, 0) is True + + +def test_is_true_session_start_no_ctrl_e(): + lines = [ + "Claude Code v1.0", + "Ctrl+E to show 5 previous messages", + "", + "", + "", + "", + ] + assert smf.is_true_session_start(lines, 0) is False + + +def test_is_true_session_start_no_previous_messages(): + lines = [ + "Claude Code v1.0", + "Some text", + "previous messages here", + "", + "", + "", + ] + assert smf.is_true_session_start(lines, 0) is False + + +# ── find_session_boundaries ──────────────────────────────────────────── + + +def test_find_session_boundaries_two_sessions(): + lines = [ + "Claude Code v1.0", + "content 1", + "", + "", + "", + "", + "", + "Claude Code v1.0", + "content 2", + "", + "", + "", + "", + "", + ] + boundaries = smf.find_session_boundaries(lines) + assert boundaries == [0, 7] + + +def test_find_session_boundaries_none(): + lines = ["Just some text", "No sessions here"] + assert smf.find_session_boundaries(lines) == [] + + +def test_find_session_boundaries_context_restore_skipped(): + lines = [ + "Claude Code v1.0", + "content", + "", + "", + "", + "", + "", + "Claude Code v1.0", + "Ctrl+E to show 5 previous messages", + "", + "", + "", + "", + ] + boundaries = smf.find_session_boundaries(lines) + assert len(boundaries) == 1 + + +# ── extract_timestamp ────────────────────────────────────────────────── + + +def test_extract_timestamp_found(): + lines = ["⏺ 2:30 PM Wednesday, March 25, 2026"] + human, iso = smf.extract_timestamp(lines) + assert human == "2026-03-25_230PM" + assert iso == "2026-03-25" + + +def test_extract_timestamp_not_found(): + lines = ["No timestamp here"] + human, iso = smf.extract_timestamp(lines) + assert human is None + assert iso is None + + +def test_extract_timestamp_only_checks_first_50(): + lines = ["filler\n"] * 51 + ["⏺ 1:00 AM Monday, January 01, 2026"] + human, iso = smf.extract_timestamp(lines) + assert human is None + + +# ── extract_subject ──────────────────────────────────────────────────── + + +def test_extract_subject_found(): + lines = ["> How do we handle authentication?"] + subject = smf.extract_subject(lines) + assert "authentication" in subject.lower() + + +def test_extract_subject_skips_commands(): + lines = ["> cd /some/dir", "> git status", "> What is the plan?"] + subject = smf.extract_subject(lines) + assert "plan" in subject.lower() + + +def test_extract_subject_fallback(): + lines = ["No prompts at all", "Just text"] + subject = smf.extract_subject(lines) + assert subject == "session" + + +def test_extract_subject_short_prompt_skipped(): + lines = ["> ok", "> yes", "> What about the deployment strategy?"] + subject = smf.extract_subject(lines) + assert "deployment" in subject.lower() + + +def test_extract_subject_truncated(): + lines = ["> " + "a" * 100] + subject = smf.extract_subject(lines) + assert len(subject) <= 60 + + +# ── split_file ───────────────────────────────────────────────────────── + + +def _make_mega_file(tmp_path, n_sessions=3, lines_per_session=15): + """Create a mega-file with N sessions.""" + content = "" + for i in range(n_sessions): + content += f"Claude Code v1.{i}\n" + content += f"> What about topic {i} and how it works?\n" + for j in range(lines_per_session - 2): + content += f"Line {j} of session {i}\n" + path = tmp_path / "mega.txt" + path.write_text(content) + return path + + +def test_split_file_creates_output(tmp_path): + mega = _make_mega_file(tmp_path) + out_dir = tmp_path / "output" + out_dir.mkdir() + written = smf.split_file(str(mega), str(out_dir)) + assert len(written) >= 2 + for p in written: + assert p.exists() + + +def test_split_file_dry_run(tmp_path): + mega = _make_mega_file(tmp_path) + out_dir = tmp_path / "output" + out_dir.mkdir() + written = smf.split_file(str(mega), str(out_dir), dry_run=True) + assert len(written) >= 2 + for p in written: + assert not p.exists() + + +def test_split_file_not_mega(tmp_path): + """File with fewer than 2 sessions is not split.""" + path = tmp_path / "single.txt" + path.write_text("Claude Code v1.0\nJust one session\n" + "line\n" * 20) + written = smf.split_file(str(path), str(tmp_path)) + assert written == [] + + +def test_split_file_output_dir_none(tmp_path): + """When output_dir is None, writes to same dir as source.""" + mega = _make_mega_file(tmp_path) + written = smf.split_file(str(mega), None) + assert len(written) >= 2 + for p in written: + assert str(p.parent) == str(tmp_path) + + +def test_split_file_tiny_fragments_skipped(tmp_path): + """Tiny chunks (< 10 lines) are skipped.""" + content = "Claude Code v1.0\nline\n" * 2 + "Claude Code v1.0\n" + "line\n" * 20 + path = tmp_path / "tiny.txt" + path.write_text(content) + written = smf.split_file(str(path), str(tmp_path)) + # The first chunk is very small, should be skipped + for p in written: + assert p.stat().st_size > 0