Merge branch 'main' into fix/issue-347-codex-hook-message-counting
This commit is contained in:
@@ -0,0 +1,13 @@
|
|||||||
|
# Default owners for everything
|
||||||
|
* @milla-jovovich @bensig @igorls
|
||||||
|
|
||||||
|
# Core library
|
||||||
|
mempalace/ @milla-jovovich @bensig
|
||||||
|
|
||||||
|
# CI and workflows
|
||||||
|
.github/ @bensig
|
||||||
|
|
||||||
|
# Plugins and integrations
|
||||||
|
.claude-plugin/ @bensig
|
||||||
|
.codex-plugin/ @bensig
|
||||||
|
integrations/ @bensig
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "pip"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
open-pull-requests-limit: 5
|
||||||
|
- package-ecosystem: "github-actions"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
open-pull-requests-limit: 3
|
||||||
@@ -7,7 +7,7 @@ on:
|
|||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test-linux:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -18,8 +18,27 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- run: pip install -e ".[dev]"
|
- run: pip install -e ".[dev]"
|
||||||
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=30
|
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
|
||||||
|
|
||||||
|
test-windows:
|
||||||
|
runs-on: windows-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: actions/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version: "3.9"
|
||||||
|
- run: pip install -e ".[dev]"
|
||||||
|
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
|
||||||
|
|
||||||
|
test-macos:
|
||||||
|
runs-on: macos-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: actions/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version: "3.9"
|
||||||
|
- run: pip install -e ".[dev]"
|
||||||
|
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
|
||||||
lint:
|
lint:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
@@ -27,6 +46,6 @@ jobs:
|
|||||||
- uses: actions/setup-python@v6
|
- uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
- run: pip install ruff
|
- run: pip install "ruff>=0.4.0,<0.5"
|
||||||
- run: ruff check .
|
- run: ruff check .
|
||||||
- run: ruff format --check .
|
- run: ruff format --check .
|
||||||
|
|||||||
+27
@@ -6,3 +6,30 @@ __pycache__/
|
|||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
mempal.yaml
|
mempal.yaml
|
||||||
.a5c/
|
.a5c/
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# IDEs
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# Coverage
|
||||||
|
htmlcov/
|
||||||
|
.coverage
|
||||||
|
coverage.xml
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
|
||||||
|
# ChromaDB local data
|
||||||
|
*.sqlite3-journal
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
# AGENTS.md
|
||||||
|
|
||||||
|
> How to build, test, and contribute to MemPalace.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run tests
|
||||||
|
python -m pytest tests/ -v --ignore=tests/benchmarks
|
||||||
|
|
||||||
|
# Run tests with coverage
|
||||||
|
python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing
|
||||||
|
|
||||||
|
# Lint
|
||||||
|
ruff check .
|
||||||
|
|
||||||
|
# Format
|
||||||
|
ruff format .
|
||||||
|
|
||||||
|
# Format check (CI mode)
|
||||||
|
ruff format --check .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project structure
|
||||||
|
|
||||||
|
```
|
||||||
|
mempalace/
|
||||||
|
├── mcp_server.py # MCP server — all read/write tools
|
||||||
|
├── miner.py # Project file miner
|
||||||
|
├── convo_miner.py # Conversation transcript miner
|
||||||
|
├── searcher.py # Semantic search
|
||||||
|
├── knowledge_graph.py # Temporal entity-relationship graph (SQLite)
|
||||||
|
├── palace.py # Shared palace operations (ChromaDB access)
|
||||||
|
├── config.py # Configuration + input validation
|
||||||
|
├── normalize.py # Transcript format detection + normalization
|
||||||
|
├── cli.py # CLI dispatcher
|
||||||
|
├── dialect.py # AAAK compression dialect
|
||||||
|
├── palace_graph.py # Room traversal + cross-wing tunnels
|
||||||
|
├── hooks_cli.py # Hook system for auto-save
|
||||||
|
└── version.py # Single source of truth for version
|
||||||
|
```
|
||||||
|
|
||||||
|
## Conventions
|
||||||
|
|
||||||
|
- **Python style**: snake_case for functions/variables, PascalCase for classes
|
||||||
|
- **Linter**: ruff with E/F/W rules
|
||||||
|
- **Formatter**: ruff format, double quotes
|
||||||
|
- **Commits**: conventional commits (`fix:`, `feat:`, `test:`, `docs:`, `ci:`)
|
||||||
|
- **Tests**: `tests/test_*.py`, fixtures in `tests/conftest.py`
|
||||||
|
- **Coverage**: 85% threshold (80% on Windows due to ChromaDB file lock cleanup)
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
User → CLI / MCP Server → ChromaDB (vector store) + SQLite (knowledge graph)
|
||||||
|
|
||||||
|
Palace structure:
|
||||||
|
WING (person/project)
|
||||||
|
└── ROOM (topic)
|
||||||
|
└── DRAWER (verbatim text chunk)
|
||||||
|
|
||||||
|
Knowledge Graph:
|
||||||
|
ENTITY → PREDICATE → ENTITY (with valid_from / valid_to dates)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key files for common tasks
|
||||||
|
|
||||||
|
- **Adding an MCP tool**: `mempalace/mcp_server.py` — add handler function + TOOLS dict entry
|
||||||
|
- **Changing search**: `mempalace/searcher.py`
|
||||||
|
- **Modifying mining**: `mempalace/miner.py` (project files) or `mempalace/convo_miner.py` (transcripts)
|
||||||
|
- **Input validation**: `mempalace/config.py` — `sanitize_name()` / `sanitize_content()`
|
||||||
|
- **Tests**: mirror source structure in `tests/test_<module>.py`
|
||||||
@@ -585,6 +585,9 @@ mempalace compress --wing myapp # AAAK compress
|
|||||||
|
|
||||||
# Status
|
# Status
|
||||||
mempalace status # palace overview
|
mempalace status # palace overview
|
||||||
|
|
||||||
|
# MCP
|
||||||
|
mempalace mcp # show MCP setup command
|
||||||
```
|
```
|
||||||
|
|
||||||
All commands accept `--palace <path>` to override the default location.
|
All commands accept `--palace <path>` to override the default location.
|
||||||
@@ -707,7 +710,7 @@ PRs welcome. See [CONTRIBUTING.md](CONTRIBUTING.md) for setup and guidelines.
|
|||||||
MIT — see [LICENSE](LICENSE).
|
MIT — see [LICENSE](LICENSE).
|
||||||
|
|
||||||
<!-- Link Definitions -->
|
<!-- Link Definitions -->
|
||||||
[version-shield]: https://img.shields.io/badge/version-3.0.0-4dc9f6?style=flat-square&labelColor=0a0e14
|
[version-shield]: https://img.shields.io/badge/version-3.1.0-4dc9f6?style=flat-square&labelColor=0a0e14
|
||||||
[release-link]: https://github.com/milla-jovovich/mempalace/releases
|
[release-link]: https://github.com/milla-jovovich/mempalace/releases
|
||||||
[python-shield]: https://img.shields.io/badge/python-3.9+-7dd8f8?style=flat-square&labelColor=0a0e14&logo=python&logoColor=7dd8f8
|
[python-shield]: https://img.shields.io/badge/python-3.9+-7dd8f8?style=flat-square&labelColor=0a0e14&logo=python&logoColor=7dd8f8
|
||||||
[python-link]: https://www.python.org/
|
[python-link]: https://www.python.org/
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
-- MemPalace Knowledge Graph Schema
|
||||||
|
-- SQLite database at ~/.mempalace/knowledge_graph.db
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS entities (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
type TEXT DEFAULT 'unknown',
|
||||||
|
properties TEXT DEFAULT '{}'
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS triples (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
subject TEXT NOT NULL,
|
||||||
|
predicate TEXT NOT NULL,
|
||||||
|
object TEXT NOT NULL,
|
||||||
|
valid_from TEXT,
|
||||||
|
valid_to TEXT,
|
||||||
|
confidence REAL DEFAULT 1.0,
|
||||||
|
source_closet TEXT,
|
||||||
|
source_file TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS attributes (
|
||||||
|
entity_id TEXT NOT NULL,
|
||||||
|
key TEXT NOT NULL,
|
||||||
|
value TEXT,
|
||||||
|
valid_from TEXT,
|
||||||
|
valid_to TEXT,
|
||||||
|
PRIMARY KEY (entity_id, key, valid_from)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Indexes
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_triples_subject ON triples(subject);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_triples_object ON triples(object);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_triples_predicate ON triples(predicate);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to);
|
||||||
@@ -64,13 +64,20 @@ MEMPAL_DIR=""
|
|||||||
# Read JSON input from stdin
|
# Read JSON input from stdin
|
||||||
INPUT=$(cat)
|
INPUT=$(cat)
|
||||||
|
|
||||||
# Parse fields from Claude Code's JSON
|
# Parse all fields in a single Python call (3x faster than separate invocations)
|
||||||
SESSION_ID=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('session_id','unknown'))" 2>/dev/null)
|
eval $(echo "$INPUT" | python3 -c "
|
||||||
# Sanitize SESSION_ID to prevent path traversal (only allow alnum, dash, underscore)
|
import sys, json
|
||||||
SESSION_ID=$(echo "$SESSION_ID" | tr -cd 'a-zA-Z0-9_-')
|
data = json.load(sys.stdin)
|
||||||
[ -z "$SESSION_ID" ] && SESSION_ID="unknown"
|
sid = data.get('session_id', 'unknown')
|
||||||
STOP_HOOK_ACTIVE=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('stop_hook_active', False))" 2>/dev/null)
|
sha = data.get('stop_hook_active', False)
|
||||||
TRANSCRIPT_PATH=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('transcript_path',''))" 2>/dev/null)
|
tp = data.get('transcript_path', '')
|
||||||
|
# Shell-safe output — only allow alphanumeric, underscore, hyphen, slash, dot, tilde
|
||||||
|
import re
|
||||||
|
safe = lambda s: re.sub(r'[^a-zA-Z0-9_/.\-~]', '', str(s))
|
||||||
|
print(f'SESSION_ID=\"{safe(sid)}\"')
|
||||||
|
print(f'STOP_HOOK_ACTIVE=\"{sha}\"')
|
||||||
|
print(f'TRANSCRIPT_PATH=\"{safe(tp)}\"')
|
||||||
|
" 2>/dev/null)
|
||||||
|
|
||||||
# Expand ~ in path
|
# Expand ~ in path
|
||||||
TRANSCRIPT_PATH="${TRANSCRIPT_PATH/#\~/$HOME}"
|
TRANSCRIPT_PATH="${TRANSCRIPT_PATH/#\~/$HOME}"
|
||||||
@@ -83,6 +90,7 @@ if [ "$STOP_HOOK_ACTIVE" = "True" ] || [ "$STOP_HOOK_ACTIVE" = "true" ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Count human messages in the JSONL transcript
|
# Count human messages in the JSONL transcript
|
||||||
|
# SECURITY: Pass transcript path as sys.argv to avoid shell injection via crafted paths
|
||||||
if [ -f "$TRANSCRIPT_PATH" ]; then
|
if [ -f "$TRANSCRIPT_PATH" ]; then
|
||||||
EXCHANGE_COUNT=$(python3 - "$TRANSCRIPT_PATH" <<'PYEOF'
|
EXCHANGE_COUNT=$(python3 - "$TRANSCRIPT_PATH" <<'PYEOF'
|
||||||
import json, sys
|
import json, sys
|
||||||
@@ -94,7 +102,6 @@ with open(sys.argv[1]) as f:
|
|||||||
msg = entry.get('message', {})
|
msg = entry.get('message', {})
|
||||||
if isinstance(msg, dict) and msg.get('role') == 'user':
|
if isinstance(msg, dict) and msg.get('role') == 'user':
|
||||||
content = msg.get('content', '')
|
content = msg.get('content', '')
|
||||||
# Skip system/command messages — only count real human input
|
|
||||||
if isinstance(content, str) and '<command-message>' in content:
|
if isinstance(content, str) and '<command-message>' in content:
|
||||||
continue
|
continue
|
||||||
count += 1
|
count += 1
|
||||||
|
|||||||
@@ -0,0 +1,154 @@
|
|||||||
|
---
|
||||||
|
name: mempalace
|
||||||
|
description: "MemPalace — Local AI memory with 96.6% recall. Semantic search, temporal knowledge graph, palace architecture (wings/rooms/drawers). Free, no cloud, no API keys."
|
||||||
|
version: 3.1.0
|
||||||
|
homepage: https://github.com/milla-jovovich/mempalace
|
||||||
|
user-invocable: true
|
||||||
|
metadata:
|
||||||
|
openclaw:
|
||||||
|
emoji: "\U0001F3DB"
|
||||||
|
os:
|
||||||
|
- darwin
|
||||||
|
- linux
|
||||||
|
- win32
|
||||||
|
requires:
|
||||||
|
anyBins:
|
||||||
|
- mempalace
|
||||||
|
- python3
|
||||||
|
install:
|
||||||
|
- id: mempalace-pip
|
||||||
|
kind: uv
|
||||||
|
label: "Install MemPalace (Python, local ChromaDB)"
|
||||||
|
package: mempalace
|
||||||
|
bins:
|
||||||
|
- mempalace
|
||||||
|
---
|
||||||
|
|
||||||
|
# MemPalace — Local AI Memory System
|
||||||
|
|
||||||
|
You have access to a local memory palace via MCP tools. The palace stores verbatim conversation history and a temporal knowledge graph — all on the user's machine, zero cloud, zero API calls.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
- **Wings** = people or projects (e.g. `wing_alice`, `wing_myproject`)
|
||||||
|
- **Halls** = categories (facts, events, preferences, advice)
|
||||||
|
- **Rooms** = specific topics (e.g. `chromadb-setup`, `riley-school`)
|
||||||
|
- **Drawers** = individual memory chunks (verbatim text)
|
||||||
|
- **Knowledge Graph** = entity-relationship facts with time validity
|
||||||
|
|
||||||
|
## Protocol — FOLLOW THIS EVERY SESSION
|
||||||
|
|
||||||
|
1. **ON WAKE-UP**: Call `mempalace_status` to load palace overview and AAAK dialect spec.
|
||||||
|
2. **BEFORE RESPONDING** about any person, project, or past event: call `mempalace_search` or `mempalace_kg_query` FIRST. Never guess from memory — verify from the palace.
|
||||||
|
3. **IF UNSURE** about a fact (name, age, relationship, preference): say "let me check" and query. Wrong is worse than slow.
|
||||||
|
4. **AFTER EACH SESSION**: Call `mempalace_diary_write` to record what happened, what you learned, what matters.
|
||||||
|
5. **WHEN FACTS CHANGE**: Call `mempalace_kg_invalidate` on the old fact, then `mempalace_kg_add` for the new one.
|
||||||
|
|
||||||
|
## Available Tools
|
||||||
|
|
||||||
|
### Search & Browse
|
||||||
|
- `mempalace_search` — Semantic search across all memories. Always start here.
|
||||||
|
- `query` (required): natural language search — keep it short, keywords or a question. Do NOT include system prompts or conversation context.
|
||||||
|
- `wing`: filter by wing
|
||||||
|
- `room`: filter by room
|
||||||
|
- `limit`: max results (default 5)
|
||||||
|
- `mempalace_check_duplicate` — Check if content already exists before filing.
|
||||||
|
- `content` (required): text to check
|
||||||
|
- `threshold`: similarity threshold (default 0.9 — lowering to 0.85–0.87 often catches more near-duplicates without significant false positives)
|
||||||
|
- `mempalace_status` — Palace overview: total drawers, wings, rooms, AAAK spec
|
||||||
|
- `mempalace_list_wings` — All wings with drawer counts
|
||||||
|
- `mempalace_list_rooms` — Rooms within a wing (optional wing filter)
|
||||||
|
- `mempalace_get_taxonomy` — Full wing/room/count tree
|
||||||
|
- `mempalace_get_aaak_spec` — Get AAAK compression dialect specification
|
||||||
|
|
||||||
|
### Knowledge Graph (Temporal Facts)
|
||||||
|
- `mempalace_kg_query` — Query entity relationships. Supports time filtering.
|
||||||
|
- `entity` (required): e.g. "Max", "MyProject"
|
||||||
|
- `as_of`: date filter (YYYY-MM-DD) — what was true at that time
|
||||||
|
- `direction`: "outgoing", "incoming", or "both" (default "both")
|
||||||
|
- `mempalace_kg_add` — Add a fact: subject -> predicate -> object
|
||||||
|
- `subject`, `predicate`, `object` (required)
|
||||||
|
- `valid_from`: when this became true
|
||||||
|
- `source_closet`: source reference
|
||||||
|
- `mempalace_kg_invalidate` — Mark a fact as no longer true
|
||||||
|
- `subject`, `predicate`, `object` (required)
|
||||||
|
- `ended`: when it stopped being true (default: today)
|
||||||
|
- `mempalace_kg_timeline` — Chronological story of an entity
|
||||||
|
- `entity`: filter by entity name (optional — all events if omitted)
|
||||||
|
- `mempalace_kg_stats` — Graph overview: entities, triples, relationship types
|
||||||
|
|
||||||
|
### Palace Graph (Cross-Domain Connections)
|
||||||
|
- `mempalace_traverse` — Walk from a room, find connected ideas across wings
|
||||||
|
- `start_room` (required): room to start from
|
||||||
|
- `max_hops`: connection depth (default 2)
|
||||||
|
- `mempalace_find_tunnels` — Find rooms that bridge two wings
|
||||||
|
- `wing_a`, `wing_b` (required)
|
||||||
|
- `mempalace_graph_stats` — Graph connectivity overview
|
||||||
|
|
||||||
|
### Write
|
||||||
|
- `mempalace_add_drawer` — Store verbatim content into a wing/room
|
||||||
|
- `wing`, `room`, `content` (required)
|
||||||
|
- `source_file`: optional source reference
|
||||||
|
- Checks for duplicates automatically
|
||||||
|
- `mempalace_delete_drawer` — Remove a drawer by ID
|
||||||
|
- `drawer_id` (required)
|
||||||
|
- `mempalace_diary_write` — Write a session diary entry
|
||||||
|
- `agent_name` (required): your name/identifier
|
||||||
|
- `entry` (required): what happened, what you learned, what matters
|
||||||
|
- `topic`: category tag (default "general")
|
||||||
|
- `mempalace_diary_read` — Read recent diary entries
|
||||||
|
- `agent_name` (required)
|
||||||
|
- `last_n`: number of entries (default 10)
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Install MemPalace and populate the palace:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mempalace
|
||||||
|
mempalace init ~/my-convos
|
||||||
|
mempalace mine ~/my-convos
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenClaw MCP config
|
||||||
|
|
||||||
|
Add to your OpenClaw MCP configuration:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"mempalace": {
|
||||||
|
"command": "python3",
|
||||||
|
"args": ["-m", "mempalace.mcp_server"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Or via CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
openclaw mcp set mempalace '{"command":"python3","args":["-m","mempalace.mcp_server"]}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Other MCP hosts
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Claude Code
|
||||||
|
claude mcp add mempalace -- python -m mempalace.mcp_server
|
||||||
|
|
||||||
|
# Cursor — add to .cursor/mcp.json
|
||||||
|
# Codex — add to .codex/mcp.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tips
|
||||||
|
|
||||||
|
- Search is semantic (meaning-based), not keyword. "What did we discuss about database performance?" works better than "database".
|
||||||
|
- The knowledge graph stores typed relationships with time windows. Use it for facts about people and projects — it knows WHEN things were true.
|
||||||
|
- Diary entries accumulate across sessions. Write one at the end of each conversation to build continuity.
|
||||||
|
- Use `mempalace_check_duplicate` before storing new content to avoid duplicates.
|
||||||
|
- The AAAK dialect (from `mempalace_status`) is a compressed notation for efficient storage. Read it naturally — expand codes mentally, treat *markers* as emotional context.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
[MemPalace](https://github.com/milla-jovovich/mempalace) is MIT licensed. Created by Milla Jovovich, Ben Sigman, Igor Lins e Silva, and contributors.
|
||||||
@@ -14,6 +14,7 @@ Commands:
|
|||||||
mempalace mine <dir> Mine project files (default)
|
mempalace mine <dir> Mine project files (default)
|
||||||
mempalace mine <dir> --mode convos Mine conversation exports
|
mempalace mine <dir> --mode convos Mine conversation exports
|
||||||
mempalace search "query" Find anything, exact words
|
mempalace search "query" Find anything, exact words
|
||||||
|
mempalace mcp Show MCP setup command
|
||||||
mempalace wake-up Show L0 + L1 wake-up context
|
mempalace wake-up Show L0 + L1 wake-up context
|
||||||
mempalace wake-up --wing my_app Wake-up for a specific project
|
mempalace wake-up --wing my_app Wake-up for a specific project
|
||||||
mempalace status Show what's been filed
|
mempalace status Show what's been filed
|
||||||
@@ -28,6 +29,7 @@ Examples:
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import shlex
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -148,6 +150,14 @@ def cmd_split(args):
|
|||||||
sys.argv = old_argv
|
sys.argv = old_argv
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_migrate(args):
|
||||||
|
"""Migrate palace from a different ChromaDB version."""
|
||||||
|
from .migrate import migrate
|
||||||
|
|
||||||
|
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
|
||||||
|
migrate(palace_path=palace_path, dry_run=args.dry_run)
|
||||||
|
|
||||||
|
|
||||||
def cmd_status(args):
|
def cmd_status(args):
|
||||||
from .miner import status
|
from .miner import status
|
||||||
|
|
||||||
@@ -202,6 +212,7 @@ def cmd_repair(args):
|
|||||||
print(f" Extracted {len(all_ids)} drawers")
|
print(f" Extracted {len(all_ids)} drawers")
|
||||||
|
|
||||||
# Backup and rebuild
|
# Backup and rebuild
|
||||||
|
palace_path = palace_path.rstrip(os.sep)
|
||||||
backup_path = palace_path + ".backup"
|
backup_path = palace_path + ".backup"
|
||||||
if os.path.exists(backup_path):
|
if os.path.exists(backup_path):
|
||||||
shutil.rmtree(backup_path)
|
shutil.rmtree(backup_path)
|
||||||
@@ -240,6 +251,27 @@ def cmd_instructions(args):
|
|||||||
run_instructions(name=args.name)
|
run_instructions(name=args.name)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_mcp(args):
|
||||||
|
"""Show how to wire MemPalace into MCP-capable hosts."""
|
||||||
|
base_server_cmd = "python -m mempalace.mcp_server"
|
||||||
|
|
||||||
|
if args.palace:
|
||||||
|
resolved_palace = str(Path(args.palace).expanduser())
|
||||||
|
server_cmd = f"{base_server_cmd} --palace {shlex.quote(resolved_palace)}"
|
||||||
|
else:
|
||||||
|
server_cmd = base_server_cmd
|
||||||
|
|
||||||
|
print("MemPalace MCP quick setup:")
|
||||||
|
print(f" claude mcp add mempalace -- {server_cmd}")
|
||||||
|
print("\nRun the server directly:")
|
||||||
|
print(f" {server_cmd}")
|
||||||
|
|
||||||
|
if not args.palace:
|
||||||
|
print("\nOptional custom palace:")
|
||||||
|
print(f" claude mcp add mempalace -- {base_server_cmd} --palace /path/to/palace")
|
||||||
|
print(f" {base_server_cmd} --palace /path/to/palace")
|
||||||
|
|
||||||
|
|
||||||
def cmd_compress(args):
|
def cmd_compress(args):
|
||||||
"""Compress drawers in a wing using AAAK Dialect."""
|
"""Compress drawers in a wing using AAAK Dialect."""
|
||||||
import chromadb
|
import chromadb
|
||||||
@@ -500,7 +532,24 @@ def main():
|
|||||||
help="Rebuild palace vector index from stored data (fixes segfaults after corruption)",
|
help="Rebuild palace vector index from stored data (fixes segfaults after corruption)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# mcp
|
||||||
|
sub.add_parser(
|
||||||
|
"mcp",
|
||||||
|
help="Show MCP setup command for connecting MemPalace to your AI client",
|
||||||
|
)
|
||||||
|
|
||||||
# status
|
# status
|
||||||
|
# migrate
|
||||||
|
p_migrate = sub.add_parser(
|
||||||
|
"migrate",
|
||||||
|
help="Migrate palace from a different ChromaDB version (fixes 3.0.0 → 3.1.0 upgrade)",
|
||||||
|
)
|
||||||
|
p_migrate.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
action="store_true",
|
||||||
|
help="Show what would be migrated without changing anything",
|
||||||
|
)
|
||||||
|
|
||||||
sub.add_parser("status", help="Show what's been filed")
|
sub.add_parser("status", help="Show what's been filed")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -531,9 +580,11 @@ def main():
|
|||||||
"mine": cmd_mine,
|
"mine": cmd_mine,
|
||||||
"split": cmd_split,
|
"split": cmd_split,
|
||||||
"search": cmd_search,
|
"search": cmd_search,
|
||||||
|
"mcp": cmd_mcp,
|
||||||
"compress": cmd_compress,
|
"compress": cmd_compress,
|
||||||
"wake-up": cmd_wakeup,
|
"wake-up": cmd_wakeup,
|
||||||
"repair": cmd_repair,
|
"repair": cmd_repair,
|
||||||
|
"migrate": cmd_migrate,
|
||||||
"status": cmd_status,
|
"status": cmd_status,
|
||||||
}
|
}
|
||||||
dispatch[args.command](args)
|
dispatch[args.command](args)
|
||||||
|
|||||||
@@ -6,8 +6,58 @@ Priority: env vars > config file (~/.mempalace/config.json) > defaults
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
# ── Input validation ──────────────────────────────────────────────────────────
|
||||||
|
# Shared sanitizers for wing/room/entity names. Prevents path traversal,
|
||||||
|
# excessively long strings, and special characters that could cause issues
|
||||||
|
# in file paths, SQLite, or ChromaDB metadata.
|
||||||
|
|
||||||
|
MAX_NAME_LENGTH = 128
|
||||||
|
_SAFE_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_ .'-]{0,126}[a-zA-Z0-9]?$")
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_name(value: str, field_name: str = "name") -> str:
|
||||||
|
"""Validate and sanitize a wing/room/entity name.
|
||||||
|
|
||||||
|
Raises ValueError if the name is invalid.
|
||||||
|
"""
|
||||||
|
if not isinstance(value, str) or not value.strip():
|
||||||
|
raise ValueError(f"{field_name} must be a non-empty string")
|
||||||
|
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
if len(value) > MAX_NAME_LENGTH:
|
||||||
|
raise ValueError(f"{field_name} exceeds maximum length of {MAX_NAME_LENGTH} characters")
|
||||||
|
|
||||||
|
# Block path traversal
|
||||||
|
if ".." in value or "/" in value or "\\" in value:
|
||||||
|
raise ValueError(f"{field_name} contains invalid path characters")
|
||||||
|
|
||||||
|
# Block null bytes
|
||||||
|
if "\x00" in value:
|
||||||
|
raise ValueError(f"{field_name} contains null bytes")
|
||||||
|
|
||||||
|
# Enforce safe character set
|
||||||
|
if not _SAFE_NAME_RE.match(value):
|
||||||
|
raise ValueError(f"{field_name} contains invalid characters")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_content(value: str, max_length: int = 100_000) -> str:
|
||||||
|
"""Validate drawer/diary content length."""
|
||||||
|
if not isinstance(value, str) or not value.strip():
|
||||||
|
raise ValueError("content must be a non-empty string")
|
||||||
|
if len(value) > max_length:
|
||||||
|
raise ValueError(f"content exceeds maximum length of {max_length} characters")
|
||||||
|
if "\x00" in value:
|
||||||
|
raise ValueError("content contains null bytes")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace")
|
DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace")
|
||||||
DEFAULT_COLLECTION_NAME = "mempalace_drawers"
|
DEFAULT_COLLECTION_NAME = "mempalace_drawers"
|
||||||
|
|
||||||
@@ -126,6 +176,11 @@ class MempalaceConfig:
|
|||||||
def init(self):
|
def init(self):
|
||||||
"""Create config directory and write default config.json if it doesn't exist."""
|
"""Create config directory and write default config.json if it doesn't exist."""
|
||||||
self._config_dir.mkdir(parents=True, exist_ok=True)
|
self._config_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Restrict directory permissions to owner only (Unix)
|
||||||
|
try:
|
||||||
|
self._config_dir.chmod(0o700)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass # Windows doesn't support Unix permissions
|
||||||
if not self._config_file.exists():
|
if not self._config_file.exists():
|
||||||
default_config = {
|
default_config = {
|
||||||
"palace_path": DEFAULT_PALACE_PATH,
|
"palace_path": DEFAULT_PALACE_PATH,
|
||||||
@@ -135,6 +190,11 @@ class MempalaceConfig:
|
|||||||
}
|
}
|
||||||
with open(self._config_file, "w") as f:
|
with open(self._config_file, "w") as f:
|
||||||
json.dump(default_config, f, indent=2)
|
json.dump(default_config, f, indent=2)
|
||||||
|
# Restrict config file to owner read/write only
|
||||||
|
try:
|
||||||
|
self._config_file.chmod(0o600)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass
|
||||||
return self._config_file
|
return self._config_file
|
||||||
|
|
||||||
def save_people_map(self, people_map):
|
def save_people_map(self, people_map):
|
||||||
|
|||||||
+11
-35
@@ -15,9 +15,8 @@ from pathlib import Path
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import chromadb
|
|
||||||
|
|
||||||
from .normalize import normalize
|
from .normalize import normalize
|
||||||
|
from .palace import SKIP_DIRS, get_collection, file_already_mined
|
||||||
|
|
||||||
|
|
||||||
# File types that might contain conversations
|
# File types that might contain conversations
|
||||||
@@ -28,22 +27,8 @@ CONVO_EXTENSIONS = {
|
|||||||
".jsonl",
|
".jsonl",
|
||||||
}
|
}
|
||||||
|
|
||||||
SKIP_DIRS = {
|
|
||||||
".git",
|
|
||||||
"node_modules",
|
|
||||||
"__pycache__",
|
|
||||||
".venv",
|
|
||||||
"venv",
|
|
||||||
"env",
|
|
||||||
"dist",
|
|
||||||
"build",
|
|
||||||
".next",
|
|
||||||
".mempalace",
|
|
||||||
"tool-results",
|
|
||||||
"memory",
|
|
||||||
}
|
|
||||||
|
|
||||||
MIN_CHUNK_SIZE = 30
|
MIN_CHUNK_SIZE = 30
|
||||||
|
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -211,23 +196,6 @@ def detect_convo_room(content: str) -> str:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def get_collection(palace_path: str):
|
|
||||||
os.makedirs(palace_path, exist_ok=True)
|
|
||||||
client = chromadb.PersistentClient(path=palace_path)
|
|
||||||
try:
|
|
||||||
return client.get_collection("mempalace_drawers")
|
|
||||||
except Exception:
|
|
||||||
return client.create_collection("mempalace_drawers")
|
|
||||||
|
|
||||||
|
|
||||||
def file_already_mined(collection, source_file: str) -> bool:
|
|
||||||
try:
|
|
||||||
results = collection.get(where={"source_file": source_file}, limit=1)
|
|
||||||
return len(results.get("ids", [])) > 0
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# SCAN FOR CONVERSATION FILES
|
# SCAN FOR CONVERSATION FILES
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -244,6 +212,14 @@ def scan_convos(convo_dir: str) -> list:
|
|||||||
continue
|
continue
|
||||||
filepath = Path(root) / filename
|
filepath = Path(root) / filename
|
||||||
if filepath.suffix.lower() in CONVO_EXTENSIONS:
|
if filepath.suffix.lower() in CONVO_EXTENSIONS:
|
||||||
|
# Skip symlinks and oversized files
|
||||||
|
if filepath.is_symlink():
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if filepath.stat().st_size > MAX_FILE_SIZE:
|
||||||
|
continue
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
files.append(filepath)
|
files.append(filepath)
|
||||||
return files
|
return files
|
||||||
|
|
||||||
@@ -356,7 +332,7 @@ def mine_convos(
|
|||||||
chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room
|
chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room
|
||||||
if extract_mode == "general":
|
if extract_mode == "general":
|
||||||
room_counts[chunk_room] += 1
|
room_counts[chunk_room] += 1
|
||||||
drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.md5((source_file + str(chunk['chunk_index'])).encode(), usedforsecurity=False).hexdigest()[:16]}"
|
drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}"
|
||||||
try:
|
try:
|
||||||
collection.add(
|
collection.add(
|
||||||
documents=[chunk["content"]],
|
documents=[chunk["content"]],
|
||||||
|
|||||||
@@ -309,7 +309,7 @@ class EntityRegistry:
|
|||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self._path.write_text(json.dumps(self._data, indent=2))
|
self._path.write_text(json.dumps(self._data, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _empty() -> dict:
|
def _empty() -> dict:
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ def hook_stop(data: dict, harness: str):
|
|||||||
if since_last >= SAVE_INTERVAL and exchange_count > 0:
|
if since_last >= SAVE_INTERVAL and exchange_count > 0:
|
||||||
# Update last save point
|
# Update last save point
|
||||||
try:
|
try:
|
||||||
last_save_file.write_text(str(exchange_count))
|
last_save_file.write_text(str(exchange_count), encoding="utf-8")
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ AI memory system. Store everything, find anything. Local, free, no API key.
|
|||||||
mempalace compress Compress palace storage
|
mempalace compress Compress palace storage
|
||||||
mempalace status Show palace status
|
mempalace status Show palace status
|
||||||
mempalace repair Rebuild vector index
|
mempalace repair Rebuild vector index
|
||||||
|
mempalace mcp Show MCP setup command
|
||||||
mempalace hook run Run hook logic (for harness integration)
|
mempalace hook run Run hook logic (for harness integration)
|
||||||
mempalace instructions <name> Output skill instructions
|
mempalace instructions <name> Output skill instructions
|
||||||
|
|
||||||
|
|||||||
@@ -50,11 +50,14 @@ class KnowledgeGraph:
|
|||||||
def __init__(self, db_path: str = None):
|
def __init__(self, db_path: str = None):
|
||||||
self.db_path = db_path or DEFAULT_KG_PATH
|
self.db_path = db_path or DEFAULT_KG_PATH
|
||||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._connection = None
|
||||||
self._init_db()
|
self._init_db()
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self):
|
||||||
conn = self._conn()
|
conn = self._conn()
|
||||||
conn.executescript("""
|
conn.executescript("""
|
||||||
|
PRAGMA journal_mode=WAL;
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS entities (
|
CREATE TABLE IF NOT EXISTS entities (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
@@ -84,12 +87,19 @@ class KnowledgeGraph:
|
|||||||
CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to);
|
CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to);
|
||||||
""")
|
""")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def _conn(self):
|
def _conn(self):
|
||||||
conn = sqlite3.connect(self.db_path, timeout=10)
|
if self._connection is None:
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
self._connection = sqlite3.connect(self.db_path, timeout=10, check_same_thread=False)
|
||||||
return conn
|
self._connection.execute("PRAGMA journal_mode=WAL")
|
||||||
|
self._connection.row_factory = sqlite3.Row
|
||||||
|
return self._connection
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the database connection."""
|
||||||
|
if self._connection is not None:
|
||||||
|
self._connection.close()
|
||||||
|
self._connection = None
|
||||||
|
|
||||||
def _entity_id(self, name: str) -> str:
|
def _entity_id(self, name: str) -> str:
|
||||||
return name.lower().replace(" ", "_").replace("'", "")
|
return name.lower().replace(" ", "_").replace("'", "")
|
||||||
@@ -101,12 +111,11 @@ class KnowledgeGraph:
|
|||||||
eid = self._entity_id(name)
|
eid = self._entity_id(name)
|
||||||
props = json.dumps(properties or {})
|
props = json.dumps(properties or {})
|
||||||
conn = self._conn()
|
conn = self._conn()
|
||||||
conn.execute(
|
with conn:
|
||||||
"INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)",
|
conn.execute(
|
||||||
(eid, name, entity_type, props),
|
"INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)",
|
||||||
)
|
(eid, name, entity_type, props),
|
||||||
conn.commit()
|
)
|
||||||
conn.close()
|
|
||||||
return eid
|
return eid
|
||||||
|
|
||||||
def add_triple(
|
def add_triple(
|
||||||
@@ -134,38 +143,38 @@ class KnowledgeGraph:
|
|||||||
|
|
||||||
# Auto-create entities if they don't exist
|
# Auto-create entities if they don't exist
|
||||||
conn = self._conn()
|
conn = self._conn()
|
||||||
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject))
|
with conn:
|
||||||
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj))
|
conn.execute(
|
||||||
|
"INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject)
|
||||||
|
)
|
||||||
|
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj))
|
||||||
|
|
||||||
# Check for existing identical triple
|
# Check for existing identical triple
|
||||||
existing = conn.execute(
|
existing = conn.execute(
|
||||||
"SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
|
"SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
|
||||||
(sub_id, pred, obj_id),
|
(sub_id, pred, obj_id),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
conn.close()
|
return existing["id"] # Already exists and still valid
|
||||||
return existing[0] # Already exists and still valid
|
|
||||||
|
|
||||||
triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.md5(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:8]}"
|
triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.sha256(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:12]}"
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file)
|
"""INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||||
(
|
(
|
||||||
triple_id,
|
triple_id,
|
||||||
sub_id,
|
sub_id,
|
||||||
pred,
|
pred,
|
||||||
obj_id,
|
obj_id,
|
||||||
valid_from,
|
valid_from,
|
||||||
valid_to,
|
valid_to,
|
||||||
confidence,
|
confidence,
|
||||||
source_closet,
|
source_closet,
|
||||||
source_file,
|
source_file,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return triple_id
|
return triple_id
|
||||||
|
|
||||||
def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None):
|
def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None):
|
||||||
@@ -176,12 +185,11 @@ class KnowledgeGraph:
|
|||||||
ended = ended or date.today().isoformat()
|
ended = ended or date.today().isoformat()
|
||||||
|
|
||||||
conn = self._conn()
|
conn = self._conn()
|
||||||
conn.execute(
|
with conn:
|
||||||
"UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
|
conn.execute(
|
||||||
(ended, sub_id, pred, obj_id),
|
"UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
|
||||||
)
|
(ended, sub_id, pred, obj_id),
|
||||||
conn.commit()
|
)
|
||||||
conn.close()
|
|
||||||
|
|
||||||
# ── Query operations ──────────────────────────────────────────────────
|
# ── Query operations ──────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -208,13 +216,13 @@ class KnowledgeGraph:
|
|||||||
{
|
{
|
||||||
"direction": "outgoing",
|
"direction": "outgoing",
|
||||||
"subject": name,
|
"subject": name,
|
||||||
"predicate": row[2],
|
"predicate": row["predicate"],
|
||||||
"object": row[10], # obj_name
|
"object": row["obj_name"],
|
||||||
"valid_from": row[4],
|
"valid_from": row["valid_from"],
|
||||||
"valid_to": row[5],
|
"valid_to": row["valid_to"],
|
||||||
"confidence": row[6],
|
"confidence": row["confidence"],
|
||||||
"source_closet": row[7],
|
"source_closet": row["source_closet"],
|
||||||
"current": row[5] is None,
|
"current": row["valid_to"] is None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -228,18 +236,17 @@ class KnowledgeGraph:
|
|||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"direction": "incoming",
|
"direction": "incoming",
|
||||||
"subject": row[10], # sub_name
|
"subject": row["sub_name"],
|
||||||
"predicate": row[2],
|
"predicate": row["predicate"],
|
||||||
"object": name,
|
"object": name,
|
||||||
"valid_from": row[4],
|
"valid_from": row["valid_from"],
|
||||||
"valid_to": row[5],
|
"valid_to": row["valid_to"],
|
||||||
"confidence": row[6],
|
"confidence": row["confidence"],
|
||||||
"source_closet": row[7],
|
"source_closet": row["source_closet"],
|
||||||
"current": row[5] is None,
|
"current": row["valid_to"] is None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
conn.close()
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def query_relationship(self, predicate: str, as_of: str = None):
|
def query_relationship(self, predicate: str, as_of: str = None):
|
||||||
@@ -262,15 +269,14 @@ class KnowledgeGraph:
|
|||||||
for row in conn.execute(query, params).fetchall():
|
for row in conn.execute(query, params).fetchall():
|
||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"subject": row[10],
|
"subject": row["sub_name"],
|
||||||
"predicate": pred,
|
"predicate": pred,
|
||||||
"object": row[11],
|
"object": row["obj_name"],
|
||||||
"valid_from": row[4],
|
"valid_from": row["valid_from"],
|
||||||
"valid_to": row[5],
|
"valid_to": row["valid_to"],
|
||||||
"current": row[5] is None,
|
"current": row["valid_to"] is None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
conn.close()
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def timeline(self, entity_name: str = None):
|
def timeline(self, entity_name: str = None):
|
||||||
@@ -300,15 +306,14 @@ class KnowledgeGraph:
|
|||||||
LIMIT 100
|
LIMIT 100
|
||||||
""").fetchall()
|
""").fetchall()
|
||||||
|
|
||||||
conn.close()
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"subject": r[10],
|
"subject": r["sub_name"],
|
||||||
"predicate": r[2],
|
"predicate": r["predicate"],
|
||||||
"object": r[11],
|
"object": r["obj_name"],
|
||||||
"valid_from": r[4],
|
"valid_from": r["valid_from"],
|
||||||
"valid_to": r[5],
|
"valid_to": r["valid_to"],
|
||||||
"current": r[5] is None,
|
"current": r["valid_to"] is None,
|
||||||
}
|
}
|
||||||
for r in rows
|
for r in rows
|
||||||
]
|
]
|
||||||
@@ -317,17 +322,18 @@ class KnowledgeGraph:
|
|||||||
|
|
||||||
def stats(self):
|
def stats(self):
|
||||||
conn = self._conn()
|
conn = self._conn()
|
||||||
entities = conn.execute("SELECT COUNT(*) FROM entities").fetchone()[0]
|
entities = conn.execute("SELECT COUNT(*) as cnt FROM entities").fetchone()["cnt"]
|
||||||
triples = conn.execute("SELECT COUNT(*) FROM triples").fetchone()[0]
|
triples = conn.execute("SELECT COUNT(*) as cnt FROM triples").fetchone()["cnt"]
|
||||||
current = conn.execute("SELECT COUNT(*) FROM triples WHERE valid_to IS NULL").fetchone()[0]
|
current = conn.execute(
|
||||||
|
"SELECT COUNT(*) as cnt FROM triples WHERE valid_to IS NULL"
|
||||||
|
).fetchone()["cnt"]
|
||||||
expired = triples - current
|
expired = triples - current
|
||||||
predicates = [
|
predicates = [
|
||||||
r[0]
|
r["predicate"]
|
||||||
for r in conn.execute(
|
for r in conn.execute(
|
||||||
"SELECT DISTINCT predicate FROM triples ORDER BY predicate"
|
"SELECT DISTINCT predicate FROM triples ORDER BY predicate"
|
||||||
).fetchall()
|
).fetchall()
|
||||||
]
|
]
|
||||||
conn.close()
|
|
||||||
return {
|
return {
|
||||||
"entities": entities,
|
"entities": entities,
|
||||||
"triples": triples,
|
"triples": triples,
|
||||||
|
|||||||
+145
-11
@@ -24,8 +24,9 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import hashlib
|
import hashlib
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from .config import MempalaceConfig
|
from .config import MempalaceConfig, sanitize_name, sanitize_content
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
from .searcher import search_memories
|
from .searcher import search_memories
|
||||||
from .palace_graph import traverse, find_tunnels, graph_stats
|
from .palace_graph import traverse, find_tunnels, graph_stats
|
||||||
@@ -44,7 +45,9 @@ def _parse_args():
|
|||||||
metavar="PATH",
|
metavar="PATH",
|
||||||
help="Path to the palace directory (overrides config file and env var)",
|
help="Path to the palace directory (overrides config file and env var)",
|
||||||
)
|
)
|
||||||
args, _ = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
if unknown:
|
||||||
|
logger.debug("Ignoring unknown args: %s", unknown)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
@@ -64,16 +67,60 @@ _client_cache = None
|
|||||||
_collection_cache = None
|
_collection_cache = None
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== WRITE-AHEAD LOG ====================
|
||||||
|
# Every write operation is logged to a JSONL file before execution.
|
||||||
|
# This provides an audit trail for detecting memory poisoning and
|
||||||
|
# enables review/rollback of writes from external or untrusted sources.
|
||||||
|
|
||||||
|
_WAL_DIR = Path(os.path.expanduser("~/.mempalace/wal"))
|
||||||
|
_WAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
_WAL_DIR.chmod(0o700)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass
|
||||||
|
_WAL_FILE = _WAL_DIR / "write_log.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def _wal_log(operation: str, params: dict, result: dict = None):
|
||||||
|
"""Append a write operation to the write-ahead log."""
|
||||||
|
entry = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"operation": operation,
|
||||||
|
"params": params,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
with open(_WAL_FILE, "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(entry, default=str) + "\n")
|
||||||
|
try:
|
||||||
|
_WAL_FILE.chmod(0o600)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WAL write failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
_client_cache = None
|
||||||
|
_collection_cache = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client():
|
||||||
|
"""Return a singleton ChromaDB PersistentClient."""
|
||||||
|
global _client_cache
|
||||||
|
if _client_cache is None:
|
||||||
|
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
|
||||||
|
return _client_cache
|
||||||
|
|
||||||
|
|
||||||
def _get_collection(create=False):
|
def _get_collection(create=False):
|
||||||
"""Return the ChromaDB collection, caching the client between calls."""
|
"""Return the ChromaDB collection, caching the client between calls."""
|
||||||
global _client_cache, _collection_cache
|
global _collection_cache
|
||||||
try:
|
try:
|
||||||
if _client_cache is None:
|
client = _get_client()
|
||||||
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
|
|
||||||
if create:
|
if create:
|
||||||
_collection_cache = _client_cache.get_or_create_collection(_config.collection_name)
|
_collection_cache = client.get_or_create_collection(_config.collection_name)
|
||||||
elif _collection_cache is None:
|
elif _collection_cache is None:
|
||||||
_collection_cache = _client_cache.get_collection(_config.collection_name)
|
_collection_cache = client.get_collection(_config.collection_name)
|
||||||
return _collection_cache
|
return _collection_cache
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
@@ -280,11 +327,30 @@ def tool_add_drawer(
|
|||||||
wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp"
|
wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp"
|
||||||
):
|
):
|
||||||
"""File verbatim content into a wing/room. Checks for duplicates first."""
|
"""File verbatim content into a wing/room. Checks for duplicates first."""
|
||||||
|
try:
|
||||||
|
wing = sanitize_name(wing, "wing")
|
||||||
|
room = sanitize_name(room, "room")
|
||||||
|
content = sanitize_content(content)
|
||||||
|
except ValueError as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
col = _get_collection(create=True)
|
col = _get_collection(create=True)
|
||||||
if not col:
|
if not col:
|
||||||
return _no_palace()
|
return _no_palace()
|
||||||
|
|
||||||
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}"
|
drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((wing + room + content[:100]).encode()).hexdigest()[:24]}"
|
||||||
|
|
||||||
|
_wal_log(
|
||||||
|
"add_drawer",
|
||||||
|
{
|
||||||
|
"drawer_id": drawer_id,
|
||||||
|
"wing": wing,
|
||||||
|
"room": room,
|
||||||
|
"added_by": added_by,
|
||||||
|
"content_length": len(content),
|
||||||
|
"content_preview": content[:200],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Idempotency: if the deterministic ID already exists, return success as a no-op.
|
# Idempotency: if the deterministic ID already exists, return success as a no-op.
|
||||||
try:
|
try:
|
||||||
@@ -323,6 +389,19 @@ def tool_delete_drawer(drawer_id: str):
|
|||||||
existing = col.get(ids=[drawer_id])
|
existing = col.get(ids=[drawer_id])
|
||||||
if not existing["ids"]:
|
if not existing["ids"]:
|
||||||
return {"success": False, "error": f"Drawer not found: {drawer_id}"}
|
return {"success": False, "error": f"Drawer not found: {drawer_id}"}
|
||||||
|
|
||||||
|
# Log the deletion with the content being removed for audit trail
|
||||||
|
deleted_content = existing.get("documents", [""])[0] if existing.get("documents") else ""
|
||||||
|
deleted_meta = existing.get("metadatas", [{}])[0] if existing.get("metadatas") else {}
|
||||||
|
_wal_log(
|
||||||
|
"delete_drawer",
|
||||||
|
{
|
||||||
|
"drawer_id": drawer_id,
|
||||||
|
"deleted_meta": deleted_meta,
|
||||||
|
"content_preview": deleted_content[:200],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
col.delete(ids=[drawer_id])
|
col.delete(ids=[drawer_id])
|
||||||
logger.info(f"Deleted drawer: {drawer_id}")
|
logger.info(f"Deleted drawer: {drawer_id}")
|
||||||
@@ -344,6 +423,23 @@ def tool_kg_add(
|
|||||||
subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None
|
subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None
|
||||||
):
|
):
|
||||||
"""Add a relationship to the knowledge graph."""
|
"""Add a relationship to the knowledge graph."""
|
||||||
|
try:
|
||||||
|
subject = sanitize_name(subject, "subject")
|
||||||
|
predicate = sanitize_name(predicate, "predicate")
|
||||||
|
object = sanitize_name(object, "object")
|
||||||
|
except ValueError as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
_wal_log(
|
||||||
|
"kg_add",
|
||||||
|
{
|
||||||
|
"subject": subject,
|
||||||
|
"predicate": predicate,
|
||||||
|
"object": object,
|
||||||
|
"valid_from": valid_from,
|
||||||
|
"source_closet": source_closet,
|
||||||
|
},
|
||||||
|
)
|
||||||
triple_id = _kg.add_triple(
|
triple_id = _kg.add_triple(
|
||||||
subject, predicate, object, valid_from=valid_from, source_closet=source_closet
|
subject, predicate, object, valid_from=valid_from, source_closet=source_closet
|
||||||
)
|
)
|
||||||
@@ -352,6 +448,10 @@ def tool_kg_add(
|
|||||||
|
|
||||||
def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None):
|
def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None):
|
||||||
"""Mark a fact as no longer true (set end date)."""
|
"""Mark a fact as no longer true (set end date)."""
|
||||||
|
_wal_log(
|
||||||
|
"kg_invalidate",
|
||||||
|
{"subject": subject, "predicate": predicate, "object": object, "ended": ended},
|
||||||
|
)
|
||||||
_kg.invalidate(subject, predicate, object, ended=ended)
|
_kg.invalidate(subject, predicate, object, ended=ended)
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -382,6 +482,12 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
|
|||||||
This is the agent's personal journal — observations, thoughts,
|
This is the agent's personal journal — observations, thoughts,
|
||||||
what it worked on, what it noticed, what it thinks matters.
|
what it worked on, what it noticed, what it thinks matters.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
agent_name = sanitize_name(agent_name, "agent_name")
|
||||||
|
entry = sanitize_content(entry)
|
||||||
|
except ValueError as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
wing = f"wing_{agent_name.lower().replace(' ', '_')}"
|
wing = f"wing_{agent_name.lower().replace(' ', '_')}"
|
||||||
room = "diary"
|
room = "diary"
|
||||||
col = _get_collection(create=True)
|
col = _get_collection(create=True)
|
||||||
@@ -389,9 +495,23 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
|
|||||||
return _no_palace()
|
return _no_palace()
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.md5(entry[:50].encode()).hexdigest()[:8]}"
|
entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.sha256(entry[:50].encode()).hexdigest()[:12]}"
|
||||||
|
|
||||||
|
_wal_log(
|
||||||
|
"diary_write",
|
||||||
|
{
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"topic": topic,
|
||||||
|
"entry_id": entry_id,
|
||||||
|
"entry_preview": entry[:200],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# TODO: Future versions should expand AAAK before embedding to improve
|
||||||
|
# semantic search quality. For now, store raw AAAK in metadata so it's
|
||||||
|
# preserved, and keep the document as-is for embedding (even though
|
||||||
|
# compressed AAAK degrades embedding quality).
|
||||||
col.add(
|
col.add(
|
||||||
ids=[entry_id],
|
ids=[entry_id],
|
||||||
documents=[entry],
|
documents=[entry],
|
||||||
@@ -717,17 +837,31 @@ TOOLS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTED_PROTOCOL_VERSIONS = [
|
||||||
|
"2025-11-25",
|
||||||
|
"2025-06-18",
|
||||||
|
"2025-03-26",
|
||||||
|
"2024-11-05",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def handle_request(request):
|
def handle_request(request):
|
||||||
method = request.get("method", "")
|
method = request.get("method", "")
|
||||||
params = request.get("params", {})
|
params = request.get("params", {})
|
||||||
req_id = request.get("id")
|
req_id = request.get("id")
|
||||||
|
|
||||||
if method == "initialize":
|
if method == "initialize":
|
||||||
|
client_version = params.get("protocolVersion", SUPPORTED_PROTOCOL_VERSIONS[-1])
|
||||||
|
negotiated = (
|
||||||
|
client_version
|
||||||
|
if client_version in SUPPORTED_PROTOCOL_VERSIONS
|
||||||
|
else SUPPORTED_PROTOCOL_VERSIONS[0]
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": req_id,
|
"id": req_id,
|
||||||
"result": {
|
"result": {
|
||||||
"protocolVersion": "2024-11-05",
|
"protocolVersion": negotiated,
|
||||||
"capabilities": {"tools": {}},
|
"capabilities": {"tools": {}},
|
||||||
"serverInfo": {"name": "mempalace", "version": __version__},
|
"serverInfo": {"name": "mempalace", "version": __version__},
|
||||||
},
|
},
|
||||||
@@ -747,7 +881,7 @@ def handle_request(request):
|
|||||||
}
|
}
|
||||||
elif method == "tools/call":
|
elif method == "tools/call":
|
||||||
tool_name = params.get("name")
|
tool_name = params.get("name")
|
||||||
tool_args = params.get("arguments", {})
|
tool_args = params.get("arguments") or {}
|
||||||
if tool_name not in TOOLS:
|
if tool_name not in TOOLS:
|
||||||
return {
|
return {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
|
|||||||
@@ -0,0 +1,214 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
mempalace migrate — Recover a palace created with a different ChromaDB version.
|
||||||
|
|
||||||
|
Reads documents and metadata directly from the palace's SQLite database
|
||||||
|
(bypassing ChromaDB's API, which fails on version-mismatched palaces),
|
||||||
|
then re-imports everything into a fresh palace using the currently installed
|
||||||
|
ChromaDB version.
|
||||||
|
|
||||||
|
This fixes the 3.0.0 → 3.1.0 upgrade path where chromadb was downgraded
|
||||||
|
from 1.5.x to 0.6.x, breaking the on-disk storage format.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
mempalace migrate # migrate default palace
|
||||||
|
mempalace migrate --palace /path/to/palace # migrate specific palace
|
||||||
|
mempalace migrate --dry-run # show what would be migrated
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sqlite3
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def extract_drawers_from_sqlite(db_path: str) -> list:
|
||||||
|
"""Read all drawers directly from ChromaDB's SQLite, bypassing the API.
|
||||||
|
|
||||||
|
Works regardless of which ChromaDB version created the database.
|
||||||
|
Returns list of dicts with 'id', 'document', and 'metadata' keys.
|
||||||
|
"""
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
# Get all embedding IDs and their documents
|
||||||
|
rows = conn.execute("""
|
||||||
|
SELECT e.embedding_id,
|
||||||
|
MAX(CASE WHEN em.key = 'chroma:document' THEN em.string_value END) as document
|
||||||
|
FROM embeddings e
|
||||||
|
JOIN embedding_metadata em ON em.id = e.id
|
||||||
|
GROUP BY e.embedding_id
|
||||||
|
""").fetchall()
|
||||||
|
|
||||||
|
drawers = []
|
||||||
|
for row in rows:
|
||||||
|
embedding_id = row["embedding_id"]
|
||||||
|
document = row["document"]
|
||||||
|
if not document:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get metadata for this embedding
|
||||||
|
meta_rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT em.key, em.string_value, em.int_value, em.float_value, em.bool_value
|
||||||
|
FROM embedding_metadata em
|
||||||
|
JOIN embeddings e ON e.id = em.id
|
||||||
|
WHERE e.embedding_id = ?
|
||||||
|
AND em.key NOT LIKE 'chroma:%'
|
||||||
|
""",
|
||||||
|
(embedding_id,),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
for mr in meta_rows:
|
||||||
|
key = mr["key"]
|
||||||
|
if mr["string_value"] is not None:
|
||||||
|
metadata[key] = mr["string_value"]
|
||||||
|
elif mr["int_value"] is not None:
|
||||||
|
metadata[key] = mr["int_value"]
|
||||||
|
elif mr["float_value"] is not None:
|
||||||
|
metadata[key] = mr["float_value"]
|
||||||
|
elif mr["bool_value"] is not None:
|
||||||
|
metadata[key] = bool(mr["bool_value"])
|
||||||
|
|
||||||
|
drawers.append(
|
||||||
|
{
|
||||||
|
"id": embedding_id,
|
||||||
|
"document": document,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
return drawers
|
||||||
|
|
||||||
|
|
||||||
|
def detect_chromadb_version(db_path: str) -> str:
|
||||||
|
"""Detect which ChromaDB version created the database by checking schema."""
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
try:
|
||||||
|
# 1.x has schema_str column in collections table
|
||||||
|
cols = [r[1] for r in conn.execute("PRAGMA table_info(collections)").fetchall()]
|
||||||
|
if "schema_str" in cols:
|
||||||
|
return "1.x"
|
||||||
|
# 0.6.x has embeddings_queue but no schema_str
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
|
||||||
|
]
|
||||||
|
if "embeddings_queue" in tables:
|
||||||
|
return "0.6.x"
|
||||||
|
return "unknown"
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(palace_path: str, dry_run: bool = False):
|
||||||
|
"""Migrate a palace to the currently installed ChromaDB version."""
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
palace_path = os.path.expanduser(palace_path)
|
||||||
|
db_path = os.path.join(palace_path, "chroma.sqlite3")
|
||||||
|
|
||||||
|
if not os.path.isfile(db_path):
|
||||||
|
print(f"\n No palace database found at {db_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print(" MemPalace Migrate")
|
||||||
|
print(f"{'=' * 60}\n")
|
||||||
|
print(f" Palace: {palace_path}")
|
||||||
|
print(f" Database: {db_path}")
|
||||||
|
print(f" DB size: {os.path.getsize(db_path) / 1024 / 1024:.1f} MB")
|
||||||
|
|
||||||
|
# Detect version
|
||||||
|
source_version = detect_chromadb_version(db_path)
|
||||||
|
print(f" Source: ChromaDB {source_version}")
|
||||||
|
print(f" Target: ChromaDB {chromadb.__version__}")
|
||||||
|
|
||||||
|
# Try reading with current chromadb first
|
||||||
|
try:
|
||||||
|
client = chromadb.PersistentClient(path=palace_path)
|
||||||
|
col = client.get_collection("mempalace_drawers")
|
||||||
|
count = col.count()
|
||||||
|
print(f"\n Palace is already readable by chromadb {chromadb.__version__}.")
|
||||||
|
print(f" {count} drawers found. No migration needed.")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
print(f"\n Palace is NOT readable by chromadb {chromadb.__version__}.")
|
||||||
|
print(" Extracting from SQLite directly...")
|
||||||
|
|
||||||
|
# Extract all drawers via raw SQL
|
||||||
|
drawers = extract_drawers_from_sqlite(db_path)
|
||||||
|
print(f" Extracted {len(drawers)} drawers from SQLite")
|
||||||
|
|
||||||
|
if not drawers:
|
||||||
|
print(" Nothing to migrate.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Show summary
|
||||||
|
wings = defaultdict(lambda: defaultdict(int))
|
||||||
|
for d in drawers:
|
||||||
|
w = d["metadata"].get("wing", "?")
|
||||||
|
r = d["metadata"].get("room", "?")
|
||||||
|
wings[w][r] += 1
|
||||||
|
|
||||||
|
print("\n Summary:")
|
||||||
|
for wing, rooms in sorted(wings.items()):
|
||||||
|
total = sum(rooms.values())
|
||||||
|
print(f" WING: {wing} ({total} drawers)")
|
||||||
|
for room, count in sorted(rooms.items(), key=lambda x: -x[1]):
|
||||||
|
print(f" ROOM: {room:30} {count:5}")
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
print("\n DRY RUN — no changes made.")
|
||||||
|
print(f" Would migrate {len(drawers)} drawers.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Backup the old palace
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
backup_path = f"{palace_path}.pre-migrate.{timestamp}"
|
||||||
|
print(f"\n Backing up to {backup_path}...")
|
||||||
|
shutil.copytree(palace_path, backup_path)
|
||||||
|
|
||||||
|
# Build fresh palace in a temp directory (avoids chromadb reading old state)
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_")
|
||||||
|
print(f" Creating fresh palace in {temp_palace}...")
|
||||||
|
client = chromadb.PersistentClient(path=temp_palace)
|
||||||
|
col = client.get_or_create_collection("mempalace_drawers")
|
||||||
|
|
||||||
|
# Re-import in batches
|
||||||
|
batch_size = 500
|
||||||
|
imported = 0
|
||||||
|
for i in range(0, len(drawers), batch_size):
|
||||||
|
batch = drawers[i : i + batch_size]
|
||||||
|
col.add(
|
||||||
|
ids=[d["id"] for d in batch],
|
||||||
|
documents=[d["document"] for d in batch],
|
||||||
|
metadatas=[d["metadata"] for d in batch],
|
||||||
|
)
|
||||||
|
imported += len(batch)
|
||||||
|
print(f" Imported {imported}/{len(drawers)} drawers...")
|
||||||
|
|
||||||
|
# Verify before swapping
|
||||||
|
final_count = col.count()
|
||||||
|
del col
|
||||||
|
del client
|
||||||
|
|
||||||
|
# Swap: remove old palace, move new one into place
|
||||||
|
print(" Swapping old palace for migrated version...")
|
||||||
|
shutil.rmtree(palace_path)
|
||||||
|
shutil.move(temp_palace, palace_path)
|
||||||
|
|
||||||
|
print("\n Migration complete.")
|
||||||
|
print(f" Drawers migrated: {final_count}")
|
||||||
|
print(f" Backup at: {backup_path}")
|
||||||
|
|
||||||
|
if final_count != len(drawers):
|
||||||
|
print(f" WARNING: Expected {len(drawers)}, got {final_count}")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 60}\n")
|
||||||
|
return True
|
||||||
+14
-58
@@ -17,6 +17,8 @@ from collections import defaultdict
|
|||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
|
|
||||||
|
from .palace import SKIP_DIRS, get_collection, file_already_mined
|
||||||
|
|
||||||
READABLE_EXTENSIONS = {
|
READABLE_EXTENSIONS = {
|
||||||
".txt",
|
".txt",
|
||||||
".md",
|
".md",
|
||||||
@@ -40,32 +42,6 @@ READABLE_EXTENSIONS = {
|
|||||||
".toml",
|
".toml",
|
||||||
}
|
}
|
||||||
|
|
||||||
SKIP_DIRS = {
|
|
||||||
".git",
|
|
||||||
"node_modules",
|
|
||||||
"__pycache__",
|
|
||||||
".venv",
|
|
||||||
"venv",
|
|
||||||
"env",
|
|
||||||
"dist",
|
|
||||||
"build",
|
|
||||||
".next",
|
|
||||||
"coverage",
|
|
||||||
".mempalace",
|
|
||||||
".ruff_cache",
|
|
||||||
".mypy_cache",
|
|
||||||
".pytest_cache",
|
|
||||||
".cache",
|
|
||||||
".tox",
|
|
||||||
".nox",
|
|
||||||
".idea",
|
|
||||||
".vscode",
|
|
||||||
".ipynb_checkpoints",
|
|
||||||
".eggs",
|
|
||||||
"htmlcov",
|
|
||||||
"target",
|
|
||||||
}
|
|
||||||
|
|
||||||
SKIP_FILENAMES = {
|
SKIP_FILENAMES = {
|
||||||
"mempalace.yaml",
|
"mempalace.yaml",
|
||||||
"mempalace.yml",
|
"mempalace.yml",
|
||||||
@@ -78,6 +54,7 @@ SKIP_FILENAMES = {
|
|||||||
CHUNK_SIZE = 800 # chars per drawer
|
CHUNK_SIZE = 800 # chars per drawer
|
||||||
CHUNK_OVERLAP = 100 # overlap between chunks
|
CHUNK_OVERLAP = 100 # overlap between chunks
|
||||||
MIN_CHUNK_SIZE = 50 # skip tiny chunks
|
MIN_CHUNK_SIZE = 50 # skip tiny chunks
|
||||||
|
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -393,41 +370,11 @@ def chunk_text(content: str, source_file: str) -> list:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def get_collection(palace_path: str):
|
|
||||||
os.makedirs(palace_path, exist_ok=True)
|
|
||||||
client = chromadb.PersistentClient(path=palace_path)
|
|
||||||
try:
|
|
||||||
return client.get_collection("mempalace_drawers")
|
|
||||||
except Exception:
|
|
||||||
return client.create_collection("mempalace_drawers")
|
|
||||||
|
|
||||||
|
|
||||||
def file_already_mined(collection, source_file: str) -> bool:
|
|
||||||
"""Fast check: has this file been filed before and is unchanged?
|
|
||||||
|
|
||||||
Compares the stored mtime in drawer metadata against the file's current
|
|
||||||
mtime. Returns False (needs re-mining) when the file has been modified
|
|
||||||
since it was last mined, or when no mtime was stored.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
results = collection.get(where={"source_file": source_file}, limit=1)
|
|
||||||
if not results.get("ids"):
|
|
||||||
return False
|
|
||||||
stored_meta = results["metadatas"][0] if results.get("metadatas") else {}
|
|
||||||
stored_mtime = stored_meta.get("source_mtime")
|
|
||||||
if stored_mtime is None:
|
|
||||||
return False
|
|
||||||
current_mtime = os.path.getmtime(source_file)
|
|
||||||
return float(stored_mtime) == current_mtime
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def add_drawer(
|
def add_drawer(
|
||||||
collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str
|
collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str
|
||||||
):
|
):
|
||||||
"""Add one drawer to the palace."""
|
"""Add one drawer to the palace."""
|
||||||
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}"
|
drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((source_file + str(chunk_index)).encode()).hexdigest()[:24]}"
|
||||||
try:
|
try:
|
||||||
metadata = {
|
metadata = {
|
||||||
"wing": wing,
|
"wing": wing,
|
||||||
@@ -470,7 +417,7 @@ def process_file(
|
|||||||
|
|
||||||
# Skip if already filed
|
# Skip if already filed
|
||||||
source_file = str(filepath)
|
source_file = str(filepath)
|
||||||
if not dry_run and file_already_mined(collection, source_file):
|
if not dry_run and file_already_mined(collection, source_file, check_mtime=True):
|
||||||
return 0, None
|
return 0, None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -562,6 +509,15 @@ def scan_project(
|
|||||||
if respect_gitignore and active_matchers and not force_include:
|
if respect_gitignore and active_matchers and not force_include:
|
||||||
if is_gitignored(filepath, active_matchers, is_dir=False):
|
if is_gitignored(filepath, active_matchers, is_dir=False):
|
||||||
continue
|
continue
|
||||||
|
# Skip symlinks — prevents following links to /dev/urandom, etc.
|
||||||
|
if filepath.is_symlink():
|
||||||
|
continue
|
||||||
|
# Skip files exceeding size limit
|
||||||
|
try:
|
||||||
|
if filepath.stat().st_size > MAX_FILE_SIZE:
|
||||||
|
continue
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
files.append(filepath)
|
files.append(filepath)
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,12 @@ def normalize(filepath: str) -> str:
|
|||||||
Load a file and normalize to transcript format if it's a chat export.
|
Load a file and normalize to transcript format if it's a chat export.
|
||||||
Plain text files pass through unchanged.
|
Plain text files pass through unchanged.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
file_size = os.path.getsize(filepath)
|
||||||
|
except OSError as e:
|
||||||
|
raise IOError(f"Could not read {filepath}: {e}")
|
||||||
|
if file_size > 500 * 1024 * 1024: # 500 MB safety limit
|
||||||
|
raise IOError(f"File too large ({file_size // (1024*1024)} MB): {filepath}")
|
||||||
try:
|
try:
|
||||||
with open(filepath, "r", encoding="utf-8", errors="replace") as f:
|
with open(filepath, "r", encoding="utf-8", errors="replace") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|||||||
@@ -312,7 +312,7 @@ def _generate_aaak_bootstrap(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
(mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines))
|
(mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines), encoding="utf-8")
|
||||||
|
|
||||||
# Critical facts bootstrap (pre-palace — before any mining)
|
# Critical facts bootstrap (pre-palace — before any mining)
|
||||||
facts_lines = [
|
facts_lines = [
|
||||||
@@ -359,7 +359,7 @@ def _generate_aaak_bootstrap(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
(mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines))
|
(mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines), encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
def run_onboarding(
|
def run_onboarding(
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
"""
|
||||||
|
palace.py — Shared palace operations.
|
||||||
|
|
||||||
|
Consolidates ChromaDB access patterns used by both miners and the MCP server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
SKIP_DIRS = {
|
||||||
|
".git",
|
||||||
|
"node_modules",
|
||||||
|
"__pycache__",
|
||||||
|
".venv",
|
||||||
|
"venv",
|
||||||
|
"env",
|
||||||
|
"dist",
|
||||||
|
"build",
|
||||||
|
".next",
|
||||||
|
"coverage",
|
||||||
|
".mempalace",
|
||||||
|
".ruff_cache",
|
||||||
|
".mypy_cache",
|
||||||
|
".pytest_cache",
|
||||||
|
".cache",
|
||||||
|
".tox",
|
||||||
|
".nox",
|
||||||
|
".idea",
|
||||||
|
".vscode",
|
||||||
|
".ipynb_checkpoints",
|
||||||
|
".eggs",
|
||||||
|
"htmlcov",
|
||||||
|
"target",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_collection(palace_path: str, collection_name: str = "mempalace_drawers"):
|
||||||
|
"""Get or create the palace ChromaDB collection."""
|
||||||
|
os.makedirs(palace_path, exist_ok=True)
|
||||||
|
try:
|
||||||
|
os.chmod(palace_path, 0o700)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass
|
||||||
|
client = chromadb.PersistentClient(path=palace_path)
|
||||||
|
try:
|
||||||
|
return client.get_collection(collection_name)
|
||||||
|
except Exception:
|
||||||
|
return client.create_collection(collection_name)
|
||||||
|
|
||||||
|
|
||||||
|
def file_already_mined(collection, source_file: str, check_mtime: bool = False) -> bool:
|
||||||
|
"""Check if a file has already been filed in the palace.
|
||||||
|
|
||||||
|
When check_mtime=True (used by project miner), returns False if the file
|
||||||
|
has been modified since it was last mined, so it gets re-mined.
|
||||||
|
When check_mtime=False (used by convo miner), just checks existence.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
results = collection.get(where={"source_file": source_file}, limit=1)
|
||||||
|
if not results.get("ids"):
|
||||||
|
return False
|
||||||
|
if check_mtime:
|
||||||
|
stored_meta = results.get("metadatas", [{}])[0]
|
||||||
|
stored_mtime = stored_meta.get("source_mtime")
|
||||||
|
if stored_mtime is None:
|
||||||
|
return False
|
||||||
|
current_mtime = os.path.getmtime(source_file)
|
||||||
|
return float(stored_mtime) == current_mtime
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
@@ -182,6 +182,10 @@ def split_file(filepath, output_dir, dry_run=False):
|
|||||||
Returns list of output paths written (or would be written if dry_run).
|
Returns list of output paths written (or would be written if dry_run).
|
||||||
"""
|
"""
|
||||||
path = Path(filepath)
|
path = Path(filepath)
|
||||||
|
max_size = 500 * 1024 * 1024 # 500 MB safety limit
|
||||||
|
if path.stat().st_size > max_size:
|
||||||
|
print(f" SKIP: {path.name} exceeds {max_size // (1024*1024)} MB limit")
|
||||||
|
return []
|
||||||
lines = path.read_text(errors="replace").splitlines(keepends=True)
|
lines = path.read_text(errors="replace").splitlines(keepends=True)
|
||||||
|
|
||||||
boundaries = find_session_boundaries(lines)
|
boundaries = find_session_boundaries(lines)
|
||||||
@@ -219,7 +223,7 @@ def split_file(filepath, output_dir, dry_run=False):
|
|||||||
if dry_run:
|
if dry_run:
|
||||||
print(f" [{i + 1}/{len(boundaries) - 1}] {name} ({len(chunk)} lines)")
|
print(f" [{i + 1}/{len(boundaries) - 1}] {name} ({len(chunk)} lines)")
|
||||||
else:
|
else:
|
||||||
out_path.write_text("".join(chunk))
|
out_path.write_text("".join(chunk), encoding="utf-8")
|
||||||
print(f" ✓ {name} ({len(chunk)} lines)")
|
print(f" ✓ {name} ({len(chunk)} lines)")
|
||||||
|
|
||||||
written.append(out_path)
|
written.append(out_path)
|
||||||
@@ -266,7 +270,11 @@ def main():
|
|||||||
files = sorted(src_dir.glob("*.txt"))
|
files = sorted(src_dir.glob("*.txt"))
|
||||||
|
|
||||||
mega_files = []
|
mega_files = []
|
||||||
|
max_scan_size = 500 * 1024 * 1024 # 500 MB
|
||||||
for f in files:
|
for f in files:
|
||||||
|
if f.stat().st_size > max_scan_size:
|
||||||
|
print(f" SKIP: {f.name} exceeds {max_scan_size // (1024*1024)} MB limit")
|
||||||
|
continue
|
||||||
lines = f.read_text(errors="replace").splitlines(keepends=True)
|
lines = f.read_text(errors="replace").splitlines(keepends=True)
|
||||||
boundaries = find_session_boundaries(lines)
|
boundaries = find_session_boundaries(lines)
|
||||||
if len(boundaries) >= args.min_sessions:
|
if len(boundaries) >= args.min_sessions:
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
"""Single source of truth for the MemPalace package version."""
|
"""Single source of truth for the MemPalace package version."""
|
||||||
|
|
||||||
__version__ = "3.0.14"
|
__version__ = "3.1.0"
|
||||||
|
|||||||
+8
-4
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "mempalace"
|
name = "mempalace"
|
||||||
version = "3.0.14"
|
version = "3.1.0"
|
||||||
description = "Give your AI a memory — mine projects and conversations into a searchable palace. No API key required."
|
description = "Give your AI a memory — mine projects and conversations into a searchable palace. No API key required."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
@@ -26,7 +26,7 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"chromadb>=0.5.0,<0.7",
|
"chromadb>=0.5.0,<0.7",
|
||||||
"pyyaml>=6.0",
|
"pyyaml>=6.0,<7",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
@@ -54,11 +54,15 @@ packages = ["mempalace"]
|
|||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 100
|
line-length = 100
|
||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
extend-exclude = ["benchmarks"]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["E", "F", "W"]
|
select = ["E", "F", "W", "C901"]
|
||||||
ignore = ["E501"]
|
ignore = ["E501"]
|
||||||
|
|
||||||
|
[tool.ruff.lint.mccabe]
|
||||||
|
max-complexity = 25
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
quote-style = "double"
|
quote-style = "double"
|
||||||
|
|
||||||
@@ -76,7 +80,7 @@ markers = [
|
|||||||
source = ["mempalace"]
|
source = ["mempalace"]
|
||||||
|
|
||||||
[tool.coverage.report]
|
[tool.coverage.report]
|
||||||
fail_under = 30
|
fail_under = 85
|
||||||
show_missing = true
|
show_missing = true
|
||||||
exclude_lines = [
|
exclude_lines = [
|
||||||
"if __name__",
|
"if __name__",
|
||||||
|
|||||||
@@ -148,9 +148,9 @@ class TestWakeUpTokenBudget:
|
|||||||
record_metric("wakeup_budget", f"tokens_at_{n_drawers}", token_estimate)
|
record_metric("wakeup_budget", f"tokens_at_{n_drawers}", token_estimate)
|
||||||
record_metric("wakeup_budget", f"chars_at_{n_drawers}", len(text))
|
record_metric("wakeup_budget", f"chars_at_{n_drawers}", len(text))
|
||||||
|
|
||||||
assert token_estimate < 1200, (
|
assert (
|
||||||
f"Wake-up exceeded budget: ~{token_estimate} tokens at {n_drawers} drawers"
|
token_estimate < 1200
|
||||||
)
|
), f"Wake-up exceeded budget: ~{token_estimate} tokens at {n_drawers} drawers"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
|
|||||||
@@ -0,0 +1,652 @@
|
|||||||
|
"""Tests for mempalace.cli — the main CLI dispatcher."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mempalace.cli import (
|
||||||
|
cmd_compress,
|
||||||
|
cmd_hook,
|
||||||
|
cmd_init,
|
||||||
|
cmd_instructions,
|
||||||
|
cmd_mine,
|
||||||
|
cmd_repair,
|
||||||
|
cmd_search,
|
||||||
|
cmd_split,
|
||||||
|
cmd_status,
|
||||||
|
cmd_wakeup,
|
||||||
|
main,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── cmd_status ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_status_default_palace(mock_config_cls):
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
args = argparse.Namespace(palace=None)
|
||||||
|
mock_miner = MagicMock()
|
||||||
|
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
|
||||||
|
cmd_status(args)
|
||||||
|
mock_miner.status.assert_called_once_with(palace_path="/fake/palace")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_status_custom_palace(mock_config_cls):
|
||||||
|
args = argparse.Namespace(palace="~/my_palace")
|
||||||
|
mock_miner = MagicMock()
|
||||||
|
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
|
||||||
|
cmd_status(args)
|
||||||
|
import os
|
||||||
|
|
||||||
|
expected = os.path.expanduser("~/my_palace")
|
||||||
|
mock_miner.status.assert_called_once_with(palace_path=expected)
|
||||||
|
|
||||||
|
|
||||||
|
# ── cmd_search ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_search_calls_search(mock_config_cls):
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
args = argparse.Namespace(
|
||||||
|
palace=None, query="test query", wing="mywing", room="myroom", results=3
|
||||||
|
)
|
||||||
|
with patch("mempalace.searcher.search") as mock_search:
|
||||||
|
cmd_search(args)
|
||||||
|
mock_search.assert_called_once_with(
|
||||||
|
query="test query",
|
||||||
|
palace_path="/fake/palace",
|
||||||
|
wing="mywing",
|
||||||
|
room="myroom",
|
||||||
|
n_results=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_search_error_exits(mock_config_cls):
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
args = argparse.Namespace(palace=None, query="q", wing=None, room=None, results=5)
|
||||||
|
from mempalace.searcher import SearchError
|
||||||
|
|
||||||
|
with patch("mempalace.searcher.search", side_effect=SearchError("fail")):
|
||||||
|
with pytest.raises(SystemExit) as exc_info:
|
||||||
|
cmd_search(args)
|
||||||
|
assert exc_info.value.code == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── cmd_instructions ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_cmd_instructions_calls_run_instructions():
|
||||||
|
args = argparse.Namespace(name="help")
|
||||||
|
with patch("mempalace.instructions_cli.run_instructions") as mock_run:
|
||||||
|
cmd_instructions(args)
|
||||||
|
mock_run.assert_called_once_with(name="help")
|
||||||
|
|
||||||
|
|
||||||
|
# ── cmd_hook ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_cmd_hook_calls_run_hook():
|
||||||
|
args = argparse.Namespace(hook="session-start", harness="claude-code")
|
||||||
|
with patch("mempalace.hooks_cli.run_hook") as mock_run:
|
||||||
|
cmd_hook(args)
|
||||||
|
mock_run.assert_called_once_with(hook_name="session-start", harness="claude-code")
|
||||||
|
|
||||||
|
|
||||||
|
# ── cmd_init ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_init_no_entities(mock_config_cls, tmp_path):
|
||||||
|
args = argparse.Namespace(dir=str(tmp_path), yes=True)
|
||||||
|
with (
|
||||||
|
patch("mempalace.entity_detector.scan_for_detection", return_value=[]),
|
||||||
|
patch("mempalace.room_detector_local.detect_rooms_local") as mock_rooms,
|
||||||
|
):
|
||||||
|
cmd_init(args)
|
||||||
|
mock_rooms.assert_called_once_with(project_dir=str(tmp_path), yes=True)
|
||||||
|
mock_config_cls.return_value.init.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_init_with_entities(mock_config_cls, tmp_path):
|
||||||
|
fake_files = [tmp_path / "a.txt"]
|
||||||
|
detected = {"people": [{"name": "Alice"}], "projects": [], "uncertain": []}
|
||||||
|
confirmed = {"people": ["Alice"], "projects": []}
|
||||||
|
args = argparse.Namespace(dir=str(tmp_path), yes=True)
|
||||||
|
with (
|
||||||
|
patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files),
|
||||||
|
patch("mempalace.entity_detector.detect_entities", return_value=detected),
|
||||||
|
patch("mempalace.entity_detector.confirm_entities", return_value=confirmed),
|
||||||
|
patch("mempalace.room_detector_local.detect_rooms_local"),
|
||||||
|
patch("builtins.open", MagicMock()),
|
||||||
|
):
|
||||||
|
cmd_init(args)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_init_with_entities_zero_total(mock_config_cls, tmp_path, capsys):
|
||||||
|
"""When entities detected but total is 0, prints 'No entities' message."""
|
||||||
|
fake_files = [tmp_path / "a.txt"]
|
||||||
|
detected = {"people": [], "projects": [], "uncertain": []}
|
||||||
|
args = argparse.Namespace(dir=str(tmp_path), yes=False)
|
||||||
|
with (
|
||||||
|
patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files),
|
||||||
|
patch("mempalace.entity_detector.detect_entities", return_value=detected),
|
||||||
|
patch("mempalace.room_detector_local.detect_rooms_local"),
|
||||||
|
):
|
||||||
|
cmd_init(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "No entities detected" in out
|
||||||
|
|
||||||
|
|
||||||
|
# ── cmd_mine ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_mine_projects_mode(mock_config_cls):
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
args = argparse.Namespace(
|
||||||
|
dir="/src",
|
||||||
|
palace=None,
|
||||||
|
mode="projects",
|
||||||
|
wing=None,
|
||||||
|
agent="mempalace",
|
||||||
|
limit=0,
|
||||||
|
dry_run=False,
|
||||||
|
no_gitignore=False,
|
||||||
|
include_ignored=[],
|
||||||
|
extract="exchange",
|
||||||
|
)
|
||||||
|
with patch("mempalace.miner.mine") as mock_mine:
|
||||||
|
cmd_mine(args)
|
||||||
|
mock_mine.assert_called_once_with(
|
||||||
|
project_dir="/src",
|
||||||
|
palace_path="/fake/palace",
|
||||||
|
wing_override=None,
|
||||||
|
agent="mempalace",
|
||||||
|
limit=0,
|
||||||
|
dry_run=False,
|
||||||
|
respect_gitignore=True,
|
||||||
|
include_ignored=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_mine_convos_mode(mock_config_cls):
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
args = argparse.Namespace(
|
||||||
|
dir="/chats",
|
||||||
|
palace=None,
|
||||||
|
mode="convos",
|
||||||
|
wing="mywing",
|
||||||
|
agent="me",
|
||||||
|
limit=10,
|
||||||
|
dry_run=True,
|
||||||
|
no_gitignore=False,
|
||||||
|
include_ignored=[],
|
||||||
|
extract="general",
|
||||||
|
)
|
||||||
|
with patch("mempalace.convo_miner.mine_convos") as mock_mine:
|
||||||
|
cmd_mine(args)
|
||||||
|
mock_mine.assert_called_once_with(
|
||||||
|
convo_dir="/chats",
|
||||||
|
palace_path="/fake/palace",
|
||||||
|
wing="mywing",
|
||||||
|
agent="me",
|
||||||
|
limit=10,
|
||||||
|
dry_run=True,
|
||||||
|
extract_mode="general",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_mine_include_ignored_comma_split(mock_config_cls):
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
args = argparse.Namespace(
|
||||||
|
dir="/src",
|
||||||
|
palace=None,
|
||||||
|
mode="projects",
|
||||||
|
wing=None,
|
||||||
|
agent="mempalace",
|
||||||
|
limit=0,
|
||||||
|
dry_run=False,
|
||||||
|
no_gitignore=False,
|
||||||
|
include_ignored=["a.txt,b.txt", "c.txt"],
|
||||||
|
extract="exchange",
|
||||||
|
)
|
||||||
|
with patch("mempalace.miner.mine") as mock_mine:
|
||||||
|
cmd_mine(args)
|
||||||
|
mock_mine.assert_called_once()
|
||||||
|
call_kwargs = mock_mine.call_args[1]
|
||||||
|
assert call_kwargs["include_ignored"] == ["a.txt", "b.txt", "c.txt"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── cmd_wakeup ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_wakeup(mock_config_cls, capsys):
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
args = argparse.Namespace(palace=None, wing=None)
|
||||||
|
mock_stack = MagicMock()
|
||||||
|
mock_stack.wake_up.return_value = "Hello world context"
|
||||||
|
with patch("mempalace.layers.MemoryStack", return_value=mock_stack):
|
||||||
|
cmd_wakeup(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "Hello world context" in out
|
||||||
|
assert "tokens" in out
|
||||||
|
|
||||||
|
|
||||||
|
# ── cmd_split ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_cmd_split_basic():
|
||||||
|
args = argparse.Namespace(dir="/chats", output_dir=None, dry_run=False, min_sessions=2)
|
||||||
|
with patch("mempalace.split_mega_files.main") as mock_main:
|
||||||
|
cmd_split(args)
|
||||||
|
mock_main.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cmd_split_all_options():
|
||||||
|
args = argparse.Namespace(dir="/chats", output_dir="/out", dry_run=True, min_sessions=5)
|
||||||
|
with patch("mempalace.split_mega_files.main") as mock_main:
|
||||||
|
cmd_split(args)
|
||||||
|
mock_main.assert_called_once()
|
||||||
|
# sys.argv should be restored
|
||||||
|
assert sys.argv[0] != "mempalace split"
|
||||||
|
|
||||||
|
|
||||||
|
# ── main() argparse dispatch ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_no_args_prints_help(capsys):
|
||||||
|
with patch("sys.argv", ["mempalace"]):
|
||||||
|
main()
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "MemPalace" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_status_dispatches():
|
||||||
|
with (
|
||||||
|
patch("sys.argv", ["mempalace", "status"]),
|
||||||
|
patch("mempalace.cli.cmd_status") as mock_cmd,
|
||||||
|
):
|
||||||
|
main()
|
||||||
|
mock_cmd.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_search_dispatches():
|
||||||
|
with (
|
||||||
|
patch("sys.argv", ["mempalace", "search", "my query"]),
|
||||||
|
patch("mempalace.cli.cmd_search") as mock_cmd,
|
||||||
|
):
|
||||||
|
main()
|
||||||
|
mock_cmd.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_init_dispatches():
|
||||||
|
with (
|
||||||
|
patch("sys.argv", ["mempalace", "init", "/some/dir"]),
|
||||||
|
patch("mempalace.cli.cmd_init") as mock_cmd,
|
||||||
|
):
|
||||||
|
main()
|
||||||
|
mock_cmd.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_mine_dispatches():
|
||||||
|
with (
|
||||||
|
patch("sys.argv", ["mempalace", "mine", "/some/dir"]),
|
||||||
|
patch("mempalace.cli.cmd_mine") as mock_cmd,
|
||||||
|
):
|
||||||
|
main()
|
||||||
|
mock_cmd.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_wakeup_dispatches():
|
||||||
|
with (
|
||||||
|
patch("sys.argv", ["mempalace", "wake-up"]),
|
||||||
|
patch("mempalace.cli.cmd_wakeup") as mock_cmd,
|
||||||
|
):
|
||||||
|
main()
|
||||||
|
mock_cmd.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_split_dispatches():
|
||||||
|
with (
|
||||||
|
patch("sys.argv", ["mempalace", "split", "/chats"]),
|
||||||
|
patch("mempalace.cli.cmd_split") as mock_cmd,
|
||||||
|
):
|
||||||
|
main()
|
||||||
|
mock_cmd.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_command_prints_setup_guidance(monkeypatch, capsys):
|
||||||
|
monkeypatch.setattr(sys, "argv", ["mempalace", "mcp"])
|
||||||
|
|
||||||
|
main()
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "MemPalace MCP quick setup:" in captured.out
|
||||||
|
assert "claude mcp add mempalace -- python -m mempalace.mcp_server" in captured.out
|
||||||
|
assert "\nOptional custom palace:\n" in captured.out
|
||||||
|
assert "python -m mempalace.mcp_server --palace /path/to/palace" in captured.out
|
||||||
|
assert "[--palace /path/to/palace]" not in captured.out
|
||||||
|
assert captured.err == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_command_uses_custom_palace_path_when_provided(monkeypatch, capsys):
|
||||||
|
monkeypatch.setattr(sys, "argv", ["mempalace", "--palace", "~/tmp/my palace", "mcp"])
|
||||||
|
|
||||||
|
main()
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
expanded = str(Path("~/tmp/my palace").expanduser())
|
||||||
|
|
||||||
|
assert "python -m mempalace.mcp_server --palace" in captured.out
|
||||||
|
assert expanded in captured.out
|
||||||
|
assert "Optional custom palace:" not in captured.out
|
||||||
|
assert "[--palace /path/to/palace]" not in captured.out
|
||||||
|
assert captured.err == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_hook_no_subcommand_prints_help(capsys):
|
||||||
|
with patch("sys.argv", ["mempalace", "hook"]):
|
||||||
|
main()
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "hook" in out.lower() or "run" in out.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_hook_run_dispatches():
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"sys.argv",
|
||||||
|
["mempalace", "hook", "run", "--hook", "session-start", "--harness", "claude-code"],
|
||||||
|
),
|
||||||
|
patch("mempalace.cli.cmd_hook") as mock_cmd,
|
||||||
|
):
|
||||||
|
main()
|
||||||
|
mock_cmd.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_instructions_no_subcommand_prints_help(capsys):
|
||||||
|
with patch("sys.argv", ["mempalace", "instructions"]):
|
||||||
|
main()
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "instructions" in out.lower() or "init" in out.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_instructions_dispatches():
|
||||||
|
with (
|
||||||
|
patch("sys.argv", ["mempalace", "instructions", "help"]),
|
||||||
|
patch("mempalace.cli.cmd_instructions") as mock_cmd,
|
||||||
|
):
|
||||||
|
main()
|
||||||
|
mock_cmd.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_repair_dispatches():
|
||||||
|
with (
|
||||||
|
patch("sys.argv", ["mempalace", "repair"]),
|
||||||
|
patch("mempalace.cli.cmd_repair") as mock_cmd,
|
||||||
|
):
|
||||||
|
main()
|
||||||
|
mock_cmd.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_compress_dispatches():
|
||||||
|
with (
|
||||||
|
patch("sys.argv", ["mempalace", "compress"]),
|
||||||
|
patch("mempalace.cli.cmd_compress") as mock_cmd,
|
||||||
|
):
|
||||||
|
main()
|
||||||
|
mock_cmd.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ── cmd_repair ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_repair_no_palace(mock_config_cls, tmp_path, capsys):
|
||||||
|
mock_config_cls.return_value.palace_path = str(tmp_path / "nonexistent")
|
||||||
|
args = argparse.Namespace(palace=None)
|
||||||
|
mock_chromadb = MagicMock()
|
||||||
|
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
|
||||||
|
cmd_repair(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "No palace found" in out
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys):
|
||||||
|
palace_dir = tmp_path / "palace"
|
||||||
|
palace_dir.mkdir()
|
||||||
|
mock_config_cls.return_value.palace_path = str(palace_dir)
|
||||||
|
args = argparse.Namespace(palace=None)
|
||||||
|
mock_chromadb = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.side_effect = Exception("corrupt db")
|
||||||
|
mock_chromadb.PersistentClient.return_value = mock_client
|
||||||
|
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
|
||||||
|
cmd_repair(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "Error reading palace" in out
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys):
|
||||||
|
palace_dir = tmp_path / "palace"
|
||||||
|
palace_dir.mkdir()
|
||||||
|
mock_config_cls.return_value.palace_path = str(palace_dir)
|
||||||
|
args = argparse.Namespace(palace=None)
|
||||||
|
mock_chromadb = MagicMock()
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.count.return_value = 0
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
mock_chromadb.PersistentClient.return_value = mock_client
|
||||||
|
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
|
||||||
|
cmd_repair(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "Nothing to repair" in out
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_repair_success(mock_config_cls, tmp_path, capsys):
|
||||||
|
palace_dir = tmp_path / "palace"
|
||||||
|
palace_dir.mkdir()
|
||||||
|
mock_config_cls.return_value.palace_path = str(palace_dir)
|
||||||
|
args = argparse.Namespace(palace=None)
|
||||||
|
mock_chromadb = MagicMock()
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.count.return_value = 2
|
||||||
|
mock_col.get.return_value = {
|
||||||
|
"ids": ["id1", "id2"],
|
||||||
|
"documents": ["doc1", "doc2"],
|
||||||
|
"metadatas": [{"wing": "a"}, {"wing": "b"}],
|
||||||
|
}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
mock_new_col = MagicMock()
|
||||||
|
mock_client.create_collection.return_value = mock_new_col
|
||||||
|
mock_chromadb.PersistentClient.return_value = mock_client
|
||||||
|
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
|
||||||
|
cmd_repair(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "Repair complete" in out
|
||||||
|
assert "2 drawers rebuilt" in out
|
||||||
|
|
||||||
|
|
||||||
|
# ── cmd_compress ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_compress_no_palace(mock_config_cls, capsys):
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None)
|
||||||
|
mock_chromadb = MagicMock()
|
||||||
|
mock_chromadb.PersistentClient.side_effect = Exception("no palace")
|
||||||
|
with (
|
||||||
|
patch.dict("sys.modules", {"chromadb": mock_chromadb}),
|
||||||
|
pytest.raises(SystemExit),
|
||||||
|
):
|
||||||
|
cmd_compress(args)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_compress_no_drawers(mock_config_cls, capsys):
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
args = argparse.Namespace(palace=None, wing="mywing", dry_run=False, config=None)
|
||||||
|
mock_chromadb = MagicMock()
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
mock_chromadb.PersistentClient.return_value = mock_client
|
||||||
|
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
|
||||||
|
cmd_compress(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "No drawers found" in out
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_dialect_module(dialect_instance):
|
||||||
|
"""Create a mock dialect module with a Dialect class that returns the given instance."""
|
||||||
|
mock_mod = MagicMock()
|
||||||
|
mock_mod.Dialect.return_value = dialect_instance
|
||||||
|
mock_mod.Dialect.from_config.return_value = dialect_instance
|
||||||
|
mock_mod.Dialect.count_tokens = MagicMock(side_effect=lambda x: len(x) // 4)
|
||||||
|
return mock_mod
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_compress_dry_run(mock_config_cls, capsys):
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=None)
|
||||||
|
mock_chromadb = MagicMock()
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.side_effect = [
|
||||||
|
{
|
||||||
|
"documents": ["some long text here for testing"],
|
||||||
|
"metadatas": [{"wing": "test", "room": "general", "source_file": "test.txt"}],
|
||||||
|
"ids": ["id1"],
|
||||||
|
},
|
||||||
|
{"documents": [], "metadatas": [], "ids": []},
|
||||||
|
]
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
mock_chromadb.PersistentClient.return_value = mock_client
|
||||||
|
|
||||||
|
mock_dialect = MagicMock()
|
||||||
|
mock_dialect.compress.return_value = "compressed"
|
||||||
|
mock_dialect.compression_stats.return_value = {
|
||||||
|
"original_chars": 100,
|
||||||
|
"compressed_chars": 30,
|
||||||
|
"original_tokens": 25,
|
||||||
|
"compressed_tokens": 8,
|
||||||
|
"ratio": 3.3,
|
||||||
|
}
|
||||||
|
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
|
||||||
|
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{
|
||||||
|
"chromadb": mock_chromadb,
|
||||||
|
"mempalace.dialect": mock_dialect_mod,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
cmd_compress(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "dry run" in out.lower()
|
||||||
|
assert "Compressing" in out
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_compress_with_config(mock_config_cls, tmp_path, capsys):
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
config_file = tmp_path / "entities.json"
|
||||||
|
config_file.write_text('{"people": [], "projects": []}')
|
||||||
|
args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=str(config_file))
|
||||||
|
mock_chromadb = MagicMock()
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
mock_chromadb.PersistentClient.return_value = mock_client
|
||||||
|
|
||||||
|
mock_dialect = MagicMock()
|
||||||
|
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
|
||||||
|
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{
|
||||||
|
"chromadb": mock_chromadb,
|
||||||
|
"mempalace.dialect": mock_dialect_mod,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
cmd_compress(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "Loaded entity config" in out
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_compress_stores_results(mock_config_cls, capsys):
|
||||||
|
"""Non-dry-run compress stores to mempalace_compressed collection."""
|
||||||
|
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||||
|
args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None)
|
||||||
|
mock_chromadb = MagicMock()
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.side_effect = [
|
||||||
|
{
|
||||||
|
"documents": ["text"],
|
||||||
|
"metadatas": [{"wing": "w", "room": "r", "source_file": "f.txt"}],
|
||||||
|
"ids": ["id1"],
|
||||||
|
},
|
||||||
|
{"documents": [], "metadatas": [], "ids": []},
|
||||||
|
]
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
mock_comp_col = MagicMock()
|
||||||
|
mock_client.get_or_create_collection.return_value = mock_comp_col
|
||||||
|
mock_chromadb.PersistentClient.return_value = mock_client
|
||||||
|
|
||||||
|
mock_dialect = MagicMock()
|
||||||
|
mock_dialect.compress.return_value = "compressed"
|
||||||
|
mock_dialect.compression_stats.return_value = {
|
||||||
|
"original_chars": 100,
|
||||||
|
"compressed_chars": 30,
|
||||||
|
"original_tokens": 25,
|
||||||
|
"compressed_tokens": 8,
|
||||||
|
"ratio": 3.3,
|
||||||
|
}
|
||||||
|
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
|
||||||
|
|
||||||
|
with patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{
|
||||||
|
"chromadb": mock_chromadb,
|
||||||
|
"mempalace.dialect": mock_dialect_mod,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
cmd_compress(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "Stored" in out
|
||||||
|
mock_comp_col.upsert.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cmd_repair_trailing_slash_does_not_recurse():
|
||||||
|
"""Repair with trailing slash should put backup outside palace dir (#395)."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
args = argparse.Namespace(palace="/tmp/fake_palace/")
|
||||||
|
with patch("mempalace.cli.os.path.isdir", return_value=False):
|
||||||
|
cmd_repair(args)
|
||||||
|
# Verify the rstrip logic: palace_path should not end with separator
|
||||||
|
palace_path = os.path.expanduser(args.palace).rstrip(os.sep)
|
||||||
|
backup_path = palace_path + ".backup"
|
||||||
|
assert not backup_path.startswith(palace_path + os.sep)
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
"""Extra tests for mempalace.config to cover remaining gaps."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from mempalace.config import MempalaceConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_bad_json(tmp_path):
|
||||||
|
"""Bad JSON in config file falls back to empty."""
|
||||||
|
(tmp_path / "config.json").write_text("not json", encoding="utf-8")
|
||||||
|
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||||
|
assert cfg.palace_path # still returns default
|
||||||
|
|
||||||
|
|
||||||
|
def test_people_map_from_file(tmp_path):
|
||||||
|
(tmp_path / "people_map.json").write_text(json.dumps({"bob": "Robert"}), encoding="utf-8")
|
||||||
|
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||||
|
assert cfg.people_map == {"bob": "Robert"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_people_map_bad_json(tmp_path):
|
||||||
|
(tmp_path / "people_map.json").write_text("bad", encoding="utf-8")
|
||||||
|
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||||
|
assert cfg.people_map == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_people_map_missing(tmp_path):
|
||||||
|
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||||
|
assert cfg.people_map == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_topic_wings_default(tmp_path):
|
||||||
|
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||||
|
assert isinstance(cfg.topic_wings, list)
|
||||||
|
assert "emotions" in cfg.topic_wings
|
||||||
|
|
||||||
|
|
||||||
|
def test_hall_keywords_default(tmp_path):
|
||||||
|
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||||
|
assert isinstance(cfg.hall_keywords, dict)
|
||||||
|
assert "technical" in cfg.hall_keywords
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_idempotent(tmp_path):
|
||||||
|
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||||
|
cfg.init()
|
||||||
|
cfg.init() # second call should not overwrite
|
||||||
|
with open(tmp_path / "config.json") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
assert "palace_path" in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_people_map(tmp_path):
|
||||||
|
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||||
|
result = cfg.save_people_map({"alice": "Alice Smith"})
|
||||||
|
assert result.exists()
|
||||||
|
with open(result) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
assert data["alice"] == "Alice Smith"
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_mempal_palace_path(tmp_path):
|
||||||
|
"""MEMPAL_PALACE_PATH (legacy) should also work."""
|
||||||
|
os.environ.pop("MEMPALACE_PALACE_PATH", None)
|
||||||
|
os.environ["MEMPAL_PALACE_PATH"] = "/legacy/path"
|
||||||
|
try:
|
||||||
|
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||||
|
assert cfg.palace_path == "/legacy/path"
|
||||||
|
finally:
|
||||||
|
del os.environ["MEMPAL_PALACE_PATH"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_collection_name_from_config(tmp_path):
|
||||||
|
(tmp_path / "config.json").write_text(
|
||||||
|
json.dumps({"collection_name": "custom_col"}), encoding="utf-8"
|
||||||
|
)
|
||||||
|
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||||
|
assert cfg.collection_name == "custom_col"
|
||||||
@@ -23,4 +23,4 @@ def test_convo_mining():
|
|||||||
results = col.query(query_texts=["memory persistence"], n_results=1)
|
results = col.query(query_texts=["memory persistence"], n_results=1)
|
||||||
assert len(results["documents"][0]) > 0
|
assert len(results["documents"][0]) > 0
|
||||||
|
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||||
|
|||||||
@@ -0,0 +1,102 @@
|
|||||||
|
"""Unit tests for convo_miner pure functions (no chromadb needed)."""
|
||||||
|
|
||||||
|
from mempalace.convo_miner import (
|
||||||
|
chunk_exchanges,
|
||||||
|
detect_convo_room,
|
||||||
|
scan_convos,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkExchanges:
|
||||||
|
def test_exchange_chunking(self):
|
||||||
|
content = (
|
||||||
|
"> What is memory?\n"
|
||||||
|
"Memory is persistence of information over time.\n\n"
|
||||||
|
"> Why does it matter?\n"
|
||||||
|
"It enables continuity across sessions and conversations.\n\n"
|
||||||
|
"> How do we build it?\n"
|
||||||
|
"With structured storage and retrieval mechanisms.\n"
|
||||||
|
)
|
||||||
|
chunks = chunk_exchanges(content)
|
||||||
|
assert len(chunks) >= 2
|
||||||
|
assert all("content" in c and "chunk_index" in c for c in chunks)
|
||||||
|
|
||||||
|
def test_paragraph_fallback(self):
|
||||||
|
"""Content without '>' lines falls back to paragraph chunking."""
|
||||||
|
content = (
|
||||||
|
"This is a long paragraph about memory systems. " * 10 + "\n\n"
|
||||||
|
"This is another paragraph about storage. " * 10 + "\n\n"
|
||||||
|
"And a third paragraph about retrieval. " * 10
|
||||||
|
)
|
||||||
|
chunks = chunk_exchanges(content)
|
||||||
|
assert len(chunks) >= 2
|
||||||
|
|
||||||
|
def test_paragraph_line_group_fallback(self):
|
||||||
|
"""Long content with no paragraph breaks chunks by line groups."""
|
||||||
|
lines = [f"Line {i}: some content that is meaningful" for i in range(60)]
|
||||||
|
content = "\n".join(lines)
|
||||||
|
chunks = chunk_exchanges(content)
|
||||||
|
assert len(chunks) >= 1
|
||||||
|
|
||||||
|
def test_empty_content(self):
|
||||||
|
chunks = chunk_exchanges("")
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_short_content_skipped(self):
|
||||||
|
chunks = chunk_exchanges("> hi\nbye")
|
||||||
|
# Too short to produce chunks (below MIN_CHUNK_SIZE)
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetectConvoRoom:
|
||||||
|
def test_technical_room(self):
|
||||||
|
content = "Let me debug this python function and fix the code error in the api"
|
||||||
|
assert detect_convo_room(content) == "technical"
|
||||||
|
|
||||||
|
def test_planning_room(self):
|
||||||
|
content = "We need to plan the roadmap for the next sprint and set milestone deadlines"
|
||||||
|
assert detect_convo_room(content) == "planning"
|
||||||
|
|
||||||
|
def test_architecture_room(self):
|
||||||
|
content = "The architecture uses a service layer with component interface and module design"
|
||||||
|
assert detect_convo_room(content) == "architecture"
|
||||||
|
|
||||||
|
def test_decisions_room(self):
|
||||||
|
content = "We decided to switch and migrated to the new framework after we chose it"
|
||||||
|
assert detect_convo_room(content) == "decisions"
|
||||||
|
|
||||||
|
def test_general_fallback(self):
|
||||||
|
content = "Hello, how are you doing today? The weather is nice."
|
||||||
|
assert detect_convo_room(content) == "general"
|
||||||
|
|
||||||
|
|
||||||
|
class TestScanConvos:
|
||||||
|
def test_scan_finds_txt_and_md(self, tmp_path):
|
||||||
|
(tmp_path / "chat.txt").write_text("hello", encoding="utf-8")
|
||||||
|
(tmp_path / "notes.md").write_text("world", encoding="utf-8")
|
||||||
|
(tmp_path / "image.png").write_bytes(b"fake")
|
||||||
|
files = scan_convos(str(tmp_path))
|
||||||
|
extensions = {f.suffix for f in files}
|
||||||
|
assert ".txt" in extensions
|
||||||
|
assert ".md" in extensions
|
||||||
|
assert ".png" not in extensions
|
||||||
|
|
||||||
|
def test_scan_skips_git_dir(self, tmp_path):
|
||||||
|
git_dir = tmp_path / ".git"
|
||||||
|
git_dir.mkdir()
|
||||||
|
(git_dir / "config.txt").write_text("git stuff", encoding="utf-8")
|
||||||
|
(tmp_path / "chat.txt").write_text("hello", encoding="utf-8")
|
||||||
|
files = scan_convos(str(tmp_path))
|
||||||
|
assert len(files) == 1
|
||||||
|
|
||||||
|
def test_scan_skips_meta_json(self, tmp_path):
|
||||||
|
(tmp_path / "chat.meta.json").write_text("{}", encoding="utf-8")
|
||||||
|
(tmp_path / "chat.json").write_text("{}", encoding="utf-8")
|
||||||
|
files = scan_convos(str(tmp_path))
|
||||||
|
names = [f.name for f in files]
|
||||||
|
assert "chat.json" in names
|
||||||
|
assert "chat.meta.json" not in names
|
||||||
|
|
||||||
|
def test_scan_empty_dir(self, tmp_path):
|
||||||
|
files = scan_convos(str(tmp_path))
|
||||||
|
assert files == []
|
||||||
@@ -0,0 +1,380 @@
|
|||||||
|
"""Tests for mempalace.entity_detector."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from mempalace.entity_detector import (
|
||||||
|
PROSE_EXTENSIONS,
|
||||||
|
STOPWORDS,
|
||||||
|
_print_entity_list,
|
||||||
|
classify_entity,
|
||||||
|
confirm_entities,
|
||||||
|
detect_entities,
|
||||||
|
extract_candidates,
|
||||||
|
scan_for_detection,
|
||||||
|
score_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_candidates ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_candidates_finds_frequent_names():
|
||||||
|
text = "Riley said hello. Riley laughed. Riley smiled. Riley waved."
|
||||||
|
result = extract_candidates(text)
|
||||||
|
assert "Riley" in result
|
||||||
|
assert result["Riley"] >= 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_candidates_ignores_stopwords():
|
||||||
|
# "The" appears many times but is a stopword
|
||||||
|
text = "The The The The The The"
|
||||||
|
result = extract_candidates(text)
|
||||||
|
assert "The" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_candidates_requires_min_frequency():
|
||||||
|
text = "Riley said hi. Devon waved."
|
||||||
|
result = extract_candidates(text)
|
||||||
|
# Each name appears only once, below the threshold of 3
|
||||||
|
assert "Riley" not in result
|
||||||
|
assert "Devon" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_candidates_finds_multi_word_names():
|
||||||
|
# Multi-word names need 3+ occurrences and no stopwords
|
||||||
|
text = "Claude Code is great. Claude Code rocks. Claude Code works. Claude Code rules."
|
||||||
|
result = extract_candidates(text)
|
||||||
|
assert "Claude Code" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_candidates_empty_text():
|
||||||
|
result = extract_candidates("")
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ── score_entity ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_score_entity_person_verbs():
|
||||||
|
text = "Riley said hello. Riley asked why. Riley told me."
|
||||||
|
lines = text.splitlines()
|
||||||
|
result = score_entity("Riley", text, lines)
|
||||||
|
assert result["person_score"] > 0
|
||||||
|
assert len(result["person_signals"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_score_entity_project_verbs():
|
||||||
|
text = "We are building ChromaDB. We deployed ChromaDB. Install ChromaDB."
|
||||||
|
lines = text.splitlines()
|
||||||
|
result = score_entity("ChromaDB", text, lines)
|
||||||
|
assert result["project_score"] > 0
|
||||||
|
assert len(result["project_signals"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_score_entity_dialogue_markers():
|
||||||
|
text = "Riley: Hey, how are you?\nRiley: I'm fine."
|
||||||
|
lines = text.splitlines()
|
||||||
|
result = score_entity("Riley", text, lines)
|
||||||
|
assert result["person_score"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_score_entity_code_ref():
|
||||||
|
text = "Check out ChromaDB.py for details. Also ChromaDB.js is good."
|
||||||
|
lines = text.splitlines()
|
||||||
|
result = score_entity("ChromaDB", text, lines)
|
||||||
|
assert result["project_score"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_score_entity_no_signals():
|
||||||
|
text = "Nothing interesting here at all."
|
||||||
|
lines = text.splitlines()
|
||||||
|
result = score_entity("Riley", text, lines)
|
||||||
|
assert result["person_score"] == 0
|
||||||
|
assert result["project_score"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── classify_entity ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_entity_no_signals_gives_uncertain():
|
||||||
|
scores = {
|
||||||
|
"person_score": 0,
|
||||||
|
"project_score": 0,
|
||||||
|
"person_signals": [],
|
||||||
|
"project_signals": [],
|
||||||
|
}
|
||||||
|
result = classify_entity("Foo", 10, scores)
|
||||||
|
assert result["type"] == "uncertain"
|
||||||
|
assert result["name"] == "Foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_entity_strong_project():
|
||||||
|
scores = {
|
||||||
|
"person_score": 0,
|
||||||
|
"project_score": 10,
|
||||||
|
"person_signals": [],
|
||||||
|
"project_signals": ["project verb (5x)", "code file reference (2x)"],
|
||||||
|
}
|
||||||
|
result = classify_entity("ChromaDB", 5, scores)
|
||||||
|
assert result["type"] == "project"
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_entity_strong_person_needs_two_signal_types():
|
||||||
|
scores = {
|
||||||
|
"person_score": 10,
|
||||||
|
"project_score": 0,
|
||||||
|
"person_signals": [
|
||||||
|
"dialogue marker (3x)",
|
||||||
|
"'Riley ...' action (4x)",
|
||||||
|
],
|
||||||
|
"project_signals": [],
|
||||||
|
}
|
||||||
|
result = classify_entity("Riley", 8, scores)
|
||||||
|
assert result["type"] == "person"
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_entity_pronoun_only_is_uncertain():
|
||||||
|
scores = {
|
||||||
|
"person_score": 8,
|
||||||
|
"project_score": 0,
|
||||||
|
"person_signals": ["pronoun nearby (4x)"],
|
||||||
|
"project_signals": [],
|
||||||
|
}
|
||||||
|
result = classify_entity("Riley", 5, scores)
|
||||||
|
assert result["type"] == "uncertain"
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_entity_mixed_signals():
|
||||||
|
scores = {
|
||||||
|
"person_score": 5,
|
||||||
|
"project_score": 5,
|
||||||
|
"person_signals": ["pronoun nearby (2x)"],
|
||||||
|
"project_signals": ["project verb (2x)"],
|
||||||
|
}
|
||||||
|
result = classify_entity("Lantern", 5, scores)
|
||||||
|
assert result["type"] == "uncertain"
|
||||||
|
assert "mixed signals" in result["signals"][-1]
|
||||||
|
|
||||||
|
|
||||||
|
# ── detect_entities (integration) ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_entities_with_person_file(tmp_path):
|
||||||
|
f = tmp_path / "notes.txt"
|
||||||
|
content = "\n".join(
|
||||||
|
[
|
||||||
|
"Riley said hello today.",
|
||||||
|
"Riley asked about the project.",
|
||||||
|
"Riley told me she was happy.",
|
||||||
|
"Riley: I think we should go.",
|
||||||
|
"Hey Riley, thanks for the help.",
|
||||||
|
"Riley laughed and smiled.",
|
||||||
|
"Riley decided to join.",
|
||||||
|
"Riley pushed the change.",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
f.write_text(content)
|
||||||
|
result = detect_entities([f])
|
||||||
|
all_names = [e["name"] for cat in result.values() for e in cat]
|
||||||
|
assert "Riley" in all_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_entities_with_project_file(tmp_path):
|
||||||
|
f = tmp_path / "readme.txt"
|
||||||
|
# "ChromaDB" has uppercase+lowercase mix but extract_candidates looks
|
||||||
|
# for /[A-Z][a-z]{1,19}/ — so we need a name that matches that regex.
|
||||||
|
# Use "Lantern" which matches the capitalized-word pattern.
|
||||||
|
content = "\n".join(
|
||||||
|
[
|
||||||
|
"The Lantern project is great.",
|
||||||
|
"Building Lantern was fun.",
|
||||||
|
"We deployed Lantern today.",
|
||||||
|
"Install Lantern with pip install Lantern.",
|
||||||
|
"Check Lantern.py for the source.",
|
||||||
|
"Lantern v2 is faster.",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
f.write_text(content)
|
||||||
|
result = detect_entities([f])
|
||||||
|
all_names = [e["name"] for cat in result.values() for e in cat]
|
||||||
|
assert "Lantern" in all_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_entities_empty_files(tmp_path):
|
||||||
|
f = tmp_path / "empty.txt"
|
||||||
|
f.write_text("")
|
||||||
|
result = detect_entities([f])
|
||||||
|
assert result == {"people": [], "projects": [], "uncertain": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_entities_handles_missing_file(tmp_path):
|
||||||
|
missing = tmp_path / "nonexistent.txt"
|
||||||
|
result = detect_entities([missing])
|
||||||
|
assert result == {"people": [], "projects": [], "uncertain": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_entities_respects_max_files(tmp_path):
|
||||||
|
files = []
|
||||||
|
for i in range(5):
|
||||||
|
f = tmp_path / f"file{i}.txt"
|
||||||
|
f.write_text("Riley said hello. " * 10)
|
||||||
|
files.append(f)
|
||||||
|
# max_files=2 should only read 2 files
|
||||||
|
result = detect_entities(files, max_files=2)
|
||||||
|
# Should still work without error
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
|
||||||
|
|
||||||
|
# ── scan_for_detection ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_for_detection_finds_prose(tmp_path):
|
||||||
|
(tmp_path / "notes.md").write_text("hello")
|
||||||
|
(tmp_path / "data.txt").write_text("world")
|
||||||
|
(tmp_path / "code.py").write_text("import os")
|
||||||
|
files = scan_for_detection(str(tmp_path))
|
||||||
|
extensions = {os.path.splitext(str(f))[1] for f in files}
|
||||||
|
# Prose files should be found
|
||||||
|
assert ".md" in extensions or ".txt" in extensions
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_for_detection_skips_git_dir(tmp_path):
|
||||||
|
git_dir = tmp_path / ".git"
|
||||||
|
git_dir.mkdir()
|
||||||
|
(git_dir / "config.txt").write_text("git config")
|
||||||
|
(tmp_path / "readme.md").write_text("hello")
|
||||||
|
files = scan_for_detection(str(tmp_path))
|
||||||
|
file_strs = [str(f) for f in files]
|
||||||
|
assert not any(".git" in f for f in file_strs)
|
||||||
|
|
||||||
|
|
||||||
|
# ── module-level constants ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stopwords_contains_common_words():
|
||||||
|
assert "the" in STOPWORDS
|
||||||
|
assert "import" in STOPWORDS
|
||||||
|
assert "class" in STOPWORDS
|
||||||
|
|
||||||
|
|
||||||
|
def test_prose_extensions():
|
||||||
|
assert ".txt" in PROSE_EXTENSIONS
|
||||||
|
assert ".md" in PROSE_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
# ── _print_entity_list ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_entity_list_with_entities(capsys):
|
||||||
|
entities = [
|
||||||
|
{"name": "Alice", "confidence": 0.9, "signals": ["dialogue marker (3x)"]},
|
||||||
|
{"name": "Bob", "confidence": 0.5, "signals": []},
|
||||||
|
]
|
||||||
|
_print_entity_list(entities, "PEOPLE")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "PEOPLE" in out
|
||||||
|
assert "Alice" in out
|
||||||
|
assert "Bob" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_entity_list_empty(capsys):
|
||||||
|
_print_entity_list([], "PEOPLE")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "none detected" in out
|
||||||
|
|
||||||
|
|
||||||
|
# ── confirm_entities ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_entities_yes_mode():
|
||||||
|
detected = {
|
||||||
|
"people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}],
|
||||||
|
"projects": [{"name": "Acme", "confidence": 0.8, "signals": ["test"]}],
|
||||||
|
"uncertain": [{"name": "Foo", "confidence": 0.4, "signals": ["test"]}],
|
||||||
|
}
|
||||||
|
result = confirm_entities(detected, yes=True)
|
||||||
|
assert result["people"] == ["Alice"]
|
||||||
|
assert result["projects"] == ["Acme"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_entities_accept_all():
|
||||||
|
detected = {
|
||||||
|
"people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
with patch("builtins.input", side_effect=["", "n"]):
|
||||||
|
result = confirm_entities(detected, yes=False)
|
||||||
|
assert "Alice" in result["people"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_entities_edit_reclassify_uncertain():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [
|
||||||
|
{"name": "Foo", "confidence": 0.4, "signals": ["test"]},
|
||||||
|
{"name": "Bar", "confidence": 0.4, "signals": ["test"]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
with patch(
|
||||||
|
"builtins.input",
|
||||||
|
side_effect=[
|
||||||
|
"edit", # choice
|
||||||
|
"p", # Foo -> person
|
||||||
|
"s", # Bar -> skip
|
||||||
|
"", # no removals from people
|
||||||
|
"", # no removals from projects
|
||||||
|
"n", # don't add missing
|
||||||
|
],
|
||||||
|
):
|
||||||
|
result = confirm_entities(detected, yes=False)
|
||||||
|
assert "Foo" in result["people"]
|
||||||
|
assert "Bar" not in result["people"]
|
||||||
|
assert "Bar" not in result["projects"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_entities_add_mode():
|
||||||
|
detected = {
|
||||||
|
"people": [],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
with patch(
|
||||||
|
"builtins.input",
|
||||||
|
side_effect=[
|
||||||
|
"add", # choice = add
|
||||||
|
"NewPerson", # name
|
||||||
|
"p", # person
|
||||||
|
"NewProj", # name
|
||||||
|
"r", # project
|
||||||
|
"", # stop adding
|
||||||
|
],
|
||||||
|
):
|
||||||
|
result = confirm_entities(detected, yes=False)
|
||||||
|
assert "NewPerson" in result["people"]
|
||||||
|
assert "NewProj" in result["projects"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── scan_for_detection fallback ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_for_detection_fallback_to_all_readable(tmp_path):
|
||||||
|
"""When fewer than 3 prose files, falls back to include all readable files."""
|
||||||
|
(tmp_path / "one.md").write_text("hello")
|
||||||
|
(tmp_path / "two.txt").write_text("world")
|
||||||
|
# Only 2 prose files, so it should also include code files
|
||||||
|
(tmp_path / "code.py").write_text("import os")
|
||||||
|
(tmp_path / "app.js").write_text("console.log()")
|
||||||
|
files = scan_for_detection(str(tmp_path))
|
||||||
|
extensions = {os.path.splitext(str(f))[1] for f in files}
|
||||||
|
assert ".py" in extensions or ".js" in extensions
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_for_detection_max_files(tmp_path):
|
||||||
|
"""Caps to max_files."""
|
||||||
|
for i in range(20):
|
||||||
|
(tmp_path / f"note{i}.md").write_text(f"content {i}")
|
||||||
|
files = scan_for_detection(str(tmp_path), max_files=5)
|
||||||
|
assert len(files) <= 5
|
||||||
@@ -0,0 +1,313 @@
|
|||||||
|
"""Tests for mempalace.entity_registry."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from mempalace.entity_registry import (
|
||||||
|
COMMON_ENGLISH_WORDS,
|
||||||
|
PERSON_CONTEXT_PATTERNS,
|
||||||
|
EntityRegistry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── COMMON_ENGLISH_WORDS ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_common_english_words_has_expected_entries():
|
||||||
|
assert "ever" in COMMON_ENGLISH_WORDS
|
||||||
|
assert "grace" in COMMON_ENGLISH_WORDS
|
||||||
|
assert "will" in COMMON_ENGLISH_WORDS
|
||||||
|
assert "may" in COMMON_ENGLISH_WORDS
|
||||||
|
assert "monday" in COMMON_ENGLISH_WORDS
|
||||||
|
|
||||||
|
|
||||||
|
def test_common_english_words_is_lowercase():
|
||||||
|
for word in COMMON_ENGLISH_WORDS:
|
||||||
|
assert word == word.lower(), f"{word} should be lowercase"
|
||||||
|
|
||||||
|
|
||||||
|
# ── PERSON_CONTEXT_PATTERNS ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_context_patterns_is_nonempty():
|
||||||
|
assert len(PERSON_CONTEXT_PATTERNS) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── EntityRegistry creation and empty state ─────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_from_nonexistent_dir(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
assert registry.people == {}
|
||||||
|
assert registry.projects == []
|
||||||
|
assert registry.mode == "personal"
|
||||||
|
assert registry.ambiguous_flags == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_load_roundtrip(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="work",
|
||||||
|
people=[{"name": "Alice", "relationship": "colleague", "context": "work"}],
|
||||||
|
projects=["MemPalace"],
|
||||||
|
)
|
||||||
|
# Load again from same dir
|
||||||
|
loaded = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
assert loaded.mode == "work"
|
||||||
|
assert "Alice" in loaded.people
|
||||||
|
assert "MemPalace" in loaded.projects
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_creates_file(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.save()
|
||||||
|
assert (tmp_path / "entity_registry.json").exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ── seed ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed_registers_people(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[
|
||||||
|
{"name": "Riley", "relationship": "daughter", "context": "personal"},
|
||||||
|
{"name": "Devon", "relationship": "friend", "context": "personal"},
|
||||||
|
],
|
||||||
|
projects=["MemPalace"],
|
||||||
|
)
|
||||||
|
assert "Riley" in registry.people
|
||||||
|
assert "Devon" in registry.people
|
||||||
|
assert registry.people["Riley"]["relationship"] == "daughter"
|
||||||
|
assert registry.people["Riley"]["source"] == "onboarding"
|
||||||
|
assert registry.people["Riley"]["confidence"] == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed_registers_projects(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(mode="work", people=[], projects=["Acme", "Widget"])
|
||||||
|
assert registry.projects == ["Acme", "Widget"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed_sets_mode(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(mode="combo", people=[], projects=[])
|
||||||
|
assert registry.mode == "combo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed_flags_ambiguous_names(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[
|
||||||
|
{"name": "Grace", "relationship": "friend", "context": "personal"},
|
||||||
|
{"name": "Riley", "relationship": "daughter", "context": "personal"},
|
||||||
|
],
|
||||||
|
projects=[],
|
||||||
|
)
|
||||||
|
assert "grace" in registry.ambiguous_flags
|
||||||
|
# Riley is not a common English word
|
||||||
|
assert "riley" not in registry.ambiguous_flags
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed_with_aliases(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}],
|
||||||
|
projects=[],
|
||||||
|
aliases={"Max": "Maxwell"},
|
||||||
|
)
|
||||||
|
assert "Maxwell" in registry.people
|
||||||
|
assert "Max" in registry.people
|
||||||
|
assert registry.people["Max"].get("canonical") == "Maxwell"
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed_skips_empty_names(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[{"name": "", "relationship": "", "context": "personal"}],
|
||||||
|
projects=[],
|
||||||
|
)
|
||||||
|
assert len(registry.people) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── lookup ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_known_person(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||||
|
projects=[],
|
||||||
|
)
|
||||||
|
result = registry.lookup("Riley")
|
||||||
|
assert result["type"] == "person"
|
||||||
|
assert result["confidence"] == 1.0
|
||||||
|
assert result["name"] == "Riley"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_known_project(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(mode="work", people=[], projects=["MemPalace"])
|
||||||
|
result = registry.lookup("MemPalace")
|
||||||
|
assert result["type"] == "project"
|
||||||
|
assert result["confidence"] == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_unknown_word(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(mode="personal", people=[], projects=[])
|
||||||
|
result = registry.lookup("Xyzzy")
|
||||||
|
assert result["type"] == "unknown"
|
||||||
|
assert result["confidence"] == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_case_insensitive(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||||
|
projects=[],
|
||||||
|
)
|
||||||
|
result = registry.lookup("riley")
|
||||||
|
assert result["type"] == "person"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_alias(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}],
|
||||||
|
projects=[],
|
||||||
|
aliases={"Max": "Maxwell"},
|
||||||
|
)
|
||||||
|
result = registry.lookup("Max")
|
||||||
|
assert result["type"] == "person"
|
||||||
|
|
||||||
|
|
||||||
|
# ── disambiguation ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_ambiguous_word_as_person(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[{"name": "Grace", "relationship": "friend", "context": "personal"}],
|
||||||
|
projects=[],
|
||||||
|
)
|
||||||
|
result = registry.lookup("Grace", context="I went with Grace today")
|
||||||
|
assert result["type"] == "person"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookup_ambiguous_word_as_concept(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[{"name": "Ever", "relationship": "friend", "context": "personal"}],
|
||||||
|
projects=[],
|
||||||
|
)
|
||||||
|
result = registry.lookup("Ever", context="have you ever tried this")
|
||||||
|
assert result["type"] == "concept"
|
||||||
|
|
||||||
|
|
||||||
|
# ── research (Wikipedia) — mocked ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_research_caches_result(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(mode="personal", people=[], projects=[])
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"inferred_type": "person",
|
||||||
|
"confidence": 0.80,
|
||||||
|
"wiki_summary": "Saoirse is an Irish given name.",
|
||||||
|
"wiki_title": "Saoirse",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result):
|
||||||
|
result = registry.research("Saoirse", auto_confirm=True)
|
||||||
|
assert result["inferred_type"] == "person"
|
||||||
|
|
||||||
|
# Second call should use cache, not call Wikipedia again
|
||||||
|
with patch(
|
||||||
|
"mempalace.entity_registry._wikipedia_lookup",
|
||||||
|
side_effect=AssertionError("should not be called"),
|
||||||
|
):
|
||||||
|
cached = registry.research("Saoirse")
|
||||||
|
assert cached["inferred_type"] == "person"
|
||||||
|
|
||||||
|
|
||||||
|
def test_confirm_research_adds_to_people(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(mode="personal", people=[], projects=[])
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"inferred_type": "person",
|
||||||
|
"confidence": 0.80,
|
||||||
|
"wiki_summary": "Saoirse is a name",
|
||||||
|
"wiki_title": "Saoirse",
|
||||||
|
}
|
||||||
|
with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result):
|
||||||
|
registry.research("Saoirse", auto_confirm=False)
|
||||||
|
|
||||||
|
registry.confirm_research("Saoirse", entity_type="person", relationship="friend")
|
||||||
|
assert "Saoirse" in registry.people
|
||||||
|
assert registry.people["Saoirse"]["source"] == "wiki"
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_people_from_query ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_people_from_query(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[
|
||||||
|
{"name": "Riley", "relationship": "daughter", "context": "personal"},
|
||||||
|
{"name": "Devon", "relationship": "friend", "context": "personal"},
|
||||||
|
],
|
||||||
|
projects=[],
|
||||||
|
)
|
||||||
|
found = registry.extract_people_from_query("What did Riley say about the weather?")
|
||||||
|
assert "Riley" in found
|
||||||
|
assert "Devon" not in found
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_unknown_candidates ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_unknown_candidates(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(mode="personal", people=[], projects=[])
|
||||||
|
unknowns = registry.extract_unknown_candidates("Saoirse went to the store")
|
||||||
|
assert "Saoirse" in unknowns
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_unknown_candidates_skips_known(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||||
|
projects=[],
|
||||||
|
)
|
||||||
|
unknowns = registry.extract_unknown_candidates("Riley went to the store")
|
||||||
|
assert "Riley" not in unknowns
|
||||||
|
|
||||||
|
|
||||||
|
# ── summary ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_summary(tmp_path):
|
||||||
|
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||||
|
registry.seed(
|
||||||
|
mode="personal",
|
||||||
|
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||||
|
projects=["MemPalace"],
|
||||||
|
)
|
||||||
|
s = registry.summary()
|
||||||
|
assert "personal" in s
|
||||||
|
assert "Riley" in s
|
||||||
|
assert "MemPalace" in s
|
||||||
@@ -0,0 +1,248 @@
|
|||||||
|
"""Tests for mempalace.general_extractor."""
|
||||||
|
|
||||||
|
from mempalace.general_extractor import (
|
||||||
|
ALL_MARKERS,
|
||||||
|
NEGATIVE_WORDS,
|
||||||
|
POSITIVE_WORDS,
|
||||||
|
_extract_prose,
|
||||||
|
_get_sentiment,
|
||||||
|
_has_resolution,
|
||||||
|
_is_code_line,
|
||||||
|
_score_markers,
|
||||||
|
_split_into_segments,
|
||||||
|
extract_memories,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_memories — empty / no markers ───────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_memories_empty_text():
|
||||||
|
result = extract_memories("")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_memories_no_markers():
|
||||||
|
result = extract_memories("The quick brown fox jumped over the lazy dog.")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_memories_short_text_skipped():
|
||||||
|
# Paragraphs shorter than 20 chars are skipped
|
||||||
|
result = extract_memories("ok sure")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_memories — decision markers ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_memories_decision():
|
||||||
|
text = (
|
||||||
|
"We decided to go with PostgreSQL instead of MySQL "
|
||||||
|
"because the performance was better for our use case. "
|
||||||
|
"The trade-off was more complexity in setup."
|
||||||
|
)
|
||||||
|
result = extract_memories(text)
|
||||||
|
assert len(result) >= 1
|
||||||
|
assert any(m["memory_type"] == "decision" for m in result)
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_memories — preference markers ───────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_memories_preference():
|
||||||
|
text = (
|
||||||
|
"I prefer using snake_case in Python code. "
|
||||||
|
"Please always use type hints. "
|
||||||
|
"Never use wildcard imports."
|
||||||
|
)
|
||||||
|
result = extract_memories(text)
|
||||||
|
assert len(result) >= 1
|
||||||
|
assert any(m["memory_type"] == "preference" for m in result)
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_memories — milestone markers ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_memories_milestone():
|
||||||
|
text = (
|
||||||
|
"It finally works! After three days of debugging, "
|
||||||
|
"I figured out the issue. The breakthrough was realizing "
|
||||||
|
"the config file was cached. Got it working at 2am."
|
||||||
|
)
|
||||||
|
result = extract_memories(text)
|
||||||
|
assert len(result) >= 1
|
||||||
|
assert any(m["memory_type"] == "milestone" for m in result)
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_memories — problem markers ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_memories_problem():
|
||||||
|
text = (
|
||||||
|
"There's a critical bug in the auth module. "
|
||||||
|
"The error keeps crashing the server. "
|
||||||
|
"The root cause was a missing null check. "
|
||||||
|
"The problem is that tokens expire silently."
|
||||||
|
)
|
||||||
|
result = extract_memories(text)
|
||||||
|
assert len(result) >= 1
|
||||||
|
types = {m["memory_type"] for m in result}
|
||||||
|
assert "problem" in types or "milestone" in types # resolved problems become milestones
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_memories — emotional markers ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_memories_emotional():
|
||||||
|
text = (
|
||||||
|
"I feel so proud of what we built together. "
|
||||||
|
"I love working on this project, it makes me happy. "
|
||||||
|
"I'm grateful for the team and the beautiful code we wrote."
|
||||||
|
)
|
||||||
|
result = extract_memories(text)
|
||||||
|
assert len(result) >= 1
|
||||||
|
assert any(m["memory_type"] == "emotional" for m in result)
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_memories — chunk_index ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_memories_chunk_index_increments():
|
||||||
|
text = (
|
||||||
|
"We decided to use React because it fits our team.\n\n"
|
||||||
|
"I prefer functional components always.\n\n"
|
||||||
|
"It works! We finally shipped the v1.0 release."
|
||||||
|
)
|
||||||
|
result = extract_memories(text)
|
||||||
|
if len(result) >= 2:
|
||||||
|
indices = [m["chunk_index"] for m in result]
|
||||||
|
assert indices == list(range(len(result)))
|
||||||
|
|
||||||
|
|
||||||
|
# ── _score_markers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_score_markers_with_matches():
|
||||||
|
score, keywords = _score_markers(
|
||||||
|
"we decided to go with postgres because it is faster",
|
||||||
|
ALL_MARKERS["decision"],
|
||||||
|
)
|
||||||
|
assert score > 0
|
||||||
|
assert len(keywords) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_score_markers_no_matches():
|
||||||
|
score, keywords = _score_markers("nothing relevant here", ALL_MARKERS["decision"])
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── _get_sentiment ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_sentiment_positive():
|
||||||
|
assert _get_sentiment("I am so happy and proud of this breakthrough") == "positive"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_sentiment_negative():
|
||||||
|
assert _get_sentiment("This bug caused a crash and total failure") == "negative"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_sentiment_neutral():
|
||||||
|
assert _get_sentiment("The meeting is at three") == "neutral"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _has_resolution ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_resolution_true():
|
||||||
|
assert _has_resolution("I fixed the auth bug and it works now") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_resolution_false():
|
||||||
|
assert _has_resolution("The server keeps crashing") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── _is_code_line ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_code_line_detects_code():
|
||||||
|
assert _is_code_line(" import os") is True
|
||||||
|
assert _is_code_line(" $ pip install flask") is True
|
||||||
|
assert _is_code_line(" ```python") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_code_line_allows_prose():
|
||||||
|
assert _is_code_line("This is a regular sentence about coding.") is False
|
||||||
|
assert _is_code_line("") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── _extract_prose ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_prose_strips_code_blocks():
|
||||||
|
text = "Hello world\n```\nimport os\nprint('hi')\n```\nGoodbye"
|
||||||
|
result = _extract_prose(text)
|
||||||
|
assert "import os" not in result
|
||||||
|
assert "Hello world" in result
|
||||||
|
assert "Goodbye" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_prose_returns_original_if_all_code():
|
||||||
|
text = "import os\nfrom sys import argv"
|
||||||
|
result = _extract_prose(text)
|
||||||
|
# Falls back to original text if nothing left
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── _split_into_segments ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_into_segments_by_paragraph():
|
||||||
|
text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph."
|
||||||
|
result = _split_into_segments(text)
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_into_segments_by_turns():
|
||||||
|
lines = []
|
||||||
|
for i in range(5):
|
||||||
|
lines.append(f"Human: Question {i}")
|
||||||
|
lines.append(f"Assistant: Answer {i}")
|
||||||
|
text = "\n".join(lines)
|
||||||
|
result = _split_into_segments(text)
|
||||||
|
assert len(result) >= 3 # turn-based splitting should fire
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_into_segments_single_block():
|
||||||
|
# Many lines without double-newline produces chunked segments
|
||||||
|
lines = [f"Line {i} of the document" for i in range(30)]
|
||||||
|
text = "\n".join(lines)
|
||||||
|
result = _split_into_segments(text)
|
||||||
|
assert len(result) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── ALL_MARKERS constant ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_markers_has_five_types():
|
||||||
|
assert set(ALL_MARKERS.keys()) == {
|
||||||
|
"decision",
|
||||||
|
"preference",
|
||||||
|
"milestone",
|
||||||
|
"problem",
|
||||||
|
"emotional",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── POSITIVE_WORDS / NEGATIVE_WORDS ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_positive_words():
|
||||||
|
assert "happy" in POSITIVE_WORDS
|
||||||
|
assert "proud" in POSITIVE_WORDS
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_words():
|
||||||
|
assert "bug" in NEGATIVE_WORDS
|
||||||
|
assert "crash" in NEGATIVE_WORDS
|
||||||
@@ -1,17 +1,24 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mempalace.hooks_cli import (
|
from mempalace.hooks_cli import (
|
||||||
SAVE_INTERVAL,
|
SAVE_INTERVAL,
|
||||||
STOP_BLOCK_REASON,
|
STOP_BLOCK_REASON,
|
||||||
PRECOMPACT_BLOCK_REASON,
|
PRECOMPACT_BLOCK_REASON,
|
||||||
_count_human_messages,
|
_count_human_messages,
|
||||||
|
_log,
|
||||||
|
_maybe_auto_ingest,
|
||||||
|
_parse_harness_input,
|
||||||
_sanitize_session_id,
|
_sanitize_session_id,
|
||||||
hook_stop,
|
hook_stop,
|
||||||
hook_session_start,
|
hook_session_start,
|
||||||
hook_precompact,
|
hook_precompact,
|
||||||
|
run_hook,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -205,3 +212,209 @@ def test_precompact_always_blocks(tmp_path):
|
|||||||
)
|
)
|
||||||
assert result["decision"] == "block"
|
assert result["decision"] == "block"
|
||||||
assert result["reason"] == PRECOMPACT_BLOCK_REASON
|
assert result["reason"] == PRECOMPACT_BLOCK_REASON
|
||||||
|
|
||||||
|
|
||||||
|
# --- _log ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_writes_to_hook_log(tmp_path):
|
||||||
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||||
|
_log("test message")
|
||||||
|
log_path = tmp_path / "hook.log"
|
||||||
|
assert log_path.is_file()
|
||||||
|
content = log_path.read_text()
|
||||||
|
assert "test message" in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_oserror_is_silenced(tmp_path):
|
||||||
|
"""_log should not raise if the directory cannot be created."""
|
||||||
|
with patch("mempalace.hooks_cli.STATE_DIR", Path("/nonexistent/deeply/nested/dir")):
|
||||||
|
# Should not raise
|
||||||
|
_log("this will fail silently")
|
||||||
|
|
||||||
|
|
||||||
|
# --- _maybe_auto_ingest ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_maybe_auto_ingest_no_env(tmp_path):
|
||||||
|
"""Without MEMPAL_DIR set, does nothing."""
|
||||||
|
with patch.dict("os.environ", {}, clear=True):
|
||||||
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||||
|
_maybe_auto_ingest() # should not raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_maybe_auto_ingest_with_env(tmp_path):
|
||||||
|
"""With MEMPAL_DIR set to a valid directory, spawns subprocess."""
|
||||||
|
mempal_dir = tmp_path / "project"
|
||||||
|
mempal_dir.mkdir()
|
||||||
|
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
|
||||||
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||||
|
with patch("mempalace.hooks_cli.subprocess.Popen") as mock_popen:
|
||||||
|
_maybe_auto_ingest()
|
||||||
|
mock_popen.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_maybe_auto_ingest_oserror(tmp_path):
|
||||||
|
"""OSError during subprocess spawn is silenced."""
|
||||||
|
mempal_dir = tmp_path / "project"
|
||||||
|
mempal_dir.mkdir()
|
||||||
|
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
|
||||||
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||||
|
with patch("mempalace.hooks_cli.subprocess.Popen", side_effect=OSError("fail")):
|
||||||
|
_maybe_auto_ingest() # should not raise
|
||||||
|
|
||||||
|
|
||||||
|
# --- _parse_harness_input ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_harness_input_unknown():
|
||||||
|
"""Unknown harness should sys.exit(1)."""
|
||||||
|
with pytest.raises(SystemExit) as exc_info:
|
||||||
|
_parse_harness_input({"session_id": "test"}, "unknown-harness")
|
||||||
|
assert exc_info.value.code == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_harness_input_valid():
|
||||||
|
result = _parse_harness_input(
|
||||||
|
{"session_id": "abc-123", "stop_hook_active": True, "transcript_path": "/tmp/t.jsonl"},
|
||||||
|
"claude-code",
|
||||||
|
)
|
||||||
|
assert result["session_id"] == "abc-123"
|
||||||
|
assert result["stop_hook_active"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# --- hook_stop with OSError on write ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_hook_oserror_on_last_save_read(tmp_path):
|
||||||
|
"""When last_save_file has invalid content, falls back to 0."""
|
||||||
|
transcript = tmp_path / "t.jsonl"
|
||||||
|
_write_transcript(
|
||||||
|
transcript,
|
||||||
|
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
|
||||||
|
)
|
||||||
|
# Write invalid content to last save file
|
||||||
|
(tmp_path / "test_last_save").write_text("not_a_number")
|
||||||
|
result = _capture_hook_output(
|
||||||
|
hook_stop,
|
||||||
|
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
|
||||||
|
state_dir=tmp_path,
|
||||||
|
)
|
||||||
|
assert result["decision"] == "block"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_hook_oserror_on_write(tmp_path):
|
||||||
|
"""When write to last_save_file fails, hook still outputs correctly."""
|
||||||
|
transcript = tmp_path / "t.jsonl"
|
||||||
|
_write_transcript(
|
||||||
|
transcript,
|
||||||
|
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
|
||||||
|
)
|
||||||
|
|
||||||
|
def bad_write_text(*args, **kwargs):
|
||||||
|
raise OSError("disk full")
|
||||||
|
|
||||||
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||||
|
with patch.object(Path, "write_text", bad_write_text):
|
||||||
|
result = _capture_hook_output(
|
||||||
|
hook_stop,
|
||||||
|
{
|
||||||
|
"session_id": "test",
|
||||||
|
"stop_hook_active": False,
|
||||||
|
"transcript_path": str(transcript),
|
||||||
|
},
|
||||||
|
state_dir=tmp_path,
|
||||||
|
)
|
||||||
|
assert result["decision"] == "block"
|
||||||
|
|
||||||
|
|
||||||
|
# --- hook_precompact with MEMPAL_DIR ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_precompact_with_mempal_dir(tmp_path):
|
||||||
|
"""Precompact runs subprocess.run when MEMPAL_DIR is set."""
|
||||||
|
mempal_dir = tmp_path / "project"
|
||||||
|
mempal_dir.mkdir()
|
||||||
|
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
|
||||||
|
with patch("mempalace.hooks_cli.subprocess.run") as mock_run:
|
||||||
|
result = _capture_hook_output(
|
||||||
|
hook_precompact,
|
||||||
|
{"session_id": "test"},
|
||||||
|
state_dir=tmp_path,
|
||||||
|
)
|
||||||
|
assert result["decision"] == "block"
|
||||||
|
mock_run.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_precompact_with_mempal_dir_oserror(tmp_path):
|
||||||
|
"""Precompact handles OSError from subprocess gracefully."""
|
||||||
|
mempal_dir = tmp_path / "project"
|
||||||
|
mempal_dir.mkdir()
|
||||||
|
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
|
||||||
|
with patch("mempalace.hooks_cli.subprocess.run", side_effect=OSError("fail")):
|
||||||
|
result = _capture_hook_output(
|
||||||
|
hook_precompact,
|
||||||
|
{"session_id": "test"},
|
||||||
|
state_dir=tmp_path,
|
||||||
|
)
|
||||||
|
assert result["decision"] == "block"
|
||||||
|
|
||||||
|
|
||||||
|
# --- run_hook ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_hook_dispatches_session_start(tmp_path):
|
||||||
|
"""run_hook reads stdin JSON and dispatches to correct handler."""
|
||||||
|
stdin_data = json.dumps({"session_id": "run-test"})
|
||||||
|
with patch("sys.stdin", io.StringIO(stdin_data)):
|
||||||
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||||
|
with patch("mempalace.hooks_cli._output") as mock_output:
|
||||||
|
run_hook("session-start", "claude-code")
|
||||||
|
mock_output.assert_called_once_with({})
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_hook_dispatches_stop(tmp_path):
|
||||||
|
transcript = tmp_path / "t.jsonl"
|
||||||
|
_write_transcript(
|
||||||
|
transcript, [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(3)]
|
||||||
|
)
|
||||||
|
stdin_data = json.dumps(
|
||||||
|
{
|
||||||
|
"session_id": "run-test",
|
||||||
|
"stop_hook_active": False,
|
||||||
|
"transcript_path": str(transcript),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with patch("sys.stdin", io.StringIO(stdin_data)):
|
||||||
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||||
|
with patch("mempalace.hooks_cli._output") as mock_output:
|
||||||
|
run_hook("stop", "claude-code")
|
||||||
|
mock_output.assert_called_once_with({})
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_hook_dispatches_precompact(tmp_path):
|
||||||
|
stdin_data = json.dumps({"session_id": "run-test"})
|
||||||
|
with patch("sys.stdin", io.StringIO(stdin_data)):
|
||||||
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||||
|
with patch("mempalace.hooks_cli._output") as mock_output:
|
||||||
|
run_hook("precompact", "claude-code")
|
||||||
|
mock_output.assert_called_once()
|
||||||
|
call_args = mock_output.call_args[0][0]
|
||||||
|
assert call_args["decision"] == "block"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_hook_unknown_hook():
|
||||||
|
stdin_data = json.dumps({"session_id": "test"})
|
||||||
|
with patch("sys.stdin", io.StringIO(stdin_data)):
|
||||||
|
with pytest.raises(SystemExit) as exc_info:
|
||||||
|
run_hook("nonexistent", "claude-code")
|
||||||
|
assert exc_info.value.code == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_hook_invalid_json(tmp_path):
|
||||||
|
"""Invalid stdin JSON should not crash — falls back to empty dict."""
|
||||||
|
with patch("sys.stdin", io.StringIO("not valid json")):
|
||||||
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||||
|
with patch("mempalace.hooks_cli._output") as mock_output:
|
||||||
|
run_hook("session-start", "claude-code")
|
||||||
|
mock_output.assert_called_once_with({})
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
"""Tests for mempalace.instructions_cli — instruction text output."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mempalace.instructions_cli import AVAILABLE, INSTRUCTIONS_DIR, run_instructions
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_instructions_valid_name(capsys):
|
||||||
|
"""Valid name prints the .md file content."""
|
||||||
|
name = "init"
|
||||||
|
expected = (INSTRUCTIONS_DIR / f"{name}.md").read_text()
|
||||||
|
run_instructions(name)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out.strip() == expected.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_instructions_all_available(capsys):
|
||||||
|
"""Every name in AVAILABLE should succeed without error."""
|
||||||
|
for name in AVAILABLE:
|
||||||
|
run_instructions(name)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert len(out) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_instructions_invalid_name(capsys):
|
||||||
|
"""Invalid name should sys.exit(1) and print error to stderr."""
|
||||||
|
with pytest.raises(SystemExit) as exc_info:
|
||||||
|
run_instructions("nonexistent")
|
||||||
|
assert exc_info.value.code == 1
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "Unknown instructions: nonexistent" in captured.err
|
||||||
|
assert "Available:" in captured.err
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_instructions_missing_md_file(capsys, tmp_path):
|
||||||
|
"""If the .md file is missing on disk, should sys.exit(1)."""
|
||||||
|
with patch("mempalace.instructions_cli.INSTRUCTIONS_DIR", tmp_path):
|
||||||
|
with patch("mempalace.instructions_cli.AVAILABLE", ["fakecmd"]):
|
||||||
|
with pytest.raises(SystemExit) as exc_info:
|
||||||
|
run_instructions("fakecmd")
|
||||||
|
assert exc_info.value.code == 1
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "Instructions file not found" in captured.err
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
"""Extra knowledge graph tests for seed_from_entity_facts and query_relationship."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mempalace.knowledge_graph import KnowledgeGraph
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def kg(tmp_path):
|
||||||
|
return KnowledgeGraph(db_path=str(tmp_path / "kg.db"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeedFromEntityFacts:
|
||||||
|
def test_seed_person_with_partner(self, kg):
|
||||||
|
facts = {
|
||||||
|
"alice": {
|
||||||
|
"full_name": "Alice Smith",
|
||||||
|
"type": "person",
|
||||||
|
"gender": "female",
|
||||||
|
"partner": "bob",
|
||||||
|
"relationship": "husband",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kg.seed_from_entity_facts(facts)
|
||||||
|
stats = kg.stats()
|
||||||
|
assert stats["entities"] >= 1
|
||||||
|
results = kg.query_entity("Alice Smith", direction="outgoing")
|
||||||
|
predicates = {r["predicate"] for r in results}
|
||||||
|
assert "married_to" in predicates
|
||||||
|
assert "is_partner_of" in predicates
|
||||||
|
|
||||||
|
def test_seed_child(self, kg):
|
||||||
|
facts = {
|
||||||
|
"max": {
|
||||||
|
"full_name": "Max",
|
||||||
|
"type": "person",
|
||||||
|
"birthday": "2015-04-01",
|
||||||
|
"parent": "alice",
|
||||||
|
"relationship": "daughter",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kg.seed_from_entity_facts(facts)
|
||||||
|
results = kg.query_entity("Max", direction="outgoing")
|
||||||
|
predicates = {r["predicate"] for r in results}
|
||||||
|
assert "child_of" in predicates
|
||||||
|
assert "is_child_of" in predicates
|
||||||
|
|
||||||
|
def test_seed_sibling(self, kg):
|
||||||
|
facts = {
|
||||||
|
"emma": {
|
||||||
|
"full_name": "Emma",
|
||||||
|
"type": "person",
|
||||||
|
"relationship": "brother",
|
||||||
|
"sibling": "max",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kg.seed_from_entity_facts(facts)
|
||||||
|
results = kg.query_entity("Emma", direction="outgoing")
|
||||||
|
predicates = {r["predicate"] for r in results}
|
||||||
|
assert "is_sibling_of" in predicates
|
||||||
|
|
||||||
|
def test_seed_dog(self, kg):
|
||||||
|
facts = {
|
||||||
|
"rex": {
|
||||||
|
"full_name": "Rex",
|
||||||
|
"type": "animal",
|
||||||
|
"relationship": "dog",
|
||||||
|
"owner": "alice",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kg.seed_from_entity_facts(facts)
|
||||||
|
results = kg.query_entity("Rex", direction="outgoing")
|
||||||
|
predicates = {r["predicate"] for r in results}
|
||||||
|
assert "is_pet_of" in predicates
|
||||||
|
|
||||||
|
def test_seed_with_interests(self, kg):
|
||||||
|
facts = {
|
||||||
|
"max": {
|
||||||
|
"full_name": "Max",
|
||||||
|
"type": "person",
|
||||||
|
"interests": ["swimming", "chess"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kg.seed_from_entity_facts(facts)
|
||||||
|
results = kg.query_entity("Max", direction="outgoing")
|
||||||
|
objects = {r["object"] for r in results if r["predicate"] == "loves"}
|
||||||
|
assert "Swimming" in objects
|
||||||
|
assert "Chess" in objects
|
||||||
|
|
||||||
|
def test_seed_minimal_facts(self, kg):
|
||||||
|
"""Facts with no relationships just create entities."""
|
||||||
|
facts = {"bob": {"full_name": "Bob"}}
|
||||||
|
kg.seed_from_entity_facts(facts)
|
||||||
|
stats = kg.stats()
|
||||||
|
assert stats["entities"] >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryRelationshipWithTime:
|
||||||
|
def test_query_relationship_with_as_of(self, kg):
|
||||||
|
kg.add_triple("Alice", "works_at", "Acme", valid_from="2020-01-01", valid_to="2024-12-31")
|
||||||
|
kg.add_triple("Alice", "works_at", "NewCo", valid_from="2025-01-01")
|
||||||
|
results = kg.query_relationship("works_at", as_of="2023-06-01")
|
||||||
|
objects = [r["object"] for r in results]
|
||||||
|
assert "Acme" in objects
|
||||||
|
assert "NewCo" not in objects
|
||||||
@@ -0,0 +1,719 @@
|
|||||||
|
"""Tests for mempalace.layers — Layer0, Layer1, Layer2, Layer3, MemoryStack."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from mempalace.layers import Layer0, Layer1, Layer2, Layer3, MemoryStack
|
||||||
|
|
||||||
|
|
||||||
|
# ── Layer0 — with identity file ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer0_reads_identity_file(tmp_path):
|
||||||
|
identity_file = tmp_path / "identity.txt"
|
||||||
|
identity_file.write_text("I am Atlas, a personal AI assistant for Alice.")
|
||||||
|
layer = Layer0(identity_path=str(identity_file))
|
||||||
|
text = layer.render()
|
||||||
|
assert "Atlas" in text
|
||||||
|
assert "Alice" in text
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer0_caches_text(tmp_path):
|
||||||
|
identity_file = tmp_path / "identity.txt"
|
||||||
|
identity_file.write_text("Hello world")
|
||||||
|
layer = Layer0(identity_path=str(identity_file))
|
||||||
|
first = layer.render()
|
||||||
|
identity_file.write_text("Changed content")
|
||||||
|
second = layer.render()
|
||||||
|
assert first == second
|
||||||
|
assert second == "Hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer0_missing_file_returns_default(tmp_path):
|
||||||
|
missing = str(tmp_path / "nonexistent.txt")
|
||||||
|
layer = Layer0(identity_path=missing)
|
||||||
|
text = layer.render()
|
||||||
|
assert "No identity configured" in text
|
||||||
|
assert "identity.txt" in text
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer0_token_estimate(tmp_path):
|
||||||
|
identity_file = tmp_path / "identity.txt"
|
||||||
|
content = "A" * 400
|
||||||
|
identity_file.write_text(content)
|
||||||
|
layer = Layer0(identity_path=str(identity_file))
|
||||||
|
estimate = layer.token_estimate()
|
||||||
|
assert estimate == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer0_token_estimate_empty(tmp_path):
|
||||||
|
identity_file = tmp_path / "identity.txt"
|
||||||
|
identity_file.write_text("")
|
||||||
|
layer = Layer0(identity_path=str(identity_file))
|
||||||
|
assert layer.token_estimate() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer0_strips_whitespace(tmp_path):
|
||||||
|
identity_file = tmp_path / "identity.txt"
|
||||||
|
identity_file.write_text(" Hello world \n\n")
|
||||||
|
layer = Layer0(identity_path=str(identity_file))
|
||||||
|
text = layer.render()
|
||||||
|
assert text == "Hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer0_default_path():
|
||||||
|
layer = Layer0()
|
||||||
|
expected = os.path.expanduser("~/.mempalace/identity.txt")
|
||||||
|
assert layer.path == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ── Layer1 — mocked chromadb ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_chromadb_for_layer(docs, metas, monkeypatch=None):
|
||||||
|
"""Return a mock PersistentClient whose collection.get returns docs/metas."""
|
||||||
|
mock_col = MagicMock()
|
||||||
|
# First batch returns data, second batch returns empty (end of pagination)
|
||||||
|
mock_col.get.side_effect = [
|
||||||
|
{"documents": docs, "metadatas": metas},
|
||||||
|
{"documents": [], "metadatas": []},
|
||||||
|
]
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer1_no_palace():
|
||||||
|
"""Layer1 returns helpful message when no palace exists."""
|
||||||
|
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||||
|
mock_cfg.return_value.palace_path = "/nonexistent/palace"
|
||||||
|
layer = Layer1(palace_path="/nonexistent/palace")
|
||||||
|
result = layer.generate()
|
||||||
|
assert "No palace found" in result or "No memories" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer1_generates_essential_story():
|
||||||
|
docs = [
|
||||||
|
"Important memory about project decisions",
|
||||||
|
"Key architectural choice for the backend",
|
||||||
|
]
|
||||||
|
metas = [
|
||||||
|
{"room": "decisions", "source_file": "meeting.txt", "importance": 5},
|
||||||
|
{"room": "architecture", "source_file": "design.txt", "importance": 4},
|
||||||
|
]
|
||||||
|
mock_client = _mock_chromadb_for_layer(docs, metas)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer1(palace_path="/fake")
|
||||||
|
result = layer.generate()
|
||||||
|
|
||||||
|
assert "ESSENTIAL STORY" in result
|
||||||
|
assert "project decisions" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer1_empty_palace():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.return_value = {"documents": [], "metadatas": []}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer1(palace_path="/fake")
|
||||||
|
result = layer.generate()
|
||||||
|
|
||||||
|
assert "No memories" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer1_with_wing_filter():
|
||||||
|
docs = ["Memory about project X"]
|
||||||
|
metas = [{"room": "general", "source_file": "x.txt", "importance": 3}]
|
||||||
|
mock_client = _mock_chromadb_for_layer(docs, metas)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer1(palace_path="/fake", wing="project_x")
|
||||||
|
result = layer.generate()
|
||||||
|
|
||||||
|
assert "ESSENTIAL STORY" in result
|
||||||
|
# Verify wing filter was passed
|
||||||
|
call_kwargs = mock_client.get_collection.return_value.get.call_args_list[0][1]
|
||||||
|
assert call_kwargs.get("where") == {"wing": "project_x"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer1_truncates_long_snippets():
|
||||||
|
docs = ["A" * 300]
|
||||||
|
metas = [{"room": "general", "source_file": "long.txt"}]
|
||||||
|
mock_client = _mock_chromadb_for_layer(docs, metas)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer1(palace_path="/fake")
|
||||||
|
result = layer.generate()
|
||||||
|
|
||||||
|
assert "..." in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer1_respects_max_chars():
|
||||||
|
"""L1 stops adding entries once MAX_CHARS is reached."""
|
||||||
|
docs = [f"Memory number {i} with substantial content padding here" for i in range(30)]
|
||||||
|
metas = [{"room": "general", "source_file": f"f{i}.txt", "importance": 5} for i in range(30)]
|
||||||
|
mock_client = _mock_chromadb_for_layer(docs, metas)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer1(palace_path="/fake")
|
||||||
|
layer.MAX_CHARS = 200 # Very low cap to trigger truncation
|
||||||
|
result = layer.generate()
|
||||||
|
|
||||||
|
assert "more in L3 search" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer1_importance_from_various_keys():
|
||||||
|
"""Layer1 tries importance, emotional_weight, weight keys."""
|
||||||
|
docs = ["mem1", "mem2", "mem3"]
|
||||||
|
metas = [
|
||||||
|
{"room": "r", "emotional_weight": 5},
|
||||||
|
{"room": "r", "weight": 1},
|
||||||
|
{"room": "r"}, # no weight key, defaults to 3
|
||||||
|
]
|
||||||
|
mock_client = _mock_chromadb_for_layer(docs, metas)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer1(palace_path="/fake")
|
||||||
|
result = layer.generate()
|
||||||
|
|
||||||
|
assert "ESSENTIAL STORY" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer1_batch_exception_breaks():
|
||||||
|
"""If col.get raises on a batch, loop breaks gracefully."""
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.side_effect = [
|
||||||
|
{"documents": ["doc1"], "metadatas": [{"room": "r"}]},
|
||||||
|
RuntimeError("batch error"),
|
||||||
|
]
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer1(palace_path="/fake")
|
||||||
|
result = layer.generate()
|
||||||
|
|
||||||
|
assert "ESSENTIAL STORY" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ── Layer2 — mocked chromadb ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer2_no_palace():
|
||||||
|
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||||
|
mock_cfg.return_value.palace_path = "/nonexistent/palace"
|
||||||
|
layer = Layer2(palace_path="/nonexistent/palace")
|
||||||
|
result = layer.retrieve(wing="test")
|
||||||
|
assert "No palace found" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer2_retrieve_with_wing():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.return_value = {
|
||||||
|
"documents": ["Some memory about the project"],
|
||||||
|
"metadatas": [{"room": "backend", "source_file": "notes.txt"}],
|
||||||
|
}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer2(palace_path="/fake")
|
||||||
|
result = layer.retrieve(wing="project")
|
||||||
|
|
||||||
|
assert "ON-DEMAND" in result
|
||||||
|
assert "memory about the project" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer2_retrieve_with_room():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.return_value = {
|
||||||
|
"documents": ["Backend architecture notes"],
|
||||||
|
"metadatas": [{"room": "architecture", "source_file": "arch.txt"}],
|
||||||
|
}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer2(palace_path="/fake")
|
||||||
|
result = layer.retrieve(room="architecture")
|
||||||
|
|
||||||
|
assert "ON-DEMAND" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer2_retrieve_wing_and_room():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.return_value = {
|
||||||
|
"documents": ["Filtered result"],
|
||||||
|
"metadatas": [{"room": "backend", "source_file": "x.txt"}],
|
||||||
|
}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer2(palace_path="/fake")
|
||||||
|
result = layer.retrieve(wing="proj", room="backend")
|
||||||
|
|
||||||
|
assert "ON-DEMAND" in result
|
||||||
|
call_kwargs = mock_col.get.call_args[1]
|
||||||
|
assert "$and" in call_kwargs.get("where", {})
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer2_retrieve_empty():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.return_value = {"documents": [], "metadatas": []}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer2(palace_path="/fake")
|
||||||
|
result = layer.retrieve(wing="missing")
|
||||||
|
|
||||||
|
assert "No drawers found" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer2_retrieve_no_filter():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.return_value = {"documents": [], "metadatas": []}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer2(palace_path="/fake")
|
||||||
|
layer.retrieve()
|
||||||
|
|
||||||
|
# No where filter should be passed
|
||||||
|
call_kwargs = mock_col.get.call_args[1]
|
||||||
|
assert "where" not in call_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer2_retrieve_error():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.side_effect = RuntimeError("db error")
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer2(palace_path="/fake")
|
||||||
|
result = layer.retrieve(wing="test")
|
||||||
|
|
||||||
|
assert "Retrieval error" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer2_truncates_long_snippets():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.get.return_value = {
|
||||||
|
"documents": ["B" * 400],
|
||||||
|
"metadatas": [{"room": "r", "source_file": "s.txt"}],
|
||||||
|
}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer2(palace_path="/fake")
|
||||||
|
result = layer.retrieve(wing="test")
|
||||||
|
|
||||||
|
assert "..." in result
|
||||||
|
|
||||||
|
|
||||||
|
# ── Layer3 — mocked chromadb ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_query_results(docs, metas, dists):
|
||||||
|
return {
|
||||||
|
"documents": [docs],
|
||||||
|
"metadatas": [metas],
|
||||||
|
"distances": [dists],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_no_palace():
|
||||||
|
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||||
|
mock_cfg.return_value.palace_path = "/nonexistent/palace"
|
||||||
|
layer = Layer3(palace_path="/nonexistent/palace")
|
||||||
|
result = layer.search("test query")
|
||||||
|
assert "No palace found" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_search_raw_no_palace():
|
||||||
|
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||||
|
mock_cfg.return_value.palace_path = "/nonexistent/palace"
|
||||||
|
layer = Layer3(palace_path="/nonexistent/palace")
|
||||||
|
result = layer.search_raw("test query")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_search_with_results():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.return_value = _mock_query_results(
|
||||||
|
["Found this important memory"],
|
||||||
|
[{"wing": "project", "room": "backend", "source_file": "notes.txt"}],
|
||||||
|
[0.2],
|
||||||
|
)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer3(palace_path="/fake")
|
||||||
|
result = layer.search("important")
|
||||||
|
|
||||||
|
assert "SEARCH RESULTS" in result
|
||||||
|
assert "important memory" in result
|
||||||
|
assert "sim=0.8" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_search_no_results():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.return_value = _mock_query_results([], [], [])
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer3(palace_path="/fake")
|
||||||
|
result = layer.search("nothing")
|
||||||
|
|
||||||
|
assert "No results found" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_search_with_wing_filter():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.return_value = _mock_query_results(
|
||||||
|
["result"],
|
||||||
|
[{"wing": "proj", "room": "r"}],
|
||||||
|
[0.1],
|
||||||
|
)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer3(palace_path="/fake")
|
||||||
|
layer.search("q", wing="proj")
|
||||||
|
|
||||||
|
call_kwargs = mock_col.query.call_args[1]
|
||||||
|
assert call_kwargs["where"] == {"wing": "proj"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_search_with_room_filter():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.return_value = _mock_query_results(
|
||||||
|
["result"],
|
||||||
|
[{"wing": "w", "room": "backend"}],
|
||||||
|
[0.1],
|
||||||
|
)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer3(palace_path="/fake")
|
||||||
|
layer.search("q", room="backend")
|
||||||
|
|
||||||
|
call_kwargs = mock_col.query.call_args[1]
|
||||||
|
assert call_kwargs["where"] == {"room": "backend"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_search_with_wing_and_room():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.return_value = _mock_query_results(
|
||||||
|
["result"],
|
||||||
|
[{"wing": "proj", "room": "backend"}],
|
||||||
|
[0.1],
|
||||||
|
)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer3(palace_path="/fake")
|
||||||
|
layer.search("q", wing="proj", room="backend")
|
||||||
|
|
||||||
|
call_kwargs = mock_col.query.call_args[1]
|
||||||
|
assert "$and" in call_kwargs["where"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_search_error():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.side_effect = RuntimeError("search failed")
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer3(palace_path="/fake")
|
||||||
|
result = layer.search("q")
|
||||||
|
|
||||||
|
assert "Search error" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_search_truncates_long_docs():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.return_value = _mock_query_results(
|
||||||
|
["C" * 400],
|
||||||
|
[{"wing": "w", "room": "r", "source_file": "s.txt"}],
|
||||||
|
[0.1],
|
||||||
|
)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer3(palace_path="/fake")
|
||||||
|
result = layer.search("q")
|
||||||
|
|
||||||
|
assert "..." in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_search_raw_returns_dicts():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.return_value = _mock_query_results(
|
||||||
|
["doc text"],
|
||||||
|
[{"wing": "proj", "room": "backend", "source_file": "f.txt"}],
|
||||||
|
[0.3],
|
||||||
|
)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer3(palace_path="/fake")
|
||||||
|
hits = layer.search_raw("q")
|
||||||
|
|
||||||
|
assert len(hits) == 1
|
||||||
|
assert hits[0]["text"] == "doc text"
|
||||||
|
assert hits[0]["wing"] == "proj"
|
||||||
|
assert hits[0]["similarity"] == 0.7
|
||||||
|
assert "metadata" in hits[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_search_raw_with_filters():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.return_value = _mock_query_results(
|
||||||
|
["doc"],
|
||||||
|
[{"wing": "w", "room": "r"}],
|
||||||
|
[0.1],
|
||||||
|
)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer3(palace_path="/fake")
|
||||||
|
layer.search_raw("q", wing="w", room="r")
|
||||||
|
|
||||||
|
call_kwargs = mock_col.query.call_args[1]
|
||||||
|
assert "$and" in call_kwargs["where"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer3_search_raw_error():
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.side_effect = RuntimeError("fail")
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
layer = Layer3(palace_path="/fake")
|
||||||
|
result = layer.search_raw("q")
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryStack ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_stack_wake_up(tmp_path):
|
||||||
|
identity_file = tmp_path / "identity.txt"
|
||||||
|
identity_file.write_text("I am Atlas.")
|
||||||
|
|
||||||
|
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||||
|
mock_cfg.return_value.palace_path = "/nonexistent"
|
||||||
|
stack = MemoryStack(
|
||||||
|
palace_path="/nonexistent",
|
||||||
|
identity_path=str(identity_file),
|
||||||
|
)
|
||||||
|
result = stack.wake_up()
|
||||||
|
|
||||||
|
assert "Atlas" in result
|
||||||
|
# L1 will say no palace found
|
||||||
|
assert "No palace" in result or "No memories" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_stack_wake_up_with_wing(tmp_path):
|
||||||
|
identity_file = tmp_path / "identity.txt"
|
||||||
|
identity_file.write_text("I am Atlas.")
|
||||||
|
|
||||||
|
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||||
|
mock_cfg.return_value.palace_path = "/nonexistent"
|
||||||
|
stack = MemoryStack(
|
||||||
|
palace_path="/nonexistent",
|
||||||
|
identity_path=str(identity_file),
|
||||||
|
)
|
||||||
|
result = stack.wake_up(wing="my_project")
|
||||||
|
|
||||||
|
assert stack.l1.wing == "my_project"
|
||||||
|
assert "Atlas" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_stack_recall(tmp_path):
|
||||||
|
identity_file = tmp_path / "identity.txt"
|
||||||
|
identity_file.write_text("I am Atlas.")
|
||||||
|
|
||||||
|
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||||
|
mock_cfg.return_value.palace_path = "/nonexistent"
|
||||||
|
stack = MemoryStack(
|
||||||
|
palace_path="/nonexistent",
|
||||||
|
identity_path=str(identity_file),
|
||||||
|
)
|
||||||
|
result = stack.recall(wing="test")
|
||||||
|
|
||||||
|
assert "No palace found" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_stack_search(tmp_path):
|
||||||
|
identity_file = tmp_path / "identity.txt"
|
||||||
|
identity_file.write_text("I am Atlas.")
|
||||||
|
|
||||||
|
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||||
|
mock_cfg.return_value.palace_path = "/nonexistent"
|
||||||
|
stack = MemoryStack(
|
||||||
|
palace_path="/nonexistent",
|
||||||
|
identity_path=str(identity_file),
|
||||||
|
)
|
||||||
|
result = stack.search("test query")
|
||||||
|
|
||||||
|
assert "No palace found" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_stack_status(tmp_path):
|
||||||
|
identity_file = tmp_path / "identity.txt"
|
||||||
|
identity_file.write_text("I am Atlas.")
|
||||||
|
|
||||||
|
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||||
|
mock_cfg.return_value.palace_path = "/nonexistent"
|
||||||
|
stack = MemoryStack(
|
||||||
|
palace_path="/nonexistent",
|
||||||
|
identity_path=str(identity_file),
|
||||||
|
)
|
||||||
|
result = stack.status()
|
||||||
|
|
||||||
|
assert result["palace_path"] == "/nonexistent"
|
||||||
|
assert result["total_drawers"] == 0
|
||||||
|
assert "L0_identity" in result
|
||||||
|
assert "L1_essential" in result
|
||||||
|
assert "L2_on_demand" in result
|
||||||
|
assert "L3_deep_search" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_stack_status_with_palace(tmp_path):
|
||||||
|
identity_file = tmp_path / "identity.txt"
|
||||||
|
identity_file.write_text("I am Atlas.")
|
||||||
|
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.count.return_value = 42
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||||
|
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||||
|
):
|
||||||
|
mock_cfg.return_value.palace_path = "/fake"
|
||||||
|
stack = MemoryStack(
|
||||||
|
palace_path="/fake",
|
||||||
|
identity_path=str(identity_file),
|
||||||
|
)
|
||||||
|
result = stack.status()
|
||||||
|
|
||||||
|
assert result["total_drawers"] == 42
|
||||||
|
assert result["L0_identity"]["exists"] is True
|
||||||
@@ -42,6 +42,50 @@ class TestHandleRequest:
|
|||||||
assert resp["result"]["serverInfo"]["name"] == "mempalace"
|
assert resp["result"]["serverInfo"]["name"] == "mempalace"
|
||||||
assert resp["id"] == 1
|
assert resp["id"] == 1
|
||||||
|
|
||||||
|
def test_initialize_negotiates_client_version(self):
|
||||||
|
from mempalace.mcp_server import handle_request
|
||||||
|
|
||||||
|
resp = handle_request(
|
||||||
|
{
|
||||||
|
"method": "initialize",
|
||||||
|
"id": 1,
|
||||||
|
"params": {"protocolVersion": "2025-11-25"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert resp["result"]["protocolVersion"] == "2025-11-25"
|
||||||
|
|
||||||
|
def test_initialize_negotiates_older_supported_version(self):
|
||||||
|
from mempalace.mcp_server import handle_request
|
||||||
|
|
||||||
|
resp = handle_request(
|
||||||
|
{
|
||||||
|
"method": "initialize",
|
||||||
|
"id": 1,
|
||||||
|
"params": {"protocolVersion": "2025-03-26"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert resp["result"]["protocolVersion"] == "2025-03-26"
|
||||||
|
|
||||||
|
def test_initialize_unknown_version_falls_back_to_latest(self):
|
||||||
|
from mempalace.mcp_server import handle_request
|
||||||
|
|
||||||
|
resp = handle_request(
|
||||||
|
{
|
||||||
|
"method": "initialize",
|
||||||
|
"id": 1,
|
||||||
|
"params": {"protocolVersion": "9999-12-31"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
from mempalace.mcp_server import SUPPORTED_PROTOCOL_VERSIONS
|
||||||
|
|
||||||
|
assert resp["result"]["protocolVersion"] == SUPPORTED_PROTOCOL_VERSIONS[0]
|
||||||
|
|
||||||
|
def test_initialize_missing_version_uses_oldest(self):
|
||||||
|
from mempalace.mcp_server import handle_request, SUPPORTED_PROTOCOL_VERSIONS
|
||||||
|
|
||||||
|
resp = handle_request({"method": "initialize", "id": 1, "params": {}})
|
||||||
|
assert resp["result"]["protocolVersion"] == SUPPORTED_PROTOCOL_VERSIONS[-1]
|
||||||
|
|
||||||
def test_notifications_initialized_returns_none(self):
|
def test_notifications_initialized_returns_none(self):
|
||||||
from mempalace.mcp_server import handle_request
|
from mempalace.mcp_server import handle_request
|
||||||
|
|
||||||
@@ -59,6 +103,23 @@ class TestHandleRequest:
|
|||||||
assert "mempalace_add_drawer" in names
|
assert "mempalace_add_drawer" in names
|
||||||
assert "mempalace_kg_add" in names
|
assert "mempalace_kg_add" in names
|
||||||
|
|
||||||
|
def test_null_arguments_does_not_hang(self, monkeypatch, config, palace_path, seeded_kg):
|
||||||
|
"""Sending arguments: null should return a result, not hang (#394)."""
|
||||||
|
_patch_mcp_server(monkeypatch, config, seeded_kg)
|
||||||
|
from mempalace.mcp_server import handle_request
|
||||||
|
|
||||||
|
_client, _col = _get_collection(palace_path, create=True)
|
||||||
|
del _client
|
||||||
|
resp = handle_request(
|
||||||
|
{
|
||||||
|
"method": "tools/call",
|
||||||
|
"id": 10,
|
||||||
|
"params": {"name": "mempalace_status", "arguments": None},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert "error" not in resp
|
||||||
|
assert resp["result"] is not None
|
||||||
|
|
||||||
def test_unknown_tool(self):
|
def test_unknown_tool(self):
|
||||||
from mempalace.mcp_server import handle_request
|
from mempalace.mcp_server import handle_request
|
||||||
|
|
||||||
|
|||||||
+55
-1
@@ -7,6 +7,7 @@ import chromadb
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from mempalace.miner import mine, scan_project
|
from mempalace.miner import mine, scan_project
|
||||||
|
from mempalace.palace import file_already_mined
|
||||||
|
|
||||||
|
|
||||||
def write_file(path: Path, content: str):
|
def write_file(path: Path, content: str):
|
||||||
@@ -47,7 +48,7 @@ def test_project_mining():
|
|||||||
col = client.get_collection("mempalace_drawers")
|
col = client.get_collection("mempalace_drawers")
|
||||||
assert col.count() > 0
|
assert col.count() > 0
|
||||||
finally:
|
finally:
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
def test_scan_project_respects_gitignore():
|
def test_scan_project_respects_gitignore():
|
||||||
@@ -206,3 +207,56 @@ def test_scan_project_skip_dirs_still_apply_without_override():
|
|||||||
assert scanned_files(project_root, respect_gitignore=False) == ["main.py"]
|
assert scanned_files(project_root, respect_gitignore=False) == ["main.py"]
|
||||||
finally:
|
finally:
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_already_mined_check_mtime():
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
|
palace_path = os.path.join(tmpdir, "palace")
|
||||||
|
os.makedirs(palace_path)
|
||||||
|
client = chromadb.PersistentClient(path=palace_path)
|
||||||
|
col = client.get_or_create_collection("mempalace_drawers")
|
||||||
|
|
||||||
|
test_file = os.path.join(tmpdir, "test.txt")
|
||||||
|
with open(test_file, "w") as f:
|
||||||
|
f.write("hello world")
|
||||||
|
|
||||||
|
mtime = os.path.getmtime(test_file)
|
||||||
|
|
||||||
|
# Not mined yet
|
||||||
|
assert file_already_mined(col, test_file) is False
|
||||||
|
assert file_already_mined(col, test_file, check_mtime=True) is False
|
||||||
|
|
||||||
|
# Add it with mtime
|
||||||
|
col.add(
|
||||||
|
ids=["d1"],
|
||||||
|
documents=["hello world"],
|
||||||
|
metadatas=[{"source_file": test_file, "source_mtime": str(mtime)}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Already mined (no mtime check)
|
||||||
|
assert file_already_mined(col, test_file) is True
|
||||||
|
# Already mined (mtime matches)
|
||||||
|
assert file_already_mined(col, test_file, check_mtime=True) is True
|
||||||
|
|
||||||
|
# Modify file and force a different mtime (Windows has low mtime resolution)
|
||||||
|
with open(test_file, "w") as f:
|
||||||
|
f.write("modified content")
|
||||||
|
os.utime(test_file, (mtime + 10, mtime + 10))
|
||||||
|
|
||||||
|
# Still mined without mtime check
|
||||||
|
assert file_already_mined(col, test_file) is True
|
||||||
|
# Needs re-mining with mtime check
|
||||||
|
assert file_already_mined(col, test_file, check_mtime=True) is False
|
||||||
|
|
||||||
|
# Record with no mtime stored should return False for check_mtime
|
||||||
|
col.add(
|
||||||
|
ids=["d2"],
|
||||||
|
documents=["other"],
|
||||||
|
metadatas=[{"source_file": "/fake/no_mtime.txt"}],
|
||||||
|
)
|
||||||
|
assert file_already_mined(col, "/fake/no_mtime.txt", check_mtime=True) is False
|
||||||
|
finally:
|
||||||
|
# Release ChromaDB file handles before cleanup (required on Windows)
|
||||||
|
del col, client
|
||||||
|
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||||
|
|||||||
+500
-20
@@ -1,31 +1,511 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import tempfile
|
from unittest.mock import patch
|
||||||
from mempalace.normalize import normalize
|
|
||||||
|
from mempalace.normalize import (
|
||||||
|
_extract_content,
|
||||||
|
_messages_to_transcript,
|
||||||
|
_try_chatgpt_json,
|
||||||
|
_try_claude_ai_json,
|
||||||
|
_try_claude_code_jsonl,
|
||||||
|
_try_codex_jsonl,
|
||||||
|
_try_normalize_json,
|
||||||
|
_try_slack_json,
|
||||||
|
normalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_plain_text():
|
# ── normalize() top-level ──────────────────────────────────────────────
|
||||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False)
|
|
||||||
f.write("Hello world\nSecond line\n")
|
|
||||||
f.close()
|
def test_plain_text(tmp_path):
|
||||||
result = normalize(f.name)
|
f = tmp_path / "plain.txt"
|
||||||
|
f.write_text("Hello world\nSecond line\n")
|
||||||
|
result = normalize(str(f))
|
||||||
assert "Hello world" in result
|
assert "Hello world" in result
|
||||||
os.unlink(f.name)
|
|
||||||
|
|
||||||
|
|
||||||
def test_claude_json():
|
def test_claude_json(tmp_path):
|
||||||
data = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}]
|
data = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}]
|
||||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
|
f = tmp_path / "claude.json"
|
||||||
json.dump(data, f)
|
f.write_text(json.dumps(data))
|
||||||
f.close()
|
result = normalize(str(f))
|
||||||
result = normalize(f.name)
|
|
||||||
assert "Hi" in result
|
assert "Hi" in result
|
||||||
os.unlink(f.name)
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty():
|
def test_empty(tmp_path):
|
||||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False)
|
f = tmp_path / "empty.txt"
|
||||||
f.close()
|
f.write_text("")
|
||||||
result = normalize(f.name)
|
result = normalize(str(f))
|
||||||
assert result.strip() == ""
|
assert result.strip() == ""
|
||||||
os.unlink(f.name)
|
|
||||||
|
|
||||||
|
def test_normalize_io_error():
|
||||||
|
"""normalize raises IOError for unreadable file."""
|
||||||
|
try:
|
||||||
|
normalize("/nonexistent/path/file.txt")
|
||||||
|
assert False, "Should have raised"
|
||||||
|
except IOError as e:
|
||||||
|
assert "Could not read" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_already_has_markers(tmp_path):
|
||||||
|
"""Files with >= 3 '>' lines pass through unchanged."""
|
||||||
|
content = "> question 1\nanswer 1\n> question 2\nanswer 2\n> question 3\nanswer 3\n"
|
||||||
|
f = tmp_path / "markers.txt"
|
||||||
|
f.write_text(content)
|
||||||
|
result = normalize(str(f))
|
||||||
|
assert result == content
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_json_content_detected_by_brace(tmp_path):
|
||||||
|
"""A .txt file starting with [ triggers JSON parsing."""
|
||||||
|
data = [{"role": "user", "content": "Hey"}, {"role": "assistant", "content": "Hi there"}]
|
||||||
|
f = tmp_path / "chat.txt"
|
||||||
|
f.write_text(json.dumps(data))
|
||||||
|
result = normalize(str(f))
|
||||||
|
assert "Hey" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_whitespace_only(tmp_path):
|
||||||
|
f = tmp_path / "ws.txt"
|
||||||
|
f.write_text(" \n \n ")
|
||||||
|
result = normalize(str(f))
|
||||||
|
assert result.strip() == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── _extract_content ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_content_string():
|
||||||
|
assert _extract_content("hello") == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_content_list_of_strings():
|
||||||
|
assert _extract_content(["hello", "world"]) == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_content_list_of_blocks():
|
||||||
|
blocks = [{"type": "text", "text": "hello"}, {"type": "image", "url": "x"}]
|
||||||
|
assert _extract_content(blocks) == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_content_dict():
|
||||||
|
assert _extract_content({"text": "hello"}) == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_content_none():
|
||||||
|
assert _extract_content(None) == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_content_mixed_list():
|
||||||
|
blocks = ["plain", {"type": "text", "text": "block"}]
|
||||||
|
assert _extract_content(blocks) == "plain block"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _try_claude_code_jsonl ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_code_jsonl_valid():
|
||||||
|
lines = [
|
||||||
|
json.dumps({"type": "human", "message": {"content": "What is X?"}}),
|
||||||
|
json.dumps({"type": "assistant", "message": {"content": "X is Y."}}),
|
||||||
|
]
|
||||||
|
result = _try_claude_code_jsonl("\n".join(lines))
|
||||||
|
assert result is not None
|
||||||
|
assert "> What is X?" in result
|
||||||
|
assert "X is Y." in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_code_jsonl_user_type():
|
||||||
|
lines = [
|
||||||
|
json.dumps({"type": "user", "message": {"content": "Q"}}),
|
||||||
|
json.dumps({"type": "assistant", "message": {"content": "A"}}),
|
||||||
|
]
|
||||||
|
result = _try_claude_code_jsonl("\n".join(lines))
|
||||||
|
assert result is not None
|
||||||
|
assert "> Q" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_code_jsonl_too_few_messages():
|
||||||
|
lines = [json.dumps({"type": "human", "message": {"content": "only one"}})]
|
||||||
|
result = _try_claude_code_jsonl("\n".join(lines))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_code_jsonl_invalid_json_lines():
|
||||||
|
lines = [
|
||||||
|
"not json",
|
||||||
|
json.dumps({"type": "human", "message": {"content": "Q"}}),
|
||||||
|
json.dumps({"type": "assistant", "message": {"content": "A"}}),
|
||||||
|
]
|
||||||
|
result = _try_claude_code_jsonl("\n".join(lines))
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_code_jsonl_non_dict_entries():
|
||||||
|
lines = [
|
||||||
|
json.dumps([1, 2, 3]),
|
||||||
|
json.dumps({"type": "human", "message": {"content": "Q"}}),
|
||||||
|
json.dumps({"type": "assistant", "message": {"content": "A"}}),
|
||||||
|
]
|
||||||
|
result = _try_claude_code_jsonl("\n".join(lines))
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ── _try_codex_jsonl ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_jsonl_valid():
|
||||||
|
lines = [
|
||||||
|
json.dumps({"type": "session_meta", "payload": {}}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||||
|
]
|
||||||
|
result = _try_codex_jsonl("\n".join(lines))
|
||||||
|
assert result is not None
|
||||||
|
assert "> Q" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_jsonl_no_session_meta():
|
||||||
|
"""Without session_meta, codex parser returns None."""
|
||||||
|
lines = [
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||||
|
]
|
||||||
|
result = _try_codex_jsonl("\n".join(lines))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_jsonl_skips_non_event_msg():
|
||||||
|
lines = [
|
||||||
|
json.dumps({"type": "session_meta"}),
|
||||||
|
json.dumps({"type": "response_item", "payload": {"type": "user_message", "message": "X"}}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||||
|
]
|
||||||
|
result = _try_codex_jsonl("\n".join(lines))
|
||||||
|
assert result is not None
|
||||||
|
assert "X" not in result.split("> Q")[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_jsonl_non_string_message():
|
||||||
|
lines = [
|
||||||
|
json.dumps({"type": "session_meta"}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": 123}}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||||
|
]
|
||||||
|
result = _try_codex_jsonl("\n".join(lines))
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_jsonl_empty_text_skipped():
|
||||||
|
lines = [
|
||||||
|
json.dumps({"type": "session_meta"}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": " "}}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||||
|
]
|
||||||
|
result = _try_codex_jsonl("\n".join(lines))
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_codex_jsonl_payload_not_dict():
|
||||||
|
lines = [
|
||||||
|
json.dumps({"type": "session_meta"}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": "not a dict"}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||||
|
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||||
|
]
|
||||||
|
result = _try_codex_jsonl("\n".join(lines))
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ── _try_claude_ai_json ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_ai_flat_messages():
|
||||||
|
data = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there"},
|
||||||
|
]
|
||||||
|
result = _try_claude_ai_json(data)
|
||||||
|
assert result is not None
|
||||||
|
assert "> Hello" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_ai_dict_with_messages_key():
|
||||||
|
data = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = _try_claude_ai_json(data)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_ai_privacy_export():
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"chat_messages": [
|
||||||
|
{"role": "human", "content": "Q1"},
|
||||||
|
{"role": "ai", "content": "A1"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
result = _try_claude_ai_json(data)
|
||||||
|
assert result is not None
|
||||||
|
assert "> Q1" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_ai_not_a_list():
|
||||||
|
result = _try_claude_ai_json("not a list")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_ai_too_few_messages():
|
||||||
|
data = [{"role": "user", "content": "Hello"}]
|
||||||
|
result = _try_claude_ai_json(data)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_ai_dict_with_chat_messages_key():
|
||||||
|
data = {
|
||||||
|
"chat_messages": [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "World"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = _try_claude_ai_json(data)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_ai_privacy_export_non_dict_items():
|
||||||
|
"""Non-dict items in privacy export are skipped."""
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"chat_messages": [
|
||||||
|
"not a dict",
|
||||||
|
{"role": "user", "content": "Q"},
|
||||||
|
{"role": "assistant", "content": "A"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"not a convo",
|
||||||
|
]
|
||||||
|
result = _try_claude_ai_json(data)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ── _try_chatgpt_json ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatgpt_json_valid():
|
||||||
|
data = {
|
||||||
|
"mapping": {
|
||||||
|
"root": {
|
||||||
|
"parent": None,
|
||||||
|
"message": None,
|
||||||
|
"children": ["msg1"],
|
||||||
|
},
|
||||||
|
"msg1": {
|
||||||
|
"parent": "root",
|
||||||
|
"message": {
|
||||||
|
"author": {"role": "user"},
|
||||||
|
"content": {"parts": ["Hello ChatGPT"]},
|
||||||
|
},
|
||||||
|
"children": ["msg2"],
|
||||||
|
},
|
||||||
|
"msg2": {
|
||||||
|
"parent": "msg1",
|
||||||
|
"message": {
|
||||||
|
"author": {"role": "assistant"},
|
||||||
|
"content": {"parts": ["Hello! How can I help?"]},
|
||||||
|
},
|
||||||
|
"children": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = _try_chatgpt_json(data)
|
||||||
|
assert result is not None
|
||||||
|
assert "> Hello ChatGPT" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatgpt_json_no_mapping():
|
||||||
|
result = _try_chatgpt_json({"data": []})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatgpt_json_not_dict():
|
||||||
|
result = _try_chatgpt_json([1, 2, 3])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatgpt_json_fallback_root():
|
||||||
|
"""Root node has a message (no synthetic root), uses fallback."""
|
||||||
|
data = {
|
||||||
|
"mapping": {
|
||||||
|
"root": {
|
||||||
|
"parent": None,
|
||||||
|
"message": {
|
||||||
|
"author": {"role": "system"},
|
||||||
|
"content": {"parts": ["system prompt"]},
|
||||||
|
},
|
||||||
|
"children": ["msg1"],
|
||||||
|
},
|
||||||
|
"msg1": {
|
||||||
|
"parent": "root",
|
||||||
|
"message": {
|
||||||
|
"author": {"role": "user"},
|
||||||
|
"content": {"parts": ["Hello"]},
|
||||||
|
},
|
||||||
|
"children": ["msg2"],
|
||||||
|
},
|
||||||
|
"msg2": {
|
||||||
|
"parent": "msg1",
|
||||||
|
"message": {
|
||||||
|
"author": {"role": "assistant"},
|
||||||
|
"content": {"parts": ["Hi there"]},
|
||||||
|
},
|
||||||
|
"children": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = _try_chatgpt_json(data)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatgpt_json_too_few_messages():
|
||||||
|
data = {
|
||||||
|
"mapping": {
|
||||||
|
"root": {
|
||||||
|
"parent": None,
|
||||||
|
"message": None,
|
||||||
|
"children": ["msg1"],
|
||||||
|
},
|
||||||
|
"msg1": {
|
||||||
|
"parent": "root",
|
||||||
|
"message": {
|
||||||
|
"author": {"role": "user"},
|
||||||
|
"content": {"parts": ["Only one"]},
|
||||||
|
},
|
||||||
|
"children": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = _try_chatgpt_json(data)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── _try_slack_json ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_json_valid():
|
||||||
|
data = [
|
||||||
|
{"type": "message", "user": "U1", "text": "Hello"},
|
||||||
|
{"type": "message", "user": "U2", "text": "Hi there"},
|
||||||
|
]
|
||||||
|
result = _try_slack_json(data)
|
||||||
|
assert result is not None
|
||||||
|
assert "Hello" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_json_not_a_list():
|
||||||
|
result = _try_slack_json({"type": "message"})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_json_too_few_messages():
|
||||||
|
data = [{"type": "message", "user": "U1", "text": "Hello"}]
|
||||||
|
result = _try_slack_json(data)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_json_skips_non_message_types():
|
||||||
|
data = [
|
||||||
|
{"type": "channel_join", "user": "U1", "text": "joined"},
|
||||||
|
{"type": "message", "user": "U1", "text": "Hello"},
|
||||||
|
{"type": "message", "user": "U2", "text": "Hi"},
|
||||||
|
]
|
||||||
|
result = _try_slack_json(data)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_json_three_users():
|
||||||
|
"""Three speakers get alternating roles."""
|
||||||
|
data = [
|
||||||
|
{"type": "message", "user": "U1", "text": "Hello"},
|
||||||
|
{"type": "message", "user": "U2", "text": "Hi"},
|
||||||
|
{"type": "message", "user": "U3", "text": "Hey"},
|
||||||
|
]
|
||||||
|
result = _try_slack_json(data)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_json_empty_text_skipped():
|
||||||
|
data = [
|
||||||
|
{"type": "message", "user": "U1", "text": ""},
|
||||||
|
{"type": "message", "user": "U1", "text": "Hello"},
|
||||||
|
{"type": "message", "user": "U2", "text": "Hi"},
|
||||||
|
]
|
||||||
|
result = _try_slack_json(data)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_json_username_fallback():
|
||||||
|
data = [
|
||||||
|
{"type": "message", "username": "bot1", "text": "Hello"},
|
||||||
|
{"type": "message", "username": "bot2", "text": "Hi"},
|
||||||
|
]
|
||||||
|
result = _try_slack_json(data)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ── _try_normalize_json ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_try_normalize_json_invalid_json():
|
||||||
|
result = _try_normalize_json("not json at all {{{")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_try_normalize_json_valid_but_unknown_schema():
|
||||||
|
result = _try_normalize_json(json.dumps({"random": "data"}))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── _messages_to_transcript ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_to_transcript_basic():
|
||||||
|
msgs = [("user", "Q"), ("assistant", "A")]
|
||||||
|
with patch("mempalace.normalize.spellcheck_user_text", side_effect=lambda x: x, create=True):
|
||||||
|
result = _messages_to_transcript(msgs, spellcheck=False)
|
||||||
|
assert "> Q" in result
|
||||||
|
assert "A" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_to_transcript_consecutive_users():
|
||||||
|
"""Two user messages in a row (no assistant between)."""
|
||||||
|
msgs = [("user", "Q1"), ("user", "Q2"), ("assistant", "A")]
|
||||||
|
result = _messages_to_transcript(msgs, spellcheck=False)
|
||||||
|
assert "> Q1" in result
|
||||||
|
assert "> Q2" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_to_transcript_assistant_first():
|
||||||
|
"""Leading assistant message (no user before it)."""
|
||||||
|
msgs = [("assistant", "preamble"), ("user", "Q"), ("assistant", "A")]
|
||||||
|
result = _messages_to_transcript(msgs, spellcheck=False)
|
||||||
|
assert "preamble" in result
|
||||||
|
assert "> Q" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_rejects_large_file():
|
||||||
|
"""Files over 500 MB should raise IOError before reading."""
|
||||||
|
with patch("mempalace.normalize.os.path.getsize", return_value=600 * 1024 * 1024):
|
||||||
|
try:
|
||||||
|
normalize("/fake/huge_file.txt")
|
||||||
|
assert False, "Should have raised IOError"
|
||||||
|
except IOError as e:
|
||||||
|
assert "too large" in str(e).lower()
|
||||||
|
|||||||
@@ -0,0 +1,452 @@
|
|||||||
|
"""Tests for mempalace.onboarding."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from mempalace.onboarding import (
|
||||||
|
DEFAULT_WINGS,
|
||||||
|
_ask,
|
||||||
|
_ask_mode,
|
||||||
|
_ask_people,
|
||||||
|
_ask_projects,
|
||||||
|
_ask_wings,
|
||||||
|
_auto_detect,
|
||||||
|
_generate_aaak_bootstrap,
|
||||||
|
_header,
|
||||||
|
_hr,
|
||||||
|
_warn_ambiguous,
|
||||||
|
_yn,
|
||||||
|
quick_setup,
|
||||||
|
run_onboarding,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force UTF-8 for Windows (source file contains Unicode symbols like hearts/stars)
|
||||||
|
os.environ["PYTHONUTF8"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
# ── DEFAULT_WINGS ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_wings_has_expected_keys():
|
||||||
|
assert "work" in DEFAULT_WINGS
|
||||||
|
assert "personal" in DEFAULT_WINGS
|
||||||
|
assert "combo" in DEFAULT_WINGS
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_wings_work_has_projects():
|
||||||
|
assert "projects" in DEFAULT_WINGS["work"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_wings_personal_has_family():
|
||||||
|
assert "family" in DEFAULT_WINGS["personal"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_wings_combo_has_both():
|
||||||
|
wings = DEFAULT_WINGS["combo"]
|
||||||
|
assert "family" in wings
|
||||||
|
assert "work" in wings
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_wings_values_are_lists():
|
||||||
|
for mode, wings in DEFAULT_WINGS.items():
|
||||||
|
assert isinstance(wings, list), f"{mode} wings should be a list"
|
||||||
|
assert len(wings) >= 3, f"{mode} should have at least 3 wings"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _warn_ambiguous ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_warn_ambiguous_flags_common_words():
|
||||||
|
people = [
|
||||||
|
{"name": "Grace", "relationship": "friend"},
|
||||||
|
{"name": "Riley", "relationship": "daughter"},
|
||||||
|
]
|
||||||
|
result = _warn_ambiguous(people)
|
||||||
|
assert "Grace" in result
|
||||||
|
# Riley is not a common English word
|
||||||
|
assert "Riley" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_warn_ambiguous_empty_list():
|
||||||
|
result = _warn_ambiguous([])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_warn_ambiguous_no_ambiguous_names():
|
||||||
|
people = [
|
||||||
|
{"name": "Riley", "relationship": "daughter"},
|
||||||
|
{"name": "Devon", "relationship": "friend"},
|
||||||
|
]
|
||||||
|
result = _warn_ambiguous(people)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_warn_ambiguous_multiple_hits():
|
||||||
|
people = [
|
||||||
|
{"name": "Grace", "relationship": "friend"},
|
||||||
|
{"name": "May", "relationship": "aunt"},
|
||||||
|
{"name": "Joy", "relationship": "sister"},
|
||||||
|
]
|
||||||
|
result = _warn_ambiguous(people)
|
||||||
|
assert "Grace" in result
|
||||||
|
assert "May" in result
|
||||||
|
assert "Joy" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ── quick_setup ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_quick_setup_creates_registry(tmp_path):
|
||||||
|
registry = quick_setup(
|
||||||
|
mode="personal",
|
||||||
|
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||||
|
projects=["MemPalace"],
|
||||||
|
config_dir=tmp_path,
|
||||||
|
)
|
||||||
|
assert "Riley" in registry.people
|
||||||
|
assert "MemPalace" in registry.projects
|
||||||
|
assert registry.mode == "personal"
|
||||||
|
|
||||||
|
|
||||||
|
def test_quick_setup_work_mode(tmp_path):
|
||||||
|
registry = quick_setup(
|
||||||
|
mode="work",
|
||||||
|
people=[{"name": "Alice", "relationship": "colleague", "context": "work"}],
|
||||||
|
projects=["Acme"],
|
||||||
|
config_dir=tmp_path,
|
||||||
|
)
|
||||||
|
assert registry.mode == "work"
|
||||||
|
assert "Alice" in registry.people
|
||||||
|
assert "Acme" in registry.projects
|
||||||
|
|
||||||
|
|
||||||
|
def test_quick_setup_empty(tmp_path):
|
||||||
|
registry = quick_setup(mode="personal", people=[], config_dir=tmp_path)
|
||||||
|
assert len(registry.people) == 0
|
||||||
|
assert len(registry.projects) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_quick_setup_saves_to_disk(tmp_path):
|
||||||
|
quick_setup(
|
||||||
|
mode="personal",
|
||||||
|
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||||
|
config_dir=tmp_path,
|
||||||
|
)
|
||||||
|
assert (tmp_path / "entity_registry.json").exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ── _generate_aaak_bootstrap ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_aaak_bootstrap_creates_files(tmp_path):
|
||||||
|
people = [
|
||||||
|
{"name": "Riley", "relationship": "daughter", "context": "personal"},
|
||||||
|
{"name": "Devon", "relationship": "friend", "context": "personal"},
|
||||||
|
]
|
||||||
|
projects = ["MemPalace"]
|
||||||
|
wings = ["family", "creative"]
|
||||||
|
_generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path)
|
||||||
|
|
||||||
|
assert (tmp_path / "aaak_entities.md").exists()
|
||||||
|
assert (tmp_path / "critical_facts.md").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_aaak_bootstrap_entities_content(tmp_path):
|
||||||
|
people = [{"name": "Riley", "relationship": "daughter", "context": "personal"}]
|
||||||
|
projects = ["MemPalace"]
|
||||||
|
wings = ["family"]
|
||||||
|
_generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path)
|
||||||
|
|
||||||
|
content = (tmp_path / "aaak_entities.md").read_text()
|
||||||
|
assert "Riley" in content
|
||||||
|
assert "RIL" in content # entity code
|
||||||
|
assert "MemPalace" in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_aaak_bootstrap_facts_content(tmp_path):
|
||||||
|
people = [
|
||||||
|
{"name": "Alice", "relationship": "colleague", "context": "work"},
|
||||||
|
]
|
||||||
|
projects = ["Acme"]
|
||||||
|
wings = ["projects"]
|
||||||
|
_generate_aaak_bootstrap(people, projects, wings, "work", config_dir=tmp_path)
|
||||||
|
|
||||||
|
content = (tmp_path / "critical_facts.md").read_text()
|
||||||
|
assert "Alice" in content
|
||||||
|
assert "Acme" in content
|
||||||
|
assert "work" in content.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_aaak_bootstrap_empty_people(tmp_path):
|
||||||
|
_generate_aaak_bootstrap([], [], ["general"], "personal", config_dir=tmp_path)
|
||||||
|
assert (tmp_path / "aaak_entities.md").exists()
|
||||||
|
assert (tmp_path / "critical_facts.md").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_aaak_bootstrap_collision(tmp_path):
|
||||||
|
"""Two people with same 3-letter code get different codes."""
|
||||||
|
people = [
|
||||||
|
{"name": "Alice", "relationship": "friend", "context": "work"},
|
||||||
|
{"name": "Alison", "relationship": "coworker", "context": "work"},
|
||||||
|
]
|
||||||
|
_generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path)
|
||||||
|
content = (tmp_path / "aaak_entities.md").read_text()
|
||||||
|
assert "ALI" in content
|
||||||
|
assert "ALIS" in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_aaak_bootstrap_no_relationship(tmp_path):
|
||||||
|
"""Person without relationship string still generates entry."""
|
||||||
|
people = [{"name": "Bob", "context": "work"}]
|
||||||
|
_generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path)
|
||||||
|
content = (tmp_path / "aaak_entities.md").read_text()
|
||||||
|
assert "BOB=Bob" in content
|
||||||
|
|
||||||
|
|
||||||
|
# ── _hr, _header ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_hr_prints_line(capsys):
|
||||||
|
_hr()
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "─" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_header_prints_banner(capsys):
|
||||||
|
_header("Test Title")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "Test Title" in out
|
||||||
|
assert "=" in out
|
||||||
|
|
||||||
|
|
||||||
|
# ── _ask ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_with_default_uses_default():
|
||||||
|
with patch("builtins.input", return_value=""):
|
||||||
|
result = _ask("prompt", default="fallback")
|
||||||
|
assert result == "fallback"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_with_default_uses_input():
|
||||||
|
with patch("builtins.input", return_value="custom"):
|
||||||
|
result = _ask("prompt", default="fallback")
|
||||||
|
assert result == "custom"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_no_default():
|
||||||
|
with patch("builtins.input", return_value="answer"):
|
||||||
|
result = _ask("prompt")
|
||||||
|
assert result == "answer"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _yn ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_yn_default_yes_empty_input():
|
||||||
|
with patch("builtins.input", return_value=""):
|
||||||
|
assert _yn("continue?") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_yn_default_no_empty_input():
|
||||||
|
with patch("builtins.input", return_value=""):
|
||||||
|
assert _yn("continue?", default="n") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_yn_explicit_yes():
|
||||||
|
with patch("builtins.input", return_value="yes"):
|
||||||
|
assert _yn("continue?", default="n") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_yn_explicit_no():
|
||||||
|
with patch("builtins.input", return_value="no"):
|
||||||
|
assert _yn("continue?") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── _ask_mode ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_mode_work():
|
||||||
|
with patch("builtins.input", return_value="1"):
|
||||||
|
assert _ask_mode() == "work"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_mode_personal():
|
||||||
|
with patch("builtins.input", return_value="2"):
|
||||||
|
assert _ask_mode() == "personal"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_mode_combo():
|
||||||
|
with patch("builtins.input", return_value="3"):
|
||||||
|
assert _ask_mode() == "combo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_mode_retries_on_bad_input():
|
||||||
|
with patch("builtins.input", side_effect=["x", "bad", "1"]):
|
||||||
|
assert _ask_mode() == "work"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _ask_people ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_people_personal_mode():
|
||||||
|
with patch("builtins.input", side_effect=["Alice, daughter", "", "done"]):
|
||||||
|
people, aliases = _ask_people("personal")
|
||||||
|
assert len(people) == 1
|
||||||
|
assert people[0]["name"] == "Alice"
|
||||||
|
assert people[0]["relationship"] == "daughter"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_people_work_mode():
|
||||||
|
with patch("builtins.input", side_effect=["Bob, manager", "", "done"]):
|
||||||
|
people, aliases = _ask_people("work")
|
||||||
|
assert len(people) == 1
|
||||||
|
assert people[0]["name"] == "Bob"
|
||||||
|
assert people[0]["context"] == "work"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_people_combo_mode():
|
||||||
|
with patch(
|
||||||
|
"builtins.input",
|
||||||
|
side_effect=[
|
||||||
|
"Alice, daughter",
|
||||||
|
"",
|
||||||
|
"done", # personal
|
||||||
|
"Bob, boss",
|
||||||
|
"done", # work
|
||||||
|
],
|
||||||
|
):
|
||||||
|
people, aliases = _ask_people("combo")
|
||||||
|
assert len(people) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_people_with_nickname():
|
||||||
|
with patch("builtins.input", side_effect=["Alice, daughter", "Ali", "done"]):
|
||||||
|
people, aliases = _ask_people("personal")
|
||||||
|
assert aliases == {"Ali": "Alice"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_people_empty_name_skipped():
|
||||||
|
with patch("builtins.input", side_effect=["", "done"]):
|
||||||
|
people, aliases = _ask_people("personal")
|
||||||
|
assert len(people) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── _ask_projects ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_projects_personal_returns_empty():
|
||||||
|
result = _ask_projects("personal")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_projects_work_mode():
|
||||||
|
with patch("builtins.input", side_effect=["Acme", "BigCo", "done"]):
|
||||||
|
result = _ask_projects("work")
|
||||||
|
assert result == ["Acme", "BigCo"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_projects_empty_entry_stops():
|
||||||
|
with patch("builtins.input", side_effect=["Acme", ""]):
|
||||||
|
result = _ask_projects("work")
|
||||||
|
assert result == ["Acme"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── _ask_wings ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_wings_accept_defaults():
|
||||||
|
with patch("builtins.input", return_value=""):
|
||||||
|
result = _ask_wings("work")
|
||||||
|
assert result == DEFAULT_WINGS["work"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_ask_wings_custom():
|
||||||
|
with patch("builtins.input", return_value="alpha, beta, gamma"):
|
||||||
|
result = _ask_wings("personal")
|
||||||
|
assert result == ["alpha", "beta", "gamma"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── _auto_detect ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_detect_no_files(tmp_path):
|
||||||
|
result = _auto_detect(str(tmp_path), [])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_detect_filters_known(tmp_path):
|
||||||
|
known = [{"name": "Alice"}]
|
||||||
|
fake_detected = {
|
||||||
|
"people": [
|
||||||
|
{"name": "Alice", "confidence": 0.9, "signals": ["test"]},
|
||||||
|
{"name": "Bob", "confidence": 0.8, "signals": ["test"]},
|
||||||
|
],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
with (
|
||||||
|
patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]),
|
||||||
|
patch("mempalace.onboarding.detect_entities", return_value=fake_detected),
|
||||||
|
):
|
||||||
|
result = _auto_detect(str(tmp_path), known)
|
||||||
|
names = [p["name"] for p in result]
|
||||||
|
assert "Alice" not in names
|
||||||
|
assert "Bob" in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_detect_filters_low_confidence(tmp_path):
|
||||||
|
fake_detected = {
|
||||||
|
"people": [{"name": "Bob", "confidence": 0.5, "signals": ["test"]}],
|
||||||
|
"projects": [],
|
||||||
|
"uncertain": [],
|
||||||
|
}
|
||||||
|
with (
|
||||||
|
patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]),
|
||||||
|
patch("mempalace.onboarding.detect_entities", return_value=fake_detected),
|
||||||
|
):
|
||||||
|
result = _auto_detect(str(tmp_path), [])
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_detect_handles_exception(tmp_path):
|
||||||
|
with patch("mempalace.onboarding.scan_for_detection", side_effect=Exception("boom")):
|
||||||
|
result = _auto_detect(str(tmp_path), [])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── run_onboarding ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_onboarding_basic_flow(tmp_path):
|
||||||
|
"""Test the full onboarding flow with minimal mocking."""
|
||||||
|
with (
|
||||||
|
patch("mempalace.onboarding._ask_mode", return_value="work"),
|
||||||
|
patch(
|
||||||
|
"mempalace.onboarding._ask_people",
|
||||||
|
return_value=([{"name": "Bob", "relationship": "boss", "context": "work"}], {}),
|
||||||
|
),
|
||||||
|
patch("mempalace.onboarding._ask_projects", return_value=["Acme"]),
|
||||||
|
patch("mempalace.onboarding._ask_wings", return_value=["projects", "team"]),
|
||||||
|
patch("mempalace.onboarding._yn", return_value=False),
|
||||||
|
patch("mempalace.onboarding._warn_ambiguous", return_value=[]),
|
||||||
|
):
|
||||||
|
registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False)
|
||||||
|
assert "Bob" in registry.people
|
||||||
|
assert "Acme" in registry.projects
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_onboarding_with_ambiguous_names(tmp_path):
|
||||||
|
"""Onboarding prints a warning for ambiguous names."""
|
||||||
|
with (
|
||||||
|
patch("mempalace.onboarding._ask_mode", return_value="personal"),
|
||||||
|
patch(
|
||||||
|
"mempalace.onboarding._ask_people",
|
||||||
|
return_value=([{"name": "Grace", "relationship": "friend", "context": "personal"}], {}),
|
||||||
|
),
|
||||||
|
patch("mempalace.onboarding._ask_projects", return_value=[]),
|
||||||
|
patch("mempalace.onboarding._ask_wings", return_value=["family"]),
|
||||||
|
patch("mempalace.onboarding._yn", return_value=False),
|
||||||
|
):
|
||||||
|
registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False)
|
||||||
|
assert "Grace" in registry.people
|
||||||
@@ -0,0 +1,244 @@
|
|||||||
|
"""Tests for mempalace.palace_graph — graph traversal layer.
|
||||||
|
|
||||||
|
All ChromaDB access is mocked — no real database needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fake_collection(metadatas, ids=None):
|
||||||
|
"""Create a mock collection that returns the given metadata in batches."""
|
||||||
|
if ids is None:
|
||||||
|
ids = [f"id_{i}" for i in range(len(metadatas))]
|
||||||
|
|
||||||
|
col = MagicMock()
|
||||||
|
col.count.return_value = len(metadatas)
|
||||||
|
|
||||||
|
def fake_get(limit=1000, offset=0, include=None):
|
||||||
|
batch_meta = metadatas[offset : offset + limit]
|
||||||
|
batch_ids = ids[offset : offset + limit]
|
||||||
|
return {"ids": batch_ids, "metadatas": batch_meta}
|
||||||
|
|
||||||
|
col.get.side_effect = fake_get
|
||||||
|
return col
|
||||||
|
|
||||||
|
|
||||||
|
# Patch chromadb at import time so palace_graph can be imported
|
||||||
|
with patch.dict("sys.modules", {"chromadb": MagicMock()}):
|
||||||
|
from mempalace.palace_graph import (
|
||||||
|
_fuzzy_match,
|
||||||
|
build_graph,
|
||||||
|
find_tunnels,
|
||||||
|
graph_stats,
|
||||||
|
traverse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- build_graph ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildGraph:
|
||||||
|
def test_empty_collection(self):
|
||||||
|
col = _make_fake_collection([])
|
||||||
|
nodes, edges = build_graph(col=col)
|
||||||
|
assert nodes == {}
|
||||||
|
assert edges == []
|
||||||
|
|
||||||
|
def test_falsy_collection(self):
|
||||||
|
"""When col is explicitly falsy, build_graph returns empty."""
|
||||||
|
nodes, edges = build_graph(col=0)
|
||||||
|
assert nodes == {}
|
||||||
|
assert edges == []
|
||||||
|
|
||||||
|
def test_single_wing_no_edges(self):
|
||||||
|
col = _make_fake_collection(
|
||||||
|
[
|
||||||
|
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
|
||||||
|
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-02"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
nodes, edges = build_graph(col=col)
|
||||||
|
assert "auth" in nodes
|
||||||
|
assert nodes["auth"]["count"] == 2
|
||||||
|
assert edges == []
|
||||||
|
|
||||||
|
def test_multi_wing_creates_edges(self):
|
||||||
|
col = _make_fake_collection(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"room": "chromadb",
|
||||||
|
"wing": "wing_code",
|
||||||
|
"hall": "databases",
|
||||||
|
"date": "2026-01-01",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"room": "chromadb",
|
||||||
|
"wing": "wing_project",
|
||||||
|
"hall": "databases",
|
||||||
|
"date": "2026-01-02",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
nodes, edges = build_graph(col=col)
|
||||||
|
assert "chromadb" in nodes
|
||||||
|
assert len(edges) == 1
|
||||||
|
assert edges[0]["wing_a"] == "wing_code"
|
||||||
|
assert edges[0]["wing_b"] == "wing_project"
|
||||||
|
assert edges[0]["hall"] == "databases"
|
||||||
|
|
||||||
|
def test_general_room_excluded(self):
|
||||||
|
col = _make_fake_collection(
|
||||||
|
[
|
||||||
|
{"room": "general", "wing": "wing_code", "hall": "misc", "date": ""},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
nodes, edges = build_graph(col=col)
|
||||||
|
assert "general" not in nodes
|
||||||
|
|
||||||
|
def test_missing_wing_excluded(self):
|
||||||
|
col = _make_fake_collection(
|
||||||
|
[
|
||||||
|
{"room": "orphan", "wing": "", "hall": "misc", "date": ""},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
nodes, edges = build_graph(col=col)
|
||||||
|
assert "orphan" not in nodes
|
||||||
|
|
||||||
|
def test_dates_capped_at_five(self):
|
||||||
|
col = _make_fake_collection(
|
||||||
|
[
|
||||||
|
{"room": "busy", "wing": "w", "hall": "h", "date": f"2026-01-{i:02d}"}
|
||||||
|
for i in range(1, 10)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
nodes, _ = build_graph(col=col)
|
||||||
|
assert len(nodes["busy"]["dates"]) <= 5
|
||||||
|
|
||||||
|
|
||||||
|
# --- traverse ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestTraverse:
|
||||||
|
def _build_col(self):
|
||||||
|
return _make_fake_collection(
|
||||||
|
[
|
||||||
|
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
|
||||||
|
{"room": "login", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
|
||||||
|
{"room": "deploy", "wing": "wing_ops", "hall": "infra", "date": "2026-01-01"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_traverse_known_room(self):
|
||||||
|
col = self._build_col()
|
||||||
|
result = traverse("auth", col=col)
|
||||||
|
assert isinstance(result, list)
|
||||||
|
rooms = [r["room"] for r in result]
|
||||||
|
assert "auth" in rooms
|
||||||
|
# login shares wing_code with auth
|
||||||
|
assert "login" in rooms
|
||||||
|
|
||||||
|
def test_traverse_unknown_room(self):
|
||||||
|
col = self._build_col()
|
||||||
|
result = traverse("nonexistent", col=col)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "error" in result
|
||||||
|
assert "suggestions" in result
|
||||||
|
|
||||||
|
def test_traverse_max_hops(self):
|
||||||
|
col = self._build_col()
|
||||||
|
result = traverse("auth", col=col, max_hops=0)
|
||||||
|
# Only the start room itself at hop 0
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["room"] == "auth"
|
||||||
|
|
||||||
|
|
||||||
|
# --- find_tunnels ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindTunnels:
|
||||||
|
def _build_tunnel_col(self):
|
||||||
|
return _make_fake_collection(
|
||||||
|
[
|
||||||
|
{"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"},
|
||||||
|
{"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"},
|
||||||
|
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_find_all_tunnels(self):
|
||||||
|
col = self._build_tunnel_col()
|
||||||
|
tunnels = find_tunnels(col=col)
|
||||||
|
assert len(tunnels) == 1
|
||||||
|
assert tunnels[0]["room"] == "chromadb"
|
||||||
|
|
||||||
|
def test_find_tunnels_with_wing_filter(self):
|
||||||
|
col = self._build_tunnel_col()
|
||||||
|
tunnels = find_tunnels(wing_a="wing_code", col=col)
|
||||||
|
assert len(tunnels) == 1
|
||||||
|
|
||||||
|
def test_find_tunnels_no_match(self):
|
||||||
|
col = self._build_tunnel_col()
|
||||||
|
tunnels = find_tunnels(wing_a="wing_nonexistent", col=col)
|
||||||
|
assert tunnels == []
|
||||||
|
|
||||||
|
def test_find_tunnels_both_wings(self):
|
||||||
|
col = self._build_tunnel_col()
|
||||||
|
tunnels = find_tunnels(wing_a="wing_code", wing_b="wing_project", col=col)
|
||||||
|
assert len(tunnels) == 1
|
||||||
|
assert tunnels[0]["room"] == "chromadb"
|
||||||
|
|
||||||
|
|
||||||
|
# --- graph_stats ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphStats:
|
||||||
|
def test_empty_graph(self):
|
||||||
|
col = _make_fake_collection([])
|
||||||
|
stats = graph_stats(col=col)
|
||||||
|
assert stats["total_rooms"] == 0
|
||||||
|
assert stats["tunnel_rooms"] == 0
|
||||||
|
assert stats["total_edges"] == 0
|
||||||
|
|
||||||
|
def test_stats_with_data(self):
|
||||||
|
col = _make_fake_collection(
|
||||||
|
[
|
||||||
|
{"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"},
|
||||||
|
{"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"},
|
||||||
|
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
stats = graph_stats(col=col)
|
||||||
|
assert stats["total_rooms"] == 2
|
||||||
|
assert stats["tunnel_rooms"] == 1
|
||||||
|
assert stats["total_edges"] == 1
|
||||||
|
assert "wing_code" in stats["rooms_per_wing"]
|
||||||
|
|
||||||
|
|
||||||
|
# --- _fuzzy_match ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestFuzzyMatch:
|
||||||
|
def test_exact_substring(self):
|
||||||
|
nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}}
|
||||||
|
result = _fuzzy_match("chromadb", nodes)
|
||||||
|
assert "chromadb-setup" in result
|
||||||
|
|
||||||
|
def test_partial_word_match(self):
|
||||||
|
nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}}
|
||||||
|
result = _fuzzy_match("auth", nodes)
|
||||||
|
assert "auth-module" in result
|
||||||
|
|
||||||
|
def test_no_match(self):
|
||||||
|
nodes = {"chromadb-setup": {}, "auth-module": {}}
|
||||||
|
result = _fuzzy_match("zzzzz", nodes)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_hyphenated_query(self):
|
||||||
|
nodes = {"riley-college-apps": {}, "college-prep": {}}
|
||||||
|
result = _fuzzy_match("riley-college", nodes)
|
||||||
|
assert "riley-college-apps" in result
|
||||||
|
|
||||||
|
def test_max_results(self):
|
||||||
|
nodes = {f"room-{i}": {} for i in range(20)}
|
||||||
|
result = _fuzzy_match("room", nodes, n=3)
|
||||||
|
assert len(result) <= 3
|
||||||
@@ -0,0 +1,264 @@
|
|||||||
|
"""Tests for mempalace.room_detector_local."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from mempalace.room_detector_local import (
|
||||||
|
FOLDER_ROOM_MAP,
|
||||||
|
detect_rooms_from_files,
|
||||||
|
detect_rooms_from_folders,
|
||||||
|
detect_rooms_local,
|
||||||
|
get_user_approval,
|
||||||
|
print_proposed_structure,
|
||||||
|
save_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── FOLDER_ROOM_MAP ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_folder_room_map_has_expected_mappings():
|
||||||
|
assert FOLDER_ROOM_MAP["frontend"] == "frontend"
|
||||||
|
assert FOLDER_ROOM_MAP["backend"] == "backend"
|
||||||
|
assert FOLDER_ROOM_MAP["docs"] == "documentation"
|
||||||
|
assert FOLDER_ROOM_MAP["tests"] == "testing"
|
||||||
|
assert FOLDER_ROOM_MAP["config"] == "configuration"
|
||||||
|
|
||||||
|
|
||||||
|
def test_folder_room_map_alternative_names():
|
||||||
|
assert FOLDER_ROOM_MAP["front-end"] == "frontend"
|
||||||
|
assert FOLDER_ROOM_MAP["back-end"] == "backend"
|
||||||
|
assert FOLDER_ROOM_MAP["server"] == "backend"
|
||||||
|
assert FOLDER_ROOM_MAP["client"] == "frontend"
|
||||||
|
assert FOLDER_ROOM_MAP["api"] == "backend"
|
||||||
|
|
||||||
|
|
||||||
|
# ── detect_rooms_from_folders ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_from_folders_standard_layout(tmp_path):
|
||||||
|
(tmp_path / "frontend").mkdir()
|
||||||
|
(tmp_path / "backend").mkdir()
|
||||||
|
(tmp_path / "docs").mkdir()
|
||||||
|
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||||
|
room_names = {r["name"] for r in rooms}
|
||||||
|
assert "frontend" in room_names
|
||||||
|
assert "backend" in room_names
|
||||||
|
assert "documentation" in room_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_from_folders_always_has_general(tmp_path):
|
||||||
|
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||||
|
room_names = {r["name"] for r in rooms}
|
||||||
|
assert "general" in room_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_from_folders_empty_dir(tmp_path):
|
||||||
|
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||||
|
# Should at least have "general"
|
||||||
|
assert len(rooms) >= 1
|
||||||
|
assert any(r["name"] == "general" for r in rooms)
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_from_folders_skips_git(tmp_path):
|
||||||
|
(tmp_path / ".git").mkdir()
|
||||||
|
(tmp_path / "node_modules").mkdir()
|
||||||
|
(tmp_path / "frontend").mkdir()
|
||||||
|
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||||
|
room_names = {r["name"] for r in rooms}
|
||||||
|
assert ".git" not in room_names
|
||||||
|
assert "node_modules" not in room_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_from_folders_nested_dirs(tmp_path):
|
||||||
|
src = tmp_path / "src"
|
||||||
|
src.mkdir()
|
||||||
|
(src / "components").mkdir()
|
||||||
|
(src / "routes").mkdir()
|
||||||
|
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||||
|
room_names = {r["name"] for r in rooms}
|
||||||
|
# Nested dirs should be detected at one level deep
|
||||||
|
assert "frontend" in room_names or "backend" in room_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_from_folders_room_has_description(tmp_path):
|
||||||
|
(tmp_path / "docs").mkdir()
|
||||||
|
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||||
|
doc_room = next((r for r in rooms if r["name"] == "documentation"), None)
|
||||||
|
assert doc_room is not None
|
||||||
|
assert "description" in doc_room
|
||||||
|
assert "docs" in doc_room["description"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_from_folders_room_has_keywords(tmp_path):
|
||||||
|
(tmp_path / "frontend").mkdir()
|
||||||
|
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||||
|
fe_room = next((r for r in rooms if r["name"] == "frontend"), None)
|
||||||
|
assert fe_room is not None
|
||||||
|
assert "keywords" in fe_room
|
||||||
|
assert len(fe_room["keywords"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_from_folders_custom_named_dirs(tmp_path):
|
||||||
|
(tmp_path / "mylib").mkdir()
|
||||||
|
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||||
|
room_names = {r["name"] for r in rooms}
|
||||||
|
# Custom dir names that don't match FOLDER_ROOM_MAP get added as-is
|
||||||
|
assert "mylib" in room_names or "general" in room_names
|
||||||
|
|
||||||
|
|
||||||
|
# ── detect_rooms_from_files ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_from_files_with_matching_filenames(tmp_path):
|
||||||
|
# Create files whose names contain room keywords
|
||||||
|
for name in ["test_auth.py", "test_login.py", "test_api.py"]:
|
||||||
|
(tmp_path / name).write_text("content")
|
||||||
|
rooms = detect_rooms_from_files(str(tmp_path))
|
||||||
|
room_names = {r["name"] for r in rooms}
|
||||||
|
assert "testing" in room_names or "general" in room_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_from_files_empty_dir(tmp_path):
|
||||||
|
rooms = detect_rooms_from_files(str(tmp_path))
|
||||||
|
assert len(rooms) >= 1
|
||||||
|
assert any(r["name"] == "general" for r in rooms)
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_from_files_caps_at_six(tmp_path):
|
||||||
|
# Create many files with different keywords to hit the cap
|
||||||
|
for keyword in ["test", "doc", "api", "config", "frontend", "backend", "design", "meeting"]:
|
||||||
|
for i in range(3):
|
||||||
|
(tmp_path / f"{keyword}_file_{i}.txt").write_text("content")
|
||||||
|
rooms = detect_rooms_from_files(str(tmp_path))
|
||||||
|
assert len(rooms) <= 6
|
||||||
|
|
||||||
|
|
||||||
|
# ── save_config ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_config_creates_yaml(tmp_path):
|
||||||
|
rooms = [
|
||||||
|
{"name": "frontend", "description": "UI files", "keywords": ["frontend"]},
|
||||||
|
{"name": "backend", "description": "Server files", "keywords": ["backend"]},
|
||||||
|
]
|
||||||
|
save_config(str(tmp_path), "myproject", rooms)
|
||||||
|
config_file = tmp_path / "mempalace.yaml"
|
||||||
|
assert config_file.exists()
|
||||||
|
content = config_file.read_text()
|
||||||
|
assert "myproject" in content
|
||||||
|
assert "frontend" in content
|
||||||
|
assert "backend" in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_config_valid_yaml(tmp_path):
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
rooms = [{"name": "general", "description": "All files", "keywords": []}]
|
||||||
|
save_config(str(tmp_path), "test_proj", rooms)
|
||||||
|
config_file = tmp_path / "mempalace.yaml"
|
||||||
|
data = yaml.safe_load(config_file.read_text())
|
||||||
|
assert data["wing"] == "test_proj"
|
||||||
|
assert len(data["rooms"]) == 1
|
||||||
|
assert data["rooms"][0]["name"] == "general"
|
||||||
|
|
||||||
|
|
||||||
|
# ── print_proposed_structure ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_proposed_structure(capsys):
|
||||||
|
rooms = [
|
||||||
|
{"name": "frontend", "description": "UI files"},
|
||||||
|
{"name": "general", "description": "Everything else"},
|
||||||
|
]
|
||||||
|
print_proposed_structure("myapp", rooms, 42, "folder structure")
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "myapp" in out
|
||||||
|
assert "frontend" in out
|
||||||
|
assert "42 files" in out
|
||||||
|
assert "folder structure" in out
|
||||||
|
|
||||||
|
|
||||||
|
# ── get_user_approval ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_approval_accept_all():
|
||||||
|
rooms = [{"name": "frontend", "description": "UI"}]
|
||||||
|
with patch("builtins.input", return_value=""):
|
||||||
|
result = get_user_approval(rooms)
|
||||||
|
assert result == rooms
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_approval_edit_remove():
|
||||||
|
rooms = [
|
||||||
|
{"name": "frontend", "description": "UI"},
|
||||||
|
{"name": "backend", "description": "Server"},
|
||||||
|
]
|
||||||
|
with patch("builtins.input", side_effect=["edit", "1", "n"]):
|
||||||
|
result = get_user_approval(rooms)
|
||||||
|
# Room 1 (frontend) removed
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["name"] == "backend"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_approval_add_room():
|
||||||
|
rooms = [{"name": "general", "description": "All files"}]
|
||||||
|
with patch(
|
||||||
|
"builtins.input",
|
||||||
|
side_effect=[
|
||||||
|
"add",
|
||||||
|
"custom_room",
|
||||||
|
"My custom room",
|
||||||
|
"",
|
||||||
|
],
|
||||||
|
):
|
||||||
|
result = get_user_approval(rooms)
|
||||||
|
names = [r["name"] for r in result]
|
||||||
|
assert "custom_room" in names
|
||||||
|
|
||||||
|
|
||||||
|
# ── detect_rooms_local ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_local_yes_mode(tmp_path):
|
||||||
|
(tmp_path / "docs").mkdir()
|
||||||
|
(tmp_path / "docs" / "readme.md").write_text("hello")
|
||||||
|
mock_miner = MagicMock()
|
||||||
|
mock_miner.scan_project.return_value = ["file1.py"]
|
||||||
|
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
|
||||||
|
detect_rooms_local(str(tmp_path), yes=True)
|
||||||
|
assert (tmp_path / "mempalace.yaml").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_local_fallback_to_files(tmp_path):
|
||||||
|
"""When folder detection gives only 'general', falls back to file patterns."""
|
||||||
|
for i in range(3):
|
||||||
|
(tmp_path / f"test_file_{i}.py").write_text("content")
|
||||||
|
mock_miner = MagicMock()
|
||||||
|
mock_miner.scan_project.return_value = ["f1", "f2"]
|
||||||
|
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
|
||||||
|
detect_rooms_local(str(tmp_path), yes=True)
|
||||||
|
assert (tmp_path / "mempalace.yaml").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_local_missing_dir():
|
||||||
|
"""Non-existent directory causes sys.exit."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
detect_rooms_local("/nonexistent/path/that/does/not/exist", yes=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_rooms_local_interactive(tmp_path):
|
||||||
|
(tmp_path / "src").mkdir()
|
||||||
|
(tmp_path / "src" / "main.py").write_text("code")
|
||||||
|
mock_miner = MagicMock()
|
||||||
|
mock_miner.scan_project.return_value = ["f1"]
|
||||||
|
with (
|
||||||
|
patch.dict("sys.modules", {"mempalace.miner": mock_miner}),
|
||||||
|
patch(
|
||||||
|
"mempalace.room_detector_local.get_user_approval",
|
||||||
|
return_value=[{"name": "general", "description": "All files", "keywords": []}],
|
||||||
|
),
|
||||||
|
):
|
||||||
|
detect_rooms_local(str(tmp_path), yes=False)
|
||||||
|
assert (tmp_path / "mempalace.yaml").exists()
|
||||||
+83
-3
@@ -1,10 +1,18 @@
|
|||||||
"""
|
"""
|
||||||
test_searcher.py — Tests for the programmatic search_memories API.
|
test_searcher.py -- Tests for both search() (CLI) and search_memories() (API).
|
||||||
|
|
||||||
Tests the library-facing search interface (not the CLI print variant).
|
Uses the real ChromaDB fixtures from conftest.py for integration tests,
|
||||||
|
plus mock-based tests for error paths.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from mempalace.searcher import search_memories
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mempalace.searcher import SearchError, search, search_memories
|
||||||
|
|
||||||
|
|
||||||
|
# ── search_memories (API) ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestSearchMemories:
|
class TestSearchMemories:
|
||||||
@@ -43,3 +51,75 @@ class TestSearchMemories:
|
|||||||
assert "source_file" in hit
|
assert "source_file" in hit
|
||||||
assert "similarity" in hit
|
assert "similarity" in hit
|
||||||
assert isinstance(hit["similarity"], float)
|
assert isinstance(hit["similarity"], float)
|
||||||
|
|
||||||
|
def test_search_memories_query_error(self):
|
||||||
|
"""search_memories returns error dict when query raises."""
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.side_effect = RuntimeError("query failed")
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client):
|
||||||
|
result = search_memories("test", "/fake/path")
|
||||||
|
assert "error" in result
|
||||||
|
assert "query failed" in result["error"]
|
||||||
|
|
||||||
|
def test_search_memories_filters_in_result(self, palace_path, seeded_collection):
|
||||||
|
result = search_memories("test", palace_path, wing="project", room="backend")
|
||||||
|
assert result["filters"]["wing"] == "project"
|
||||||
|
assert result["filters"]["room"] == "backend"
|
||||||
|
|
||||||
|
|
||||||
|
# ── search() (CLI print function) ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchCLI:
|
||||||
|
def test_search_prints_results(self, palace_path, seeded_collection, capsys):
|
||||||
|
search("JWT authentication", palace_path)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "JWT" in captured.out or "authentication" in captured.out
|
||||||
|
|
||||||
|
def test_search_with_wing_filter(self, palace_path, seeded_collection, capsys):
|
||||||
|
search("planning", palace_path, wing="notes")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "Results for" in captured.out
|
||||||
|
|
||||||
|
def test_search_with_room_filter(self, palace_path, seeded_collection, capsys):
|
||||||
|
search("database", palace_path, room="backend")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "Room:" in captured.out
|
||||||
|
|
||||||
|
def test_search_with_wing_and_room(self, palace_path, seeded_collection, capsys):
|
||||||
|
search("code", palace_path, wing="project", room="frontend")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "Wing:" in captured.out
|
||||||
|
assert "Room:" in captured.out
|
||||||
|
|
||||||
|
def test_search_no_palace_raises(self, tmp_path):
|
||||||
|
with pytest.raises(SearchError, match="No palace found"):
|
||||||
|
search("anything", str(tmp_path / "missing"))
|
||||||
|
|
||||||
|
def test_search_no_results(self, palace_path, collection, capsys):
|
||||||
|
"""Empty collection returns no results message."""
|
||||||
|
# collection is empty (no seeded data)
|
||||||
|
result = search("xyzzy_nonexistent_query", palace_path, n_results=1)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
# Either prints "No results" or returns None
|
||||||
|
assert result is None or "No results" in captured.out
|
||||||
|
|
||||||
|
def test_search_query_error_raises(self):
|
||||||
|
"""search raises SearchError when query fails."""
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.query.side_effect = RuntimeError("boom")
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
|
||||||
|
with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client):
|
||||||
|
with pytest.raises(SearchError, match="Search error"):
|
||||||
|
search("test", "/fake/path")
|
||||||
|
|
||||||
|
def test_search_n_results(self, palace_path, seeded_collection, capsys):
|
||||||
|
search("code", palace_path, n_results=1)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
# Should have output with at least one result block
|
||||||
|
assert "[1]" in captured.out
|
||||||
|
|||||||
@@ -0,0 +1,160 @@
|
|||||||
|
"""Tests for mempalace.spellcheck — spell-correction utilities."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from mempalace.spellcheck import (
|
||||||
|
_edit_distance,
|
||||||
|
_get_system_words,
|
||||||
|
_should_skip,
|
||||||
|
spellcheck_transcript,
|
||||||
|
spellcheck_transcript_line,
|
||||||
|
spellcheck_user_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- _should_skip ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestShouldSkip:
|
||||||
|
"""Token-level skip logic."""
|
||||||
|
|
||||||
|
def test_short_tokens_skipped(self):
|
||||||
|
assert _should_skip("hi", set()) is True
|
||||||
|
assert _should_skip("ok", set()) is True
|
||||||
|
assert _should_skip("I", set()) is True
|
||||||
|
|
||||||
|
def test_digits_skipped(self):
|
||||||
|
assert _should_skip("3am", set()) is True
|
||||||
|
assert _should_skip("top10", set()) is True
|
||||||
|
assert _should_skip("bge-large-v1.5", set()) is True
|
||||||
|
|
||||||
|
def test_camelcase_skipped(self):
|
||||||
|
assert _should_skip("ChromaDB", set()) is True
|
||||||
|
assert _should_skip("MemPalace", set()) is True
|
||||||
|
|
||||||
|
def test_allcaps_skipped(self):
|
||||||
|
assert _should_skip("NDCG", set()) is True
|
||||||
|
assert _should_skip("MAX_RESULTS", set()) is True
|
||||||
|
|
||||||
|
def test_technical_skipped(self):
|
||||||
|
assert _should_skip("bge-large", set()) is True
|
||||||
|
assert _should_skip("train_test", set()) is True
|
||||||
|
|
||||||
|
def test_url_skipped(self):
|
||||||
|
assert _should_skip("https://example.com", set()) is True
|
||||||
|
assert _should_skip("www.google.com", set()) is True
|
||||||
|
|
||||||
|
def test_code_or_emoji_skipped(self):
|
||||||
|
assert _should_skip("`code`", set()) is True
|
||||||
|
assert _should_skip("**bold**", set()) is True
|
||||||
|
|
||||||
|
def test_known_name_skipped(self):
|
||||||
|
assert _should_skip("mempalace", {"mempalace"}) is True
|
||||||
|
|
||||||
|
def test_normal_word_not_skipped(self):
|
||||||
|
assert _should_skip("hello", set()) is False
|
||||||
|
assert _should_skip("question", set()) is False
|
||||||
|
|
||||||
|
|
||||||
|
# --- _edit_distance ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestEditDistance:
|
||||||
|
def test_identical(self):
|
||||||
|
assert _edit_distance("hello", "hello") == 0
|
||||||
|
|
||||||
|
def test_empty_strings(self):
|
||||||
|
assert _edit_distance("", "abc") == 3
|
||||||
|
assert _edit_distance("abc", "") == 3
|
||||||
|
assert _edit_distance("", "") == 0
|
||||||
|
|
||||||
|
def test_single_edit(self):
|
||||||
|
assert _edit_distance("cat", "bat") == 1 # substitution
|
||||||
|
assert _edit_distance("cat", "cats") == 1 # insertion
|
||||||
|
assert _edit_distance("cats", "cat") == 1 # deletion
|
||||||
|
|
||||||
|
def test_known_distance(self):
|
||||||
|
assert _edit_distance("kitten", "sitting") == 3
|
||||||
|
|
||||||
|
|
||||||
|
# --- _get_system_words ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_system_words_returns_set():
|
||||||
|
result = _get_system_words()
|
||||||
|
assert isinstance(result, set)
|
||||||
|
|
||||||
|
|
||||||
|
# --- spellcheck_user_text ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_spellcheck_user_text_passthrough_no_autocorrect():
|
||||||
|
"""When autocorrect is not installed, text passes through unchanged."""
|
||||||
|
with patch("mempalace.spellcheck._get_speller", return_value=None):
|
||||||
|
text = "somee misspeledd textt"
|
||||||
|
assert spellcheck_user_text(text) == text
|
||||||
|
|
||||||
|
|
||||||
|
def test_spellcheck_user_text_with_speller():
|
||||||
|
"""When a speller is available, it corrects words."""
|
||||||
|
|
||||||
|
def fake_speller(word):
|
||||||
|
corrections = {"knoe": "know", "befor": "before"}
|
||||||
|
return corrections.get(word, word)
|
||||||
|
|
||||||
|
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
|
||||||
|
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
|
||||||
|
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
|
||||||
|
result = spellcheck_user_text("knoe the question befor")
|
||||||
|
assert "know" in result
|
||||||
|
assert "before" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_spellcheck_preserves_technical_terms():
|
||||||
|
"""Technical terms should never be touched even with a speller."""
|
||||||
|
|
||||||
|
def fake_speller(word):
|
||||||
|
return "WRONG"
|
||||||
|
|
||||||
|
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
|
||||||
|
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
|
||||||
|
result = spellcheck_user_text("ChromaDB bge-large", known_names=set())
|
||||||
|
assert "ChromaDB" in result
|
||||||
|
assert "bge-large" in result
|
||||||
|
assert "WRONG" not in result
|
||||||
|
|
||||||
|
|
||||||
|
# --- spellcheck_transcript_line ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcript_line_user_turn():
|
||||||
|
"""Lines starting with '>' should be processed."""
|
||||||
|
with patch("mempalace.spellcheck.spellcheck_user_text", return_value="corrected"):
|
||||||
|
result = spellcheck_transcript_line("> hello world")
|
||||||
|
assert "corrected" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcript_line_assistant_turn():
|
||||||
|
"""Lines not starting with '>' should pass through unchanged."""
|
||||||
|
line = "This is an assistant response"
|
||||||
|
assert spellcheck_transcript_line(line) == line
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcript_line_empty_user_turn():
|
||||||
|
"""A '> ' line with no message content should pass through."""
|
||||||
|
line = "> "
|
||||||
|
assert spellcheck_transcript_line(line) == line
|
||||||
|
|
||||||
|
|
||||||
|
# --- spellcheck_transcript ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_spellcheck_transcript_processes_content():
|
||||||
|
"""Full transcript: only '>' lines are touched."""
|
||||||
|
content = "Assistant line\n> user line\nAnother assistant line"
|
||||||
|
with patch("mempalace.spellcheck.spellcheck_user_text", return_value="fixed"):
|
||||||
|
result = spellcheck_transcript(content)
|
||||||
|
lines = result.split("\n")
|
||||||
|
assert lines[0] == "Assistant line"
|
||||||
|
assert "fixed" in lines[1]
|
||||||
|
assert lines[2] == "Another assistant line"
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
"""Extra spellcheck tests covering _load_known_names and speller edge cases."""
|
||||||
|
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from mempalace.spellcheck import (
|
||||||
|
_load_known_names,
|
||||||
|
spellcheck_user_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadKnownNames:
|
||||||
|
def test_returns_names_from_registry(self):
|
||||||
|
mock_reg = MagicMock()
|
||||||
|
mock_reg._data = {
|
||||||
|
"entities": {
|
||||||
|
"e1": {"canonical": "Alice", "aliases": ["ali"]},
|
||||||
|
"e2": {"canonical": "Bob", "aliases": []},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with patch("mempalace.entity_registry.EntityRegistry") as MockER:
|
||||||
|
MockER.load.return_value = mock_reg
|
||||||
|
names = _load_known_names()
|
||||||
|
assert "alice" in names
|
||||||
|
assert "ali" in names
|
||||||
|
assert "bob" in names
|
||||||
|
|
||||||
|
def test_returns_empty_on_exception(self):
|
||||||
|
with patch(
|
||||||
|
"mempalace.entity_registry.EntityRegistry.load",
|
||||||
|
side_effect=Exception("no registry"),
|
||||||
|
):
|
||||||
|
names = _load_known_names()
|
||||||
|
assert names == set()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpellerEdgeCases:
|
||||||
|
def test_capitalized_word_skipped(self):
|
||||||
|
"""Capitalized words (likely proper nouns) are not corrected."""
|
||||||
|
|
||||||
|
def fake_speller(word):
|
||||||
|
return "WRONG"
|
||||||
|
|
||||||
|
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
|
||||||
|
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
|
||||||
|
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
|
||||||
|
result = spellcheck_user_text("Alice went home")
|
||||||
|
assert "Alice" in result
|
||||||
|
assert "WRONG" not in result
|
||||||
|
|
||||||
|
def test_system_word_not_corrected(self):
|
||||||
|
"""Words in system dict should not be corrected."""
|
||||||
|
|
||||||
|
def fake_speller(word):
|
||||||
|
return "WRONG"
|
||||||
|
|
||||||
|
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
|
||||||
|
with patch("mempalace.spellcheck._get_system_words", return_value={"coherently"}):
|
||||||
|
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
|
||||||
|
result = spellcheck_user_text("coherently")
|
||||||
|
assert "coherently" in result
|
||||||
|
|
||||||
|
def test_high_edit_distance_rejected(self):
|
||||||
|
"""Corrections with too many edits are rejected."""
|
||||||
|
|
||||||
|
def fake_speller(word):
|
||||||
|
return "completely_different_word"
|
||||||
|
|
||||||
|
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
|
||||||
|
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
|
||||||
|
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
|
||||||
|
result = spellcheck_user_text("hello")
|
||||||
|
assert "hello" in result
|
||||||
@@ -3,6 +3,9 @@ import json
|
|||||||
from mempalace import split_mega_files as smf
|
from mempalace import split_mega_files as smf
|
||||||
|
|
||||||
|
|
||||||
|
# ── Config loading ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def test_load_known_people_falls_back_when_config_missing(monkeypatch, tmp_path):
|
def test_load_known_people_falls_back_when_config_missing(monkeypatch, tmp_path):
|
||||||
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", tmp_path / "missing.json")
|
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", tmp_path / "missing.json")
|
||||||
smf._KNOWN_NAMES_CACHE = None
|
smf._KNOWN_NAMES_CACHE = None
|
||||||
@@ -46,3 +49,244 @@ def test_extract_people_detects_names_from_content(monkeypatch):
|
|||||||
monkeypatch.setattr(smf, "KNOWN_PEOPLE", ["Alice", "Ben"])
|
monkeypatch.setattr(smf, "KNOWN_PEOPLE", ["Alice", "Ben"])
|
||||||
people = smf.extract_people(["> Alice reviewed the change with Ben\n"])
|
people = smf.extract_people(["> Alice reviewed the change with Ben\n"])
|
||||||
assert people == ["Alice", "Ben"]
|
assert people == ["Alice", "Ben"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Config: force_reload and invalid JSON ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_known_names_force_reload(monkeypatch, tmp_path):
|
||||||
|
config_path = tmp_path / "known_names.json"
|
||||||
|
config_path.write_text(json.dumps(["Alice"]))
|
||||||
|
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
|
||||||
|
smf._KNOWN_NAMES_CACHE = None
|
||||||
|
|
||||||
|
smf._load_known_names_config()
|
||||||
|
assert smf._KNOWN_NAMES_CACHE == ["Alice"]
|
||||||
|
|
||||||
|
config_path.write_text(json.dumps(["Bob"]))
|
||||||
|
smf._load_known_names_config(force_reload=True)
|
||||||
|
assert smf._KNOWN_NAMES_CACHE == ["Bob"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_known_names_invalid_json(monkeypatch, tmp_path):
|
||||||
|
config_path = tmp_path / "known_names.json"
|
||||||
|
config_path.write_text("not json {{{")
|
||||||
|
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
|
||||||
|
smf._KNOWN_NAMES_CACHE = None
|
||||||
|
|
||||||
|
result = smf._load_known_names_config()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_known_names_caching(monkeypatch, tmp_path):
|
||||||
|
config_path = tmp_path / "known_names.json"
|
||||||
|
config_path.write_text(json.dumps(["Alice"]))
|
||||||
|
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
|
||||||
|
smf._KNOWN_NAMES_CACHE = None
|
||||||
|
|
||||||
|
smf._load_known_names_config()
|
||||||
|
# Second call returns cached value without re-reading
|
||||||
|
config_path.write_text(json.dumps(["Changed"]))
|
||||||
|
result = smf._load_known_names_config()
|
||||||
|
assert result == ["Alice"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── is_true_session_start ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_true_session_start_yes():
|
||||||
|
lines = ["Claude Code v1.0", "Some content", "More content", "", "", ""]
|
||||||
|
assert smf.is_true_session_start(lines, 0) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_true_session_start_no_ctrl_e():
|
||||||
|
lines = [
|
||||||
|
"Claude Code v1.0",
|
||||||
|
"Ctrl+E to show 5 previous messages",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
assert smf.is_true_session_start(lines, 0) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_true_session_start_no_previous_messages():
|
||||||
|
lines = [
|
||||||
|
"Claude Code v1.0",
|
||||||
|
"Some text",
|
||||||
|
"previous messages here",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
assert smf.is_true_session_start(lines, 0) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── find_session_boundaries ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_session_boundaries_two_sessions():
|
||||||
|
lines = [
|
||||||
|
"Claude Code v1.0",
|
||||||
|
"content 1",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"Claude Code v1.0",
|
||||||
|
"content 2",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
boundaries = smf.find_session_boundaries(lines)
|
||||||
|
assert boundaries == [0, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_session_boundaries_none():
|
||||||
|
lines = ["Just some text", "No sessions here"]
|
||||||
|
assert smf.find_session_boundaries(lines) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_session_boundaries_context_restore_skipped():
|
||||||
|
lines = [
|
||||||
|
"Claude Code v1.0",
|
||||||
|
"content",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"Claude Code v1.0",
|
||||||
|
"Ctrl+E to show 5 previous messages",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
boundaries = smf.find_session_boundaries(lines)
|
||||||
|
assert len(boundaries) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_timestamp ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_timestamp_found():
|
||||||
|
lines = ["⏺ 2:30 PM Wednesday, March 25, 2026"]
|
||||||
|
human, iso = smf.extract_timestamp(lines)
|
||||||
|
assert human == "2026-03-25_230PM"
|
||||||
|
assert iso == "2026-03-25"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_timestamp_not_found():
|
||||||
|
lines = ["No timestamp here"]
|
||||||
|
human, iso = smf.extract_timestamp(lines)
|
||||||
|
assert human is None
|
||||||
|
assert iso is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_timestamp_only_checks_first_50():
|
||||||
|
lines = ["filler\n"] * 51 + ["⏺ 1:00 AM Monday, January 01, 2026"]
|
||||||
|
human, iso = smf.extract_timestamp(lines)
|
||||||
|
assert human is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_subject ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_subject_found():
|
||||||
|
lines = ["> How do we handle authentication?"]
|
||||||
|
subject = smf.extract_subject(lines)
|
||||||
|
assert "authentication" in subject.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_subject_skips_commands():
|
||||||
|
lines = ["> cd /some/dir", "> git status", "> What is the plan?"]
|
||||||
|
subject = smf.extract_subject(lines)
|
||||||
|
assert "plan" in subject.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_subject_fallback():
|
||||||
|
lines = ["No prompts at all", "Just text"]
|
||||||
|
subject = smf.extract_subject(lines)
|
||||||
|
assert subject == "session"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_subject_short_prompt_skipped():
|
||||||
|
lines = ["> ok", "> yes", "> What about the deployment strategy?"]
|
||||||
|
subject = smf.extract_subject(lines)
|
||||||
|
assert "deployment" in subject.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_subject_truncated():
|
||||||
|
lines = ["> " + "a" * 100]
|
||||||
|
subject = smf.extract_subject(lines)
|
||||||
|
assert len(subject) <= 60
|
||||||
|
|
||||||
|
|
||||||
|
# ── split_file ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mega_file(tmp_path, n_sessions=3, lines_per_session=15):
|
||||||
|
"""Create a mega-file with N sessions."""
|
||||||
|
content = ""
|
||||||
|
for i in range(n_sessions):
|
||||||
|
content += f"Claude Code v1.{i}\n"
|
||||||
|
content += f"> What about topic {i} and how it works?\n"
|
||||||
|
for j in range(lines_per_session - 2):
|
||||||
|
content += f"Line {j} of session {i}\n"
|
||||||
|
path = tmp_path / "mega.txt"
|
||||||
|
path.write_text(content)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_file_creates_output(tmp_path):
|
||||||
|
mega = _make_mega_file(tmp_path)
|
||||||
|
out_dir = tmp_path / "output"
|
||||||
|
out_dir.mkdir()
|
||||||
|
written = smf.split_file(str(mega), str(out_dir))
|
||||||
|
assert len(written) >= 2
|
||||||
|
for p in written:
|
||||||
|
assert p.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_file_dry_run(tmp_path):
|
||||||
|
mega = _make_mega_file(tmp_path)
|
||||||
|
out_dir = tmp_path / "output"
|
||||||
|
out_dir.mkdir()
|
||||||
|
written = smf.split_file(str(mega), str(out_dir), dry_run=True)
|
||||||
|
assert len(written) >= 2
|
||||||
|
for p in written:
|
||||||
|
assert not p.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_file_not_mega(tmp_path):
|
||||||
|
"""File with fewer than 2 sessions is not split."""
|
||||||
|
path = tmp_path / "single.txt"
|
||||||
|
path.write_text("Claude Code v1.0\nJust one session\n" + "line\n" * 20)
|
||||||
|
written = smf.split_file(str(path), str(tmp_path))
|
||||||
|
assert written == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_file_output_dir_none(tmp_path):
|
||||||
|
"""When output_dir is None, writes to same dir as source."""
|
||||||
|
mega = _make_mega_file(tmp_path)
|
||||||
|
written = smf.split_file(str(mega), None)
|
||||||
|
assert len(written) >= 2
|
||||||
|
for p in written:
|
||||||
|
assert str(p.parent) == str(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_file_tiny_fragments_skipped(tmp_path):
|
||||||
|
"""Tiny chunks (< 10 lines) are skipped."""
|
||||||
|
content = "Claude Code v1.0\nline\n" * 2 + "Claude Code v1.0\n" + "line\n" * 20
|
||||||
|
path = tmp_path / "tiny.txt"
|
||||||
|
path.write_text(content)
|
||||||
|
written = smf.split_file(str(path), str(tmp_path))
|
||||||
|
# The first chunk is very small, should be skipped
|
||||||
|
for p in written:
|
||||||
|
assert p.stat().st_size > 0
|
||||||
|
|||||||
Reference in New Issue
Block a user