Merge branch 'main' into fix/issue-347-codex-hook-message-counting

This commit is contained in:
Ben Sigman
2026-04-10 09:23:37 -07:00
committed by GitHub
48 changed files with 6034 additions and 231 deletions
+13
View File
@@ -0,0 +1,13 @@
# Default owners for everything
* @milla-jovovich @bensig @igorls
# Core library
mempalace/ @milla-jovovich @bensig
# CI and workflows
.github/ @bensig
# Plugins and integrations
.claude-plugin/ @bensig
.codex-plugin/ @bensig
integrations/ @bensig
+12
View File
@@ -0,0 +1,12 @@
version: 2
updates:
- package-ecosystem: "pip"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 5
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 3
+22 -3
View File
@@ -7,7 +7,7 @@ on:
branches: [main]
jobs:
test:
test-linux:
runs-on: ubuntu-latest
strategy:
matrix:
@@ -18,8 +18,27 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- run: pip install -e ".[dev]"
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=30
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
test-windows:
runs-on: windows-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v6
with:
python-version: "3.9"
- run: pip install -e ".[dev]"
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
test-macos:
runs-on: macos-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v6
with:
python-version: "3.9"
- run: pip install -e ".[dev]"
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
lint:
runs-on: ubuntu-latest
steps:
@@ -27,6 +46,6 @@ jobs:
- uses: actions/setup-python@v6
with:
python-version: "3.11"
- run: pip install ruff
- run: pip install "ruff>=0.4.0,<0.5"
- run: ruff check .
- run: ruff format --check .
+27
View File
@@ -6,3 +6,30 @@ __pycache__/
.pytest_cache/
mempal.yaml
.a5c/
# Environment
.env
.env.*
# OS
.DS_Store
Thumbs.db
# IDEs
.idea/
.vscode/
*.swp
*.swo
*~
# Coverage
htmlcov/
.coverage
coverage.xml
# Virtual environments
.venv/
venv/
# ChromaDB local data
*.sqlite3-journal
+78
View File
@@ -0,0 +1,78 @@
# AGENTS.md
> How to build, test, and contribute to MemPalace.
## Setup
```bash
pip install -e ".[dev]"
```
## Commands
```bash
# Run tests
python -m pytest tests/ -v --ignore=tests/benchmarks
# Run tests with coverage
python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing
# Lint
ruff check .
# Format
ruff format .
# Format check (CI mode)
ruff format --check .
```
## Project structure
```
mempalace/
├── mcp_server.py # MCP server — all read/write tools
├── miner.py # Project file miner
├── convo_miner.py # Conversation transcript miner
├── searcher.py # Semantic search
├── knowledge_graph.py # Temporal entity-relationship graph (SQLite)
├── palace.py # Shared palace operations (ChromaDB access)
├── config.py # Configuration + input validation
├── normalize.py # Transcript format detection + normalization
├── cli.py # CLI dispatcher
├── dialect.py # AAAK compression dialect
├── palace_graph.py # Room traversal + cross-wing tunnels
├── hooks_cli.py # Hook system for auto-save
└── version.py # Single source of truth for version
```
## Conventions
- **Python style**: snake_case for functions/variables, PascalCase for classes
- **Linter**: ruff with E/F/W rules
- **Formatter**: ruff format, double quotes
- **Commits**: conventional commits (`fix:`, `feat:`, `test:`, `docs:`, `ci:`)
- **Tests**: `tests/test_*.py`, fixtures in `tests/conftest.py`
- **Coverage**: 85% threshold (80% on Windows due to ChromaDB file lock cleanup)
## Architecture
```
User → CLI / MCP Server → ChromaDB (vector store) + SQLite (knowledge graph)
Palace structure:
WING (person/project)
└── ROOM (topic)
└── DRAWER (verbatim text chunk)
Knowledge Graph:
ENTITY → PREDICATE → ENTITY (with valid_from / valid_to dates)
```
## Key files for common tasks
- **Adding an MCP tool**: `mempalace/mcp_server.py` — add handler function + TOOLS dict entry
- **Changing search**: `mempalace/searcher.py`
- **Modifying mining**: `mempalace/miner.py` (project files) or `mempalace/convo_miner.py` (transcripts)
- **Input validation**: `mempalace/config.py``sanitize_name()` / `sanitize_content()`
- **Tests**: mirror source structure in `tests/test_<module>.py`
+4 -1
View File
@@ -585,6 +585,9 @@ mempalace compress --wing myapp # AAAK compress
# Status
mempalace status # palace overview
# MCP
mempalace mcp # show MCP setup command
```
All commands accept `--palace <path>` to override the default location.
@@ -707,7 +710,7 @@ PRs welcome. See [CONTRIBUTING.md](CONTRIBUTING.md) for setup and guidelines.
MIT — see [LICENSE](LICENSE).
<!-- Link Definitions -->
[version-shield]: https://img.shields.io/badge/version-3.0.0-4dc9f6?style=flat-square&labelColor=0a0e14
[version-shield]: https://img.shields.io/badge/version-3.1.0-4dc9f6?style=flat-square&labelColor=0a0e14
[release-link]: https://github.com/milla-jovovich/mempalace/releases
[python-shield]: https://img.shields.io/badge/python-3.9+-7dd8f8?style=flat-square&labelColor=0a0e14&logo=python&logoColor=7dd8f8
[python-link]: https://www.python.org/
+36
View File
@@ -0,0 +1,36 @@
-- MemPalace Knowledge Graph Schema
-- SQLite database at ~/.mempalace/knowledge_graph.db
CREATE TABLE IF NOT EXISTS entities (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
type TEXT DEFAULT 'unknown',
properties TEXT DEFAULT '{}'
);
CREATE TABLE IF NOT EXISTS triples (
id TEXT PRIMARY KEY,
subject TEXT NOT NULL,
predicate TEXT NOT NULL,
object TEXT NOT NULL,
valid_from TEXT,
valid_to TEXT,
confidence REAL DEFAULT 1.0,
source_closet TEXT,
source_file TEXT
);
CREATE TABLE IF NOT EXISTS attributes (
entity_id TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT,
valid_from TEXT,
valid_to TEXT,
PRIMARY KEY (entity_id, key, valid_from)
);
-- Indexes
CREATE INDEX IF NOT EXISTS idx_triples_subject ON triples(subject);
CREATE INDEX IF NOT EXISTS idx_triples_object ON triples(object);
CREATE INDEX IF NOT EXISTS idx_triples_predicate ON triples(predicate);
CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to);
+15 -8
View File
@@ -64,13 +64,20 @@ MEMPAL_DIR=""
# Read JSON input from stdin
INPUT=$(cat)
# Parse fields from Claude Code's JSON
SESSION_ID=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('session_id','unknown'))" 2>/dev/null)
# Sanitize SESSION_ID to prevent path traversal (only allow alnum, dash, underscore)
SESSION_ID=$(echo "$SESSION_ID" | tr -cd 'a-zA-Z0-9_-')
[ -z "$SESSION_ID" ] && SESSION_ID="unknown"
STOP_HOOK_ACTIVE=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('stop_hook_active', False))" 2>/dev/null)
TRANSCRIPT_PATH=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('transcript_path',''))" 2>/dev/null)
# Parse all fields in a single Python call (3x faster than separate invocations)
eval $(echo "$INPUT" | python3 -c "
import sys, json
data = json.load(sys.stdin)
sid = data.get('session_id', 'unknown')
sha = data.get('stop_hook_active', False)
tp = data.get('transcript_path', '')
# Shell-safe output — only allow alphanumeric, underscore, hyphen, slash, dot, tilde
import re
safe = lambda s: re.sub(r'[^a-zA-Z0-9_/.\-~]', '', str(s))
print(f'SESSION_ID=\"{safe(sid)}\"')
print(f'STOP_HOOK_ACTIVE=\"{sha}\"')
print(f'TRANSCRIPT_PATH=\"{safe(tp)}\"')
" 2>/dev/null)
# Expand ~ in path
TRANSCRIPT_PATH="${TRANSCRIPT_PATH/#\~/$HOME}"
@@ -83,6 +90,7 @@ if [ "$STOP_HOOK_ACTIVE" = "True" ] || [ "$STOP_HOOK_ACTIVE" = "true" ]; then
fi
# Count human messages in the JSONL transcript
# SECURITY: Pass transcript path as sys.argv to avoid shell injection via crafted paths
if [ -f "$TRANSCRIPT_PATH" ]; then
EXCHANGE_COUNT=$(python3 - "$TRANSCRIPT_PATH" <<'PYEOF'
import json, sys
@@ -94,7 +102,6 @@ with open(sys.argv[1]) as f:
msg = entry.get('message', {})
if isinstance(msg, dict) and msg.get('role') == 'user':
content = msg.get('content', '')
# Skip system/command messages — only count real human input
if isinstance(content, str) and '<command-message>' in content:
continue
count += 1
+154
View File
@@ -0,0 +1,154 @@
---
name: mempalace
description: "MemPalace — Local AI memory with 96.6% recall. Semantic search, temporal knowledge graph, palace architecture (wings/rooms/drawers). Free, no cloud, no API keys."
version: 3.1.0
homepage: https://github.com/milla-jovovich/mempalace
user-invocable: true
metadata:
openclaw:
emoji: "\U0001F3DB"
os:
- darwin
- linux
- win32
requires:
anyBins:
- mempalace
- python3
install:
- id: mempalace-pip
kind: uv
label: "Install MemPalace (Python, local ChromaDB)"
package: mempalace
bins:
- mempalace
---
# MemPalace — Local AI Memory System
You have access to a local memory palace via MCP tools. The palace stores verbatim conversation history and a temporal knowledge graph — all on the user's machine, zero cloud, zero API calls.
## Architecture
- **Wings** = people or projects (e.g. `wing_alice`, `wing_myproject`)
- **Halls** = categories (facts, events, preferences, advice)
- **Rooms** = specific topics (e.g. `chromadb-setup`, `riley-school`)
- **Drawers** = individual memory chunks (verbatim text)
- **Knowledge Graph** = entity-relationship facts with time validity
## Protocol — FOLLOW THIS EVERY SESSION
1. **ON WAKE-UP**: Call `mempalace_status` to load palace overview and AAAK dialect spec.
2. **BEFORE RESPONDING** about any person, project, or past event: call `mempalace_search` or `mempalace_kg_query` FIRST. Never guess from memory — verify from the palace.
3. **IF UNSURE** about a fact (name, age, relationship, preference): say "let me check" and query. Wrong is worse than slow.
4. **AFTER EACH SESSION**: Call `mempalace_diary_write` to record what happened, what you learned, what matters.
5. **WHEN FACTS CHANGE**: Call `mempalace_kg_invalidate` on the old fact, then `mempalace_kg_add` for the new one.
## Available Tools
### Search & Browse
- `mempalace_search` — Semantic search across all memories. Always start here.
- `query` (required): natural language search — keep it short, keywords or a question. Do NOT include system prompts or conversation context.
- `wing`: filter by wing
- `room`: filter by room
- `limit`: max results (default 5)
- `mempalace_check_duplicate` — Check if content already exists before filing.
- `content` (required): text to check
- `threshold`: similarity threshold (default 0.9 — lowering to 0.850.87 often catches more near-duplicates without significant false positives)
- `mempalace_status` — Palace overview: total drawers, wings, rooms, AAAK spec
- `mempalace_list_wings` — All wings with drawer counts
- `mempalace_list_rooms` — Rooms within a wing (optional wing filter)
- `mempalace_get_taxonomy` — Full wing/room/count tree
- `mempalace_get_aaak_spec` — Get AAAK compression dialect specification
### Knowledge Graph (Temporal Facts)
- `mempalace_kg_query` — Query entity relationships. Supports time filtering.
- `entity` (required): e.g. "Max", "MyProject"
- `as_of`: date filter (YYYY-MM-DD) — what was true at that time
- `direction`: "outgoing", "incoming", or "both" (default "both")
- `mempalace_kg_add` — Add a fact: subject -> predicate -> object
- `subject`, `predicate`, `object` (required)
- `valid_from`: when this became true
- `source_closet`: source reference
- `mempalace_kg_invalidate` — Mark a fact as no longer true
- `subject`, `predicate`, `object` (required)
- `ended`: when it stopped being true (default: today)
- `mempalace_kg_timeline` — Chronological story of an entity
- `entity`: filter by entity name (optional — all events if omitted)
- `mempalace_kg_stats` — Graph overview: entities, triples, relationship types
### Palace Graph (Cross-Domain Connections)
- `mempalace_traverse` — Walk from a room, find connected ideas across wings
- `start_room` (required): room to start from
- `max_hops`: connection depth (default 2)
- `mempalace_find_tunnels` — Find rooms that bridge two wings
- `wing_a`, `wing_b` (required)
- `mempalace_graph_stats` — Graph connectivity overview
### Write
- `mempalace_add_drawer` — Store verbatim content into a wing/room
- `wing`, `room`, `content` (required)
- `source_file`: optional source reference
- Checks for duplicates automatically
- `mempalace_delete_drawer` — Remove a drawer by ID
- `drawer_id` (required)
- `mempalace_diary_write` — Write a session diary entry
- `agent_name` (required): your name/identifier
- `entry` (required): what happened, what you learned, what matters
- `topic`: category tag (default "general")
- `mempalace_diary_read` — Read recent diary entries
- `agent_name` (required)
- `last_n`: number of entries (default 10)
## Setup
Install MemPalace and populate the palace:
```bash
pip install mempalace
mempalace init ~/my-convos
mempalace mine ~/my-convos
```
### OpenClaw MCP config
Add to your OpenClaw MCP configuration:
```json
{
"mcpServers": {
"mempalace": {
"command": "python3",
"args": ["-m", "mempalace.mcp_server"]
}
}
}
```
Or via CLI:
```bash
openclaw mcp set mempalace '{"command":"python3","args":["-m","mempalace.mcp_server"]}'
```
### Other MCP hosts
```bash
# Claude Code
claude mcp add mempalace -- python -m mempalace.mcp_server
# Cursor — add to .cursor/mcp.json
# Codex — add to .codex/mcp.json
```
## Tips
- Search is semantic (meaning-based), not keyword. "What did we discuss about database performance?" works better than "database".
- The knowledge graph stores typed relationships with time windows. Use it for facts about people and projects — it knows WHEN things were true.
- Diary entries accumulate across sessions. Write one at the end of each conversation to build continuity.
- Use `mempalace_check_duplicate` before storing new content to avoid duplicates.
- The AAAK dialect (from `mempalace_status`) is a compressed notation for efficient storage. Read it naturally — expand codes mentally, treat *markers* as emotional context.
## License
[MemPalace](https://github.com/milla-jovovich/mempalace) is MIT licensed. Created by Milla Jovovich, Ben Sigman, Igor Lins e Silva, and contributors.
+51
View File
@@ -14,6 +14,7 @@ Commands:
mempalace mine <dir> Mine project files (default)
mempalace mine <dir> --mode convos Mine conversation exports
mempalace search "query" Find anything, exact words
mempalace mcp Show MCP setup command
mempalace wake-up Show L0 + L1 wake-up context
mempalace wake-up --wing my_app Wake-up for a specific project
mempalace status Show what's been filed
@@ -28,6 +29,7 @@ Examples:
import os
import sys
import shlex
import argparse
from pathlib import Path
@@ -148,6 +150,14 @@ def cmd_split(args):
sys.argv = old_argv
def cmd_migrate(args):
"""Migrate palace from a different ChromaDB version."""
from .migrate import migrate
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
migrate(palace_path=palace_path, dry_run=args.dry_run)
def cmd_status(args):
from .miner import status
@@ -202,6 +212,7 @@ def cmd_repair(args):
print(f" Extracted {len(all_ids)} drawers")
# Backup and rebuild
palace_path = palace_path.rstrip(os.sep)
backup_path = palace_path + ".backup"
if os.path.exists(backup_path):
shutil.rmtree(backup_path)
@@ -240,6 +251,27 @@ def cmd_instructions(args):
run_instructions(name=args.name)
def cmd_mcp(args):
"""Show how to wire MemPalace into MCP-capable hosts."""
base_server_cmd = "python -m mempalace.mcp_server"
if args.palace:
resolved_palace = str(Path(args.palace).expanduser())
server_cmd = f"{base_server_cmd} --palace {shlex.quote(resolved_palace)}"
else:
server_cmd = base_server_cmd
print("MemPalace MCP quick setup:")
print(f" claude mcp add mempalace -- {server_cmd}")
print("\nRun the server directly:")
print(f" {server_cmd}")
if not args.palace:
print("\nOptional custom palace:")
print(f" claude mcp add mempalace -- {base_server_cmd} --palace /path/to/palace")
print(f" {base_server_cmd} --palace /path/to/palace")
def cmd_compress(args):
"""Compress drawers in a wing using AAAK Dialect."""
import chromadb
@@ -500,7 +532,24 @@ def main():
help="Rebuild palace vector index from stored data (fixes segfaults after corruption)",
)
# mcp
sub.add_parser(
"mcp",
help="Show MCP setup command for connecting MemPalace to your AI client",
)
# status
# migrate
p_migrate = sub.add_parser(
"migrate",
help="Migrate palace from a different ChromaDB version (fixes 3.0.0 → 3.1.0 upgrade)",
)
p_migrate.add_argument(
"--dry-run",
action="store_true",
help="Show what would be migrated without changing anything",
)
sub.add_parser("status", help="Show what's been filed")
args = parser.parse_args()
@@ -531,9 +580,11 @@ def main():
"mine": cmd_mine,
"split": cmd_split,
"search": cmd_search,
"mcp": cmd_mcp,
"compress": cmd_compress,
"wake-up": cmd_wakeup,
"repair": cmd_repair,
"migrate": cmd_migrate,
"status": cmd_status,
}
dispatch[args.command](args)
+60
View File
@@ -6,8 +6,58 @@ Priority: env vars > config file (~/.mempalace/config.json) > defaults
import json
import os
import re
from pathlib import Path
# ── Input validation ──────────────────────────────────────────────────────────
# Shared sanitizers for wing/room/entity names. Prevents path traversal,
# excessively long strings, and special characters that could cause issues
# in file paths, SQLite, or ChromaDB metadata.
MAX_NAME_LENGTH = 128
_SAFE_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_ .'-]{0,126}[a-zA-Z0-9]?$")
def sanitize_name(value: str, field_name: str = "name") -> str:
"""Validate and sanitize a wing/room/entity name.
Raises ValueError if the name is invalid.
"""
if not isinstance(value, str) or not value.strip():
raise ValueError(f"{field_name} must be a non-empty string")
value = value.strip()
if len(value) > MAX_NAME_LENGTH:
raise ValueError(f"{field_name} exceeds maximum length of {MAX_NAME_LENGTH} characters")
# Block path traversal
if ".." in value or "/" in value or "\\" in value:
raise ValueError(f"{field_name} contains invalid path characters")
# Block null bytes
if "\x00" in value:
raise ValueError(f"{field_name} contains null bytes")
# Enforce safe character set
if not _SAFE_NAME_RE.match(value):
raise ValueError(f"{field_name} contains invalid characters")
return value
def sanitize_content(value: str, max_length: int = 100_000) -> str:
"""Validate drawer/diary content length."""
if not isinstance(value, str) or not value.strip():
raise ValueError("content must be a non-empty string")
if len(value) > max_length:
raise ValueError(f"content exceeds maximum length of {max_length} characters")
if "\x00" in value:
raise ValueError("content contains null bytes")
return value
DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace")
DEFAULT_COLLECTION_NAME = "mempalace_drawers"
@@ -126,6 +176,11 @@ class MempalaceConfig:
def init(self):
"""Create config directory and write default config.json if it doesn't exist."""
self._config_dir.mkdir(parents=True, exist_ok=True)
# Restrict directory permissions to owner only (Unix)
try:
self._config_dir.chmod(0o700)
except (OSError, NotImplementedError):
pass # Windows doesn't support Unix permissions
if not self._config_file.exists():
default_config = {
"palace_path": DEFAULT_PALACE_PATH,
@@ -135,6 +190,11 @@ class MempalaceConfig:
}
with open(self._config_file, "w") as f:
json.dump(default_config, f, indent=2)
# Restrict config file to owner read/write only
try:
self._config_file.chmod(0o600)
except (OSError, NotImplementedError):
pass
return self._config_file
def save_people_map(self, people_map):
+11 -35
View File
@@ -15,9 +15,8 @@ from pathlib import Path
from datetime import datetime
from collections import defaultdict
import chromadb
from .normalize import normalize
from .palace import SKIP_DIRS, get_collection, file_already_mined
# File types that might contain conversations
@@ -28,22 +27,8 @@ CONVO_EXTENSIONS = {
".jsonl",
}
SKIP_DIRS = {
".git",
"node_modules",
"__pycache__",
".venv",
"venv",
"env",
"dist",
"build",
".next",
".mempalace",
"tool-results",
"memory",
}
MIN_CHUNK_SIZE = 30
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this
# =============================================================================
@@ -211,23 +196,6 @@ def detect_convo_room(content: str) -> str:
# =============================================================================
def get_collection(palace_path: str):
os.makedirs(palace_path, exist_ok=True)
client = chromadb.PersistentClient(path=palace_path)
try:
return client.get_collection("mempalace_drawers")
except Exception:
return client.create_collection("mempalace_drawers")
def file_already_mined(collection, source_file: str) -> bool:
try:
results = collection.get(where={"source_file": source_file}, limit=1)
return len(results.get("ids", [])) > 0
except Exception:
return False
# =============================================================================
# SCAN FOR CONVERSATION FILES
# =============================================================================
@@ -244,6 +212,14 @@ def scan_convos(convo_dir: str) -> list:
continue
filepath = Path(root) / filename
if filepath.suffix.lower() in CONVO_EXTENSIONS:
# Skip symlinks and oversized files
if filepath.is_symlink():
continue
try:
if filepath.stat().st_size > MAX_FILE_SIZE:
continue
except OSError:
continue
files.append(filepath)
return files
@@ -356,7 +332,7 @@ def mine_convos(
chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room
if extract_mode == "general":
room_counts[chunk_room] += 1
drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.md5((source_file + str(chunk['chunk_index'])).encode(), usedforsecurity=False).hexdigest()[:16]}"
drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}"
try:
collection.add(
documents=[chunk["content"]],
+1 -1
View File
@@ -309,7 +309,7 @@ class EntityRegistry:
def save(self):
self._path.parent.mkdir(parents=True, exist_ok=True)
self._path.write_text(json.dumps(self._data, indent=2))
self._path.write_text(json.dumps(self._data, indent=2), encoding="utf-8")
@staticmethod
def _empty() -> dict:
+1 -1
View File
@@ -158,7 +158,7 @@ def hook_stop(data: dict, harness: str):
if since_last >= SAVE_INTERVAL and exchange_count > 0:
# Update last save point
try:
last_save_file.write_text(str(exchange_count))
last_save_file.write_text(str(exchange_count), encoding="utf-8")
except OSError:
pass
+1
View File
@@ -60,6 +60,7 @@ AI memory system. Store everything, find anything. Local, free, no API key.
mempalace compress Compress palace storage
mempalace status Show palace status
mempalace repair Rebuild vector index
mempalace mcp Show MCP setup command
mempalace hook run Run hook logic (for harness integration)
mempalace instructions <name> Output skill instructions
+83 -77
View File
@@ -50,11 +50,14 @@ class KnowledgeGraph:
def __init__(self, db_path: str = None):
self.db_path = db_path or DEFAULT_KG_PATH
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._connection = None
self._init_db()
def _init_db(self):
conn = self._conn()
conn.executescript("""
PRAGMA journal_mode=WAL;
CREATE TABLE IF NOT EXISTS entities (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
@@ -84,12 +87,19 @@ class KnowledgeGraph:
CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to);
""")
conn.commit()
conn.close()
def _conn(self):
conn = sqlite3.connect(self.db_path, timeout=10)
conn.execute("PRAGMA journal_mode=WAL")
return conn
if self._connection is None:
self._connection = sqlite3.connect(self.db_path, timeout=10, check_same_thread=False)
self._connection.execute("PRAGMA journal_mode=WAL")
self._connection.row_factory = sqlite3.Row
return self._connection
def close(self):
"""Close the database connection."""
if self._connection is not None:
self._connection.close()
self._connection = None
def _entity_id(self, name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")
@@ -101,12 +111,11 @@ class KnowledgeGraph:
eid = self._entity_id(name)
props = json.dumps(properties or {})
conn = self._conn()
conn.execute(
"INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)",
(eid, name, entity_type, props),
)
conn.commit()
conn.close()
with conn:
conn.execute(
"INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)",
(eid, name, entity_type, props),
)
return eid
def add_triple(
@@ -134,38 +143,38 @@ class KnowledgeGraph:
# Auto-create entities if they don't exist
conn = self._conn()
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject))
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj))
with conn:
conn.execute(
"INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject)
)
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj))
# Check for existing identical triple
existing = conn.execute(
"SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
(sub_id, pred, obj_id),
).fetchone()
# Check for existing identical triple
existing = conn.execute(
"SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
(sub_id, pred, obj_id),
).fetchone()
if existing:
conn.close()
return existing[0] # Already exists and still valid
if existing:
return existing["id"] # Already exists and still valid
triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.md5(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:8]}"
triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.sha256(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:12]}"
conn.execute(
"""INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
triple_id,
sub_id,
pred,
obj_id,
valid_from,
valid_to,
confidence,
source_closet,
source_file,
),
)
conn.commit()
conn.close()
conn.execute(
"""INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
triple_id,
sub_id,
pred,
obj_id,
valid_from,
valid_to,
confidence,
source_closet,
source_file,
),
)
return triple_id
def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None):
@@ -176,12 +185,11 @@ class KnowledgeGraph:
ended = ended or date.today().isoformat()
conn = self._conn()
conn.execute(
"UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
(ended, sub_id, pred, obj_id),
)
conn.commit()
conn.close()
with conn:
conn.execute(
"UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
(ended, sub_id, pred, obj_id),
)
# ── Query operations ──────────────────────────────────────────────────
@@ -208,13 +216,13 @@ class KnowledgeGraph:
{
"direction": "outgoing",
"subject": name,
"predicate": row[2],
"object": row[10], # obj_name
"valid_from": row[4],
"valid_to": row[5],
"confidence": row[6],
"source_closet": row[7],
"current": row[5] is None,
"predicate": row["predicate"],
"object": row["obj_name"],
"valid_from": row["valid_from"],
"valid_to": row["valid_to"],
"confidence": row["confidence"],
"source_closet": row["source_closet"],
"current": row["valid_to"] is None,
}
)
@@ -228,18 +236,17 @@ class KnowledgeGraph:
results.append(
{
"direction": "incoming",
"subject": row[10], # sub_name
"predicate": row[2],
"subject": row["sub_name"],
"predicate": row["predicate"],
"object": name,
"valid_from": row[4],
"valid_to": row[5],
"confidence": row[6],
"source_closet": row[7],
"current": row[5] is None,
"valid_from": row["valid_from"],
"valid_to": row["valid_to"],
"confidence": row["confidence"],
"source_closet": row["source_closet"],
"current": row["valid_to"] is None,
}
)
conn.close()
return results
def query_relationship(self, predicate: str, as_of: str = None):
@@ -262,15 +269,14 @@ class KnowledgeGraph:
for row in conn.execute(query, params).fetchall():
results.append(
{
"subject": row[10],
"subject": row["sub_name"],
"predicate": pred,
"object": row[11],
"valid_from": row[4],
"valid_to": row[5],
"current": row[5] is None,
"object": row["obj_name"],
"valid_from": row["valid_from"],
"valid_to": row["valid_to"],
"current": row["valid_to"] is None,
}
)
conn.close()
return results
def timeline(self, entity_name: str = None):
@@ -300,15 +306,14 @@ class KnowledgeGraph:
LIMIT 100
""").fetchall()
conn.close()
return [
{
"subject": r[10],
"predicate": r[2],
"object": r[11],
"valid_from": r[4],
"valid_to": r[5],
"current": r[5] is None,
"subject": r["sub_name"],
"predicate": r["predicate"],
"object": r["obj_name"],
"valid_from": r["valid_from"],
"valid_to": r["valid_to"],
"current": r["valid_to"] is None,
}
for r in rows
]
@@ -317,17 +322,18 @@ class KnowledgeGraph:
def stats(self):
conn = self._conn()
entities = conn.execute("SELECT COUNT(*) FROM entities").fetchone()[0]
triples = conn.execute("SELECT COUNT(*) FROM triples").fetchone()[0]
current = conn.execute("SELECT COUNT(*) FROM triples WHERE valid_to IS NULL").fetchone()[0]
entities = conn.execute("SELECT COUNT(*) as cnt FROM entities").fetchone()["cnt"]
triples = conn.execute("SELECT COUNT(*) as cnt FROM triples").fetchone()["cnt"]
current = conn.execute(
"SELECT COUNT(*) as cnt FROM triples WHERE valid_to IS NULL"
).fetchone()["cnt"]
expired = triples - current
predicates = [
r[0]
r["predicate"]
for r in conn.execute(
"SELECT DISTINCT predicate FROM triples ORDER BY predicate"
).fetchall()
]
conn.close()
return {
"entities": entities,
"triples": triples,
+145 -11
View File
@@ -24,8 +24,9 @@ import json
import logging
import hashlib
from datetime import datetime
from pathlib import Path
from .config import MempalaceConfig
from .config import MempalaceConfig, sanitize_name, sanitize_content
from .version import __version__
from .searcher import search_memories
from .palace_graph import traverse, find_tunnels, graph_stats
@@ -44,7 +45,9 @@ def _parse_args():
metavar="PATH",
help="Path to the palace directory (overrides config file and env var)",
)
args, _ = parser.parse_known_args()
args, unknown = parser.parse_known_args()
if unknown:
logger.debug("Ignoring unknown args: %s", unknown)
return args
@@ -64,16 +67,60 @@ _client_cache = None
_collection_cache = None
# ==================== WRITE-AHEAD LOG ====================
# Every write operation is logged to a JSONL file before execution.
# This provides an audit trail for detecting memory poisoning and
# enables review/rollback of writes from external or untrusted sources.
_WAL_DIR = Path(os.path.expanduser("~/.mempalace/wal"))
_WAL_DIR.mkdir(parents=True, exist_ok=True)
try:
_WAL_DIR.chmod(0o700)
except (OSError, NotImplementedError):
pass
_WAL_FILE = _WAL_DIR / "write_log.jsonl"
def _wal_log(operation: str, params: dict, result: dict = None):
"""Append a write operation to the write-ahead log."""
entry = {
"timestamp": datetime.now().isoformat(),
"operation": operation,
"params": params,
"result": result,
}
try:
with open(_WAL_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, default=str) + "\n")
try:
_WAL_FILE.chmod(0o600)
except (OSError, NotImplementedError):
pass
except Exception as e:
logger.error(f"WAL write failed: {e}")
_client_cache = None
_collection_cache = None
def _get_client():
"""Return a singleton ChromaDB PersistentClient."""
global _client_cache
if _client_cache is None:
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
return _client_cache
def _get_collection(create=False):
"""Return the ChromaDB collection, caching the client between calls."""
global _client_cache, _collection_cache
global _collection_cache
try:
if _client_cache is None:
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
client = _get_client()
if create:
_collection_cache = _client_cache.get_or_create_collection(_config.collection_name)
_collection_cache = client.get_or_create_collection(_config.collection_name)
elif _collection_cache is None:
_collection_cache = _client_cache.get_collection(_config.collection_name)
_collection_cache = client.get_collection(_config.collection_name)
return _collection_cache
except Exception:
return None
@@ -280,11 +327,30 @@ def tool_add_drawer(
wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp"
):
"""File verbatim content into a wing/room. Checks for duplicates first."""
try:
wing = sanitize_name(wing, "wing")
room = sanitize_name(room, "room")
content = sanitize_content(content)
except ValueError as e:
return {"success": False, "error": str(e)}
col = _get_collection(create=True)
if not col:
return _no_palace()
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}"
drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((wing + room + content[:100]).encode()).hexdigest()[:24]}"
_wal_log(
"add_drawer",
{
"drawer_id": drawer_id,
"wing": wing,
"room": room,
"added_by": added_by,
"content_length": len(content),
"content_preview": content[:200],
},
)
# Idempotency: if the deterministic ID already exists, return success as a no-op.
try:
@@ -323,6 +389,19 @@ def tool_delete_drawer(drawer_id: str):
existing = col.get(ids=[drawer_id])
if not existing["ids"]:
return {"success": False, "error": f"Drawer not found: {drawer_id}"}
# Log the deletion with the content being removed for audit trail
deleted_content = existing.get("documents", [""])[0] if existing.get("documents") else ""
deleted_meta = existing.get("metadatas", [{}])[0] if existing.get("metadatas") else {}
_wal_log(
"delete_drawer",
{
"drawer_id": drawer_id,
"deleted_meta": deleted_meta,
"content_preview": deleted_content[:200],
},
)
try:
col.delete(ids=[drawer_id])
logger.info(f"Deleted drawer: {drawer_id}")
@@ -344,6 +423,23 @@ def tool_kg_add(
subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None
):
"""Add a relationship to the knowledge graph."""
try:
subject = sanitize_name(subject, "subject")
predicate = sanitize_name(predicate, "predicate")
object = sanitize_name(object, "object")
except ValueError as e:
return {"success": False, "error": str(e)}
_wal_log(
"kg_add",
{
"subject": subject,
"predicate": predicate,
"object": object,
"valid_from": valid_from,
"source_closet": source_closet,
},
)
triple_id = _kg.add_triple(
subject, predicate, object, valid_from=valid_from, source_closet=source_closet
)
@@ -352,6 +448,10 @@ def tool_kg_add(
def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None):
"""Mark a fact as no longer true (set end date)."""
_wal_log(
"kg_invalidate",
{"subject": subject, "predicate": predicate, "object": object, "ended": ended},
)
_kg.invalidate(subject, predicate, object, ended=ended)
return {
"success": True,
@@ -382,6 +482,12 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
This is the agent's personal journal — observations, thoughts,
what it worked on, what it noticed, what it thinks matters.
"""
try:
agent_name = sanitize_name(agent_name, "agent_name")
entry = sanitize_content(entry)
except ValueError as e:
return {"success": False, "error": str(e)}
wing = f"wing_{agent_name.lower().replace(' ', '_')}"
room = "diary"
col = _get_collection(create=True)
@@ -389,9 +495,23 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
return _no_palace()
now = datetime.now()
entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.md5(entry[:50].encode()).hexdigest()[:8]}"
entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.sha256(entry[:50].encode()).hexdigest()[:12]}"
_wal_log(
"diary_write",
{
"agent_name": agent_name,
"topic": topic,
"entry_id": entry_id,
"entry_preview": entry[:200],
},
)
try:
# TODO: Future versions should expand AAAK before embedding to improve
# semantic search quality. For now, store raw AAAK in metadata so it's
# preserved, and keep the document as-is for embedding (even though
# compressed AAAK degrades embedding quality).
col.add(
ids=[entry_id],
documents=[entry],
@@ -717,17 +837,31 @@ TOOLS = {
}
SUPPORTED_PROTOCOL_VERSIONS = [
"2025-11-25",
"2025-06-18",
"2025-03-26",
"2024-11-05",
]
def handle_request(request):
method = request.get("method", "")
params = request.get("params", {})
req_id = request.get("id")
if method == "initialize":
client_version = params.get("protocolVersion", SUPPORTED_PROTOCOL_VERSIONS[-1])
negotiated = (
client_version
if client_version in SUPPORTED_PROTOCOL_VERSIONS
else SUPPORTED_PROTOCOL_VERSIONS[0]
)
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"protocolVersion": "2024-11-05",
"protocolVersion": negotiated,
"capabilities": {"tools": {}},
"serverInfo": {"name": "mempalace", "version": __version__},
},
@@ -747,7 +881,7 @@ def handle_request(request):
}
elif method == "tools/call":
tool_name = params.get("name")
tool_args = params.get("arguments", {})
tool_args = params.get("arguments") or {}
if tool_name not in TOOLS:
return {
"jsonrpc": "2.0",
+214
View File
@@ -0,0 +1,214 @@
#!/usr/bin/env python3
"""
mempalace migrate — Recover a palace created with a different ChromaDB version.
Reads documents and metadata directly from the palace's SQLite database
(bypassing ChromaDB's API, which fails on version-mismatched palaces),
then re-imports everything into a fresh palace using the currently installed
ChromaDB version.
This fixes the 3.0.0 → 3.1.0 upgrade path where chromadb was downgraded
from 1.5.x to 0.6.x, breaking the on-disk storage format.
Usage:
mempalace migrate # migrate default palace
mempalace migrate --palace /path/to/palace # migrate specific palace
mempalace migrate --dry-run # show what would be migrated
"""
import os
import shutil
import sqlite3
from collections import defaultdict
from datetime import datetime
def extract_drawers_from_sqlite(db_path: str) -> list:
"""Read all drawers directly from ChromaDB's SQLite, bypassing the API.
Works regardless of which ChromaDB version created the database.
Returns list of dicts with 'id', 'document', and 'metadata' keys.
"""
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
# Get all embedding IDs and their documents
rows = conn.execute("""
SELECT e.embedding_id,
MAX(CASE WHEN em.key = 'chroma:document' THEN em.string_value END) as document
FROM embeddings e
JOIN embedding_metadata em ON em.id = e.id
GROUP BY e.embedding_id
""").fetchall()
drawers = []
for row in rows:
embedding_id = row["embedding_id"]
document = row["document"]
if not document:
continue
# Get metadata for this embedding
meta_rows = conn.execute(
"""
SELECT em.key, em.string_value, em.int_value, em.float_value, em.bool_value
FROM embedding_metadata em
JOIN embeddings e ON e.id = em.id
WHERE e.embedding_id = ?
AND em.key NOT LIKE 'chroma:%'
""",
(embedding_id,),
).fetchall()
metadata = {}
for mr in meta_rows:
key = mr["key"]
if mr["string_value"] is not None:
metadata[key] = mr["string_value"]
elif mr["int_value"] is not None:
metadata[key] = mr["int_value"]
elif mr["float_value"] is not None:
metadata[key] = mr["float_value"]
elif mr["bool_value"] is not None:
metadata[key] = bool(mr["bool_value"])
drawers.append(
{
"id": embedding_id,
"document": document,
"metadata": metadata,
}
)
conn.close()
return drawers
def detect_chromadb_version(db_path: str) -> str:
"""Detect which ChromaDB version created the database by checking schema."""
conn = sqlite3.connect(db_path)
try:
# 1.x has schema_str column in collections table
cols = [r[1] for r in conn.execute("PRAGMA table_info(collections)").fetchall()]
if "schema_str" in cols:
return "1.x"
# 0.6.x has embeddings_queue but no schema_str
tables = [
r[0]
for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
]
if "embeddings_queue" in tables:
return "0.6.x"
return "unknown"
finally:
conn.close()
def migrate(palace_path: str, dry_run: bool = False):
"""Migrate a palace to the currently installed ChromaDB version."""
import chromadb
palace_path = os.path.expanduser(palace_path)
db_path = os.path.join(palace_path, "chroma.sqlite3")
if not os.path.isfile(db_path):
print(f"\n No palace database found at {db_path}")
return False
print(f"\n{'=' * 60}")
print(" MemPalace Migrate")
print(f"{'=' * 60}\n")
print(f" Palace: {palace_path}")
print(f" Database: {db_path}")
print(f" DB size: {os.path.getsize(db_path) / 1024 / 1024:.1f} MB")
# Detect version
source_version = detect_chromadb_version(db_path)
print(f" Source: ChromaDB {source_version}")
print(f" Target: ChromaDB {chromadb.__version__}")
# Try reading with current chromadb first
try:
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection("mempalace_drawers")
count = col.count()
print(f"\n Palace is already readable by chromadb {chromadb.__version__}.")
print(f" {count} drawers found. No migration needed.")
return True
except Exception:
print(f"\n Palace is NOT readable by chromadb {chromadb.__version__}.")
print(" Extracting from SQLite directly...")
# Extract all drawers via raw SQL
drawers = extract_drawers_from_sqlite(db_path)
print(f" Extracted {len(drawers)} drawers from SQLite")
if not drawers:
print(" Nothing to migrate.")
return True
# Show summary
wings = defaultdict(lambda: defaultdict(int))
for d in drawers:
w = d["metadata"].get("wing", "?")
r = d["metadata"].get("room", "?")
wings[w][r] += 1
print("\n Summary:")
for wing, rooms in sorted(wings.items()):
total = sum(rooms.values())
print(f" WING: {wing} ({total} drawers)")
for room, count in sorted(rooms.items(), key=lambda x: -x[1]):
print(f" ROOM: {room:30} {count:5}")
if dry_run:
print("\n DRY RUN — no changes made.")
print(f" Would migrate {len(drawers)} drawers.")
return True
# Backup the old palace
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"{palace_path}.pre-migrate.{timestamp}"
print(f"\n Backing up to {backup_path}...")
shutil.copytree(palace_path, backup_path)
# Build fresh palace in a temp directory (avoids chromadb reading old state)
import tempfile
temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_")
print(f" Creating fresh palace in {temp_palace}...")
client = chromadb.PersistentClient(path=temp_palace)
col = client.get_or_create_collection("mempalace_drawers")
# Re-import in batches
batch_size = 500
imported = 0
for i in range(0, len(drawers), batch_size):
batch = drawers[i : i + batch_size]
col.add(
ids=[d["id"] for d in batch],
documents=[d["document"] for d in batch],
metadatas=[d["metadata"] for d in batch],
)
imported += len(batch)
print(f" Imported {imported}/{len(drawers)} drawers...")
# Verify before swapping
final_count = col.count()
del col
del client
# Swap: remove old palace, move new one into place
print(" Swapping old palace for migrated version...")
shutil.rmtree(palace_path)
shutil.move(temp_palace, palace_path)
print("\n Migration complete.")
print(f" Drawers migrated: {final_count}")
print(f" Backup at: {backup_path}")
if final_count != len(drawers):
print(f" WARNING: Expected {len(drawers)}, got {final_count}")
print(f"\n{'=' * 60}\n")
return True
+14 -58
View File
@@ -17,6 +17,8 @@ from collections import defaultdict
import chromadb
from .palace import SKIP_DIRS, get_collection, file_already_mined
READABLE_EXTENSIONS = {
".txt",
".md",
@@ -40,32 +42,6 @@ READABLE_EXTENSIONS = {
".toml",
}
SKIP_DIRS = {
".git",
"node_modules",
"__pycache__",
".venv",
"venv",
"env",
"dist",
"build",
".next",
"coverage",
".mempalace",
".ruff_cache",
".mypy_cache",
".pytest_cache",
".cache",
".tox",
".nox",
".idea",
".vscode",
".ipynb_checkpoints",
".eggs",
"htmlcov",
"target",
}
SKIP_FILENAMES = {
"mempalace.yaml",
"mempalace.yml",
@@ -78,6 +54,7 @@ SKIP_FILENAMES = {
CHUNK_SIZE = 800 # chars per drawer
CHUNK_OVERLAP = 100 # overlap between chunks
MIN_CHUNK_SIZE = 50 # skip tiny chunks
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this
# =============================================================================
@@ -393,41 +370,11 @@ def chunk_text(content: str, source_file: str) -> list:
# =============================================================================
def get_collection(palace_path: str):
os.makedirs(palace_path, exist_ok=True)
client = chromadb.PersistentClient(path=palace_path)
try:
return client.get_collection("mempalace_drawers")
except Exception:
return client.create_collection("mempalace_drawers")
def file_already_mined(collection, source_file: str) -> bool:
"""Fast check: has this file been filed before and is unchanged?
Compares the stored mtime in drawer metadata against the file's current
mtime. Returns False (needs re-mining) when the file has been modified
since it was last mined, or when no mtime was stored.
"""
try:
results = collection.get(where={"source_file": source_file}, limit=1)
if not results.get("ids"):
return False
stored_meta = results["metadatas"][0] if results.get("metadatas") else {}
stored_mtime = stored_meta.get("source_mtime")
if stored_mtime is None:
return False
current_mtime = os.path.getmtime(source_file)
return float(stored_mtime) == current_mtime
except Exception:
return False
def add_drawer(
collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str
):
"""Add one drawer to the palace."""
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}"
drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((source_file + str(chunk_index)).encode()).hexdigest()[:24]}"
try:
metadata = {
"wing": wing,
@@ -470,7 +417,7 @@ def process_file(
# Skip if already filed
source_file = str(filepath)
if not dry_run and file_already_mined(collection, source_file):
if not dry_run and file_already_mined(collection, source_file, check_mtime=True):
return 0, None
try:
@@ -562,6 +509,15 @@ def scan_project(
if respect_gitignore and active_matchers and not force_include:
if is_gitignored(filepath, active_matchers, is_dir=False):
continue
# Skip symlinks — prevents following links to /dev/urandom, etc.
if filepath.is_symlink():
continue
# Skip files exceeding size limit
try:
if filepath.stat().st_size > MAX_FILE_SIZE:
continue
except OSError:
continue
files.append(filepath)
return files
+6
View File
@@ -25,6 +25,12 @@ def normalize(filepath: str) -> str:
Load a file and normalize to transcript format if it's a chat export.
Plain text files pass through unchanged.
"""
try:
file_size = os.path.getsize(filepath)
except OSError as e:
raise IOError(f"Could not read {filepath}: {e}")
if file_size > 500 * 1024 * 1024: # 500 MB safety limit
raise IOError(f"File too large ({file_size // (1024*1024)} MB): {filepath}")
try:
with open(filepath, "r", encoding="utf-8", errors="replace") as f:
content = f.read()
+2 -2
View File
@@ -312,7 +312,7 @@ def _generate_aaak_bootstrap(
]
)
(mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines))
(mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines), encoding="utf-8")
# Critical facts bootstrap (pre-palace — before any mining)
facts_lines = [
@@ -359,7 +359,7 @@ def _generate_aaak_bootstrap(
]
)
(mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines))
(mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines), encoding="utf-8")
def run_onboarding(
+71
View File
@@ -0,0 +1,71 @@
"""
palace.py — Shared palace operations.
Consolidates ChromaDB access patterns used by both miners and the MCP server.
"""
import os
import chromadb
SKIP_DIRS = {
".git",
"node_modules",
"__pycache__",
".venv",
"venv",
"env",
"dist",
"build",
".next",
"coverage",
".mempalace",
".ruff_cache",
".mypy_cache",
".pytest_cache",
".cache",
".tox",
".nox",
".idea",
".vscode",
".ipynb_checkpoints",
".eggs",
"htmlcov",
"target",
}
def get_collection(palace_path: str, collection_name: str = "mempalace_drawers"):
"""Get or create the palace ChromaDB collection."""
os.makedirs(palace_path, exist_ok=True)
try:
os.chmod(palace_path, 0o700)
except (OSError, NotImplementedError):
pass
client = chromadb.PersistentClient(path=palace_path)
try:
return client.get_collection(collection_name)
except Exception:
return client.create_collection(collection_name)
def file_already_mined(collection, source_file: str, check_mtime: bool = False) -> bool:
"""Check if a file has already been filed in the palace.
When check_mtime=True (used by project miner), returns False if the file
has been modified since it was last mined, so it gets re-mined.
When check_mtime=False (used by convo miner), just checks existence.
"""
try:
results = collection.get(where={"source_file": source_file}, limit=1)
if not results.get("ids"):
return False
if check_mtime:
stored_meta = results.get("metadatas", [{}])[0]
stored_mtime = stored_meta.get("source_mtime")
if stored_mtime is None:
return False
current_mtime = os.path.getmtime(source_file)
return float(stored_mtime) == current_mtime
return True
except Exception:
return False
+9 -1
View File
@@ -182,6 +182,10 @@ def split_file(filepath, output_dir, dry_run=False):
Returns list of output paths written (or would be written if dry_run).
"""
path = Path(filepath)
max_size = 500 * 1024 * 1024 # 500 MB safety limit
if path.stat().st_size > max_size:
print(f" SKIP: {path.name} exceeds {max_size // (1024*1024)} MB limit")
return []
lines = path.read_text(errors="replace").splitlines(keepends=True)
boundaries = find_session_boundaries(lines)
@@ -219,7 +223,7 @@ def split_file(filepath, output_dir, dry_run=False):
if dry_run:
print(f" [{i + 1}/{len(boundaries) - 1}] {name} ({len(chunk)} lines)")
else:
out_path.write_text("".join(chunk))
out_path.write_text("".join(chunk), encoding="utf-8")
print(f"{name} ({len(chunk)} lines)")
written.append(out_path)
@@ -266,7 +270,11 @@ def main():
files = sorted(src_dir.glob("*.txt"))
mega_files = []
max_scan_size = 500 * 1024 * 1024 # 500 MB
for f in files:
if f.stat().st_size > max_scan_size:
print(f" SKIP: {f.name} exceeds {max_scan_size // (1024*1024)} MB limit")
continue
lines = f.read_text(errors="replace").splitlines(keepends=True)
boundaries = find_session_boundaries(lines)
if len(boundaries) >= args.min_sessions:
+1 -1
View File
@@ -1,3 +1,3 @@
"""Single source of truth for the MemPalace package version."""
__version__ = "3.0.14"
__version__ = "3.1.0"
+8 -4
View File
@@ -1,6 +1,6 @@
[project]
name = "mempalace"
version = "3.0.14"
version = "3.1.0"
description = "Give your AI a memory — mine projects and conversations into a searchable palace. No API key required."
readme = "README.md"
requires-python = ">=3.9"
@@ -26,7 +26,7 @@ classifiers = [
]
dependencies = [
"chromadb>=0.5.0,<0.7",
"pyyaml>=6.0",
"pyyaml>=6.0,<7",
]
[project.urls]
@@ -54,11 +54,15 @@ packages = ["mempalace"]
[tool.ruff]
line-length = 100
target-version = "py39"
extend-exclude = ["benchmarks"]
[tool.ruff.lint]
select = ["E", "F", "W"]
select = ["E", "F", "W", "C901"]
ignore = ["E501"]
[tool.ruff.lint.mccabe]
max-complexity = 25
[tool.ruff.format]
quote-style = "double"
@@ -76,7 +80,7 @@ markers = [
source = ["mempalace"]
[tool.coverage.report]
fail_under = 30
fail_under = 85
show_missing = true
exclude_lines = [
"if __name__",
+3 -3
View File
@@ -148,9 +148,9 @@ class TestWakeUpTokenBudget:
record_metric("wakeup_budget", f"tokens_at_{n_drawers}", token_estimate)
record_metric("wakeup_budget", f"chars_at_{n_drawers}", len(text))
assert token_estimate < 1200, (
f"Wake-up exceeded budget: ~{token_estimate} tokens at {n_drawers} drawers"
)
assert (
token_estimate < 1200
), f"Wake-up exceeded budget: ~{token_estimate} tokens at {n_drawers} drawers"
@pytest.mark.benchmark
+652
View File
@@ -0,0 +1,652 @@
"""Tests for mempalace.cli — the main CLI dispatcher."""
import argparse
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from mempalace.cli import (
cmd_compress,
cmd_hook,
cmd_init,
cmd_instructions,
cmd_mine,
cmd_repair,
cmd_search,
cmd_split,
cmd_status,
cmd_wakeup,
main,
)
# ── cmd_status ─────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_status_default_palace(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None)
mock_miner = MagicMock()
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
cmd_status(args)
mock_miner.status.assert_called_once_with(palace_path="/fake/palace")
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_status_custom_palace(mock_config_cls):
args = argparse.Namespace(palace="~/my_palace")
mock_miner = MagicMock()
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
cmd_status(args)
import os
expected = os.path.expanduser("~/my_palace")
mock_miner.status.assert_called_once_with(palace_path=expected)
# ── cmd_search ─────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_search_calls_search(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(
palace=None, query="test query", wing="mywing", room="myroom", results=3
)
with patch("mempalace.searcher.search") as mock_search:
cmd_search(args)
mock_search.assert_called_once_with(
query="test query",
palace_path="/fake/palace",
wing="mywing",
room="myroom",
n_results=3,
)
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_search_error_exits(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, query="q", wing=None, room=None, results=5)
from mempalace.searcher import SearchError
with patch("mempalace.searcher.search", side_effect=SearchError("fail")):
with pytest.raises(SystemExit) as exc_info:
cmd_search(args)
assert exc_info.value.code == 1
# ── cmd_instructions ───────────────────────────────────────────────────
def test_cmd_instructions_calls_run_instructions():
args = argparse.Namespace(name="help")
with patch("mempalace.instructions_cli.run_instructions") as mock_run:
cmd_instructions(args)
mock_run.assert_called_once_with(name="help")
# ── cmd_hook ───────────────────────────────────────────────────────────
def test_cmd_hook_calls_run_hook():
args = argparse.Namespace(hook="session-start", harness="claude-code")
with patch("mempalace.hooks_cli.run_hook") as mock_run:
cmd_hook(args)
mock_run.assert_called_once_with(hook_name="session-start", harness="claude-code")
# ── cmd_init ───────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_init_no_entities(mock_config_cls, tmp_path):
args = argparse.Namespace(dir=str(tmp_path), yes=True)
with (
patch("mempalace.entity_detector.scan_for_detection", return_value=[]),
patch("mempalace.room_detector_local.detect_rooms_local") as mock_rooms,
):
cmd_init(args)
mock_rooms.assert_called_once_with(project_dir=str(tmp_path), yes=True)
mock_config_cls.return_value.init.assert_called_once()
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_init_with_entities(mock_config_cls, tmp_path):
fake_files = [tmp_path / "a.txt"]
detected = {"people": [{"name": "Alice"}], "projects": [], "uncertain": []}
confirmed = {"people": ["Alice"], "projects": []}
args = argparse.Namespace(dir=str(tmp_path), yes=True)
with (
patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files),
patch("mempalace.entity_detector.detect_entities", return_value=detected),
patch("mempalace.entity_detector.confirm_entities", return_value=confirmed),
patch("mempalace.room_detector_local.detect_rooms_local"),
patch("builtins.open", MagicMock()),
):
cmd_init(args)
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_init_with_entities_zero_total(mock_config_cls, tmp_path, capsys):
"""When entities detected but total is 0, prints 'No entities' message."""
fake_files = [tmp_path / "a.txt"]
detected = {"people": [], "projects": [], "uncertain": []}
args = argparse.Namespace(dir=str(tmp_path), yes=False)
with (
patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files),
patch("mempalace.entity_detector.detect_entities", return_value=detected),
patch("mempalace.room_detector_local.detect_rooms_local"),
):
cmd_init(args)
out = capsys.readouterr().out
assert "No entities detected" in out
# ── cmd_mine ───────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_mine_projects_mode(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(
dir="/src",
palace=None,
mode="projects",
wing=None,
agent="mempalace",
limit=0,
dry_run=False,
no_gitignore=False,
include_ignored=[],
extract="exchange",
)
with patch("mempalace.miner.mine") as mock_mine:
cmd_mine(args)
mock_mine.assert_called_once_with(
project_dir="/src",
palace_path="/fake/palace",
wing_override=None,
agent="mempalace",
limit=0,
dry_run=False,
respect_gitignore=True,
include_ignored=[],
)
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_mine_convos_mode(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(
dir="/chats",
palace=None,
mode="convos",
wing="mywing",
agent="me",
limit=10,
dry_run=True,
no_gitignore=False,
include_ignored=[],
extract="general",
)
with patch("mempalace.convo_miner.mine_convos") as mock_mine:
cmd_mine(args)
mock_mine.assert_called_once_with(
convo_dir="/chats",
palace_path="/fake/palace",
wing="mywing",
agent="me",
limit=10,
dry_run=True,
extract_mode="general",
)
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_mine_include_ignored_comma_split(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(
dir="/src",
palace=None,
mode="projects",
wing=None,
agent="mempalace",
limit=0,
dry_run=False,
no_gitignore=False,
include_ignored=["a.txt,b.txt", "c.txt"],
extract="exchange",
)
with patch("mempalace.miner.mine") as mock_mine:
cmd_mine(args)
mock_mine.assert_called_once()
call_kwargs = mock_mine.call_args[1]
assert call_kwargs["include_ignored"] == ["a.txt", "b.txt", "c.txt"]
# ── cmd_wakeup ─────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_wakeup(mock_config_cls, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing=None)
mock_stack = MagicMock()
mock_stack.wake_up.return_value = "Hello world context"
with patch("mempalace.layers.MemoryStack", return_value=mock_stack):
cmd_wakeup(args)
out = capsys.readouterr().out
assert "Hello world context" in out
assert "tokens" in out
# ── cmd_split ──────────────────────────────────────────────────────────
def test_cmd_split_basic():
args = argparse.Namespace(dir="/chats", output_dir=None, dry_run=False, min_sessions=2)
with patch("mempalace.split_mega_files.main") as mock_main:
cmd_split(args)
mock_main.assert_called_once()
def test_cmd_split_all_options():
args = argparse.Namespace(dir="/chats", output_dir="/out", dry_run=True, min_sessions=5)
with patch("mempalace.split_mega_files.main") as mock_main:
cmd_split(args)
mock_main.assert_called_once()
# sys.argv should be restored
assert sys.argv[0] != "mempalace split"
# ── main() argparse dispatch ──────────────────────────────────────────
def test_main_no_args_prints_help(capsys):
with patch("sys.argv", ["mempalace"]):
main()
out = capsys.readouterr().out
assert "MemPalace" in out
def test_main_status_dispatches():
with (
patch("sys.argv", ["mempalace", "status"]),
patch("mempalace.cli.cmd_status") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_search_dispatches():
with (
patch("sys.argv", ["mempalace", "search", "my query"]),
patch("mempalace.cli.cmd_search") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_init_dispatches():
with (
patch("sys.argv", ["mempalace", "init", "/some/dir"]),
patch("mempalace.cli.cmd_init") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_mine_dispatches():
with (
patch("sys.argv", ["mempalace", "mine", "/some/dir"]),
patch("mempalace.cli.cmd_mine") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_wakeup_dispatches():
with (
patch("sys.argv", ["mempalace", "wake-up"]),
patch("mempalace.cli.cmd_wakeup") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_split_dispatches():
with (
patch("sys.argv", ["mempalace", "split", "/chats"]),
patch("mempalace.cli.cmd_split") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_mcp_command_prints_setup_guidance(monkeypatch, capsys):
monkeypatch.setattr(sys, "argv", ["mempalace", "mcp"])
main()
captured = capsys.readouterr()
assert "MemPalace MCP quick setup:" in captured.out
assert "claude mcp add mempalace -- python -m mempalace.mcp_server" in captured.out
assert "\nOptional custom palace:\n" in captured.out
assert "python -m mempalace.mcp_server --palace /path/to/palace" in captured.out
assert "[--palace /path/to/palace]" not in captured.out
assert captured.err == ""
def test_mcp_command_uses_custom_palace_path_when_provided(monkeypatch, capsys):
monkeypatch.setattr(sys, "argv", ["mempalace", "--palace", "~/tmp/my palace", "mcp"])
main()
captured = capsys.readouterr()
expanded = str(Path("~/tmp/my palace").expanduser())
assert "python -m mempalace.mcp_server --palace" in captured.out
assert expanded in captured.out
assert "Optional custom palace:" not in captured.out
assert "[--palace /path/to/palace]" not in captured.out
assert captured.err == ""
def test_main_hook_no_subcommand_prints_help(capsys):
with patch("sys.argv", ["mempalace", "hook"]):
main()
out = capsys.readouterr().out
assert "hook" in out.lower() or "run" in out.lower()
def test_main_hook_run_dispatches():
with (
patch(
"sys.argv",
["mempalace", "hook", "run", "--hook", "session-start", "--harness", "claude-code"],
),
patch("mempalace.cli.cmd_hook") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_instructions_no_subcommand_prints_help(capsys):
with patch("sys.argv", ["mempalace", "instructions"]):
main()
out = capsys.readouterr().out
assert "instructions" in out.lower() or "init" in out.lower()
def test_main_instructions_dispatches():
with (
patch("sys.argv", ["mempalace", "instructions", "help"]),
patch("mempalace.cli.cmd_instructions") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_repair_dispatches():
with (
patch("sys.argv", ["mempalace", "repair"]),
patch("mempalace.cli.cmd_repair") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_compress_dispatches():
with (
patch("sys.argv", ["mempalace", "compress"]),
patch("mempalace.cli.cmd_compress") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
# ── cmd_repair ─────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_no_palace(mock_config_cls, tmp_path, capsys):
mock_config_cls.return_value.palace_path = str(tmp_path / "nonexistent")
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args)
out = capsys.readouterr().out
assert "No palace found" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
mock_client = MagicMock()
mock_client.get_collection.side_effect = Exception("corrupt db")
mock_chromadb.PersistentClient.return_value = mock_client
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args)
out = capsys.readouterr().out
assert "Error reading palace" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.count.return_value = 0
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args)
out = capsys.readouterr().out
assert "Nothing to repair" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_success(mock_config_cls, tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.count.return_value = 2
mock_col.get.return_value = {
"ids": ["id1", "id2"],
"documents": ["doc1", "doc2"],
"metadatas": [{"wing": "a"}, {"wing": "b"}],
}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_new_col = MagicMock()
mock_client.create_collection.return_value = mock_new_col
mock_chromadb.PersistentClient.return_value = mock_client
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args)
out = capsys.readouterr().out
assert "Repair complete" in out
assert "2 drawers rebuilt" in out
# ── cmd_compress ───────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_compress_no_palace(mock_config_cls, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None)
mock_chromadb = MagicMock()
mock_chromadb.PersistentClient.side_effect = Exception("no palace")
with (
patch.dict("sys.modules", {"chromadb": mock_chromadb}),
pytest.raises(SystemExit),
):
cmd_compress(args)
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_compress_no_drawers(mock_config_cls, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing="mywing", dry_run=False, config=None)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_compress(args)
out = capsys.readouterr().out
assert "No drawers found" in out
def _make_mock_dialect_module(dialect_instance):
"""Create a mock dialect module with a Dialect class that returns the given instance."""
mock_mod = MagicMock()
mock_mod.Dialect.return_value = dialect_instance
mock_mod.Dialect.from_config.return_value = dialect_instance
mock_mod.Dialect.count_tokens = MagicMock(side_effect=lambda x: len(x) // 4)
return mock_mod
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_compress_dry_run(mock_config_cls, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=None)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.get.side_effect = [
{
"documents": ["some long text here for testing"],
"metadatas": [{"wing": "test", "room": "general", "source_file": "test.txt"}],
"ids": ["id1"],
},
{"documents": [], "metadatas": [], "ids": []},
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_dialect = MagicMock()
mock_dialect.compress.return_value = "compressed"
mock_dialect.compression_stats.return_value = {
"original_chars": 100,
"compressed_chars": 30,
"original_tokens": 25,
"compressed_tokens": 8,
"ratio": 3.3,
}
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
with patch.dict(
"sys.modules",
{
"chromadb": mock_chromadb,
"mempalace.dialect": mock_dialect_mod,
},
):
cmd_compress(args)
out = capsys.readouterr().out
assert "dry run" in out.lower()
assert "Compressing" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_compress_with_config(mock_config_cls, tmp_path, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace"
config_file = tmp_path / "entities.json"
config_file.write_text('{"people": [], "projects": []}')
args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=str(config_file))
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_dialect = MagicMock()
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
with patch.dict(
"sys.modules",
{
"chromadb": mock_chromadb,
"mempalace.dialect": mock_dialect_mod,
},
):
cmd_compress(args)
out = capsys.readouterr().out
assert "Loaded entity config" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_compress_stores_results(mock_config_cls, capsys):
"""Non-dry-run compress stores to mempalace_compressed collection."""
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.get.side_effect = [
{
"documents": ["text"],
"metadatas": [{"wing": "w", "room": "r", "source_file": "f.txt"}],
"ids": ["id1"],
},
{"documents": [], "metadatas": [], "ids": []},
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_comp_col = MagicMock()
mock_client.get_or_create_collection.return_value = mock_comp_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_dialect = MagicMock()
mock_dialect.compress.return_value = "compressed"
mock_dialect.compression_stats.return_value = {
"original_chars": 100,
"compressed_chars": 30,
"original_tokens": 25,
"compressed_tokens": 8,
"ratio": 3.3,
}
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
with patch.dict(
"sys.modules",
{
"chromadb": mock_chromadb,
"mempalace.dialect": mock_dialect_mod,
},
):
cmd_compress(args)
out = capsys.readouterr().out
assert "Stored" in out
mock_comp_col.upsert.assert_called_once()
def test_cmd_repair_trailing_slash_does_not_recurse():
"""Repair with trailing slash should put backup outside palace dir (#395)."""
import os
args = argparse.Namespace(palace="/tmp/fake_palace/")
with patch("mempalace.cli.os.path.isdir", return_value=False):
cmd_repair(args)
# Verify the rstrip logic: palace_path should not end with separator
palace_path = os.path.expanduser(args.palace).rstrip(os.sep)
backup_path = palace_path + ".backup"
assert not backup_path.startswith(palace_path + os.sep)
+79
View File
@@ -0,0 +1,79 @@
"""Extra tests for mempalace.config to cover remaining gaps."""
import json
import os
from mempalace.config import MempalaceConfig
def test_config_bad_json(tmp_path):
"""Bad JSON in config file falls back to empty."""
(tmp_path / "config.json").write_text("not json", encoding="utf-8")
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.palace_path # still returns default
def test_people_map_from_file(tmp_path):
(tmp_path / "people_map.json").write_text(json.dumps({"bob": "Robert"}), encoding="utf-8")
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.people_map == {"bob": "Robert"}
def test_people_map_bad_json(tmp_path):
(tmp_path / "people_map.json").write_text("bad", encoding="utf-8")
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.people_map == {}
def test_people_map_missing(tmp_path):
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.people_map == {}
def test_topic_wings_default(tmp_path):
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert isinstance(cfg.topic_wings, list)
assert "emotions" in cfg.topic_wings
def test_hall_keywords_default(tmp_path):
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert isinstance(cfg.hall_keywords, dict)
assert "technical" in cfg.hall_keywords
def test_init_idempotent(tmp_path):
cfg = MempalaceConfig(config_dir=str(tmp_path))
cfg.init()
cfg.init() # second call should not overwrite
with open(tmp_path / "config.json") as f:
data = json.load(f)
assert "palace_path" in data
def test_save_people_map(tmp_path):
cfg = MempalaceConfig(config_dir=str(tmp_path))
result = cfg.save_people_map({"alice": "Alice Smith"})
assert result.exists()
with open(result) as f:
data = json.load(f)
assert data["alice"] == "Alice Smith"
def test_env_mempal_palace_path(tmp_path):
"""MEMPAL_PALACE_PATH (legacy) should also work."""
os.environ.pop("MEMPALACE_PALACE_PATH", None)
os.environ["MEMPAL_PALACE_PATH"] = "/legacy/path"
try:
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.palace_path == "/legacy/path"
finally:
del os.environ["MEMPAL_PALACE_PATH"]
def test_collection_name_from_config(tmp_path):
(tmp_path / "config.json").write_text(
json.dumps({"collection_name": "custom_col"}), encoding="utf-8"
)
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.collection_name == "custom_col"
+1 -1
View File
@@ -23,4 +23,4 @@ def test_convo_mining():
results = col.query(query_texts=["memory persistence"], n_results=1)
assert len(results["documents"][0]) > 0
shutil.rmtree(tmpdir)
shutil.rmtree(tmpdir, ignore_errors=True)
+102
View File
@@ -0,0 +1,102 @@
"""Unit tests for convo_miner pure functions (no chromadb needed)."""
from mempalace.convo_miner import (
chunk_exchanges,
detect_convo_room,
scan_convos,
)
class TestChunkExchanges:
def test_exchange_chunking(self):
content = (
"> What is memory?\n"
"Memory is persistence of information over time.\n\n"
"> Why does it matter?\n"
"It enables continuity across sessions and conversations.\n\n"
"> How do we build it?\n"
"With structured storage and retrieval mechanisms.\n"
)
chunks = chunk_exchanges(content)
assert len(chunks) >= 2
assert all("content" in c and "chunk_index" in c for c in chunks)
def test_paragraph_fallback(self):
"""Content without '>' lines falls back to paragraph chunking."""
content = (
"This is a long paragraph about memory systems. " * 10 + "\n\n"
"This is another paragraph about storage. " * 10 + "\n\n"
"And a third paragraph about retrieval. " * 10
)
chunks = chunk_exchanges(content)
assert len(chunks) >= 2
def test_paragraph_line_group_fallback(self):
"""Long content with no paragraph breaks chunks by line groups."""
lines = [f"Line {i}: some content that is meaningful" for i in range(60)]
content = "\n".join(lines)
chunks = chunk_exchanges(content)
assert len(chunks) >= 1
def test_empty_content(self):
chunks = chunk_exchanges("")
assert chunks == []
def test_short_content_skipped(self):
chunks = chunk_exchanges("> hi\nbye")
# Too short to produce chunks (below MIN_CHUNK_SIZE)
assert isinstance(chunks, list)
class TestDetectConvoRoom:
def test_technical_room(self):
content = "Let me debug this python function and fix the code error in the api"
assert detect_convo_room(content) == "technical"
def test_planning_room(self):
content = "We need to plan the roadmap for the next sprint and set milestone deadlines"
assert detect_convo_room(content) == "planning"
def test_architecture_room(self):
content = "The architecture uses a service layer with component interface and module design"
assert detect_convo_room(content) == "architecture"
def test_decisions_room(self):
content = "We decided to switch and migrated to the new framework after we chose it"
assert detect_convo_room(content) == "decisions"
def test_general_fallback(self):
content = "Hello, how are you doing today? The weather is nice."
assert detect_convo_room(content) == "general"
class TestScanConvos:
def test_scan_finds_txt_and_md(self, tmp_path):
(tmp_path / "chat.txt").write_text("hello", encoding="utf-8")
(tmp_path / "notes.md").write_text("world", encoding="utf-8")
(tmp_path / "image.png").write_bytes(b"fake")
files = scan_convos(str(tmp_path))
extensions = {f.suffix for f in files}
assert ".txt" in extensions
assert ".md" in extensions
assert ".png" not in extensions
def test_scan_skips_git_dir(self, tmp_path):
git_dir = tmp_path / ".git"
git_dir.mkdir()
(git_dir / "config.txt").write_text("git stuff", encoding="utf-8")
(tmp_path / "chat.txt").write_text("hello", encoding="utf-8")
files = scan_convos(str(tmp_path))
assert len(files) == 1
def test_scan_skips_meta_json(self, tmp_path):
(tmp_path / "chat.meta.json").write_text("{}", encoding="utf-8")
(tmp_path / "chat.json").write_text("{}", encoding="utf-8")
files = scan_convos(str(tmp_path))
names = [f.name for f in files]
assert "chat.json" in names
assert "chat.meta.json" not in names
def test_scan_empty_dir(self, tmp_path):
files = scan_convos(str(tmp_path))
assert files == []
+380
View File
@@ -0,0 +1,380 @@
"""Tests for mempalace.entity_detector."""
import os
from unittest.mock import patch
from mempalace.entity_detector import (
PROSE_EXTENSIONS,
STOPWORDS,
_print_entity_list,
classify_entity,
confirm_entities,
detect_entities,
extract_candidates,
scan_for_detection,
score_entity,
)
# ── extract_candidates ──────────────────────────────────────────────────
def test_extract_candidates_finds_frequent_names():
text = "Riley said hello. Riley laughed. Riley smiled. Riley waved."
result = extract_candidates(text)
assert "Riley" in result
assert result["Riley"] >= 3
def test_extract_candidates_ignores_stopwords():
# "The" appears many times but is a stopword
text = "The The The The The The"
result = extract_candidates(text)
assert "The" not in result
def test_extract_candidates_requires_min_frequency():
text = "Riley said hi. Devon waved."
result = extract_candidates(text)
# Each name appears only once, below the threshold of 3
assert "Riley" not in result
assert "Devon" not in result
def test_extract_candidates_finds_multi_word_names():
# Multi-word names need 3+ occurrences and no stopwords
text = "Claude Code is great. Claude Code rocks. Claude Code works. Claude Code rules."
result = extract_candidates(text)
assert "Claude Code" in result
def test_extract_candidates_empty_text():
result = extract_candidates("")
assert result == {}
# ── score_entity ────────────────────────────────────────────────────────
def test_score_entity_person_verbs():
text = "Riley said hello. Riley asked why. Riley told me."
lines = text.splitlines()
result = score_entity("Riley", text, lines)
assert result["person_score"] > 0
assert len(result["person_signals"]) > 0
def test_score_entity_project_verbs():
text = "We are building ChromaDB. We deployed ChromaDB. Install ChromaDB."
lines = text.splitlines()
result = score_entity("ChromaDB", text, lines)
assert result["project_score"] > 0
assert len(result["project_signals"]) > 0
def test_score_entity_dialogue_markers():
text = "Riley: Hey, how are you?\nRiley: I'm fine."
lines = text.splitlines()
result = score_entity("Riley", text, lines)
assert result["person_score"] > 0
def test_score_entity_code_ref():
text = "Check out ChromaDB.py for details. Also ChromaDB.js is good."
lines = text.splitlines()
result = score_entity("ChromaDB", text, lines)
assert result["project_score"] > 0
def test_score_entity_no_signals():
text = "Nothing interesting here at all."
lines = text.splitlines()
result = score_entity("Riley", text, lines)
assert result["person_score"] == 0
assert result["project_score"] == 0
# ── classify_entity ─────────────────────────────────────────────────────
def test_classify_entity_no_signals_gives_uncertain():
scores = {
"person_score": 0,
"project_score": 0,
"person_signals": [],
"project_signals": [],
}
result = classify_entity("Foo", 10, scores)
assert result["type"] == "uncertain"
assert result["name"] == "Foo"
def test_classify_entity_strong_project():
scores = {
"person_score": 0,
"project_score": 10,
"person_signals": [],
"project_signals": ["project verb (5x)", "code file reference (2x)"],
}
result = classify_entity("ChromaDB", 5, scores)
assert result["type"] == "project"
def test_classify_entity_strong_person_needs_two_signal_types():
scores = {
"person_score": 10,
"project_score": 0,
"person_signals": [
"dialogue marker (3x)",
"'Riley ...' action (4x)",
],
"project_signals": [],
}
result = classify_entity("Riley", 8, scores)
assert result["type"] == "person"
def test_classify_entity_pronoun_only_is_uncertain():
scores = {
"person_score": 8,
"project_score": 0,
"person_signals": ["pronoun nearby (4x)"],
"project_signals": [],
}
result = classify_entity("Riley", 5, scores)
assert result["type"] == "uncertain"
def test_classify_entity_mixed_signals():
scores = {
"person_score": 5,
"project_score": 5,
"person_signals": ["pronoun nearby (2x)"],
"project_signals": ["project verb (2x)"],
}
result = classify_entity("Lantern", 5, scores)
assert result["type"] == "uncertain"
assert "mixed signals" in result["signals"][-1]
# ── detect_entities (integration) ───────────────────────────────────────
def test_detect_entities_with_person_file(tmp_path):
f = tmp_path / "notes.txt"
content = "\n".join(
[
"Riley said hello today.",
"Riley asked about the project.",
"Riley told me she was happy.",
"Riley: I think we should go.",
"Hey Riley, thanks for the help.",
"Riley laughed and smiled.",
"Riley decided to join.",
"Riley pushed the change.",
]
)
f.write_text(content)
result = detect_entities([f])
all_names = [e["name"] for cat in result.values() for e in cat]
assert "Riley" in all_names
def test_detect_entities_with_project_file(tmp_path):
f = tmp_path / "readme.txt"
# "ChromaDB" has uppercase+lowercase mix but extract_candidates looks
# for /[A-Z][a-z]{1,19}/ — so we need a name that matches that regex.
# Use "Lantern" which matches the capitalized-word pattern.
content = "\n".join(
[
"The Lantern project is great.",
"Building Lantern was fun.",
"We deployed Lantern today.",
"Install Lantern with pip install Lantern.",
"Check Lantern.py for the source.",
"Lantern v2 is faster.",
]
)
f.write_text(content)
result = detect_entities([f])
all_names = [e["name"] for cat in result.values() for e in cat]
assert "Lantern" in all_names
def test_detect_entities_empty_files(tmp_path):
f = tmp_path / "empty.txt"
f.write_text("")
result = detect_entities([f])
assert result == {"people": [], "projects": [], "uncertain": []}
def test_detect_entities_handles_missing_file(tmp_path):
missing = tmp_path / "nonexistent.txt"
result = detect_entities([missing])
assert result == {"people": [], "projects": [], "uncertain": []}
def test_detect_entities_respects_max_files(tmp_path):
files = []
for i in range(5):
f = tmp_path / f"file{i}.txt"
f.write_text("Riley said hello. " * 10)
files.append(f)
# max_files=2 should only read 2 files
result = detect_entities(files, max_files=2)
# Should still work without error
assert isinstance(result, dict)
# ── scan_for_detection ──────────────────────────────────────────────────
def test_scan_for_detection_finds_prose(tmp_path):
(tmp_path / "notes.md").write_text("hello")
(tmp_path / "data.txt").write_text("world")
(tmp_path / "code.py").write_text("import os")
files = scan_for_detection(str(tmp_path))
extensions = {os.path.splitext(str(f))[1] for f in files}
# Prose files should be found
assert ".md" in extensions or ".txt" in extensions
def test_scan_for_detection_skips_git_dir(tmp_path):
git_dir = tmp_path / ".git"
git_dir.mkdir()
(git_dir / "config.txt").write_text("git config")
(tmp_path / "readme.md").write_text("hello")
files = scan_for_detection(str(tmp_path))
file_strs = [str(f) for f in files]
assert not any(".git" in f for f in file_strs)
# ── module-level constants ──────────────────────────────────────────────
def test_stopwords_contains_common_words():
assert "the" in STOPWORDS
assert "import" in STOPWORDS
assert "class" in STOPWORDS
def test_prose_extensions():
assert ".txt" in PROSE_EXTENSIONS
assert ".md" in PROSE_EXTENSIONS
# ── _print_entity_list ─────────────────────────────────────────────────
def test_print_entity_list_with_entities(capsys):
entities = [
{"name": "Alice", "confidence": 0.9, "signals": ["dialogue marker (3x)"]},
{"name": "Bob", "confidence": 0.5, "signals": []},
]
_print_entity_list(entities, "PEOPLE")
out = capsys.readouterr().out
assert "PEOPLE" in out
assert "Alice" in out
assert "Bob" in out
def test_print_entity_list_empty(capsys):
_print_entity_list([], "PEOPLE")
out = capsys.readouterr().out
assert "none detected" in out
# ── confirm_entities ───────────────────────────────────────────────────
def test_confirm_entities_yes_mode():
detected = {
"people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}],
"projects": [{"name": "Acme", "confidence": 0.8, "signals": ["test"]}],
"uncertain": [{"name": "Foo", "confidence": 0.4, "signals": ["test"]}],
}
result = confirm_entities(detected, yes=True)
assert result["people"] == ["Alice"]
assert result["projects"] == ["Acme"]
def test_confirm_entities_accept_all():
detected = {
"people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}],
"projects": [],
"uncertain": [],
}
with patch("builtins.input", side_effect=["", "n"]):
result = confirm_entities(detected, yes=False)
assert "Alice" in result["people"]
def test_confirm_entities_edit_reclassify_uncertain():
detected = {
"people": [],
"projects": [],
"uncertain": [
{"name": "Foo", "confidence": 0.4, "signals": ["test"]},
{"name": "Bar", "confidence": 0.4, "signals": ["test"]},
],
}
with patch(
"builtins.input",
side_effect=[
"edit", # choice
"p", # Foo -> person
"s", # Bar -> skip
"", # no removals from people
"", # no removals from projects
"n", # don't add missing
],
):
result = confirm_entities(detected, yes=False)
assert "Foo" in result["people"]
assert "Bar" not in result["people"]
assert "Bar" not in result["projects"]
def test_confirm_entities_add_mode():
detected = {
"people": [],
"projects": [],
"uncertain": [],
}
with patch(
"builtins.input",
side_effect=[
"add", # choice = add
"NewPerson", # name
"p", # person
"NewProj", # name
"r", # project
"", # stop adding
],
):
result = confirm_entities(detected, yes=False)
assert "NewPerson" in result["people"]
assert "NewProj" in result["projects"]
# ── scan_for_detection fallback ────────────────────────────────────────
def test_scan_for_detection_fallback_to_all_readable(tmp_path):
"""When fewer than 3 prose files, falls back to include all readable files."""
(tmp_path / "one.md").write_text("hello")
(tmp_path / "two.txt").write_text("world")
# Only 2 prose files, so it should also include code files
(tmp_path / "code.py").write_text("import os")
(tmp_path / "app.js").write_text("console.log()")
files = scan_for_detection(str(tmp_path))
extensions = {os.path.splitext(str(f))[1] for f in files}
assert ".py" in extensions or ".js" in extensions
def test_scan_for_detection_max_files(tmp_path):
"""Caps to max_files."""
for i in range(20):
(tmp_path / f"note{i}.md").write_text(f"content {i}")
files = scan_for_detection(str(tmp_path), max_files=5)
assert len(files) <= 5
+313
View File
@@ -0,0 +1,313 @@
"""Tests for mempalace.entity_registry."""
from unittest.mock import patch
from mempalace.entity_registry import (
COMMON_ENGLISH_WORDS,
PERSON_CONTEXT_PATTERNS,
EntityRegistry,
)
# ── COMMON_ENGLISH_WORDS ────────────────────────────────────────────────
def test_common_english_words_has_expected_entries():
assert "ever" in COMMON_ENGLISH_WORDS
assert "grace" in COMMON_ENGLISH_WORDS
assert "will" in COMMON_ENGLISH_WORDS
assert "may" in COMMON_ENGLISH_WORDS
assert "monday" in COMMON_ENGLISH_WORDS
def test_common_english_words_is_lowercase():
for word in COMMON_ENGLISH_WORDS:
assert word == word.lower(), f"{word} should be lowercase"
# ── PERSON_CONTEXT_PATTERNS ─────────────────────────────────────────────
def test_person_context_patterns_is_nonempty():
assert len(PERSON_CONTEXT_PATTERNS) > 0
# ── EntityRegistry creation and empty state ─────────────────────────────
def test_load_from_nonexistent_dir(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
assert registry.people == {}
assert registry.projects == []
assert registry.mode == "personal"
assert registry.ambiguous_flags == []
def test_save_and_load_roundtrip(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="work",
people=[{"name": "Alice", "relationship": "colleague", "context": "work"}],
projects=["MemPalace"],
)
# Load again from same dir
loaded = EntityRegistry.load(config_dir=tmp_path)
assert loaded.mode == "work"
assert "Alice" in loaded.people
assert "MemPalace" in loaded.projects
def test_save_creates_file(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.save()
assert (tmp_path / "entity_registry.json").exists()
# ── seed ────────────────────────────────────────────────────────────────
def test_seed_registers_people(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[
{"name": "Riley", "relationship": "daughter", "context": "personal"},
{"name": "Devon", "relationship": "friend", "context": "personal"},
],
projects=["MemPalace"],
)
assert "Riley" in registry.people
assert "Devon" in registry.people
assert registry.people["Riley"]["relationship"] == "daughter"
assert registry.people["Riley"]["source"] == "onboarding"
assert registry.people["Riley"]["confidence"] == 1.0
def test_seed_registers_projects(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="work", people=[], projects=["Acme", "Widget"])
assert registry.projects == ["Acme", "Widget"]
def test_seed_sets_mode(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="combo", people=[], projects=[])
assert registry.mode == "combo"
def test_seed_flags_ambiguous_names(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[
{"name": "Grace", "relationship": "friend", "context": "personal"},
{"name": "Riley", "relationship": "daughter", "context": "personal"},
],
projects=[],
)
assert "grace" in registry.ambiguous_flags
# Riley is not a common English word
assert "riley" not in registry.ambiguous_flags
def test_seed_with_aliases(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}],
projects=[],
aliases={"Max": "Maxwell"},
)
assert "Maxwell" in registry.people
assert "Max" in registry.people
assert registry.people["Max"].get("canonical") == "Maxwell"
def test_seed_skips_empty_names(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "", "relationship": "", "context": "personal"}],
projects=[],
)
assert len(registry.people) == 0
# ── lookup ──────────────────────────────────────────────────────────────
def test_lookup_known_person(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
projects=[],
)
result = registry.lookup("Riley")
assert result["type"] == "person"
assert result["confidence"] == 1.0
assert result["name"] == "Riley"
def test_lookup_known_project(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="work", people=[], projects=["MemPalace"])
result = registry.lookup("MemPalace")
assert result["type"] == "project"
assert result["confidence"] == 1.0
def test_lookup_unknown_word(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="personal", people=[], projects=[])
result = registry.lookup("Xyzzy")
assert result["type"] == "unknown"
assert result["confidence"] == 0.0
def test_lookup_case_insensitive(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
projects=[],
)
result = registry.lookup("riley")
assert result["type"] == "person"
def test_lookup_alias(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}],
projects=[],
aliases={"Max": "Maxwell"},
)
result = registry.lookup("Max")
assert result["type"] == "person"
# ── disambiguation ──────────────────────────────────────────────────────
def test_lookup_ambiguous_word_as_person(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Grace", "relationship": "friend", "context": "personal"}],
projects=[],
)
result = registry.lookup("Grace", context="I went with Grace today")
assert result["type"] == "person"
def test_lookup_ambiguous_word_as_concept(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Ever", "relationship": "friend", "context": "personal"}],
projects=[],
)
result = registry.lookup("Ever", context="have you ever tried this")
assert result["type"] == "concept"
# ── research (Wikipedia) — mocked ──────────────────────────────────────
def test_research_caches_result(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="personal", people=[], projects=[])
mock_result = {
"inferred_type": "person",
"confidence": 0.80,
"wiki_summary": "Saoirse is an Irish given name.",
"wiki_title": "Saoirse",
}
with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result):
result = registry.research("Saoirse", auto_confirm=True)
assert result["inferred_type"] == "person"
# Second call should use cache, not call Wikipedia again
with patch(
"mempalace.entity_registry._wikipedia_lookup",
side_effect=AssertionError("should not be called"),
):
cached = registry.research("Saoirse")
assert cached["inferred_type"] == "person"
def test_confirm_research_adds_to_people(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="personal", people=[], projects=[])
mock_result = {
"inferred_type": "person",
"confidence": 0.80,
"wiki_summary": "Saoirse is a name",
"wiki_title": "Saoirse",
}
with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result):
registry.research("Saoirse", auto_confirm=False)
registry.confirm_research("Saoirse", entity_type="person", relationship="friend")
assert "Saoirse" in registry.people
assert registry.people["Saoirse"]["source"] == "wiki"
# ── extract_people_from_query ───────────────────────────────────────────
def test_extract_people_from_query(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[
{"name": "Riley", "relationship": "daughter", "context": "personal"},
{"name": "Devon", "relationship": "friend", "context": "personal"},
],
projects=[],
)
found = registry.extract_people_from_query("What did Riley say about the weather?")
assert "Riley" in found
assert "Devon" not in found
# ── extract_unknown_candidates ──────────────────────────────────────────
def test_extract_unknown_candidates(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="personal", people=[], projects=[])
unknowns = registry.extract_unknown_candidates("Saoirse went to the store")
assert "Saoirse" in unknowns
def test_extract_unknown_candidates_skips_known(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
projects=[],
)
unknowns = registry.extract_unknown_candidates("Riley went to the store")
assert "Riley" not in unknowns
# ── summary ─────────────────────────────────────────────────────────────
def test_summary(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
projects=["MemPalace"],
)
s = registry.summary()
assert "personal" in s
assert "Riley" in s
assert "MemPalace" in s
+248
View File
@@ -0,0 +1,248 @@
"""Tests for mempalace.general_extractor."""
from mempalace.general_extractor import (
ALL_MARKERS,
NEGATIVE_WORDS,
POSITIVE_WORDS,
_extract_prose,
_get_sentiment,
_has_resolution,
_is_code_line,
_score_markers,
_split_into_segments,
extract_memories,
)
# ── extract_memories — empty / no markers ───────────────────────────────
def test_extract_memories_empty_text():
result = extract_memories("")
assert result == []
def test_extract_memories_no_markers():
result = extract_memories("The quick brown fox jumped over the lazy dog.")
assert result == []
def test_extract_memories_short_text_skipped():
# Paragraphs shorter than 20 chars are skipped
result = extract_memories("ok sure")
assert result == []
# ── extract_memories — decision markers ─────────────────────────────────
def test_extract_memories_decision():
text = (
"We decided to go with PostgreSQL instead of MySQL "
"because the performance was better for our use case. "
"The trade-off was more complexity in setup."
)
result = extract_memories(text)
assert len(result) >= 1
assert any(m["memory_type"] == "decision" for m in result)
# ── extract_memories — preference markers ───────────────────────────────
def test_extract_memories_preference():
text = (
"I prefer using snake_case in Python code. "
"Please always use type hints. "
"Never use wildcard imports."
)
result = extract_memories(text)
assert len(result) >= 1
assert any(m["memory_type"] == "preference" for m in result)
# ── extract_memories — milestone markers ────────────────────────────────
def test_extract_memories_milestone():
text = (
"It finally works! After three days of debugging, "
"I figured out the issue. The breakthrough was realizing "
"the config file was cached. Got it working at 2am."
)
result = extract_memories(text)
assert len(result) >= 1
assert any(m["memory_type"] == "milestone" for m in result)
# ── extract_memories — problem markers ──────────────────────────────────
def test_extract_memories_problem():
text = (
"There's a critical bug in the auth module. "
"The error keeps crashing the server. "
"The root cause was a missing null check. "
"The problem is that tokens expire silently."
)
result = extract_memories(text)
assert len(result) >= 1
types = {m["memory_type"] for m in result}
assert "problem" in types or "milestone" in types # resolved problems become milestones
# ── extract_memories — emotional markers ────────────────────────────────
def test_extract_memories_emotional():
text = (
"I feel so proud of what we built together. "
"I love working on this project, it makes me happy. "
"I'm grateful for the team and the beautiful code we wrote."
)
result = extract_memories(text)
assert len(result) >= 1
assert any(m["memory_type"] == "emotional" for m in result)
# ── extract_memories — chunk_index ──────────────────────────────────────
def test_extract_memories_chunk_index_increments():
text = (
"We decided to use React because it fits our team.\n\n"
"I prefer functional components always.\n\n"
"It works! We finally shipped the v1.0 release."
)
result = extract_memories(text)
if len(result) >= 2:
indices = [m["chunk_index"] for m in result]
assert indices == list(range(len(result)))
# ── _score_markers ──────────────────────────────────────────────────────
def test_score_markers_with_matches():
score, keywords = _score_markers(
"we decided to go with postgres because it is faster",
ALL_MARKERS["decision"],
)
assert score > 0
assert len(keywords) > 0
def test_score_markers_no_matches():
score, keywords = _score_markers("nothing relevant here", ALL_MARKERS["decision"])
assert score == 0.0
# ── _get_sentiment ──────────────────────────────────────────────────────
def test_get_sentiment_positive():
assert _get_sentiment("I am so happy and proud of this breakthrough") == "positive"
def test_get_sentiment_negative():
assert _get_sentiment("This bug caused a crash and total failure") == "negative"
def test_get_sentiment_neutral():
assert _get_sentiment("The meeting is at three") == "neutral"
# ── _has_resolution ─────────────────────────────────────────────────────
def test_has_resolution_true():
assert _has_resolution("I fixed the auth bug and it works now") is True
def test_has_resolution_false():
assert _has_resolution("The server keeps crashing") is False
# ── _is_code_line ───────────────────────────────────────────────────────
def test_is_code_line_detects_code():
assert _is_code_line(" import os") is True
assert _is_code_line(" $ pip install flask") is True
assert _is_code_line(" ```python") is True
def test_is_code_line_allows_prose():
assert _is_code_line("This is a regular sentence about coding.") is False
assert _is_code_line("") is False
# ── _extract_prose ──────────────────────────────────────────────────────
def test_extract_prose_strips_code_blocks():
text = "Hello world\n```\nimport os\nprint('hi')\n```\nGoodbye"
result = _extract_prose(text)
assert "import os" not in result
assert "Hello world" in result
assert "Goodbye" in result
def test_extract_prose_returns_original_if_all_code():
text = "import os\nfrom sys import argv"
result = _extract_prose(text)
# Falls back to original text if nothing left
assert len(result) > 0
# ── _split_into_segments ───────────────────────────────────────────────
def test_split_into_segments_by_paragraph():
text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph."
result = _split_into_segments(text)
assert len(result) == 3
def test_split_into_segments_by_turns():
lines = []
for i in range(5):
lines.append(f"Human: Question {i}")
lines.append(f"Assistant: Answer {i}")
text = "\n".join(lines)
result = _split_into_segments(text)
assert len(result) >= 3 # turn-based splitting should fire
def test_split_into_segments_single_block():
# Many lines without double-newline produces chunked segments
lines = [f"Line {i} of the document" for i in range(30)]
text = "\n".join(lines)
result = _split_into_segments(text)
assert len(result) >= 1
# ── ALL_MARKERS constant ───────────────────────────────────────────────
def test_all_markers_has_five_types():
assert set(ALL_MARKERS.keys()) == {
"decision",
"preference",
"milestone",
"problem",
"emotional",
}
# ── POSITIVE_WORDS / NEGATIVE_WORDS ────────────────────────────────────
def test_positive_words():
assert "happy" in POSITIVE_WORDS
assert "proud" in POSITIVE_WORDS
def test_negative_words():
assert "bug" in NEGATIVE_WORDS
assert "crash" in NEGATIVE_WORDS
+213
View File
@@ -1,17 +1,24 @@
import contextlib
import io
import json
from pathlib import Path
from unittest.mock import patch
import pytest
from mempalace.hooks_cli import (
SAVE_INTERVAL,
STOP_BLOCK_REASON,
PRECOMPACT_BLOCK_REASON,
_count_human_messages,
_log,
_maybe_auto_ingest,
_parse_harness_input,
_sanitize_session_id,
hook_stop,
hook_session_start,
hook_precompact,
run_hook,
)
@@ -205,3 +212,209 @@ def test_precompact_always_blocks(tmp_path):
)
assert result["decision"] == "block"
assert result["reason"] == PRECOMPACT_BLOCK_REASON
# --- _log ---
def test_log_writes_to_hook_log(tmp_path):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
_log("test message")
log_path = tmp_path / "hook.log"
assert log_path.is_file()
content = log_path.read_text()
assert "test message" in content
def test_log_oserror_is_silenced(tmp_path):
"""_log should not raise if the directory cannot be created."""
with patch("mempalace.hooks_cli.STATE_DIR", Path("/nonexistent/deeply/nested/dir")):
# Should not raise
_log("this will fail silently")
# --- _maybe_auto_ingest ---
def test_maybe_auto_ingest_no_env(tmp_path):
"""Without MEMPAL_DIR set, does nothing."""
with patch.dict("os.environ", {}, clear=True):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
_maybe_auto_ingest() # should not raise
def test_maybe_auto_ingest_with_env(tmp_path):
"""With MEMPAL_DIR set to a valid directory, spawns subprocess."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli.subprocess.Popen") as mock_popen:
_maybe_auto_ingest()
mock_popen.assert_called_once()
def test_maybe_auto_ingest_oserror(tmp_path):
"""OSError during subprocess spawn is silenced."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli.subprocess.Popen", side_effect=OSError("fail")):
_maybe_auto_ingest() # should not raise
# --- _parse_harness_input ---
def test_parse_harness_input_unknown():
"""Unknown harness should sys.exit(1)."""
with pytest.raises(SystemExit) as exc_info:
_parse_harness_input({"session_id": "test"}, "unknown-harness")
assert exc_info.value.code == 1
def test_parse_harness_input_valid():
result = _parse_harness_input(
{"session_id": "abc-123", "stop_hook_active": True, "transcript_path": "/tmp/t.jsonl"},
"claude-code",
)
assert result["session_id"] == "abc-123"
assert result["stop_hook_active"] is True
# --- hook_stop with OSError on write ---
def test_stop_hook_oserror_on_last_save_read(tmp_path):
"""When last_save_file has invalid content, falls back to 0."""
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
)
# Write invalid content to last save file
(tmp_path / "test_last_save").write_text("not_a_number")
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
state_dir=tmp_path,
)
assert result["decision"] == "block"
def test_stop_hook_oserror_on_write(tmp_path):
"""When write to last_save_file fails, hook still outputs correctly."""
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
)
def bad_write_text(*args, **kwargs):
raise OSError("disk full")
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch.object(Path, "write_text", bad_write_text):
result = _capture_hook_output(
hook_stop,
{
"session_id": "test",
"stop_hook_active": False,
"transcript_path": str(transcript),
},
state_dir=tmp_path,
)
assert result["decision"] == "block"
# --- hook_precompact with MEMPAL_DIR ---
def test_precompact_with_mempal_dir(tmp_path):
"""Precompact runs subprocess.run when MEMPAL_DIR is set."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.subprocess.run") as mock_run:
result = _capture_hook_output(
hook_precompact,
{"session_id": "test"},
state_dir=tmp_path,
)
assert result["decision"] == "block"
mock_run.assert_called_once()
def test_precompact_with_mempal_dir_oserror(tmp_path):
"""Precompact handles OSError from subprocess gracefully."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.subprocess.run", side_effect=OSError("fail")):
result = _capture_hook_output(
hook_precompact,
{"session_id": "test"},
state_dir=tmp_path,
)
assert result["decision"] == "block"
# --- run_hook ---
def test_run_hook_dispatches_session_start(tmp_path):
"""run_hook reads stdin JSON and dispatches to correct handler."""
stdin_data = json.dumps({"session_id": "run-test"})
with patch("sys.stdin", io.StringIO(stdin_data)):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("session-start", "claude-code")
mock_output.assert_called_once_with({})
def test_run_hook_dispatches_stop(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript, [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(3)]
)
stdin_data = json.dumps(
{
"session_id": "run-test",
"stop_hook_active": False,
"transcript_path": str(transcript),
}
)
with patch("sys.stdin", io.StringIO(stdin_data)):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("stop", "claude-code")
mock_output.assert_called_once_with({})
def test_run_hook_dispatches_precompact(tmp_path):
stdin_data = json.dumps({"session_id": "run-test"})
with patch("sys.stdin", io.StringIO(stdin_data)):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("precompact", "claude-code")
mock_output.assert_called_once()
call_args = mock_output.call_args[0][0]
assert call_args["decision"] == "block"
def test_run_hook_unknown_hook():
stdin_data = json.dumps({"session_id": "test"})
with patch("sys.stdin", io.StringIO(stdin_data)):
with pytest.raises(SystemExit) as exc_info:
run_hook("nonexistent", "claude-code")
assert exc_info.value.code == 1
def test_run_hook_invalid_json(tmp_path):
"""Invalid stdin JSON should not crash — falls back to empty dict."""
with patch("sys.stdin", io.StringIO("not valid json")):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("session-start", "claude-code")
mock_output.assert_called_once_with({})
+45
View File
@@ -0,0 +1,45 @@
"""Tests for mempalace.instructions_cli — instruction text output."""
from unittest.mock import patch
import pytest
from mempalace.instructions_cli import AVAILABLE, INSTRUCTIONS_DIR, run_instructions
def test_run_instructions_valid_name(capsys):
"""Valid name prints the .md file content."""
name = "init"
expected = (INSTRUCTIONS_DIR / f"{name}.md").read_text()
run_instructions(name)
captured = capsys.readouterr()
assert captured.out.strip() == expected.strip()
def test_run_instructions_all_available(capsys):
"""Every name in AVAILABLE should succeed without error."""
for name in AVAILABLE:
run_instructions(name)
out = capsys.readouterr().out
assert len(out) > 0
def test_run_instructions_invalid_name(capsys):
"""Invalid name should sys.exit(1) and print error to stderr."""
with pytest.raises(SystemExit) as exc_info:
run_instructions("nonexistent")
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "Unknown instructions: nonexistent" in captured.err
assert "Available:" in captured.err
def test_run_instructions_missing_md_file(capsys, tmp_path):
"""If the .md file is missing on disk, should sys.exit(1)."""
with patch("mempalace.instructions_cli.INSTRUCTIONS_DIR", tmp_path):
with patch("mempalace.instructions_cli.AVAILABLE", ["fakecmd"]):
with pytest.raises(SystemExit) as exc_info:
run_instructions("fakecmd")
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "Instructions file not found" in captured.err
+105
View File
@@ -0,0 +1,105 @@
"""Extra knowledge graph tests for seed_from_entity_facts and query_relationship."""
import pytest
from mempalace.knowledge_graph import KnowledgeGraph
@pytest.fixture
def kg(tmp_path):
return KnowledgeGraph(db_path=str(tmp_path / "kg.db"))
class TestSeedFromEntityFacts:
def test_seed_person_with_partner(self, kg):
facts = {
"alice": {
"full_name": "Alice Smith",
"type": "person",
"gender": "female",
"partner": "bob",
"relationship": "husband",
}
}
kg.seed_from_entity_facts(facts)
stats = kg.stats()
assert stats["entities"] >= 1
results = kg.query_entity("Alice Smith", direction="outgoing")
predicates = {r["predicate"] for r in results}
assert "married_to" in predicates
assert "is_partner_of" in predicates
def test_seed_child(self, kg):
facts = {
"max": {
"full_name": "Max",
"type": "person",
"birthday": "2015-04-01",
"parent": "alice",
"relationship": "daughter",
}
}
kg.seed_from_entity_facts(facts)
results = kg.query_entity("Max", direction="outgoing")
predicates = {r["predicate"] for r in results}
assert "child_of" in predicates
assert "is_child_of" in predicates
def test_seed_sibling(self, kg):
facts = {
"emma": {
"full_name": "Emma",
"type": "person",
"relationship": "brother",
"sibling": "max",
}
}
kg.seed_from_entity_facts(facts)
results = kg.query_entity("Emma", direction="outgoing")
predicates = {r["predicate"] for r in results}
assert "is_sibling_of" in predicates
def test_seed_dog(self, kg):
facts = {
"rex": {
"full_name": "Rex",
"type": "animal",
"relationship": "dog",
"owner": "alice",
}
}
kg.seed_from_entity_facts(facts)
results = kg.query_entity("Rex", direction="outgoing")
predicates = {r["predicate"] for r in results}
assert "is_pet_of" in predicates
def test_seed_with_interests(self, kg):
facts = {
"max": {
"full_name": "Max",
"type": "person",
"interests": ["swimming", "chess"],
}
}
kg.seed_from_entity_facts(facts)
results = kg.query_entity("Max", direction="outgoing")
objects = {r["object"] for r in results if r["predicate"] == "loves"}
assert "Swimming" in objects
assert "Chess" in objects
def test_seed_minimal_facts(self, kg):
"""Facts with no relationships just create entities."""
facts = {"bob": {"full_name": "Bob"}}
kg.seed_from_entity_facts(facts)
stats = kg.stats()
assert stats["entities"] >= 1
class TestQueryRelationshipWithTime:
def test_query_relationship_with_as_of(self, kg):
kg.add_triple("Alice", "works_at", "Acme", valid_from="2020-01-01", valid_to="2024-12-31")
kg.add_triple("Alice", "works_at", "NewCo", valid_from="2025-01-01")
results = kg.query_relationship("works_at", as_of="2023-06-01")
objects = [r["object"] for r in results]
assert "Acme" in objects
assert "NewCo" not in objects
+719
View File
@@ -0,0 +1,719 @@
"""Tests for mempalace.layers — Layer0, Layer1, Layer2, Layer3, MemoryStack."""
import os
from unittest.mock import MagicMock, patch
from mempalace.layers import Layer0, Layer1, Layer2, Layer3, MemoryStack
# ── Layer0 — with identity file ─────────────────────────────────────────
def test_layer0_reads_identity_file(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas, a personal AI assistant for Alice.")
layer = Layer0(identity_path=str(identity_file))
text = layer.render()
assert "Atlas" in text
assert "Alice" in text
def test_layer0_caches_text(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("Hello world")
layer = Layer0(identity_path=str(identity_file))
first = layer.render()
identity_file.write_text("Changed content")
second = layer.render()
assert first == second
assert second == "Hello world"
def test_layer0_missing_file_returns_default(tmp_path):
missing = str(tmp_path / "nonexistent.txt")
layer = Layer0(identity_path=missing)
text = layer.render()
assert "No identity configured" in text
assert "identity.txt" in text
def test_layer0_token_estimate(tmp_path):
identity_file = tmp_path / "identity.txt"
content = "A" * 400
identity_file.write_text(content)
layer = Layer0(identity_path=str(identity_file))
estimate = layer.token_estimate()
assert estimate == 100
def test_layer0_token_estimate_empty(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("")
layer = Layer0(identity_path=str(identity_file))
assert layer.token_estimate() == 0
def test_layer0_strips_whitespace(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text(" Hello world \n\n")
layer = Layer0(identity_path=str(identity_file))
text = layer.render()
assert text == "Hello world"
def test_layer0_default_path():
layer = Layer0()
expected = os.path.expanduser("~/.mempalace/identity.txt")
assert layer.path == expected
# ── Layer1 — mocked chromadb ────────────────────────────────────────────
def _mock_chromadb_for_layer(docs, metas, monkeypatch=None):
"""Return a mock PersistentClient whose collection.get returns docs/metas."""
mock_col = MagicMock()
# First batch returns data, second batch returns empty (end of pagination)
mock_col.get.side_effect = [
{"documents": docs, "metadatas": metas},
{"documents": [], "metadatas": []},
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
return mock_client
def test_layer1_no_palace():
"""Layer1 returns helpful message when no palace exists."""
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent/palace"
layer = Layer1(palace_path="/nonexistent/palace")
result = layer.generate()
assert "No palace found" in result or "No memories" in result
def test_layer1_generates_essential_story():
docs = [
"Important memory about project decisions",
"Key architectural choice for the backend",
]
metas = [
{"room": "decisions", "source_file": "meeting.txt", "importance": 5},
{"room": "architecture", "source_file": "design.txt", "importance": 4},
]
mock_client = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
result = layer.generate()
assert "ESSENTIAL STORY" in result
assert "project decisions" in result
def test_layer1_empty_palace():
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
result = layer.generate()
assert "No memories" in result
def test_layer1_with_wing_filter():
docs = ["Memory about project X"]
metas = [{"room": "general", "source_file": "x.txt", "importance": 3}]
mock_client = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake", wing="project_x")
result = layer.generate()
assert "ESSENTIAL STORY" in result
# Verify wing filter was passed
call_kwargs = mock_client.get_collection.return_value.get.call_args_list[0][1]
assert call_kwargs.get("where") == {"wing": "project_x"}
def test_layer1_truncates_long_snippets():
docs = ["A" * 300]
metas = [{"room": "general", "source_file": "long.txt"}]
mock_client = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
result = layer.generate()
assert "..." in result
def test_layer1_respects_max_chars():
"""L1 stops adding entries once MAX_CHARS is reached."""
docs = [f"Memory number {i} with substantial content padding here" for i in range(30)]
metas = [{"room": "general", "source_file": f"f{i}.txt", "importance": 5} for i in range(30)]
mock_client = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
layer.MAX_CHARS = 200 # Very low cap to trigger truncation
result = layer.generate()
assert "more in L3 search" in result
def test_layer1_importance_from_various_keys():
"""Layer1 tries importance, emotional_weight, weight keys."""
docs = ["mem1", "mem2", "mem3"]
metas = [
{"room": "r", "emotional_weight": 5},
{"room": "r", "weight": 1},
{"room": "r"}, # no weight key, defaults to 3
]
mock_client = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
result = layer.generate()
assert "ESSENTIAL STORY" in result
def test_layer1_batch_exception_breaks():
"""If col.get raises on a batch, loop breaks gracefully."""
mock_col = MagicMock()
mock_col.get.side_effect = [
{"documents": ["doc1"], "metadatas": [{"room": "r"}]},
RuntimeError("batch error"),
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
result = layer.generate()
assert "ESSENTIAL STORY" in result
# ── Layer2 — mocked chromadb ────────────────────────────────────────────
def test_layer2_no_palace():
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent/palace"
layer = Layer2(palace_path="/nonexistent/palace")
result = layer.retrieve(wing="test")
assert "No palace found" in result
def test_layer2_retrieve_with_wing():
mock_col = MagicMock()
mock_col.get.return_value = {
"documents": ["Some memory about the project"],
"metadatas": [{"room": "backend", "source_file": "notes.txt"}],
}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(wing="project")
assert "ON-DEMAND" in result
assert "memory about the project" in result
def test_layer2_retrieve_with_room():
mock_col = MagicMock()
mock_col.get.return_value = {
"documents": ["Backend architecture notes"],
"metadatas": [{"room": "architecture", "source_file": "arch.txt"}],
}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(room="architecture")
assert "ON-DEMAND" in result
def test_layer2_retrieve_wing_and_room():
mock_col = MagicMock()
mock_col.get.return_value = {
"documents": ["Filtered result"],
"metadatas": [{"room": "backend", "source_file": "x.txt"}],
}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(wing="proj", room="backend")
assert "ON-DEMAND" in result
call_kwargs = mock_col.get.call_args[1]
assert "$and" in call_kwargs.get("where", {})
def test_layer2_retrieve_empty():
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(wing="missing")
assert "No drawers found" in result
def test_layer2_retrieve_no_filter():
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
layer.retrieve()
# No where filter should be passed
call_kwargs = mock_col.get.call_args[1]
assert "where" not in call_kwargs
def test_layer2_retrieve_error():
mock_col = MagicMock()
mock_col.get.side_effect = RuntimeError("db error")
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(wing="test")
assert "Retrieval error" in result
def test_layer2_truncates_long_snippets():
mock_col = MagicMock()
mock_col.get.return_value = {
"documents": ["B" * 400],
"metadatas": [{"room": "r", "source_file": "s.txt"}],
}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(wing="test")
assert "..." in result
# ── Layer3 — mocked chromadb ────────────────────────────────────────────
def _mock_query_results(docs, metas, dists):
return {
"documents": [docs],
"metadatas": [metas],
"distances": [dists],
}
def test_layer3_no_palace():
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent/palace"
layer = Layer3(palace_path="/nonexistent/palace")
result = layer.search("test query")
assert "No palace found" in result
def test_layer3_search_raw_no_palace():
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent/palace"
layer = Layer3(palace_path="/nonexistent/palace")
result = layer.search_raw("test query")
assert result == []
def test_layer3_search_with_results():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["Found this important memory"],
[{"wing": "project", "room": "backend", "source_file": "notes.txt"}],
[0.2],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
result = layer.search("important")
assert "SEARCH RESULTS" in result
assert "important memory" in result
assert "sim=0.8" in result
def test_layer3_search_no_results():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results([], [], [])
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
result = layer.search("nothing")
assert "No results found" in result
def test_layer3_search_with_wing_filter():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["result"],
[{"wing": "proj", "room": "r"}],
[0.1],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
layer.search("q", wing="proj")
call_kwargs = mock_col.query.call_args[1]
assert call_kwargs["where"] == {"wing": "proj"}
def test_layer3_search_with_room_filter():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["result"],
[{"wing": "w", "room": "backend"}],
[0.1],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
layer.search("q", room="backend")
call_kwargs = mock_col.query.call_args[1]
assert call_kwargs["where"] == {"room": "backend"}
def test_layer3_search_with_wing_and_room():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["result"],
[{"wing": "proj", "room": "backend"}],
[0.1],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
layer.search("q", wing="proj", room="backend")
call_kwargs = mock_col.query.call_args[1]
assert "$and" in call_kwargs["where"]
def test_layer3_search_error():
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("search failed")
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
result = layer.search("q")
assert "Search error" in result
def test_layer3_search_truncates_long_docs():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["C" * 400],
[{"wing": "w", "room": "r", "source_file": "s.txt"}],
[0.1],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
result = layer.search("q")
assert "..." in result
def test_layer3_search_raw_returns_dicts():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["doc text"],
[{"wing": "proj", "room": "backend", "source_file": "f.txt"}],
[0.3],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
hits = layer.search_raw("q")
assert len(hits) == 1
assert hits[0]["text"] == "doc text"
assert hits[0]["wing"] == "proj"
assert hits[0]["similarity"] == 0.7
assert "metadata" in hits[0]
def test_layer3_search_raw_with_filters():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["doc"],
[{"wing": "w", "room": "r"}],
[0.1],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
layer.search_raw("q", wing="w", room="r")
call_kwargs = mock_col.query.call_args[1]
assert "$and" in call_kwargs["where"]
def test_layer3_search_raw_error():
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("fail")
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
result = layer.search_raw("q")
assert result == []
# ── MemoryStack ─────────────────────────────────────────────────────────
def test_memory_stack_wake_up(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent"
stack = MemoryStack(
palace_path="/nonexistent",
identity_path=str(identity_file),
)
result = stack.wake_up()
assert "Atlas" in result
# L1 will say no palace found
assert "No palace" in result or "No memories" in result
def test_memory_stack_wake_up_with_wing(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent"
stack = MemoryStack(
palace_path="/nonexistent",
identity_path=str(identity_file),
)
result = stack.wake_up(wing="my_project")
assert stack.l1.wing == "my_project"
assert "Atlas" in result
def test_memory_stack_recall(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent"
stack = MemoryStack(
palace_path="/nonexistent",
identity_path=str(identity_file),
)
result = stack.recall(wing="test")
assert "No palace found" in result
def test_memory_stack_search(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent"
stack = MemoryStack(
palace_path="/nonexistent",
identity_path=str(identity_file),
)
result = stack.search("test query")
assert "No palace found" in result
def test_memory_stack_status(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent"
stack = MemoryStack(
palace_path="/nonexistent",
identity_path=str(identity_file),
)
result = stack.status()
assert result["palace_path"] == "/nonexistent"
assert result["total_drawers"] == 0
assert "L0_identity" in result
assert "L1_essential" in result
assert "L2_on_demand" in result
assert "L3_deep_search" in result
def test_memory_stack_status_with_palace(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
mock_col = MagicMock()
mock_col.count.return_value = 42
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
stack = MemoryStack(
palace_path="/fake",
identity_path=str(identity_file),
)
result = stack.status()
assert result["total_drawers"] == 42
assert result["L0_identity"]["exists"] is True
+61
View File
@@ -42,6 +42,50 @@ class TestHandleRequest:
assert resp["result"]["serverInfo"]["name"] == "mempalace"
assert resp["id"] == 1
def test_initialize_negotiates_client_version(self):
from mempalace.mcp_server import handle_request
resp = handle_request(
{
"method": "initialize",
"id": 1,
"params": {"protocolVersion": "2025-11-25"},
}
)
assert resp["result"]["protocolVersion"] == "2025-11-25"
def test_initialize_negotiates_older_supported_version(self):
from mempalace.mcp_server import handle_request
resp = handle_request(
{
"method": "initialize",
"id": 1,
"params": {"protocolVersion": "2025-03-26"},
}
)
assert resp["result"]["protocolVersion"] == "2025-03-26"
def test_initialize_unknown_version_falls_back_to_latest(self):
from mempalace.mcp_server import handle_request
resp = handle_request(
{
"method": "initialize",
"id": 1,
"params": {"protocolVersion": "9999-12-31"},
}
)
from mempalace.mcp_server import SUPPORTED_PROTOCOL_VERSIONS
assert resp["result"]["protocolVersion"] == SUPPORTED_PROTOCOL_VERSIONS[0]
def test_initialize_missing_version_uses_oldest(self):
from mempalace.mcp_server import handle_request, SUPPORTED_PROTOCOL_VERSIONS
resp = handle_request({"method": "initialize", "id": 1, "params": {}})
assert resp["result"]["protocolVersion"] == SUPPORTED_PROTOCOL_VERSIONS[-1]
def test_notifications_initialized_returns_none(self):
from mempalace.mcp_server import handle_request
@@ -59,6 +103,23 @@ class TestHandleRequest:
assert "mempalace_add_drawer" in names
assert "mempalace_kg_add" in names
def test_null_arguments_does_not_hang(self, monkeypatch, config, palace_path, seeded_kg):
"""Sending arguments: null should return a result, not hang (#394)."""
_patch_mcp_server(monkeypatch, config, seeded_kg)
from mempalace.mcp_server import handle_request
_client, _col = _get_collection(palace_path, create=True)
del _client
resp = handle_request(
{
"method": "tools/call",
"id": 10,
"params": {"name": "mempalace_status", "arguments": None},
}
)
assert "error" not in resp
assert resp["result"] is not None
def test_unknown_tool(self):
from mempalace.mcp_server import handle_request
+55 -1
View File
@@ -7,6 +7,7 @@ import chromadb
import yaml
from mempalace.miner import mine, scan_project
from mempalace.palace import file_already_mined
def write_file(path: Path, content: str):
@@ -47,7 +48,7 @@ def test_project_mining():
col = client.get_collection("mempalace_drawers")
assert col.count() > 0
finally:
shutil.rmtree(tmpdir)
shutil.rmtree(tmpdir, ignore_errors=True)
def test_scan_project_respects_gitignore():
@@ -206,3 +207,56 @@ def test_scan_project_skip_dirs_still_apply_without_override():
assert scanned_files(project_root, respect_gitignore=False) == ["main.py"]
finally:
shutil.rmtree(tmpdir)
def test_file_already_mined_check_mtime():
tmpdir = tempfile.mkdtemp()
try:
palace_path = os.path.join(tmpdir, "palace")
os.makedirs(palace_path)
client = chromadb.PersistentClient(path=palace_path)
col = client.get_or_create_collection("mempalace_drawers")
test_file = os.path.join(tmpdir, "test.txt")
with open(test_file, "w") as f:
f.write("hello world")
mtime = os.path.getmtime(test_file)
# Not mined yet
assert file_already_mined(col, test_file) is False
assert file_already_mined(col, test_file, check_mtime=True) is False
# Add it with mtime
col.add(
ids=["d1"],
documents=["hello world"],
metadatas=[{"source_file": test_file, "source_mtime": str(mtime)}],
)
# Already mined (no mtime check)
assert file_already_mined(col, test_file) is True
# Already mined (mtime matches)
assert file_already_mined(col, test_file, check_mtime=True) is True
# Modify file and force a different mtime (Windows has low mtime resolution)
with open(test_file, "w") as f:
f.write("modified content")
os.utime(test_file, (mtime + 10, mtime + 10))
# Still mined without mtime check
assert file_already_mined(col, test_file) is True
# Needs re-mining with mtime check
assert file_already_mined(col, test_file, check_mtime=True) is False
# Record with no mtime stored should return False for check_mtime
col.add(
ids=["d2"],
documents=["other"],
metadatas=[{"source_file": "/fake/no_mtime.txt"}],
)
assert file_already_mined(col, "/fake/no_mtime.txt", check_mtime=True) is False
finally:
# Release ChromaDB file handles before cleanup (required on Windows)
del col, client
shutil.rmtree(tmpdir, ignore_errors=True)
+500 -20
View File
@@ -1,31 +1,511 @@
import os
import json
import tempfile
from mempalace.normalize import normalize
from unittest.mock import patch
from mempalace.normalize import (
_extract_content,
_messages_to_transcript,
_try_chatgpt_json,
_try_claude_ai_json,
_try_claude_code_jsonl,
_try_codex_jsonl,
_try_normalize_json,
_try_slack_json,
normalize,
)
def test_plain_text():
f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False)
f.write("Hello world\nSecond line\n")
f.close()
result = normalize(f.name)
# ── normalize() top-level ──────────────────────────────────────────────
def test_plain_text(tmp_path):
f = tmp_path / "plain.txt"
f.write_text("Hello world\nSecond line\n")
result = normalize(str(f))
assert "Hello world" in result
os.unlink(f.name)
def test_claude_json():
def test_claude_json(tmp_path):
data = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}]
f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
json.dump(data, f)
f.close()
result = normalize(f.name)
f = tmp_path / "claude.json"
f.write_text(json.dumps(data))
result = normalize(str(f))
assert "Hi" in result
os.unlink(f.name)
def test_empty():
f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False)
f.close()
result = normalize(f.name)
def test_empty(tmp_path):
f = tmp_path / "empty.txt"
f.write_text("")
result = normalize(str(f))
assert result.strip() == ""
os.unlink(f.name)
def test_normalize_io_error():
"""normalize raises IOError for unreadable file."""
try:
normalize("/nonexistent/path/file.txt")
assert False, "Should have raised"
except IOError as e:
assert "Could not read" in str(e)
def test_normalize_already_has_markers(tmp_path):
"""Files with >= 3 '>' lines pass through unchanged."""
content = "> question 1\nanswer 1\n> question 2\nanswer 2\n> question 3\nanswer 3\n"
f = tmp_path / "markers.txt"
f.write_text(content)
result = normalize(str(f))
assert result == content
def test_normalize_json_content_detected_by_brace(tmp_path):
"""A .txt file starting with [ triggers JSON parsing."""
data = [{"role": "user", "content": "Hey"}, {"role": "assistant", "content": "Hi there"}]
f = tmp_path / "chat.txt"
f.write_text(json.dumps(data))
result = normalize(str(f))
assert "Hey" in result
def test_normalize_whitespace_only(tmp_path):
f = tmp_path / "ws.txt"
f.write_text(" \n \n ")
result = normalize(str(f))
assert result.strip() == ""
# ── _extract_content ───────────────────────────────────────────────────
def test_extract_content_string():
assert _extract_content("hello") == "hello"
def test_extract_content_list_of_strings():
assert _extract_content(["hello", "world"]) == "hello world"
def test_extract_content_list_of_blocks():
blocks = [{"type": "text", "text": "hello"}, {"type": "image", "url": "x"}]
assert _extract_content(blocks) == "hello"
def test_extract_content_dict():
assert _extract_content({"text": "hello"}) == "hello"
def test_extract_content_none():
assert _extract_content(None) == ""
def test_extract_content_mixed_list():
blocks = ["plain", {"type": "text", "text": "block"}]
assert _extract_content(blocks) == "plain block"
# ── _try_claude_code_jsonl ─────────────────────────────────────────────
def test_claude_code_jsonl_valid():
lines = [
json.dumps({"type": "human", "message": {"content": "What is X?"}}),
json.dumps({"type": "assistant", "message": {"content": "X is Y."}}),
]
result = _try_claude_code_jsonl("\n".join(lines))
assert result is not None
assert "> What is X?" in result
assert "X is Y." in result
def test_claude_code_jsonl_user_type():
lines = [
json.dumps({"type": "user", "message": {"content": "Q"}}),
json.dumps({"type": "assistant", "message": {"content": "A"}}),
]
result = _try_claude_code_jsonl("\n".join(lines))
assert result is not None
assert "> Q" in result
def test_claude_code_jsonl_too_few_messages():
lines = [json.dumps({"type": "human", "message": {"content": "only one"}})]
result = _try_claude_code_jsonl("\n".join(lines))
assert result is None
def test_claude_code_jsonl_invalid_json_lines():
lines = [
"not json",
json.dumps({"type": "human", "message": {"content": "Q"}}),
json.dumps({"type": "assistant", "message": {"content": "A"}}),
]
result = _try_claude_code_jsonl("\n".join(lines))
assert result is not None
def test_claude_code_jsonl_non_dict_entries():
lines = [
json.dumps([1, 2, 3]),
json.dumps({"type": "human", "message": {"content": "Q"}}),
json.dumps({"type": "assistant", "message": {"content": "A"}}),
]
result = _try_claude_code_jsonl("\n".join(lines))
assert result is not None
# ── _try_codex_jsonl ───────────────────────────────────────────────────
def test_codex_jsonl_valid():
lines = [
json.dumps({"type": "session_meta", "payload": {}}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is not None
assert "> Q" in result
def test_codex_jsonl_no_session_meta():
"""Without session_meta, codex parser returns None."""
lines = [
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is None
def test_codex_jsonl_skips_non_event_msg():
lines = [
json.dumps({"type": "session_meta"}),
json.dumps({"type": "response_item", "payload": {"type": "user_message", "message": "X"}}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is not None
assert "X" not in result.split("> Q")[0]
def test_codex_jsonl_non_string_message():
lines = [
json.dumps({"type": "session_meta"}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": 123}}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is not None
def test_codex_jsonl_empty_text_skipped():
lines = [
json.dumps({"type": "session_meta"}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": " "}}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is not None
def test_codex_jsonl_payload_not_dict():
lines = [
json.dumps({"type": "session_meta"}),
json.dumps({"type": "event_msg", "payload": "not a dict"}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is not None
# ── _try_claude_ai_json ───────────────────────────────────────────────
def test_claude_ai_flat_messages():
data = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
result = _try_claude_ai_json(data)
assert result is not None
assert "> Hello" in result
def test_claude_ai_dict_with_messages_key():
data = {
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
}
result = _try_claude_ai_json(data)
assert result is not None
def test_claude_ai_privacy_export():
data = [
{
"chat_messages": [
{"role": "human", "content": "Q1"},
{"role": "ai", "content": "A1"},
]
}
]
result = _try_claude_ai_json(data)
assert result is not None
assert "> Q1" in result
def test_claude_ai_not_a_list():
result = _try_claude_ai_json("not a list")
assert result is None
def test_claude_ai_too_few_messages():
data = [{"role": "user", "content": "Hello"}]
result = _try_claude_ai_json(data)
assert result is None
def test_claude_ai_dict_with_chat_messages_key():
data = {
"chat_messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "World"},
]
}
result = _try_claude_ai_json(data)
assert result is not None
def test_claude_ai_privacy_export_non_dict_items():
"""Non-dict items in privacy export are skipped."""
data = [
{
"chat_messages": [
"not a dict",
{"role": "user", "content": "Q"},
{"role": "assistant", "content": "A"},
]
},
"not a convo",
]
result = _try_claude_ai_json(data)
assert result is not None
# ── _try_chatgpt_json ─────────────────────────────────────────────────
def test_chatgpt_json_valid():
data = {
"mapping": {
"root": {
"parent": None,
"message": None,
"children": ["msg1"],
},
"msg1": {
"parent": "root",
"message": {
"author": {"role": "user"},
"content": {"parts": ["Hello ChatGPT"]},
},
"children": ["msg2"],
},
"msg2": {
"parent": "msg1",
"message": {
"author": {"role": "assistant"},
"content": {"parts": ["Hello! How can I help?"]},
},
"children": [],
},
}
}
result = _try_chatgpt_json(data)
assert result is not None
assert "> Hello ChatGPT" in result
def test_chatgpt_json_no_mapping():
result = _try_chatgpt_json({"data": []})
assert result is None
def test_chatgpt_json_not_dict():
result = _try_chatgpt_json([1, 2, 3])
assert result is None
def test_chatgpt_json_fallback_root():
"""Root node has a message (no synthetic root), uses fallback."""
data = {
"mapping": {
"root": {
"parent": None,
"message": {
"author": {"role": "system"},
"content": {"parts": ["system prompt"]},
},
"children": ["msg1"],
},
"msg1": {
"parent": "root",
"message": {
"author": {"role": "user"},
"content": {"parts": ["Hello"]},
},
"children": ["msg2"],
},
"msg2": {
"parent": "msg1",
"message": {
"author": {"role": "assistant"},
"content": {"parts": ["Hi there"]},
},
"children": [],
},
}
}
result = _try_chatgpt_json(data)
assert result is not None
def test_chatgpt_json_too_few_messages():
data = {
"mapping": {
"root": {
"parent": None,
"message": None,
"children": ["msg1"],
},
"msg1": {
"parent": "root",
"message": {
"author": {"role": "user"},
"content": {"parts": ["Only one"]},
},
"children": [],
},
}
}
result = _try_chatgpt_json(data)
assert result is None
# ── _try_slack_json ────────────────────────────────────────────────────
def test_slack_json_valid():
data = [
{"type": "message", "user": "U1", "text": "Hello"},
{"type": "message", "user": "U2", "text": "Hi there"},
]
result = _try_slack_json(data)
assert result is not None
assert "Hello" in result
def test_slack_json_not_a_list():
result = _try_slack_json({"type": "message"})
assert result is None
def test_slack_json_too_few_messages():
data = [{"type": "message", "user": "U1", "text": "Hello"}]
result = _try_slack_json(data)
assert result is None
def test_slack_json_skips_non_message_types():
data = [
{"type": "channel_join", "user": "U1", "text": "joined"},
{"type": "message", "user": "U1", "text": "Hello"},
{"type": "message", "user": "U2", "text": "Hi"},
]
result = _try_slack_json(data)
assert result is not None
def test_slack_json_three_users():
"""Three speakers get alternating roles."""
data = [
{"type": "message", "user": "U1", "text": "Hello"},
{"type": "message", "user": "U2", "text": "Hi"},
{"type": "message", "user": "U3", "text": "Hey"},
]
result = _try_slack_json(data)
assert result is not None
def test_slack_json_empty_text_skipped():
data = [
{"type": "message", "user": "U1", "text": ""},
{"type": "message", "user": "U1", "text": "Hello"},
{"type": "message", "user": "U2", "text": "Hi"},
]
result = _try_slack_json(data)
assert result is not None
def test_slack_json_username_fallback():
data = [
{"type": "message", "username": "bot1", "text": "Hello"},
{"type": "message", "username": "bot2", "text": "Hi"},
]
result = _try_slack_json(data)
assert result is not None
# ── _try_normalize_json ────────────────────────────────────────────────
def test_try_normalize_json_invalid_json():
result = _try_normalize_json("not json at all {{{")
assert result is None
def test_try_normalize_json_valid_but_unknown_schema():
result = _try_normalize_json(json.dumps({"random": "data"}))
assert result is None
# ── _messages_to_transcript ────────────────────────────────────────────
def test_messages_to_transcript_basic():
msgs = [("user", "Q"), ("assistant", "A")]
with patch("mempalace.normalize.spellcheck_user_text", side_effect=lambda x: x, create=True):
result = _messages_to_transcript(msgs, spellcheck=False)
assert "> Q" in result
assert "A" in result
def test_messages_to_transcript_consecutive_users():
"""Two user messages in a row (no assistant between)."""
msgs = [("user", "Q1"), ("user", "Q2"), ("assistant", "A")]
result = _messages_to_transcript(msgs, spellcheck=False)
assert "> Q1" in result
assert "> Q2" in result
def test_messages_to_transcript_assistant_first():
"""Leading assistant message (no user before it)."""
msgs = [("assistant", "preamble"), ("user", "Q"), ("assistant", "A")]
result = _messages_to_transcript(msgs, spellcheck=False)
assert "preamble" in result
assert "> Q" in result
def test_normalize_rejects_large_file():
"""Files over 500 MB should raise IOError before reading."""
with patch("mempalace.normalize.os.path.getsize", return_value=600 * 1024 * 1024):
try:
normalize("/fake/huge_file.txt")
assert False, "Should have raised IOError"
except IOError as e:
assert "too large" in str(e).lower()
+452
View File
@@ -0,0 +1,452 @@
"""Tests for mempalace.onboarding."""
import os
from unittest.mock import patch
from mempalace.onboarding import (
DEFAULT_WINGS,
_ask,
_ask_mode,
_ask_people,
_ask_projects,
_ask_wings,
_auto_detect,
_generate_aaak_bootstrap,
_header,
_hr,
_warn_ambiguous,
_yn,
quick_setup,
run_onboarding,
)
# Force UTF-8 for Windows (source file contains Unicode symbols like hearts/stars)
os.environ["PYTHONUTF8"] = "1"
# ── DEFAULT_WINGS ───────────────────────────────────────────────────────
def test_default_wings_has_expected_keys():
assert "work" in DEFAULT_WINGS
assert "personal" in DEFAULT_WINGS
assert "combo" in DEFAULT_WINGS
def test_default_wings_work_has_projects():
assert "projects" in DEFAULT_WINGS["work"]
def test_default_wings_personal_has_family():
assert "family" in DEFAULT_WINGS["personal"]
def test_default_wings_combo_has_both():
wings = DEFAULT_WINGS["combo"]
assert "family" in wings
assert "work" in wings
def test_default_wings_values_are_lists():
for mode, wings in DEFAULT_WINGS.items():
assert isinstance(wings, list), f"{mode} wings should be a list"
assert len(wings) >= 3, f"{mode} should have at least 3 wings"
# ── _warn_ambiguous ─────────────────────────────────────────────────────
def test_warn_ambiguous_flags_common_words():
people = [
{"name": "Grace", "relationship": "friend"},
{"name": "Riley", "relationship": "daughter"},
]
result = _warn_ambiguous(people)
assert "Grace" in result
# Riley is not a common English word
assert "Riley" not in result
def test_warn_ambiguous_empty_list():
result = _warn_ambiguous([])
assert result == []
def test_warn_ambiguous_no_ambiguous_names():
people = [
{"name": "Riley", "relationship": "daughter"},
{"name": "Devon", "relationship": "friend"},
]
result = _warn_ambiguous(people)
assert result == []
def test_warn_ambiguous_multiple_hits():
people = [
{"name": "Grace", "relationship": "friend"},
{"name": "May", "relationship": "aunt"},
{"name": "Joy", "relationship": "sister"},
]
result = _warn_ambiguous(people)
assert "Grace" in result
assert "May" in result
assert "Joy" in result
# ── quick_setup ─────────────────────────────────────────────────────────
def test_quick_setup_creates_registry(tmp_path):
registry = quick_setup(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
projects=["MemPalace"],
config_dir=tmp_path,
)
assert "Riley" in registry.people
assert "MemPalace" in registry.projects
assert registry.mode == "personal"
def test_quick_setup_work_mode(tmp_path):
registry = quick_setup(
mode="work",
people=[{"name": "Alice", "relationship": "colleague", "context": "work"}],
projects=["Acme"],
config_dir=tmp_path,
)
assert registry.mode == "work"
assert "Alice" in registry.people
assert "Acme" in registry.projects
def test_quick_setup_empty(tmp_path):
registry = quick_setup(mode="personal", people=[], config_dir=tmp_path)
assert len(registry.people) == 0
assert len(registry.projects) == 0
def test_quick_setup_saves_to_disk(tmp_path):
quick_setup(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
config_dir=tmp_path,
)
assert (tmp_path / "entity_registry.json").exists()
# ── _generate_aaak_bootstrap ───────────────────────────────────────────
def test_generate_aaak_bootstrap_creates_files(tmp_path):
people = [
{"name": "Riley", "relationship": "daughter", "context": "personal"},
{"name": "Devon", "relationship": "friend", "context": "personal"},
]
projects = ["MemPalace"]
wings = ["family", "creative"]
_generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path)
assert (tmp_path / "aaak_entities.md").exists()
assert (tmp_path / "critical_facts.md").exists()
def test_generate_aaak_bootstrap_entities_content(tmp_path):
people = [{"name": "Riley", "relationship": "daughter", "context": "personal"}]
projects = ["MemPalace"]
wings = ["family"]
_generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path)
content = (tmp_path / "aaak_entities.md").read_text()
assert "Riley" in content
assert "RIL" in content # entity code
assert "MemPalace" in content
def test_generate_aaak_bootstrap_facts_content(tmp_path):
people = [
{"name": "Alice", "relationship": "colleague", "context": "work"},
]
projects = ["Acme"]
wings = ["projects"]
_generate_aaak_bootstrap(people, projects, wings, "work", config_dir=tmp_path)
content = (tmp_path / "critical_facts.md").read_text()
assert "Alice" in content
assert "Acme" in content
assert "work" in content.lower()
def test_generate_aaak_bootstrap_empty_people(tmp_path):
_generate_aaak_bootstrap([], [], ["general"], "personal", config_dir=tmp_path)
assert (tmp_path / "aaak_entities.md").exists()
assert (tmp_path / "critical_facts.md").exists()
def test_generate_aaak_bootstrap_collision(tmp_path):
"""Two people with same 3-letter code get different codes."""
people = [
{"name": "Alice", "relationship": "friend", "context": "work"},
{"name": "Alison", "relationship": "coworker", "context": "work"},
]
_generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path)
content = (tmp_path / "aaak_entities.md").read_text()
assert "ALI" in content
assert "ALIS" in content
def test_generate_aaak_bootstrap_no_relationship(tmp_path):
"""Person without relationship string still generates entry."""
people = [{"name": "Bob", "context": "work"}]
_generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path)
content = (tmp_path / "aaak_entities.md").read_text()
assert "BOB=Bob" in content
# ── _hr, _header ──────────────────────────────────────────────────────
def test_hr_prints_line(capsys):
_hr()
out = capsys.readouterr().out
assert "" in out
def test_header_prints_banner(capsys):
_header("Test Title")
out = capsys.readouterr().out
assert "Test Title" in out
assert "=" in out
# ── _ask ──────────────────────────────────────────────────────────────
def test_ask_with_default_uses_default():
with patch("builtins.input", return_value=""):
result = _ask("prompt", default="fallback")
assert result == "fallback"
def test_ask_with_default_uses_input():
with patch("builtins.input", return_value="custom"):
result = _ask("prompt", default="fallback")
assert result == "custom"
def test_ask_no_default():
with patch("builtins.input", return_value="answer"):
result = _ask("prompt")
assert result == "answer"
# ── _yn ───────────────────────────────────────────────────────────────
def test_yn_default_yes_empty_input():
with patch("builtins.input", return_value=""):
assert _yn("continue?") is True
def test_yn_default_no_empty_input():
with patch("builtins.input", return_value=""):
assert _yn("continue?", default="n") is False
def test_yn_explicit_yes():
with patch("builtins.input", return_value="yes"):
assert _yn("continue?", default="n") is True
def test_yn_explicit_no():
with patch("builtins.input", return_value="no"):
assert _yn("continue?") is False
# ── _ask_mode ─────────────────────────────────────────────────────────
def test_ask_mode_work():
with patch("builtins.input", return_value="1"):
assert _ask_mode() == "work"
def test_ask_mode_personal():
with patch("builtins.input", return_value="2"):
assert _ask_mode() == "personal"
def test_ask_mode_combo():
with patch("builtins.input", return_value="3"):
assert _ask_mode() == "combo"
def test_ask_mode_retries_on_bad_input():
with patch("builtins.input", side_effect=["x", "bad", "1"]):
assert _ask_mode() == "work"
# ── _ask_people ───────────────────────────────────────────────────────
def test_ask_people_personal_mode():
with patch("builtins.input", side_effect=["Alice, daughter", "", "done"]):
people, aliases = _ask_people("personal")
assert len(people) == 1
assert people[0]["name"] == "Alice"
assert people[0]["relationship"] == "daughter"
def test_ask_people_work_mode():
with patch("builtins.input", side_effect=["Bob, manager", "", "done"]):
people, aliases = _ask_people("work")
assert len(people) == 1
assert people[0]["name"] == "Bob"
assert people[0]["context"] == "work"
def test_ask_people_combo_mode():
with patch(
"builtins.input",
side_effect=[
"Alice, daughter",
"",
"done", # personal
"Bob, boss",
"done", # work
],
):
people, aliases = _ask_people("combo")
assert len(people) == 2
def test_ask_people_with_nickname():
with patch("builtins.input", side_effect=["Alice, daughter", "Ali", "done"]):
people, aliases = _ask_people("personal")
assert aliases == {"Ali": "Alice"}
def test_ask_people_empty_name_skipped():
with patch("builtins.input", side_effect=["", "done"]):
people, aliases = _ask_people("personal")
assert len(people) == 0
# ── _ask_projects ─────────────────────────────────────────────────────
def test_ask_projects_personal_returns_empty():
result = _ask_projects("personal")
assert result == []
def test_ask_projects_work_mode():
with patch("builtins.input", side_effect=["Acme", "BigCo", "done"]):
result = _ask_projects("work")
assert result == ["Acme", "BigCo"]
def test_ask_projects_empty_entry_stops():
with patch("builtins.input", side_effect=["Acme", ""]):
result = _ask_projects("work")
assert result == ["Acme"]
# ── _ask_wings ────────────────────────────────────────────────────────
def test_ask_wings_accept_defaults():
with patch("builtins.input", return_value=""):
result = _ask_wings("work")
assert result == DEFAULT_WINGS["work"]
def test_ask_wings_custom():
with patch("builtins.input", return_value="alpha, beta, gamma"):
result = _ask_wings("personal")
assert result == ["alpha", "beta", "gamma"]
# ── _auto_detect ──────────────────────────────────────────────────────
def test_auto_detect_no_files(tmp_path):
result = _auto_detect(str(tmp_path), [])
assert result == []
def test_auto_detect_filters_known(tmp_path):
known = [{"name": "Alice"}]
fake_detected = {
"people": [
{"name": "Alice", "confidence": 0.9, "signals": ["test"]},
{"name": "Bob", "confidence": 0.8, "signals": ["test"]},
],
"projects": [],
"uncertain": [],
}
with (
patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]),
patch("mempalace.onboarding.detect_entities", return_value=fake_detected),
):
result = _auto_detect(str(tmp_path), known)
names = [p["name"] for p in result]
assert "Alice" not in names
assert "Bob" in names
def test_auto_detect_filters_low_confidence(tmp_path):
fake_detected = {
"people": [{"name": "Bob", "confidence": 0.5, "signals": ["test"]}],
"projects": [],
"uncertain": [],
}
with (
patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]),
patch("mempalace.onboarding.detect_entities", return_value=fake_detected),
):
result = _auto_detect(str(tmp_path), [])
assert len(result) == 0
def test_auto_detect_handles_exception(tmp_path):
with patch("mempalace.onboarding.scan_for_detection", side_effect=Exception("boom")):
result = _auto_detect(str(tmp_path), [])
assert result == []
# ── run_onboarding ────────────────────────────────────────────────────
def test_run_onboarding_basic_flow(tmp_path):
"""Test the full onboarding flow with minimal mocking."""
with (
patch("mempalace.onboarding._ask_mode", return_value="work"),
patch(
"mempalace.onboarding._ask_people",
return_value=([{"name": "Bob", "relationship": "boss", "context": "work"}], {}),
),
patch("mempalace.onboarding._ask_projects", return_value=["Acme"]),
patch("mempalace.onboarding._ask_wings", return_value=["projects", "team"]),
patch("mempalace.onboarding._yn", return_value=False),
patch("mempalace.onboarding._warn_ambiguous", return_value=[]),
):
registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False)
assert "Bob" in registry.people
assert "Acme" in registry.projects
def test_run_onboarding_with_ambiguous_names(tmp_path):
"""Onboarding prints a warning for ambiguous names."""
with (
patch("mempalace.onboarding._ask_mode", return_value="personal"),
patch(
"mempalace.onboarding._ask_people",
return_value=([{"name": "Grace", "relationship": "friend", "context": "personal"}], {}),
),
patch("mempalace.onboarding._ask_projects", return_value=[]),
patch("mempalace.onboarding._ask_wings", return_value=["family"]),
patch("mempalace.onboarding._yn", return_value=False),
):
registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False)
assert "Grace" in registry.people
+244
View File
@@ -0,0 +1,244 @@
"""Tests for mempalace.palace_graph — graph traversal layer.
All ChromaDB access is mocked — no real database needed.
"""
from unittest.mock import MagicMock, patch
def _make_fake_collection(metadatas, ids=None):
"""Create a mock collection that returns the given metadata in batches."""
if ids is None:
ids = [f"id_{i}" for i in range(len(metadatas))]
col = MagicMock()
col.count.return_value = len(metadatas)
def fake_get(limit=1000, offset=0, include=None):
batch_meta = metadatas[offset : offset + limit]
batch_ids = ids[offset : offset + limit]
return {"ids": batch_ids, "metadatas": batch_meta}
col.get.side_effect = fake_get
return col
# Patch chromadb at import time so palace_graph can be imported
with patch.dict("sys.modules", {"chromadb": MagicMock()}):
from mempalace.palace_graph import (
_fuzzy_match,
build_graph,
find_tunnels,
graph_stats,
traverse,
)
# --- build_graph ---
class TestBuildGraph:
def test_empty_collection(self):
col = _make_fake_collection([])
nodes, edges = build_graph(col=col)
assert nodes == {}
assert edges == []
def test_falsy_collection(self):
"""When col is explicitly falsy, build_graph returns empty."""
nodes, edges = build_graph(col=0)
assert nodes == {}
assert edges == []
def test_single_wing_no_edges(self):
col = _make_fake_collection(
[
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-02"},
]
)
nodes, edges = build_graph(col=col)
assert "auth" in nodes
assert nodes["auth"]["count"] == 2
assert edges == []
def test_multi_wing_creates_edges(self):
col = _make_fake_collection(
[
{
"room": "chromadb",
"wing": "wing_code",
"hall": "databases",
"date": "2026-01-01",
},
{
"room": "chromadb",
"wing": "wing_project",
"hall": "databases",
"date": "2026-01-02",
},
]
)
nodes, edges = build_graph(col=col)
assert "chromadb" in nodes
assert len(edges) == 1
assert edges[0]["wing_a"] == "wing_code"
assert edges[0]["wing_b"] == "wing_project"
assert edges[0]["hall"] == "databases"
def test_general_room_excluded(self):
col = _make_fake_collection(
[
{"room": "general", "wing": "wing_code", "hall": "misc", "date": ""},
]
)
nodes, edges = build_graph(col=col)
assert "general" not in nodes
def test_missing_wing_excluded(self):
col = _make_fake_collection(
[
{"room": "orphan", "wing": "", "hall": "misc", "date": ""},
]
)
nodes, edges = build_graph(col=col)
assert "orphan" not in nodes
def test_dates_capped_at_five(self):
col = _make_fake_collection(
[
{"room": "busy", "wing": "w", "hall": "h", "date": f"2026-01-{i:02d}"}
for i in range(1, 10)
]
)
nodes, _ = build_graph(col=col)
assert len(nodes["busy"]["dates"]) <= 5
# --- traverse ---
class TestTraverse:
def _build_col(self):
return _make_fake_collection(
[
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
{"room": "login", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
{"room": "deploy", "wing": "wing_ops", "hall": "infra", "date": "2026-01-01"},
]
)
def test_traverse_known_room(self):
col = self._build_col()
result = traverse("auth", col=col)
assert isinstance(result, list)
rooms = [r["room"] for r in result]
assert "auth" in rooms
# login shares wing_code with auth
assert "login" in rooms
def test_traverse_unknown_room(self):
col = self._build_col()
result = traverse("nonexistent", col=col)
assert isinstance(result, dict)
assert "error" in result
assert "suggestions" in result
def test_traverse_max_hops(self):
col = self._build_col()
result = traverse("auth", col=col, max_hops=0)
# Only the start room itself at hop 0
assert len(result) == 1
assert result[0]["room"] == "auth"
# --- find_tunnels ---
class TestFindTunnels:
def _build_tunnel_col(self):
return _make_fake_collection(
[
{"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"},
{"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"},
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
]
)
def test_find_all_tunnels(self):
col = self._build_tunnel_col()
tunnels = find_tunnels(col=col)
assert len(tunnels) == 1
assert tunnels[0]["room"] == "chromadb"
def test_find_tunnels_with_wing_filter(self):
col = self._build_tunnel_col()
tunnels = find_tunnels(wing_a="wing_code", col=col)
assert len(tunnels) == 1
def test_find_tunnels_no_match(self):
col = self._build_tunnel_col()
tunnels = find_tunnels(wing_a="wing_nonexistent", col=col)
assert tunnels == []
def test_find_tunnels_both_wings(self):
col = self._build_tunnel_col()
tunnels = find_tunnels(wing_a="wing_code", wing_b="wing_project", col=col)
assert len(tunnels) == 1
assert tunnels[0]["room"] == "chromadb"
# --- graph_stats ---
class TestGraphStats:
def test_empty_graph(self):
col = _make_fake_collection([])
stats = graph_stats(col=col)
assert stats["total_rooms"] == 0
assert stats["tunnel_rooms"] == 0
assert stats["total_edges"] == 0
def test_stats_with_data(self):
col = _make_fake_collection(
[
{"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"},
{"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"},
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
]
)
stats = graph_stats(col=col)
assert stats["total_rooms"] == 2
assert stats["tunnel_rooms"] == 1
assert stats["total_edges"] == 1
assert "wing_code" in stats["rooms_per_wing"]
# --- _fuzzy_match ---
class TestFuzzyMatch:
def test_exact_substring(self):
nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}}
result = _fuzzy_match("chromadb", nodes)
assert "chromadb-setup" in result
def test_partial_word_match(self):
nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}}
result = _fuzzy_match("auth", nodes)
assert "auth-module" in result
def test_no_match(self):
nodes = {"chromadb-setup": {}, "auth-module": {}}
result = _fuzzy_match("zzzzz", nodes)
assert result == []
def test_hyphenated_query(self):
nodes = {"riley-college-apps": {}, "college-prep": {}}
result = _fuzzy_match("riley-college", nodes)
assert "riley-college-apps" in result
def test_max_results(self):
nodes = {f"room-{i}": {} for i in range(20)}
result = _fuzzy_match("room", nodes, n=3)
assert len(result) <= 3
+264
View File
@@ -0,0 +1,264 @@
"""Tests for mempalace.room_detector_local."""
from unittest.mock import MagicMock, patch
from mempalace.room_detector_local import (
FOLDER_ROOM_MAP,
detect_rooms_from_files,
detect_rooms_from_folders,
detect_rooms_local,
get_user_approval,
print_proposed_structure,
save_config,
)
# ── FOLDER_ROOM_MAP ────────────────────────────────────────────────────
def test_folder_room_map_has_expected_mappings():
assert FOLDER_ROOM_MAP["frontend"] == "frontend"
assert FOLDER_ROOM_MAP["backend"] == "backend"
assert FOLDER_ROOM_MAP["docs"] == "documentation"
assert FOLDER_ROOM_MAP["tests"] == "testing"
assert FOLDER_ROOM_MAP["config"] == "configuration"
def test_folder_room_map_alternative_names():
assert FOLDER_ROOM_MAP["front-end"] == "frontend"
assert FOLDER_ROOM_MAP["back-end"] == "backend"
assert FOLDER_ROOM_MAP["server"] == "backend"
assert FOLDER_ROOM_MAP["client"] == "frontend"
assert FOLDER_ROOM_MAP["api"] == "backend"
# ── detect_rooms_from_folders ───────────────────────────────────────────
def test_detect_rooms_from_folders_standard_layout(tmp_path):
(tmp_path / "frontend").mkdir()
(tmp_path / "backend").mkdir()
(tmp_path / "docs").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
room_names = {r["name"] for r in rooms}
assert "frontend" in room_names
assert "backend" in room_names
assert "documentation" in room_names
def test_detect_rooms_from_folders_always_has_general(tmp_path):
rooms = detect_rooms_from_folders(str(tmp_path))
room_names = {r["name"] for r in rooms}
assert "general" in room_names
def test_detect_rooms_from_folders_empty_dir(tmp_path):
rooms = detect_rooms_from_folders(str(tmp_path))
# Should at least have "general"
assert len(rooms) >= 1
assert any(r["name"] == "general" for r in rooms)
def test_detect_rooms_from_folders_skips_git(tmp_path):
(tmp_path / ".git").mkdir()
(tmp_path / "node_modules").mkdir()
(tmp_path / "frontend").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
room_names = {r["name"] for r in rooms}
assert ".git" not in room_names
assert "node_modules" not in room_names
def test_detect_rooms_from_folders_nested_dirs(tmp_path):
src = tmp_path / "src"
src.mkdir()
(src / "components").mkdir()
(src / "routes").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
room_names = {r["name"] for r in rooms}
# Nested dirs should be detected at one level deep
assert "frontend" in room_names or "backend" in room_names
def test_detect_rooms_from_folders_room_has_description(tmp_path):
(tmp_path / "docs").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
doc_room = next((r for r in rooms if r["name"] == "documentation"), None)
assert doc_room is not None
assert "description" in doc_room
assert "docs" in doc_room["description"]
def test_detect_rooms_from_folders_room_has_keywords(tmp_path):
(tmp_path / "frontend").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
fe_room = next((r for r in rooms if r["name"] == "frontend"), None)
assert fe_room is not None
assert "keywords" in fe_room
assert len(fe_room["keywords"]) > 0
def test_detect_rooms_from_folders_custom_named_dirs(tmp_path):
(tmp_path / "mylib").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
room_names = {r["name"] for r in rooms}
# Custom dir names that don't match FOLDER_ROOM_MAP get added as-is
assert "mylib" in room_names or "general" in room_names
# ── detect_rooms_from_files ─────────────────────────────────────────────
def test_detect_rooms_from_files_with_matching_filenames(tmp_path):
# Create files whose names contain room keywords
for name in ["test_auth.py", "test_login.py", "test_api.py"]:
(tmp_path / name).write_text("content")
rooms = detect_rooms_from_files(str(tmp_path))
room_names = {r["name"] for r in rooms}
assert "testing" in room_names or "general" in room_names
def test_detect_rooms_from_files_empty_dir(tmp_path):
rooms = detect_rooms_from_files(str(tmp_path))
assert len(rooms) >= 1
assert any(r["name"] == "general" for r in rooms)
def test_detect_rooms_from_files_caps_at_six(tmp_path):
# Create many files with different keywords to hit the cap
for keyword in ["test", "doc", "api", "config", "frontend", "backend", "design", "meeting"]:
for i in range(3):
(tmp_path / f"{keyword}_file_{i}.txt").write_text("content")
rooms = detect_rooms_from_files(str(tmp_path))
assert len(rooms) <= 6
# ── save_config ─────────────────────────────────────────────────────────
def test_save_config_creates_yaml(tmp_path):
rooms = [
{"name": "frontend", "description": "UI files", "keywords": ["frontend"]},
{"name": "backend", "description": "Server files", "keywords": ["backend"]},
]
save_config(str(tmp_path), "myproject", rooms)
config_file = tmp_path / "mempalace.yaml"
assert config_file.exists()
content = config_file.read_text()
assert "myproject" in content
assert "frontend" in content
assert "backend" in content
def test_save_config_valid_yaml(tmp_path):
import yaml
rooms = [{"name": "general", "description": "All files", "keywords": []}]
save_config(str(tmp_path), "test_proj", rooms)
config_file = tmp_path / "mempalace.yaml"
data = yaml.safe_load(config_file.read_text())
assert data["wing"] == "test_proj"
assert len(data["rooms"]) == 1
assert data["rooms"][0]["name"] == "general"
# ── print_proposed_structure ──────────────────────────────────────────
def test_print_proposed_structure(capsys):
rooms = [
{"name": "frontend", "description": "UI files"},
{"name": "general", "description": "Everything else"},
]
print_proposed_structure("myapp", rooms, 42, "folder structure")
out = capsys.readouterr().out
assert "myapp" in out
assert "frontend" in out
assert "42 files" in out
assert "folder structure" in out
# ── get_user_approval ─────────────────────────────────────────────────
def test_get_user_approval_accept_all():
rooms = [{"name": "frontend", "description": "UI"}]
with patch("builtins.input", return_value=""):
result = get_user_approval(rooms)
assert result == rooms
def test_get_user_approval_edit_remove():
rooms = [
{"name": "frontend", "description": "UI"},
{"name": "backend", "description": "Server"},
]
with patch("builtins.input", side_effect=["edit", "1", "n"]):
result = get_user_approval(rooms)
# Room 1 (frontend) removed
assert len(result) == 1
assert result[0]["name"] == "backend"
def test_get_user_approval_add_room():
rooms = [{"name": "general", "description": "All files"}]
with patch(
"builtins.input",
side_effect=[
"add",
"custom_room",
"My custom room",
"",
],
):
result = get_user_approval(rooms)
names = [r["name"] for r in result]
assert "custom_room" in names
# ── detect_rooms_local ────────────────────────────────────────────────
def test_detect_rooms_local_yes_mode(tmp_path):
(tmp_path / "docs").mkdir()
(tmp_path / "docs" / "readme.md").write_text("hello")
mock_miner = MagicMock()
mock_miner.scan_project.return_value = ["file1.py"]
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
detect_rooms_local(str(tmp_path), yes=True)
assert (tmp_path / "mempalace.yaml").exists()
def test_detect_rooms_local_fallback_to_files(tmp_path):
"""When folder detection gives only 'general', falls back to file patterns."""
for i in range(3):
(tmp_path / f"test_file_{i}.py").write_text("content")
mock_miner = MagicMock()
mock_miner.scan_project.return_value = ["f1", "f2"]
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
detect_rooms_local(str(tmp_path), yes=True)
assert (tmp_path / "mempalace.yaml").exists()
def test_detect_rooms_local_missing_dir():
"""Non-existent directory causes sys.exit."""
import pytest
with pytest.raises(SystemExit):
detect_rooms_local("/nonexistent/path/that/does/not/exist", yes=True)
def test_detect_rooms_local_interactive(tmp_path):
(tmp_path / "src").mkdir()
(tmp_path / "src" / "main.py").write_text("code")
mock_miner = MagicMock()
mock_miner.scan_project.return_value = ["f1"]
with (
patch.dict("sys.modules", {"mempalace.miner": mock_miner}),
patch(
"mempalace.room_detector_local.get_user_approval",
return_value=[{"name": "general", "description": "All files", "keywords": []}],
),
):
detect_rooms_local(str(tmp_path), yes=False)
assert (tmp_path / "mempalace.yaml").exists()
+83 -3
View File
@@ -1,10 +1,18 @@
"""
test_searcher.py Tests for the programmatic search_memories API.
test_searcher.py -- Tests for both search() (CLI) and search_memories() (API).
Tests the library-facing search interface (not the CLI print variant).
Uses the real ChromaDB fixtures from conftest.py for integration tests,
plus mock-based tests for error paths.
"""
from mempalace.searcher import search_memories
from unittest.mock import MagicMock, patch
import pytest
from mempalace.searcher import SearchError, search, search_memories
# ── search_memories (API) ──────────────────────────────────────────────
class TestSearchMemories:
@@ -43,3 +51,75 @@ class TestSearchMemories:
assert "source_file" in hit
assert "similarity" in hit
assert isinstance(hit["similarity"], float)
def test_search_memories_query_error(self):
"""search_memories returns error dict when query raises."""
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("query failed")
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client):
result = search_memories("test", "/fake/path")
assert "error" in result
assert "query failed" in result["error"]
def test_search_memories_filters_in_result(self, palace_path, seeded_collection):
result = search_memories("test", palace_path, wing="project", room="backend")
assert result["filters"]["wing"] == "project"
assert result["filters"]["room"] == "backend"
# ── search() (CLI print function) ─────────────────────────────────────
class TestSearchCLI:
def test_search_prints_results(self, palace_path, seeded_collection, capsys):
search("JWT authentication", palace_path)
captured = capsys.readouterr()
assert "JWT" in captured.out or "authentication" in captured.out
def test_search_with_wing_filter(self, palace_path, seeded_collection, capsys):
search("planning", palace_path, wing="notes")
captured = capsys.readouterr()
assert "Results for" in captured.out
def test_search_with_room_filter(self, palace_path, seeded_collection, capsys):
search("database", palace_path, room="backend")
captured = capsys.readouterr()
assert "Room:" in captured.out
def test_search_with_wing_and_room(self, palace_path, seeded_collection, capsys):
search("code", palace_path, wing="project", room="frontend")
captured = capsys.readouterr()
assert "Wing:" in captured.out
assert "Room:" in captured.out
def test_search_no_palace_raises(self, tmp_path):
with pytest.raises(SearchError, match="No palace found"):
search("anything", str(tmp_path / "missing"))
def test_search_no_results(self, palace_path, collection, capsys):
"""Empty collection returns no results message."""
# collection is empty (no seeded data)
result = search("xyzzy_nonexistent_query", palace_path, n_results=1)
captured = capsys.readouterr()
# Either prints "No results" or returns None
assert result is None or "No results" in captured.out
def test_search_query_error_raises(self):
"""search raises SearchError when query fails."""
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("boom")
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client):
with pytest.raises(SearchError, match="Search error"):
search("test", "/fake/path")
def test_search_n_results(self, palace_path, seeded_collection, capsys):
search("code", palace_path, n_results=1)
captured = capsys.readouterr()
# Should have output with at least one result block
assert "[1]" in captured.out
+160
View File
@@ -0,0 +1,160 @@
"""Tests for mempalace.spellcheck — spell-correction utilities."""
from unittest.mock import patch
from mempalace.spellcheck import (
_edit_distance,
_get_system_words,
_should_skip,
spellcheck_transcript,
spellcheck_transcript_line,
spellcheck_user_text,
)
# --- _should_skip ---
class TestShouldSkip:
"""Token-level skip logic."""
def test_short_tokens_skipped(self):
assert _should_skip("hi", set()) is True
assert _should_skip("ok", set()) is True
assert _should_skip("I", set()) is True
def test_digits_skipped(self):
assert _should_skip("3am", set()) is True
assert _should_skip("top10", set()) is True
assert _should_skip("bge-large-v1.5", set()) is True
def test_camelcase_skipped(self):
assert _should_skip("ChromaDB", set()) is True
assert _should_skip("MemPalace", set()) is True
def test_allcaps_skipped(self):
assert _should_skip("NDCG", set()) is True
assert _should_skip("MAX_RESULTS", set()) is True
def test_technical_skipped(self):
assert _should_skip("bge-large", set()) is True
assert _should_skip("train_test", set()) is True
def test_url_skipped(self):
assert _should_skip("https://example.com", set()) is True
assert _should_skip("www.google.com", set()) is True
def test_code_or_emoji_skipped(self):
assert _should_skip("`code`", set()) is True
assert _should_skip("**bold**", set()) is True
def test_known_name_skipped(self):
assert _should_skip("mempalace", {"mempalace"}) is True
def test_normal_word_not_skipped(self):
assert _should_skip("hello", set()) is False
assert _should_skip("question", set()) is False
# --- _edit_distance ---
class TestEditDistance:
def test_identical(self):
assert _edit_distance("hello", "hello") == 0
def test_empty_strings(self):
assert _edit_distance("", "abc") == 3
assert _edit_distance("abc", "") == 3
assert _edit_distance("", "") == 0
def test_single_edit(self):
assert _edit_distance("cat", "bat") == 1 # substitution
assert _edit_distance("cat", "cats") == 1 # insertion
assert _edit_distance("cats", "cat") == 1 # deletion
def test_known_distance(self):
assert _edit_distance("kitten", "sitting") == 3
# --- _get_system_words ---
def test_get_system_words_returns_set():
result = _get_system_words()
assert isinstance(result, set)
# --- spellcheck_user_text ---
def test_spellcheck_user_text_passthrough_no_autocorrect():
"""When autocorrect is not installed, text passes through unchanged."""
with patch("mempalace.spellcheck._get_speller", return_value=None):
text = "somee misspeledd textt"
assert spellcheck_user_text(text) == text
def test_spellcheck_user_text_with_speller():
"""When a speller is available, it corrects words."""
def fake_speller(word):
corrections = {"knoe": "know", "befor": "before"}
return corrections.get(word, word)
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
result = spellcheck_user_text("knoe the question befor")
assert "know" in result
assert "before" in result
def test_spellcheck_preserves_technical_terms():
"""Technical terms should never be touched even with a speller."""
def fake_speller(word):
return "WRONG"
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
result = spellcheck_user_text("ChromaDB bge-large", known_names=set())
assert "ChromaDB" in result
assert "bge-large" in result
assert "WRONG" not in result
# --- spellcheck_transcript_line ---
def test_transcript_line_user_turn():
"""Lines starting with '>' should be processed."""
with patch("mempalace.spellcheck.spellcheck_user_text", return_value="corrected"):
result = spellcheck_transcript_line("> hello world")
assert "corrected" in result
def test_transcript_line_assistant_turn():
"""Lines not starting with '>' should pass through unchanged."""
line = "This is an assistant response"
assert spellcheck_transcript_line(line) == line
def test_transcript_line_empty_user_turn():
"""A '> ' line with no message content should pass through."""
line = "> "
assert spellcheck_transcript_line(line) == line
# --- spellcheck_transcript ---
def test_spellcheck_transcript_processes_content():
"""Full transcript: only '>' lines are touched."""
content = "Assistant line\n> user line\nAnother assistant line"
with patch("mempalace.spellcheck.spellcheck_user_text", return_value="fixed"):
result = spellcheck_transcript(content)
lines = result.split("\n")
assert lines[0] == "Assistant line"
assert "fixed" in lines[1]
assert lines[2] == "Another assistant line"
+72
View File
@@ -0,0 +1,72 @@
"""Extra spellcheck tests covering _load_known_names and speller edge cases."""
from unittest.mock import patch, MagicMock
from mempalace.spellcheck import (
_load_known_names,
spellcheck_user_text,
)
class TestLoadKnownNames:
def test_returns_names_from_registry(self):
mock_reg = MagicMock()
mock_reg._data = {
"entities": {
"e1": {"canonical": "Alice", "aliases": ["ali"]},
"e2": {"canonical": "Bob", "aliases": []},
}
}
with patch("mempalace.entity_registry.EntityRegistry") as MockER:
MockER.load.return_value = mock_reg
names = _load_known_names()
assert "alice" in names
assert "ali" in names
assert "bob" in names
def test_returns_empty_on_exception(self):
with patch(
"mempalace.entity_registry.EntityRegistry.load",
side_effect=Exception("no registry"),
):
names = _load_known_names()
assert names == set()
class TestSpellerEdgeCases:
def test_capitalized_word_skipped(self):
"""Capitalized words (likely proper nouns) are not corrected."""
def fake_speller(word):
return "WRONG"
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
result = spellcheck_user_text("Alice went home")
assert "Alice" in result
assert "WRONG" not in result
def test_system_word_not_corrected(self):
"""Words in system dict should not be corrected."""
def fake_speller(word):
return "WRONG"
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
with patch("mempalace.spellcheck._get_system_words", return_value={"coherently"}):
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
result = spellcheck_user_text("coherently")
assert "coherently" in result
def test_high_edit_distance_rejected(self):
"""Corrections with too many edits are rejected."""
def fake_speller(word):
return "completely_different_word"
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
result = spellcheck_user_text("hello")
assert "hello" in result
+244
View File
@@ -3,6 +3,9 @@ import json
from mempalace import split_mega_files as smf
# ── Config loading ─────────────────────────────────────────────────────
def test_load_known_people_falls_back_when_config_missing(monkeypatch, tmp_path):
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", tmp_path / "missing.json")
smf._KNOWN_NAMES_CACHE = None
@@ -46,3 +49,244 @@ def test_extract_people_detects_names_from_content(monkeypatch):
monkeypatch.setattr(smf, "KNOWN_PEOPLE", ["Alice", "Ben"])
people = smf.extract_people(["> Alice reviewed the change with Ben\n"])
assert people == ["Alice", "Ben"]
# ── Config: force_reload and invalid JSON ──────────────────────────────
def test_load_known_names_force_reload(monkeypatch, tmp_path):
config_path = tmp_path / "known_names.json"
config_path.write_text(json.dumps(["Alice"]))
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
smf._KNOWN_NAMES_CACHE = None
smf._load_known_names_config()
assert smf._KNOWN_NAMES_CACHE == ["Alice"]
config_path.write_text(json.dumps(["Bob"]))
smf._load_known_names_config(force_reload=True)
assert smf._KNOWN_NAMES_CACHE == ["Bob"]
def test_load_known_names_invalid_json(monkeypatch, tmp_path):
config_path = tmp_path / "known_names.json"
config_path.write_text("not json {{{")
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
smf._KNOWN_NAMES_CACHE = None
result = smf._load_known_names_config()
assert result is None
def test_load_known_names_caching(monkeypatch, tmp_path):
config_path = tmp_path / "known_names.json"
config_path.write_text(json.dumps(["Alice"]))
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
smf._KNOWN_NAMES_CACHE = None
smf._load_known_names_config()
# Second call returns cached value without re-reading
config_path.write_text(json.dumps(["Changed"]))
result = smf._load_known_names_config()
assert result == ["Alice"]
# ── is_true_session_start ──────────────────────────────────────────────
def test_is_true_session_start_yes():
lines = ["Claude Code v1.0", "Some content", "More content", "", "", ""]
assert smf.is_true_session_start(lines, 0) is True
def test_is_true_session_start_no_ctrl_e():
lines = [
"Claude Code v1.0",
"Ctrl+E to show 5 previous messages",
"",
"",
"",
"",
]
assert smf.is_true_session_start(lines, 0) is False
def test_is_true_session_start_no_previous_messages():
lines = [
"Claude Code v1.0",
"Some text",
"previous messages here",
"",
"",
"",
]
assert smf.is_true_session_start(lines, 0) is False
# ── find_session_boundaries ────────────────────────────────────────────
def test_find_session_boundaries_two_sessions():
lines = [
"Claude Code v1.0",
"content 1",
"",
"",
"",
"",
"",
"Claude Code v1.0",
"content 2",
"",
"",
"",
"",
"",
]
boundaries = smf.find_session_boundaries(lines)
assert boundaries == [0, 7]
def test_find_session_boundaries_none():
lines = ["Just some text", "No sessions here"]
assert smf.find_session_boundaries(lines) == []
def test_find_session_boundaries_context_restore_skipped():
lines = [
"Claude Code v1.0",
"content",
"",
"",
"",
"",
"",
"Claude Code v1.0",
"Ctrl+E to show 5 previous messages",
"",
"",
"",
"",
]
boundaries = smf.find_session_boundaries(lines)
assert len(boundaries) == 1
# ── extract_timestamp ──────────────────────────────────────────────────
def test_extract_timestamp_found():
lines = ["⏺ 2:30 PM Wednesday, March 25, 2026"]
human, iso = smf.extract_timestamp(lines)
assert human == "2026-03-25_230PM"
assert iso == "2026-03-25"
def test_extract_timestamp_not_found():
lines = ["No timestamp here"]
human, iso = smf.extract_timestamp(lines)
assert human is None
assert iso is None
def test_extract_timestamp_only_checks_first_50():
lines = ["filler\n"] * 51 + ["⏺ 1:00 AM Monday, January 01, 2026"]
human, iso = smf.extract_timestamp(lines)
assert human is None
# ── extract_subject ────────────────────────────────────────────────────
def test_extract_subject_found():
lines = ["> How do we handle authentication?"]
subject = smf.extract_subject(lines)
assert "authentication" in subject.lower()
def test_extract_subject_skips_commands():
lines = ["> cd /some/dir", "> git status", "> What is the plan?"]
subject = smf.extract_subject(lines)
assert "plan" in subject.lower()
def test_extract_subject_fallback():
lines = ["No prompts at all", "Just text"]
subject = smf.extract_subject(lines)
assert subject == "session"
def test_extract_subject_short_prompt_skipped():
lines = ["> ok", "> yes", "> What about the deployment strategy?"]
subject = smf.extract_subject(lines)
assert "deployment" in subject.lower()
def test_extract_subject_truncated():
lines = ["> " + "a" * 100]
subject = smf.extract_subject(lines)
assert len(subject) <= 60
# ── split_file ─────────────────────────────────────────────────────────
def _make_mega_file(tmp_path, n_sessions=3, lines_per_session=15):
"""Create a mega-file with N sessions."""
content = ""
for i in range(n_sessions):
content += f"Claude Code v1.{i}\n"
content += f"> What about topic {i} and how it works?\n"
for j in range(lines_per_session - 2):
content += f"Line {j} of session {i}\n"
path = tmp_path / "mega.txt"
path.write_text(content)
return path
def test_split_file_creates_output(tmp_path):
mega = _make_mega_file(tmp_path)
out_dir = tmp_path / "output"
out_dir.mkdir()
written = smf.split_file(str(mega), str(out_dir))
assert len(written) >= 2
for p in written:
assert p.exists()
def test_split_file_dry_run(tmp_path):
mega = _make_mega_file(tmp_path)
out_dir = tmp_path / "output"
out_dir.mkdir()
written = smf.split_file(str(mega), str(out_dir), dry_run=True)
assert len(written) >= 2
for p in written:
assert not p.exists()
def test_split_file_not_mega(tmp_path):
"""File with fewer than 2 sessions is not split."""
path = tmp_path / "single.txt"
path.write_text("Claude Code v1.0\nJust one session\n" + "line\n" * 20)
written = smf.split_file(str(path), str(tmp_path))
assert written == []
def test_split_file_output_dir_none(tmp_path):
"""When output_dir is None, writes to same dir as source."""
mega = _make_mega_file(tmp_path)
written = smf.split_file(str(mega), None)
assert len(written) >= 2
for p in written:
assert str(p.parent) == str(tmp_path)
def test_split_file_tiny_fragments_skipped(tmp_path):
"""Tiny chunks (< 10 lines) are skipped."""
content = "Claude Code v1.0\nline\n" * 2 + "Claude Code v1.0\n" + "line\n" * 20
path = tmp_path / "tiny.txt"
path.write_text(content)
written = smf.split_file(str(path), str(tmp_path))
# The first chunk is very small, should be skipped
for p in written:
assert p.stat().st_size > 0