Merge branch 'main' into fix/query-sanitizer-prompt-contamination

This commit is contained in:
Ben Sigman
2026-04-09 08:11:39 -07:00
committed by GitHub
79 changed files with 9292 additions and 123 deletions
+20
View File
@@ -0,0 +1,20 @@
{
"name": "mempalace",
"interface": {
"displayName": "MemPalace"
},
"plugins": [
{
"name": "mempalace",
"source": {
"source": "local",
"path": "./.codex-plugin"
},
"policy": {
"installation": "AVAILABLE",
"authentication": "NONE"
},
"category": "Coding"
}
]
}
+9
View File
@@ -0,0 +1,9 @@
{
"mempalace": {
"command": "python3",
"args": [
"-m",
"mempalace.mcp_server"
]
}
}
+57
View File
@@ -0,0 +1,57 @@
# MemPalace Claude Code Plugin
A Claude Code plugin that gives your AI a persistent memory system. Mine projects and conversations into a searchable palace backed by ChromaDB, with 19 MCP tools, auto-save hooks, and 5 guided skills.
## Prerequisites
- Python 3.9+
## Installation
### Claude Code Marketplace
```bash
claude plugin marketplace add milla-jovovich/mempalace
claude plugin install --scope user mempalace
```
### Local Clone
```bash
claude plugin add /path/to/mempalace
```
## Post-Install Setup
After installing the plugin, run the init command to complete setup (pip install, MCP configuration, etc.):
```
/mempalace:init
```
## Available Slash Commands
| Command | Description |
|---------|-------------|
| `/mempalace:help` | Show available tools, skills, and architecture |
| `/mempalace:init` | Set up MemPalace -- install, configure MCP, onboard |
| `/mempalace:search` | Search your memories across the palace |
| `/mempalace:mine` | Mine projects and conversations into the palace |
| `/mempalace:status` | Show palace overview -- wings, rooms, drawer counts |
## Hooks
MemPalace registers two hooks that run automatically:
- **Stop** -- Saves conversation context every 15 messages.
- **PreCompact** -- Preserves important memories before context compaction.
Set the `MEMPAL_DIR` environment variable to a directory path to automatically run `mempalace mine` on that directory during each save trigger.
## MCP Server
The plugin automatically configures a local MCP server with 19 tools for storing, searching, and managing memories. No manual MCP setup is required -- `/mempalace:init` handles everything.
## Full Documentation
See the main [README](../README.md) for complete documentation, architecture details, and advanced usage.
+6
View File
@@ -0,0 +1,6 @@
---
description: Show comprehensive MemPalace help — available skills, MCP tools, CLI commands, hooks, and architecture.
allowed-tools: Bash, Read
---
Invoke the generic mempalace skill (using the Skill tool) with the `help` command, then follow its instructions.
+6
View File
@@ -0,0 +1,6 @@
---
description: Set up MemPalace — install the package, initialize a palace, configure MCP server, and verify everything works.
allowed-tools: Bash, Read, Write, Edit, Glob, Grep
---
Invoke the generic mempalace skill (using the Skill tool) with the `init` command, then follow its instructions.
+7
View File
@@ -0,0 +1,7 @@
---
description: Mine projects and conversations into the MemPalace. Supports project files, conversation exports, and auto-classification.
argument-hint: Path to project or conversation export to mine.
allowed-tools: Bash, Read, Write, Edit, Glob, Grep
---
Invoke the generic mempalace skill (using the Skill tool) with the `mine` command, then follow its instructions.
+7
View File
@@ -0,0 +1,7 @@
---
description: Search your memories across the MemPalace using semantic search with wing/room filtering.
argument-hint: Search query, optionally with wing/room filters.
allowed-tools: Bash, Read
---
Invoke the generic mempalace skill (using the Skill tool) with the `search` command, then follow its instructions.
+6
View File
@@ -0,0 +1,6 @@
---
description: Show the current state of your memory palace — wings, rooms, drawer counts, and suggestions.
allowed-tools: Bash, Read
---
Invoke the generic mempalace skill (using the Skill tool) with the `status` command, then follow its instructions.
+25
View File
@@ -0,0 +1,25 @@
{
"description": "MemPalace auto-save and pre-compact hooks",
"hooks": {
"Stop": [
{
"hooks": [
{
"type": "command",
"command": "bash ${CLAUDE_PLUGIN_ROOT}/hooks/mempal-stop-hook.sh"
}
]
}
],
"PreCompact": [
{
"hooks": [
{
"type": "command",
"command": "bash ${CLAUDE_PLUGIN_ROOT}/hooks/mempal-precompact-hook.sh"
}
]
}
]
}
}
@@ -0,0 +1,5 @@
#!/bin/bash
# MemPalace PreCompact Hook — thin wrapper calling Python CLI
# All logic lives in mempalace.hooks_cli for cross-harness extensibility
INPUT=$(cat)
echo "$INPUT" | python3 -m mempalace hook run --hook precompact --harness claude-code
+5
View File
@@ -0,0 +1,5 @@
#!/bin/bash
# MemPalace Stop Hook — thin wrapper calling Python CLI
# All logic lives in mempalace.hooks_cli for cross-harness extensibility
INPUT=$(cat)
echo "$INPUT" | python3 -m mempalace hook run --hook stop --harness claude-code
+18
View File
@@ -0,0 +1,18 @@
{
"name": "mempalace",
"owner": {
"name": "milla-jovovich",
"url": "https://github.com/milla-jovovich"
},
"plugins": [
{
"name": "mempalace",
"source": "./.claude-plugin",
"description": "AI memory system — mine projects and conversations into a searchable palace. 19 MCP tools, auto-save hooks, guided setup.",
"version": "3.0.14",
"author": {
"name": "milla-jovovich"
}
}
]
}
+29
View File
@@ -0,0 +1,29 @@
{
"name": "mempalace",
"version": "3.0.14",
"description": "Give your AI a memory — mine projects and conversations into a searchable palace. 19 MCP tools, auto-save hooks, and guided setup.",
"author": {
"name": "milla-jovovich"
},
"license": "MIT",
"commands": [],
"mcpServers": {
"mempalace": {
"command": "python3",
"args": [
"-m",
"mempalace.mcp_server"
]
}
},
"keywords": [
"memory",
"ai",
"rag",
"mcp",
"chromadb",
"palace",
"search"
],
"repository": "https://github.com/milla-jovovich/mempalace"
}
+35
View File
@@ -0,0 +1,35 @@
---
name: mempalace
description: MemPalace — mine projects and conversations into a searchable memory palace. Use when asked about mempalace, memory palace, mining memories, searching memories, or palace setup.
allowed-tools: Bash, Read, Write, Edit, Glob, Grep
---
# MemPalace
A searchable memory palace for AI — mine projects and conversations, then search them semantically.
## Prerequisites
Ensure `mempalace` is installed:
```bash
mempalace --version
```
If not installed:
```bash
pip install mempalace
```
## Usage
MemPalace provides dynamic instructions via the CLI. To get instructions for any operation:
```bash
mempalace instructions <command>
```
Where `<command>` is one of: `help`, `init`, `mine`, `search`, `status`.
Run the appropriate instructions command, then follow the returned instructions step by step.
+75
View File
@@ -0,0 +1,75 @@
# MemPalace - Codex CLI Plugin
Give your AI a persistent memory -- mine projects and conversations into a searchable palace backed by ChromaDB, with 19 MCP tools, auto-save hooks, and guided skills.
## Prerequisites
- Python 3.9+
- Codex CLI installed and configured
- `pip install mempalace`
## Installation
### Local Install
1. Copy or symlink the `.codex-plugin` directory into your project root:
```bash
cp -r .codex-plugin /path/to/your/project/.codex-plugin
```
2. Verify the plugin is detected:
```bash
codex --plugins
```
3. Initialize your palace:
```bash
codex /init
```
### Git Install
1. Clone the MemPalace repository:
```bash
git clone https://github.com/milla-jovovich/mempalace.git
cd mempalace
```
2. Install the Python package:
```bash
pip install -e .
```
3. The `.codex-plugin` directory is already in the repo root. Codex CLI will detect it automatically when you run Codex from inside the repository.
4. Initialize your palace:
```bash
codex /init
```
## Available Skills
| Skill | Description |
|-------|-------------|
| `/help` | Show available commands and usage tips |
| `/init` | Initialize a new memory palace |
| `/search` | Semantic search across all mined memories |
| `/mine` | Mine a project or conversation into your palace |
| `/status` | Show palace status, room counts, and health |
## Hooks
The plugin includes auto-save hooks that run on session stop (every 15 messages) and before context compaction, automatically preserving conversation context into your palace.
Set the `MEMPAL_DIR` environment variable to a directory path to automatically run `mempalace mine` on that directory during each save trigger.
## Support
- Repository: https://github.com/milla-jovovich/mempalace
- Issues: https://github.com/milla-jovovich/mempalace/issues
+37
View File
@@ -0,0 +1,37 @@
{
"hooks": {
"SessionStart": [
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "${CODEX_PLUGIN_ROOT}/hooks/mempal-hook.sh session-start"
}
]
}
],
"Stop": [
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "${CODEX_PLUGIN_ROOT}/hooks/mempal-hook.sh stop"
}
]
}
],
"PreCompact": [
{
"matcher": "*",
"hooks": [
{
"type": "command",
"command": "${CODEX_PLUGIN_ROOT}/hooks/mempal-hook.sh precompact"
}
]
}
]
}
}
+9
View File
@@ -0,0 +1,9 @@
#!/usr/bin/env bash
set -euo pipefail
HOOK_NAME="${1:?Usage: mempal-hook.sh <hook-name>}"
INPUT_FILE=$(mktemp) || { echo "Failed to create temp file" >&2; exit 1; }
cat > "$INPUT_FILE"
cat "$INPUT_FILE" | python3 -m mempalace hook run --hook "$HOOK_NAME" --harness codex
EXIT_CODE=$?
rm -f "$INPUT_FILE" 2>/dev/null
exit $EXIT_CODE
+52
View File
@@ -0,0 +1,52 @@
{
"name": "mempalace",
"version": "3.0.14",
"description": "Give your AI a memory — mine projects and conversations into a searchable palace. 19 MCP tools, auto-save hooks, and guided setup.",
"author": {
"name": "milla-jovovich"
},
"homepage": "https://github.com/milla-jovovich/mempalace",
"repository": "https://github.com/milla-jovovich/mempalace",
"license": "MIT",
"keywords": [
"memory",
"ai",
"rag",
"mcp",
"chromadb",
"palace",
"search"
],
"skills": "./skills/",
"hooks": "./hooks.json",
"mcpServers": {
"mempalace": {
"command": "python3",
"args": [
"-m",
"mempalace.mcp_server"
]
}
},
"interface": {
"displayName": "MemPalace",
"shortDescription": "AI memory system for Codex",
"longDescription": "Give your AI a persistent memory — mine projects and conversations into a searchable palace backed by ChromaDB, with 19 MCP tools, auto-save hooks, and guided skills.",
"developerName": "milla-jovovich",
"category": "Coding",
"capabilities": [
"Interactive",
"Read",
"Write"
],
"websiteURL": "https://github.com/milla-jovovich/mempalace",
"privacyPolicyURL": "https://github.com/milla-jovovich/mempalace",
"termsOfServiceURL": "https://github.com/milla-jovovich/mempalace",
"defaultPrompt": [
"Search my memories for recent decisions",
"Mine this project into my memory palace",
"Show my palace status and room counts"
],
"brandColor": "#7C3AED"
}
}
+13
View File
@@ -0,0 +1,13 @@
---
name: help
description: Show MemPalace help — available commands, usage tips, and getting started guidance.
allowed-tools: Bash, Read
---
# MemPalace Help
Run the following command and follow the returned instructions step by step:
```bash
mempalace instructions help
```
+13
View File
@@ -0,0 +1,13 @@
---
name: init
description: Initialize a new MemPalace — guided setup for your AI memory palace with ChromaDB backend.
allowed-tools: Bash, Read, Write, Edit
---
# MemPalace Init
Run the following command and follow the returned instructions step by step:
```bash
mempalace instructions init
```
+13
View File
@@ -0,0 +1,13 @@
---
name: mine
description: Mine a project or conversation into your MemPalace — extract and store memories for later retrieval.
allowed-tools: Bash, Read, Glob, Grep
---
# MemPalace Mine
Run the following command and follow the returned instructions step by step:
```bash
mempalace instructions mine
```
+13
View File
@@ -0,0 +1,13 @@
---
name: search
description: Search your MemPalace — semantic search across all mined memories, projects, and conversations.
allowed-tools: Bash, Read
---
# MemPalace Search
Run the following command and follow the returned instructions step by step:
```bash
mempalace instructions search
```
+13
View File
@@ -0,0 +1,13 @@
---
name: status
description: Show MemPalace status — room counts, storage usage, and palace health.
allowed-tools: Bash, Read
---
# MemPalace Status
Run the following command and follow the returned instructions step by step:
```bash
mempalace instructions status
```
+51
View File
@@ -0,0 +1,51 @@
name: Bump Version
on:
push:
branches: [main]
jobs:
bump-version:
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v6
- name: Bump patch version
run: |
CURRENT=$(python3 -c "exec(open('mempalace/version.py').read()); print(__version__)")
IFS='.' read -r MAJOR MINOR PATCH <<< "$CURRENT"
PATCH=$((PATCH + 1))
NEW="${MAJOR}.${MINOR}.${PATCH}"
echo "__version__ = \"${NEW}\"" > mempalace/version.py
# Prepend docstring
sed -i '1i"""Single source of truth for the MemPalace package version."""\n' mempalace/version.py
echo "version=$NEW" >> "$GITHUB_OUTPUT"
id: version
- name: Sync plugin.json
run: |
jq --arg v "${{ steps.version.outputs.version }}" '.version = $v' .claude-plugin/plugin.json > tmp.json && mv tmp.json .claude-plugin/plugin.json
- name: Sync marketplace.json
run: |
jq --arg v "${{ steps.version.outputs.version }}" '.plugins[0].version = $v' .claude-plugin/marketplace.json > tmp.json && mv tmp.json .claude-plugin/marketplace.json
- name: Sync codex plugin.json
run: |
jq --arg v "${{ steps.version.outputs.version }}" '.version = $v' .codex-plugin/plugin.json > tmp.json && mv tmp.json .codex-plugin/plugin.json
- name: Sync pyproject.toml
run: |
sed -i "s/^version = \".*\"/version = \"${{ steps.version.outputs.version }}\"/" pyproject.toml
- name: Commit and push
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git add mempalace/version.py .claude-plugin/plugin.json .claude-plugin/marketplace.json .codex-plugin/plugin.json pyproject.toml
if ! git diff --staged --quiet; then
git commit -m "chore: bump version to ${{ steps.version.outputs.version }}"
git push
fi
+22 -3
View File
@@ -7,7 +7,7 @@ on:
branches: [main] branches: [main]
jobs: jobs:
test: test-linux:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
@@ -18,8 +18,27 @@ jobs:
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- run: pip install -e ".[dev]" - run: pip install -e ".[dev]"
- run: python -m pytest tests/ -v - run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85
test-windows:
runs-on: windows-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v6
with:
python-version: "3.9"
- run: pip install -e ".[dev]"
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85
test-macos:
runs-on: macos-latest
steps:
- uses: actions/checkout@v6
- uses: actions/setup-python@v6
with:
python-version: "3.9"
- run: pip install -e ".[dev]"
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85
lint: lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@@ -27,6 +46,6 @@ jobs:
- uses: actions/setup-python@v6 - uses: actions/setup-python@v6
with: with:
python-version: "3.11" python-version: "3.11"
- run: pip install ruff - run: pip install "ruff>=0.4.0,<0.5"
- run: ruff check . - run: ruff check .
- run: ruff format --check . - run: ruff format --check .
+1
View File
@@ -5,3 +5,4 @@ __pycache__/
*.pyc *.pyc
.pytest_cache/ .pytest_cache/
mempal.yaml mempal.yaml
.a5c/
+19 -1
View File
@@ -29,7 +29,7 @@ Other memory systems try to fix this by letting AI decide what's worth rememberi
<br> <br>
[Quick Start](#quick-start) · [The Palace](#the-palace) · [AAAK Dialect](#aaak-compression) · [Benchmarks](#benchmarks) · [MCP Tools](#mcp-server) [Quick Start](#quick-start) · [The Palace](#the-palace) · [AAAK Dialect](#aaak-dialect-experimental) · [Benchmarks](#benchmarks) · [MCP Tools](#mcp-server)
<br> <br>
@@ -112,6 +112,17 @@ Three mining modes: **projects** (code and docs), **convos** (conversation expor
After the one-time setup (install → init → mine), you don't run MemPalace commands manually. Your AI uses it for you. There are two ways, depending on which AI you use. After the one-time setup (install → init → mine), you don't run MemPalace commands manually. Your AI uses it for you. There are two ways, depending on which AI you use.
### With Claude Code (recommended)
Native marketplace install:
```bash
claude plugin marketplace add milla-jovovich/mempalace
claude plugin install --scope user mempalace
```
Restart Claude Code, then type `/skills` to verify "mempalace" appears.
### With Claude, ChatGPT, Cursor, Gemini (MCP-compatible tools) ### With Claude, ChatGPT, Cursor, Gemini (MCP-compatible tools)
```bash ```bash
@@ -439,6 +450,11 @@ Letta charges $20200/mo for agent-managed memory. MemPalace does it with a wi
## MCP Server ## MCP Server
```bash ```bash
# Via plugin (recommended)
claude plugin marketplace add milla-jovovich/mempalace
claude plugin install --scope user mempalace
# Or manually
claude mcp add mempalace -- python -m mempalace.mcp_server claude mcp add mempalace -- python -m mempalace.mcp_server
``` ```
@@ -509,6 +525,8 @@ Two hooks for Claude Code that automatically save memories during work:
} }
``` ```
**Optional auto-ingest:** Set the `MEMPAL_DIR` environment variable to a directory path and the hooks will automatically run `mempalace mine` on that directory during each save trigger (background on stop, synchronous on precompact).
--- ---
## Benchmarks ## Benchmarks
+17 -2
View File
@@ -1,6 +1,21 @@
"""MemPalace — Give your AI a memory. No API key required.""" """MemPalace — Give your AI a memory. No API key required."""
from .cli import main import logging
from .version import __version__ import os
import platform
from .cli import main # noqa: E402
from .version import __version__ # noqa: E402
# ChromaDB 0.6.x ships a Posthog telemetry client whose capture() signature is
# incompatible with the bundled posthog library, producing noisy stderr warnings
# on every client operation ("Failed to send telemetry event … capture() takes
# 1 positional argument but 3 were given"). Silence just that logger.
logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.CRITICAL)
# ONNX Runtime's CoreML provider segfaults during vector queries on Apple Silicon.
# Force CPU execution unless the user has explicitly set a preference.
if platform.machine() == "arm64" and platform.system() == "Darwin":
os.environ.setdefault("ORT_DISABLE_COREML", "1")
__all__ = ["main", "__version__"] __all__ = ["main", "__version__"]
+60
View File
@@ -226,6 +226,20 @@ def cmd_repair(args):
print(f"\n{'=' * 55}\n") print(f"\n{'=' * 55}\n")
def cmd_hook(args):
"""Run hook logic: reads JSON from stdin, outputs JSON to stdout."""
from .hooks_cli import run_hook
run_hook(hook_name=args.hook, harness=args.harness)
def cmd_instructions(args):
"""Output skill instructions to stdout."""
from .instructions_cli import run_instructions
run_instructions(name=args.name)
def cmd_compress(args): def cmd_compress(args):
"""Compress drawers in a wing using AAAK Dialect.""" """Compress drawers in a wing using AAAK Dialect."""
import chromadb import chromadb
@@ -451,6 +465,35 @@ def main():
help="Only split files containing at least N sessions (default: 2)", help="Only split files containing at least N sessions (default: 2)",
) )
# hook
p_hook = sub.add_parser(
"hook",
help="Run hook logic (reads JSON from stdin, outputs JSON to stdout)",
)
hook_sub = p_hook.add_subparsers(dest="hook_action")
p_hook_run = hook_sub.add_parser("run", help="Execute a hook")
p_hook_run.add_argument(
"--hook",
required=True,
choices=["session-start", "stop", "precompact"],
help="Hook name to run",
)
p_hook_run.add_argument(
"--harness",
required=True,
choices=["claude-code", "codex"],
help="Harness type (determines stdin JSON format)",
)
# instructions
p_instructions = sub.add_parser(
"instructions",
help="Output skill instructions to stdout",
)
instructions_sub = p_instructions.add_subparsers(dest="instructions_name")
for instr_name in ["init", "search", "mine", "help", "status"]:
instructions_sub.add_parser(instr_name, help=f"Output {instr_name} instructions")
# repair # repair
sub.add_parser( sub.add_parser(
"repair", "repair",
@@ -466,6 +509,23 @@ def main():
parser.print_help() parser.print_help()
return return
# Handle two-level subcommands
if args.command == "hook":
if not getattr(args, "hook_action", None):
p_hook.print_help()
return
cmd_hook(args)
return
if args.command == "instructions":
name = getattr(args, "instructions_name", None)
if not name:
p_instructions.print_help()
return
args.name = name
cmd_instructions(args)
return
dispatch = { dispatch = {
"init": cmd_init, "init": cmd_init,
"mine": cmd_mine, "mine": cmd_mine,
+1 -1
View File
@@ -309,7 +309,7 @@ class EntityRegistry:
def save(self): def save(self):
self._path.parent.mkdir(parents=True, exist_ok=True) self._path.parent.mkdir(parents=True, exist_ok=True)
self._path.write_text(json.dumps(self._data, indent=2)) self._path.write_text(json.dumps(self._data, indent=2), encoding="utf-8")
@staticmethod @staticmethod
def _empty() -> dict: def _empty() -> dict:
+226
View File
@@ -0,0 +1,226 @@
"""
Hook logic for MemPalace — Python implementation of session-start, stop, and precompact hooks.
Reads JSON from stdin, outputs JSON to stdout.
Supported hooks: session-start, stop, precompact
Supported harnesses: claude-code, codex (extensible to cursor, gemini, etc.)
"""
import json
import os
import re
import subprocess
import sys
from datetime import datetime
from pathlib import Path
SAVE_INTERVAL = 15
STATE_DIR = Path.home() / ".mempalace" / "hook_state"
STOP_BLOCK_REASON = (
"AUTO-SAVE checkpoint. Save key topics, decisions, quotes, and code "
"from this session to your memory system. Organize into appropriate "
"categories. Use verbatim quotes where possible. Continue conversation "
"after saving."
)
PRECOMPACT_BLOCK_REASON = (
"COMPACTION IMMINENT. Save ALL topics, decisions, quotes, code, and "
"important context from this session to your memory system. Be thorough "
"\u2014 after compaction, detailed context will be lost. Organize into "
"appropriate categories. Use verbatim quotes where possible. Save "
"everything, then allow compaction to proceed."
)
def _sanitize_session_id(session_id: str) -> str:
"""Only allow alnum, dash, underscore to prevent path traversal."""
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "", session_id)
return sanitized or "unknown"
def _count_human_messages(transcript_path: str) -> int:
"""Count human messages in a JSONL transcript, skipping command-messages."""
path = Path(transcript_path).expanduser()
if not path.is_file():
return 0
count = 0
try:
with open(path, encoding="utf-8", errors="replace") as f:
for line in f:
try:
entry = json.loads(line)
msg = entry.get("message", {})
if isinstance(msg, dict) and msg.get("role") == "user":
content = msg.get("content", "")
if isinstance(content, str):
if "<command-message>" in content:
continue
elif isinstance(content, list):
text = " ".join(
b.get("text", "") for b in content if isinstance(b, dict)
)
if "<command-message>" in text:
continue
count += 1
except (json.JSONDecodeError, AttributeError):
pass
except OSError:
return 0
return count
def _log(message: str):
"""Append to hook state log file."""
try:
STATE_DIR.mkdir(parents=True, exist_ok=True)
log_path = STATE_DIR / "hook.log"
timestamp = datetime.now().strftime("%H:%M:%S")
with open(log_path, "a") as f:
f.write(f"[{timestamp}] {message}\n")
except OSError:
pass
def _output(data: dict):
"""Print JSON to stdout with consistent formatting (pretty-printed)."""
print(json.dumps(data, indent=2, ensure_ascii=False))
def _maybe_auto_ingest():
"""If MEMPAL_DIR is set and exists, run mempalace mine in background."""
mempal_dir = os.environ.get("MEMPAL_DIR", "")
if mempal_dir and os.path.isdir(mempal_dir):
try:
log_path = STATE_DIR / "hook.log"
with open(log_path, "a") as log_f:
subprocess.Popen(
[sys.executable, "-m", "mempalace", "mine", mempal_dir],
stdout=log_f,
stderr=log_f,
)
except OSError:
pass
SUPPORTED_HARNESSES = {"claude-code", "codex"}
def _parse_harness_input(data: dict, harness: str) -> dict:
"""Parse stdin JSON according to the harness type."""
if harness not in SUPPORTED_HARNESSES:
print(f"Unknown harness: {harness}", file=sys.stderr)
sys.exit(1)
return {
"session_id": _sanitize_session_id(str(data.get("session_id", "unknown"))),
"stop_hook_active": data.get("stop_hook_active", False),
"transcript_path": str(data.get("transcript_path", "")),
}
def hook_stop(data: dict, harness: str):
"""Stop hook: block every N messages for auto-save."""
parsed = _parse_harness_input(data, harness)
session_id = parsed["session_id"]
stop_hook_active = parsed["stop_hook_active"]
transcript_path = parsed["transcript_path"]
# If already in a save cycle, let through (infinite-loop prevention)
if str(stop_hook_active).lower() in ("true", "1", "yes"):
_output({})
return
# Count human messages
exchange_count = _count_human_messages(transcript_path)
# Track last save point
STATE_DIR.mkdir(parents=True, exist_ok=True)
last_save_file = STATE_DIR / f"{session_id}_last_save"
last_save = 0
if last_save_file.is_file():
try:
last_save = int(last_save_file.read_text().strip())
except (ValueError, OSError):
last_save = 0
since_last = exchange_count - last_save
_log(f"Session {session_id}: {exchange_count} exchanges, {since_last} since last save")
if since_last >= SAVE_INTERVAL and exchange_count > 0:
# Update last save point
try:
last_save_file.write_text(str(exchange_count), encoding="utf-8")
except OSError:
pass
_log(f"TRIGGERING SAVE at exchange {exchange_count}")
# Optional: auto-ingest if MEMPAL_DIR is set
_maybe_auto_ingest()
_output({"decision": "block", "reason": STOP_BLOCK_REASON})
else:
_output({})
def hook_session_start(data: dict, harness: str):
"""Session start hook: initialize session tracking state."""
parsed = _parse_harness_input(data, harness)
session_id = parsed["session_id"]
_log(f"SESSION START for session {session_id}")
# Initialize session state directory
STATE_DIR.mkdir(parents=True, exist_ok=True)
# Pass through — no blocking on session start
_output({})
def hook_precompact(data: dict, harness: str):
"""Precompact hook: always block with comprehensive save instruction."""
parsed = _parse_harness_input(data, harness)
session_id = parsed["session_id"]
_log(f"PRE-COMPACT triggered for session {session_id}")
# Optional: auto-ingest synchronously before compaction (so memories land first)
mempal_dir = os.environ.get("MEMPAL_DIR", "")
if mempal_dir and os.path.isdir(mempal_dir):
try:
log_path = STATE_DIR / "hook.log"
with open(log_path, "a") as log_f:
subprocess.run(
[sys.executable, "-m", "mempalace", "mine", mempal_dir],
stdout=log_f,
stderr=log_f,
timeout=60,
)
except OSError:
pass
# Always block -- compaction = save everything
_output({"decision": "block", "reason": PRECOMPACT_BLOCK_REASON})
def run_hook(hook_name: str, harness: str):
"""Main entry point: read stdin JSON, dispatch to hook handler."""
try:
data = json.load(sys.stdin)
except (json.JSONDecodeError, EOFError):
_log("WARNING: Failed to parse stdin JSON, proceeding with empty data")
data = {}
hooks = {
"session-start": hook_session_start,
"stop": hook_stop,
"precompact": hook_precompact,
}
handler = hooks.get(hook_name)
if handler is None:
print(f"Unknown hook: {hook_name}", file=sys.stderr)
sys.exit(1)
handler(data, harness)
+105
View File
@@ -0,0 +1,105 @@
# MemPalace
AI memory system. Store everything, find anything. Local, free, no API key.
---
## Slash Commands
| Command | Description |
|----------------------|--------------------------------|
| /mempalace:init | Install and set up MemPalace |
| /mempalace:search | Search your memories |
| /mempalace:mine | Mine projects and conversations|
| /mempalace:status | Palace overview and stats |
| /mempalace:help | This help message |
---
## MCP Tools (19)
### Palace (read)
- mempalace_status -- Palace status and stats
- mempalace_list_wings -- List all wings
- mempalace_list_rooms -- List rooms in a wing
- mempalace_get_taxonomy -- Get the full taxonomy tree
- mempalace_search -- Search memories by query
- mempalace_check_duplicate -- Check if a memory already exists
- mempalace_get_aaak_spec -- Get the AAAK specification
### Palace (write)
- mempalace_add_drawer -- Add a new memory (drawer)
- mempalace_delete_drawer -- Delete a memory (drawer)
### Knowledge Graph
- mempalace_kg_query -- Query the knowledge graph
- mempalace_kg_add -- Add a knowledge graph entry
- mempalace_kg_invalidate -- Invalidate a knowledge graph entry
- mempalace_kg_timeline -- View knowledge graph timeline
- mempalace_kg_stats -- Knowledge graph statistics
### Navigation
- mempalace_traverse -- Traverse the palace structure
- mempalace_find_tunnels -- Find cross-wing connections
- mempalace_graph_stats -- Graph connectivity statistics
### Agent Diary
- mempalace_diary_write -- Write a diary entry
- mempalace_diary_read -- Read diary entries
---
## CLI Commands
mempalace init <dir> Initialize a new palace
mempalace mine <dir> Mine a project (default mode)
mempalace mine <dir> --mode convos Mine conversation exports
mempalace search "query" Search your memories
mempalace split <dir> Split large transcript files
mempalace wake-up Load palace into context
mempalace compress Compress palace storage
mempalace status Show palace status
mempalace repair Rebuild vector index
mempalace hook run Run hook logic (for harness integration)
mempalace instructions <name> Output skill instructions
---
## Auto-Save Hooks
- Stop hook -- Automatically saves memories every 15 messages. Counts human
messages in the session transcript (skipping command-messages). When the
threshold is reached, blocks the AI with a save instruction. Uses
~/.mempalace/hook_state/ to track save points per session. If
stop_hook_active is true, passes through to prevent infinite loops.
- PreCompact hook -- Emergency save before context compaction. Always blocks
with a comprehensive save instruction because compaction means the AI is
about to lose detailed context.
Hooks read JSON from stdin and output JSON to stdout. They can be invoked via:
echo '{"session_id":"abc","stop_hook_active":false,"transcript_path":"..."}' | mempalace hook run --hook stop --harness claude-code
---
## Architecture
Wings (projects/people)
+-- Rooms (topics)
+-- Closets (summaries)
+-- Drawers (verbatim memories)
Halls connect rooms within a wing.
Tunnels connect rooms across wings.
The palace is stored locally using ChromaDB for vector search and SQLite for
metadata. No cloud services or API keys required.
---
## Getting Started
1. /mempalace:init -- Set up your palace
2. /mempalace:mine -- Mine a project or conversation
3. /mempalace:search -- Find what you stored
+69
View File
@@ -0,0 +1,69 @@
# MemPalace Init
Guide the user through a complete MemPalace setup. Follow each step in order,
stopping to report errors and attempt remediation before proceeding.
## Step 1: Check Python version
Run `python3 --version` (or `python --version` on Windows) and confirm the
version is 3.9 or higher. If Python is not found or the version is too old,
tell the user they need Python 3.9+ installed and stop.
## Step 2: Check if mempalace is already installed
Run `pip show mempalace` to see if the package is already present. If it is,
report the installed version and skip to Step 4.
## Step 3: Install mempalace
Run `pip install mempalace`.
### Error handling -- pip failures
If `pip install mempalace` fails, try these fallbacks in order:
1. Try `pip3 install mempalace`
2. Try `python -m pip install mempalace` (or `python3 -m pip install mempalace`)
3. If the error mentions missing build tools or compilation failures (commonly
from chromadb or its native dependencies):
- On Linux/macOS: suggest `sudo apt-get install build-essential python3-dev`
(Debian/Ubuntu) or `xcode-select --install` (macOS)
- On Windows: suggest installing Microsoft C++ Build Tools from
https://visualstudio.microsoft.com/visual-cpp-build-tools/
- Then retry the install command
4. If all attempts fail, report the error clearly and stop.
## Step 4: Ask for project directory
Ask the user which project directory they want to initialize with MemPalace.
Offer the current working directory as the default. Wait for their response
before continuing.
## Step 5: Initialize the palace
Run `mempalace init <dir>` where `<dir>` is the directory from Step 4.
If this fails, report the error and stop.
## Step 6: Configure MCP server
Run the following command to register the MemPalace MCP server with Claude:
claude mcp add mempalace -- python -m mempalace.mcp_server
If this fails, report the error but continue to the next step (MCP
configuration can be done manually later).
## Step 7: Verify installation
Run `mempalace status` and confirm the output shows a healthy palace.
If the command fails or reports errors, walk the user through troubleshooting
based on the output.
## Step 8: Show next steps
Tell the user setup is complete and suggest these next actions:
- Use /mempalace:mine to start adding data to their palace
- Use /mempalace:search to query their palace and retrieve stored knowledge
+64
View File
@@ -0,0 +1,64 @@
# MemPalace Mine
When the user invokes this skill, follow these steps:
## 1. Ask what to mine
Ask the user what they want to mine and where the source data is located.
Clarify:
- Is it a project directory (code, docs, notes)?
- Is it conversation exports (Claude, ChatGPT, Slack)?
- Do they want auto-classification (decisions, milestones, problems)?
## 2. Choose the mining mode
There are three mining modes:
### Project mining
mempalace mine <dir>
Mines code files, documentation, and notes from a project directory.
### Conversation mining
mempalace mine <dir> --mode convos
Mines conversation exports from Claude, ChatGPT, or Slack into the palace.
### General extraction (auto-classify)
mempalace mine <dir> --mode convos --extract general
Auto-classifies mined content into decisions, milestones, and problems.
## 3. Optionally split mega-files first
If the source directory contains very large files, suggest splitting them
before mining:
mempalace split <dir> [--dry-run]
Use --dry-run first to preview what will be split without making changes.
## 4. Optionally tag with a wing
If the user wants to organize mined content under a specific wing, add the
--wing flag:
mempalace mine <dir> --wing <name>
## 5. Show progress and results
Run the selected mining command and display progress as it executes. After
completion, summarize the results including:
- Number of items mined
- Categories or classifications applied
- Any warnings or skipped files
## 6. Suggest next steps
After mining completes, suggest the user try:
- /mempalace:search -- search the newly mined content
- /mempalace:status -- check the current state of their palace
- Mine more data from additional sources
+57
View File
@@ -0,0 +1,57 @@
# MemPalace Search
When the user wants to search their MemPalace memories, follow these steps:
## 1. Parse the Search Query
Extract the core search intent from the user's message. Identify any explicit
or implicit filters:
- Wing -- a top-level category (e.g., "work", "personal", "research")
- Room -- a sub-category within a wing
- Keywords / semantic query -- the actual search terms
## 2. Determine Wing/Room Filters
If the user mentions a specific domain, topic area, or context, map it to the
appropriate wing and/or room. If unsure, omit filters to search globally. You
can discover the taxonomy first if needed.
## 3. Use MCP Tools (Preferred)
If MCP tools are available, use them in this priority order:
- mempalace_search(query, wing, room) -- Primary search tool. Pass the semantic
query and any wing/room filters.
- mempalace_list_wings -- Discover all available wings. Use when the user asks
what categories exist or you need to resolve a wing name.
- mempalace_list_rooms(wing) -- List rooms within a specific wing. Use to help
the user navigate or to resolve a room name.
- mempalace_get_taxonomy -- Retrieve the full wing/room/drawer tree. Use when
the user wants an overview of their entire memory structure.
- mempalace_traverse(room) -- Walk the knowledge graph starting from a room.
Use when the user wants to explore connections and related memories.
- mempalace_find_tunnels(wing1, wing2) -- Find cross-wing connections (tunnels)
between two wings. Use when the user asks about relationships between
different knowledge domains.
## 4. CLI Fallback
If MCP tools are not available, fall back to the CLI:
mempalace search "query" [--wing X] [--room Y]
## 5. Present Results
When presenting search results:
- Always include source attribution: wing, room, and drawer for each result
- Show relevance or similarity scores if available
- Group results by wing/room when returning multiple hits
- Quote or summarize the memory content clearly
## 6. Offer Next Steps
After presenting results, offer the user options to go deeper:
- Drill deeper -- search within a specific room or narrow the query
- Traverse -- explore the knowledge graph from a related room
- Check tunnels -- look for cross-wing connections if the topic spans domains
- Browse taxonomy -- show the full structure for manual exploration
+49
View File
@@ -0,0 +1,49 @@
# MemPalace Status
Display the current state of the user's memory palace.
## Step 1: Gather Palace Status
Check if MCP tools are available (look for mempalace_status in available tools).
- If MCP is available: Call the mempalace_status tool to retrieve palace state.
- If MCP is not available: Run the CLI command: mempalace status
## Step 2: Display Wing/Room/Drawer Counts
Present the palace structure counts clearly:
- Number of wings
- Number of rooms
- Number of drawers
- Total memories stored
Keep the output concise -- use a brief summary format, not verbose tables.
## Step 3: Knowledge Graph Stats (MCP only)
If MCP tools are available, also call:
- mempalace_kg_stats -- for a knowledge graph overview (triple count, entity
count, relationship types)
- mempalace_graph_stats -- for connectivity information (connected components,
average connections per entity)
Present these alongside the palace counts in a unified summary.
## Step 4: Suggest Next Actions
Based on the current state, suggest one relevant action:
- Empty palace (zero memories): Suggest "Try /mempalace:mine to add data from
files, URLs, or text."
- Has data but no knowledge graph (memories exist but KG stats show zero
triples): Suggest "Consider adding knowledge graph triples for richer
queries."
- Healthy palace (has memories and KG data): Suggest "Use /mempalace:search to
query your memories."
## Output Style
- Be concise and informative -- aim for a quick glance, not a report.
- Use short labels and numbers, not prose paragraphs.
- If any step fails or a tool is unavailable, note it briefly and continue
with what is available.
+28
View File
@@ -0,0 +1,28 @@
"""
Instruction text output for MemPalace CLI commands.
Each instruction lives as a .md file in the instructions/ directory
inside the package. The CLI reads and prints the file content.
"""
import sys
from pathlib import Path
INSTRUCTIONS_DIR = Path(__file__).parent / "instructions"
AVAILABLE = ["init", "search", "mine", "help", "status"]
def run_instructions(name: str):
"""Read and print the instruction .md file for the given name."""
if name not in AVAILABLE:
print(f"Unknown instructions: {name}", file=sys.stderr)
print(f"Available: {', '.join(sorted(AVAILABLE))}", file=sys.stderr)
sys.exit(1)
md_path = INSTRUCTIONS_DIR / f"{name}.md"
if not md_path.is_file():
print(f"Instructions file not found: {md_path}", file=sys.stderr)
sys.exit(1)
print(md_path.read_text())
+47 -17
View File
@@ -2,7 +2,7 @@
""" """
MemPalace MCP Server — read/write palace access for Claude Code MemPalace MCP Server — read/write palace access for Claude Code
================================================================ ================================================================
Install: claude mcp add mempalace -- python -m mempalace.mcp_server Install: claude mcp add mempalace -- python -m mempalace.mcp_server [--palace /path/to/palace]
Tools (read): Tools (read):
mempalace_status — total drawers, wing/room breakdown mempalace_status — total drawers, wing/room breakdown
@@ -17,6 +17,8 @@ Tools (write):
mempalace_delete_drawer — remove a drawer by ID mempalace_delete_drawer — remove a drawer by ID
""" """
import argparse
import os
import sys import sys
import json import json
import logging import logging
@@ -32,21 +34,50 @@ import chromadb
from .knowledge_graph import KnowledgeGraph from .knowledge_graph import KnowledgeGraph
_kg = KnowledgeGraph()
logging.basicConfig(level=logging.INFO, format="%(message)s", stream=sys.stderr) logging.basicConfig(level=logging.INFO, format="%(message)s", stream=sys.stderr)
logger = logging.getLogger("mempalace_mcp") logger = logging.getLogger("mempalace_mcp")
def _parse_args():
parser = argparse.ArgumentParser(description="MemPalace MCP Server")
parser.add_argument(
"--palace",
metavar="PATH",
help="Path to the palace directory (overrides config file and env var)",
)
args, unknown = parser.parse_known_args()
if unknown:
logger.debug("Ignoring unknown args: %s", unknown)
return args
_args = _parse_args()
if _args.palace:
os.environ["MEMPALACE_PALACE_PATH"] = os.path.abspath(_args.palace)
_config = MempalaceConfig() _config = MempalaceConfig()
if _args.palace:
_kg = KnowledgeGraph(db_path=os.path.join(_config.palace_path, "knowledge_graph.sqlite3"))
else:
_kg = KnowledgeGraph()
_client_cache = None
_collection_cache = None
def _get_collection(create=False): def _get_collection(create=False):
"""Return the ChromaDB collection, or None on failure.""" """Return the ChromaDB collection, caching the client between calls."""
global _client_cache, _collection_cache
try: try:
client = chromadb.PersistentClient(path=_config.palace_path) if _client_cache is None:
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
if create: if create:
return client.get_or_create_collection(_config.collection_name) _collection_cache = _client_cache.get_or_create_collection(_config.collection_name)
return client.get_collection(_config.collection_name) elif _collection_cache is None:
_collection_cache = _client_cache.get_collection(_config.collection_name)
return _collection_cache
except Exception: except Exception:
return None return None
@@ -270,19 +301,18 @@ def tool_add_drawer(
if not col: if not col:
return _no_palace() return _no_palace()
# Duplicate check drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}"
dup = tool_check_duplicate(content, threshold=0.9)
if dup.get("is_duplicate"):
return {
"success": False,
"reason": "duplicate",
"matches": dup["matches"],
}
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((content[:100] + datetime.now().isoformat()).encode()).hexdigest()[:16]}" # Idempotency: if the deterministic ID already exists, return success as a no-op.
try:
existing = col.get(ids=[drawer_id])
if existing and existing["ids"]:
return {"success": True, "reason": "already_exists", "drawer_id": drawer_id}
except Exception:
pass
try: try:
col.add( col.upsert(
ids=[drawer_id], ids=[drawer_id],
documents=[content], documents=[content],
metadatas=[ metadatas=[
+38 -25
View File
@@ -403,10 +403,22 @@ def get_collection(palace_path: str):
def file_already_mined(collection, source_file: str) -> bool: def file_already_mined(collection, source_file: str) -> bool:
"""Fast check: has this file been filed before?""" """Fast check: has this file been filed before and is unchanged?
Compares the stored mtime in drawer metadata against the file's current
mtime. Returns False (needs re-mining) when the file has been modified
since it was last mined, or when no mtime was stored.
"""
try: try:
results = collection.get(where={"source_file": source_file}, limit=1) results = collection.get(where={"source_file": source_file}, limit=1)
return len(results.get("ids", [])) > 0 if not results.get("ids"):
return False
stored_meta = results["metadatas"][0] if results.get("metadatas") else {}
stored_mtime = stored_meta.get("source_mtime")
if stored_mtime is None:
return False
current_mtime = os.path.getmtime(source_file)
return float(stored_mtime) == current_mtime
except Exception: except Exception:
return False return False
@@ -417,24 +429,26 @@ def add_drawer(
"""Add one drawer to the palace.""" """Add one drawer to the palace."""
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}" drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}"
try: try:
collection.add( metadata = {
"wing": wing,
"room": room,
"source_file": source_file,
"chunk_index": chunk_index,
"added_by": agent,
"filed_at": datetime.now().isoformat(),
}
# Store file mtime so we can detect modifications later.
try:
metadata["source_mtime"] = os.path.getmtime(source_file)
except OSError:
pass
collection.upsert(
documents=[content], documents=[content],
ids=[drawer_id], ids=[drawer_id],
metadatas=[ metadatas=[metadata],
{
"wing": wing,
"room": room,
"source_file": source_file,
"chunk_index": chunk_index,
"added_by": agent,
"filed_at": datetime.now().isoformat(),
}
],
) )
return True return True
except Exception as e: except Exception:
if "already exists" in str(e).lower() or "duplicate" in str(e).lower():
return False
raise raise
@@ -451,29 +465,29 @@ def process_file(
rooms: list, rooms: list,
agent: str, agent: str,
dry_run: bool, dry_run: bool,
) -> int: ) -> tuple:
"""Read, chunk, route, and file one file. Returns drawer count.""" """Read, chunk, route, and file one file. Returns (drawer_count, room_name)."""
# Skip if already filed # Skip if already filed
source_file = str(filepath) source_file = str(filepath)
if not dry_run and file_already_mined(collection, source_file): if not dry_run and file_already_mined(collection, source_file):
return 0 return 0, None
try: try:
content = filepath.read_text(encoding="utf-8", errors="replace") content = filepath.read_text(encoding="utf-8", errors="replace")
except OSError: except OSError:
return 0 return 0, None
content = content.strip() content = content.strip()
if len(content) < MIN_CHUNK_SIZE: if len(content) < MIN_CHUNK_SIZE:
return 0 return 0, None
room = detect_room(filepath, content, rooms, project_path) room = detect_room(filepath, content, rooms, project_path)
chunks = chunk_text(content, source_file) chunks = chunk_text(content, source_file)
if dry_run: if dry_run:
print(f" [DRY RUN] {filepath.name} → room:{room} ({len(chunks)} drawers)") print(f" [DRY RUN] {filepath.name} → room:{room} ({len(chunks)} drawers)")
return len(chunks) return len(chunks), room
drawers_added = 0 drawers_added = 0
for chunk in chunks: for chunk in chunks:
@@ -489,7 +503,7 @@ def process_file(
if added: if added:
drawers_added += 1 drawers_added += 1
return drawers_added return drawers_added, room
# ============================================================================= # =============================================================================
@@ -608,7 +622,7 @@ def mine(
room_counts = defaultdict(int) room_counts = defaultdict(int)
for i, filepath in enumerate(files, 1): for i, filepath in enumerate(files, 1):
drawers = process_file( drawers, room = process_file(
filepath=filepath, filepath=filepath,
project_path=project_path, project_path=project_path,
collection=collection, collection=collection,
@@ -621,7 +635,6 @@ def mine(
files_skipped += 1 files_skipped += 1
else: else:
total_drawers += drawers total_drawers += drawers
room = detect_room(filepath, "", rooms, project_path)
room_counts[room] += 1 room_counts[room] += 1
if not dry_run: if not dry_run:
print(f" ✓ [{i:4}/{len(files)}] {filepath.name[:50]:50} +{drawers}") print(f" ✓ [{i:4}/{len(files)}] {filepath.name[:50]:50} +{drawers}")
+2 -2
View File
@@ -312,7 +312,7 @@ def _generate_aaak_bootstrap(
] ]
) )
(mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines)) (mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines), encoding="utf-8")
# Critical facts bootstrap (pre-palace — before any mining) # Critical facts bootstrap (pre-palace — before any mining)
facts_lines = [ facts_lines = [
@@ -359,7 +359,7 @@ def _generate_aaak_bootstrap(
] ]
) )
(mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines)) (mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines), encoding="utf-8")
def run_onboarding( def run_onboarding(
+1 -1
View File
@@ -219,7 +219,7 @@ def split_file(filepath, output_dir, dry_run=False):
if dry_run: if dry_run:
print(f" [{i + 1}/{len(boundaries) - 1}] {name} ({len(chunk)} lines)") print(f" [{i + 1}/{len(boundaries) - 1}] {name} ({len(chunk)} lines)")
else: else:
out_path.write_text("".join(chunk)) out_path.write_text("".join(chunk), encoding="utf-8")
print(f"{name} ({len(chunk)} lines)") print(f"{name} ({len(chunk)} lines)")
written.append(out_path) written.append(out_path)
+1 -1
View File
@@ -1,3 +1,3 @@
"""Single source of truth for the MemPalace package version.""" """Single source of truth for the MemPalace package version."""
__version__ = "3.0.0" __version__ = "3.0.14"
+21 -3
View File
@@ -1,6 +1,6 @@
[project] [project]
name = "mempalace" name = "mempalace"
version = "3.0.0" version = "3.0.14"
description = "Give your AI a memory — mine projects and conversations into a searchable palace. No API key required." description = "Give your AI a memory — mine projects and conversations into a searchable palace. No API key required."
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"
@@ -38,11 +38,11 @@ Repository = "https://github.com/milla-jovovich/mempalace"
mempalace = "mempalace:main" mempalace = "mempalace:main"
[project.optional-dependencies] [project.optional-dependencies]
dev = ["pytest>=7.0", "ruff>=0.4.0"] dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"]
spellcheck = ["autocorrect>=2.0"] spellcheck = ["autocorrect>=2.0"]
[dependency-groups] [dependency-groups]
dev = ["pytest>=7.0", "ruff>=0.4.0"] dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"]
[build-system] [build-system]
requires = ["hatchling"] requires = ["hatchling"]
@@ -64,3 +64,21 @@ quote-style = "double"
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = ["tests"] testpaths = ["tests"]
pythonpath = ["."]
addopts = "-m 'not benchmark and not slow and not stress'"
markers = [
"benchmark: scale/performance benchmark tests",
"slow: tests that take more than 30 seconds",
"stress: destructive scale tests (100K+ drawers)",
]
[tool.coverage.run]
source = ["mempalace"]
[tool.coverage.report]
fail_under = 85
show_missing = true
exclude_lines = [
"if __name__",
"pragma: no cover",
]
+138
View File
@@ -0,0 +1,138 @@
# MemPalace Scale Benchmark Suite
106 tests that benchmark mempalace at scale to validate real-world performance limits.
## Why
MemPalace has strong academic scores (96.6% R@5 on LongMemEval) but no empirical data on how it behaves at scale. Key unknowns:
- `tool_status()` loads ALL metadata into memory — at what palace size does this OOM?
- `PersistentClient` is re-instantiated on every MCP call — what's the overhead?
- Modified files are never re-ingested — what's the skip-check cost at scale?
- How does query latency degrade as the palace grows from 1K to 100K drawers?
- Does wing/room filtering actually improve retrieval, and by how much?
- At what per-room drawer count does recall break regardless of filtering?
This suite finds those answers.
## Quick Start
```bash
# Fast smoke test (~2 min)
uv run pytest tests/benchmarks/ -v --bench-scale=small -m "benchmark and not slow"
# Full small scale (~35 min)
uv run pytest tests/benchmarks/ -v --bench-scale=small
# Medium scale with JSON report
uv run pytest tests/benchmarks/ -v --bench-scale=medium --bench-report=results.json
# Stress test (local only, very slow)
uv run pytest tests/benchmarks/ -v --bench-scale=stress -m stress
```
## Scale Levels
| Level | Drawers | Wings | Rooms/Wing | KG Triples | Use case |
|---------|---------|-------|------------|------------|---------------------|
| small | 1,000 | 3 | 5 | 200 | CI, quick checks |
| medium | 10,000 | 8 | 12 | 2,000 | Pre-release testing |
| large | 50,000 | 15 | 20 | 10,000 | Scale limit finding |
| stress | 100,000 | 25 | 30 | 50,000 | Breaking point |
## Test Modules
### Critical Path
| File | What it tests |
|------|--------------|
| `test_mcp_bench.py` | MCP tool response times, unbounded metadata fetch, client re-instantiation overhead |
| `test_chromadb_stress.py` | ChromaDB breaking point, query degradation curve, batch vs sequential insert |
| `test_memory_profile.py` | RSS/heap growth over repeated operations, leak detection |
### Performance Baselines
| File | What it tests |
|------|--------------|
| `test_ingest_bench.py` | Mining throughput (files/sec, drawers/sec), peak RSS, chunking speed, re-ingest skip overhead |
| `test_search_bench.py` | Query latency vs palace size, recall@k with planted needles, concurrent queries, n_results scaling |
### Architectural Validation
| File | What it tests |
|------|--------------|
| `test_palace_boost.py` | Retrieval improvement from wing/room filtering at different scales |
| `test_recall_threshold.py` | Per-room recall ceiling — isolates embedding model limit with all drawers in one bucket |
| `test_knowledge_graph_bench.py` | Triple insertion rate, temporal query accuracy, SQLite concurrent access |
| `test_layers_bench.py` | MemoryStack wake-up cost, Layer1 unbounded fetch, token budget compliance |
## Architecture
```
tests/benchmarks/
conftest.py # --bench-scale / --bench-report CLI options, fixtures, markers
data_generator.py # Deterministic data factory (seeded RNG, planted needles)
report.py # JSON report writer + regression checker
test_*.py # 9 test modules (106 tests total)
```
### Data Generator
`PalaceDataGenerator(seed=42, scale="small")` produces deterministic, realistic test data:
- **`generate_project_tree()`** — writes real files + `mempalace.yaml` for `mine()` to ingest
- **`populate_palace_directly()`** — bypasses mining, inserts directly into ChromaDB (10-100x faster for search/MCP benchmarks)
- **`generate_kg_triples()`** — entity-relationship triples with temporal validity
- **`generate_search_queries()`** — queries with known-good answers for recall measurement
**Planted needles**: Unique identifiable content (e.g., `NEEDLE_0042: PostgreSQL vacuum autovacuum threshold...`) seeded into specific wings/rooms. Search queries target these needles, enabling recall@k measurement without an LLM judge.
### JSON Reports
When run with `--bench-report=path.json`, produces machine-readable output:
```json
{
"timestamp": "2026-04-07T...",
"git_sha": "abc123",
"scale": "small",
"system": {"os": "linux", "cpu_count": 8},
"results": {
"mcp_status": {"latency_ms_at_1000": 45.2, "rss_delta_mb_at_5000": 12.3},
"search": {"avg_latency_ms_at_5000": 23.1, "recall_at_5": 0.92},
"chromadb_insert": {"sequential_ms": 8500, "batched_ms": 1200, "speedup_ratio": 7.1}
}
}
```
### Regression Detection
```python
from tests.benchmarks.report import check_regression
regressions = check_regression("current.json", "baseline.json", threshold=0.2)
# Returns list of metric descriptions that degraded beyond 20%
```
## CI Integration
The GitHub Actions workflow runs benchmarks on PRs at small scale:
```yaml
benchmark:
runs-on: ubuntu-latest
if: github.event_name == 'pull_request'
# Runs: pytest tests/benchmarks/ -m "benchmark and not stress and not slow" --bench-scale=small
```
Existing unit tests are isolated with `--ignore=tests/benchmarks`.
## Markers
- `@pytest.mark.benchmark` — all benchmark tests
- `@pytest.mark.slow` — tests taking >30s even at small scale
- `@pytest.mark.stress` — tests that should only run at large/stress scale
## Dependencies
Only one new dependency beyond the existing dev stack: `psutil` (for cross-platform RSS measurement). `tracemalloc` and `resource` are stdlib.
+1
View File
@@ -0,0 +1 @@
# MemPalace scale benchmark suite
+144
View File
@@ -0,0 +1,144 @@
"""Benchmark-specific pytest configuration, fixtures, and CLI options."""
import json
import os
import tempfile
import pytest
SCALE_OPTIONS = ["small", "medium", "large", "stress"]
def pytest_addoption(parser):
parser.addoption(
"--bench-scale",
default="small",
choices=SCALE_OPTIONS,
help="Scale level for benchmark tests: small (1K), medium (10K), large (50K), stress (100K)",
)
parser.addoption(
"--bench-report",
default=None,
help="Path for JSON benchmark report output",
)
@pytest.fixture(scope="session")
def bench_scale(request):
"""The configured benchmark scale level."""
return request.config.getoption("--bench-scale")
@pytest.fixture(scope="session")
def bench_report_path(request):
"""Path for JSON report output, or None."""
return request.config.getoption("--bench-report")
@pytest.fixture
def palace_dir(tmp_path):
"""Isolated palace directory for a single test."""
p = tmp_path / "palace"
p.mkdir()
return str(p)
@pytest.fixture
def kg_db(tmp_path):
"""Isolated KG SQLite path for a single test."""
return str(tmp_path / "test_kg.sqlite3")
@pytest.fixture
def config_dir(tmp_path):
"""Isolated config directory for monkeypatching MempalaceConfig."""
d = tmp_path / "config"
d.mkdir()
config = {"palace_path": str(tmp_path / "palace"), "collection_name": "mempalace_drawers"}
with open(d / "config.json", "w") as f:
json.dump(config, f)
return str(d)
@pytest.fixture
def project_dir(tmp_path):
"""Temporary project directory for mining tests."""
d = tmp_path / "project"
d.mkdir()
return d
# ── Session-scoped result collector ──────────────────────────────────────
class BenchmarkResults:
"""Collect benchmark metrics across all tests in a session."""
def __init__(self):
self.results = {}
def record(self, category: str, metric: str, value):
if category not in self.results:
self.results[category] = {}
self.results[category][metric] = value
@pytest.fixture(scope="session")
def bench_results():
"""Session-scoped results collector shared by all benchmark tests."""
return BenchmarkResults()
def pytest_terminal_summary(terminalreporter, config):
"""Write JSON benchmark report after all tests complete."""
report_path = config.getoption("--bench-report", default=None)
if not report_path:
return
# Collect results written by individual tests via record_metric()
import platform
import subprocess
try:
git_sha = subprocess.check_output(
["git", "rev-parse", "--short", "HEAD"], text=True, stderr=subprocess.DEVNULL
).strip()
except Exception:
git_sha = "unknown"
try:
import chromadb
chromadb_version = chromadb.__version__
except Exception:
chromadb_version = "unknown"
report = {
"timestamp": __import__("datetime").datetime.now().isoformat(),
"git_sha": git_sha,
"python_version": platform.python_version(),
"chromadb_version": chromadb_version,
"scale": config.getoption("--bench-scale", default="small"),
"system": {
"os": platform.system().lower(),
"cpu_count": os.cpu_count(),
"platform": platform.platform(),
},
"results": {},
}
# Read results from the temp file written by record_metric()
results_file = os.path.join(tempfile.gettempdir(), "mempalace_bench_results.json")
if os.path.exists(results_file):
try:
with open(results_file) as f:
report["results"] = json.load(f)
os.unlink(results_file)
except Exception:
pass
os.makedirs(os.path.dirname(os.path.abspath(report_path)), exist_ok=True)
with open(report_path, "w") as f:
json.dump(report, f, indent=2)
terminalreporter.write_line(f"\nBenchmark report written to: {report_path}")
+568
View File
@@ -0,0 +1,568 @@
"""
Deterministic data factory for MemPalace scale benchmarks.
Generates realistic project files, conversations, and KG triples at
configurable scale levels. All randomness uses seeded RNG for reproducibility.
Planted "needle" drawers enable recall measurement without an LLM judge.
"""
import hashlib
import os
import random
from datetime import datetime, timedelta
from pathlib import Path
import chromadb
import yaml
# ── Scale configurations ─────────────────────────────────────────────────
SCALE_CONFIGS = {
"small": {
"drawers": 1_000,
"wings": 3,
"rooms_per_wing": 5,
"kg_entities": 50,
"kg_triples": 200,
"needles": 20,
"search_queries": 20,
},
"medium": {
"drawers": 10_000,
"wings": 8,
"rooms_per_wing": 12,
"kg_entities": 200,
"kg_triples": 2_000,
"needles": 50,
"search_queries": 50,
},
"large": {
"drawers": 50_000,
"wings": 15,
"rooms_per_wing": 20,
"kg_entities": 500,
"kg_triples": 10_000,
"needles": 100,
"search_queries": 100,
},
"stress": {
"drawers": 100_000,
"wings": 25,
"rooms_per_wing": 30,
"kg_entities": 1_000,
"kg_triples": 50_000,
"needles": 200,
"search_queries": 200,
},
}
# ── Vocabulary banks for realistic content ───────────────────────────────
WING_NAMES = [
"webapp",
"backend_api",
"mobile_app",
"data_pipeline",
"ml_platform",
"devops",
"auth_service",
"payments",
"analytics",
"docs_site",
"cli_tool",
"dashboard",
"notification_service",
"search_engine",
"user_mgmt",
"inventory",
"reporting",
"testing_infra",
"monitoring",
"email_service",
"chat_bot",
"file_storage",
"scheduler",
"gateway",
"marketplace",
]
ROOM_NAMES = [
"backend",
"frontend",
"api",
"database",
"auth",
"tests",
"docs",
"config",
"deployment",
"models",
"views",
"controllers",
"middleware",
"utils",
"schemas",
"migrations",
"fixtures",
"scripts",
"styles",
"components",
"hooks",
"services",
"routes",
"templates",
"static",
"media",
"logging",
"cache",
"queue",
"workers",
]
TECH_TERMS = [
"authentication",
"authorization",
"middleware",
"endpoint",
"REST API",
"GraphQL",
"WebSocket",
"database migration",
"ORM",
"query optimization",
"caching strategy",
"load balancer",
"rate limiting",
"pagination",
"serialization",
"validation",
"error handling",
"logging framework",
"monitoring",
"deployment pipeline",
"CI/CD",
"containerization",
"microservice",
"event sourcing",
"message queue",
"pub/sub",
"connection pooling",
"session management",
"token refresh",
"CORS",
"SSL termination",
"health check",
"circuit breaker",
"retry logic",
"batch processing",
"stream processing",
"data pipeline",
"ETL",
"feature flag",
"A/B testing",
"blue-green deployment",
"canary release",
]
CODE_SNIPPETS = [
"def process_request(data):\n validated = schema.validate(data)\n result = handler.execute(validated)\n return Response(result, status=200)\n",
"class UserRepository:\n def __init__(self, db):\n self.db = db\n def find_by_id(self, user_id):\n return self.db.query(User).filter(User.id == user_id).first()\n",
"async def fetch_data(url, timeout=30):\n async with aiohttp.ClientSession() as session:\n async with session.get(url, timeout=timeout) as resp:\n return await resp.json()\n",
"const handleSubmit = async (formData) => {\n try {\n const response = await api.post('/users', formData);\n dispatch({ type: 'USER_CREATED', payload: response.data });\n } catch (error) {\n setError(error.message);\n }\n};\n",
"SELECT u.name, COUNT(o.id) as order_count\nFROM users u\nLEFT JOIN orders o ON u.id = o.user_id\nWHERE u.created_at > '2025-01-01'\nGROUP BY u.name\nHAVING COUNT(o.id) > 5\nORDER BY order_count DESC;\n",
]
PROSE_TEMPLATES = [
"The {component} module handles {task}. It was refactored in {month} to improve {quality}. Key design decision: {decision}.",
"Bug report: {component} fails when {condition}. Root cause: {cause}. Fixed by {fix}. Regression test added in {test_file}.",
"Architecture decision: switched from {old_tech} to {new_tech} for {reason}. Migration completed {date}. Performance improved by {percent}%.",
"Meeting notes: discussed {topic} with {person}. Agreed to {action}. Deadline: {deadline}. Follow-up: {followup}.",
"Feature spec: {feature_name} allows users to {capability}. Dependencies: {deps}. Estimated effort: {effort} days.",
]
ENTITY_NAMES = [
"Alice",
"Bob",
"Carol",
"Dave",
"Eve",
"Frank",
"Grace",
"Heidi",
"Ivan",
"Judy",
"Karl",
"Linda",
"Mike",
"Nina",
"Oscar",
"Pat",
"Quinn",
"Rita",
"Steve",
"Tina",
"Ursula",
"Victor",
"Wendy",
"Xander",
]
ENTITY_TYPES = ["person", "project", "tool", "concept", "team", "service"]
PREDICATES = [
"works_on",
"manages",
"reports_to",
"collaborates_with",
"created",
"maintains",
"uses",
"depends_on",
"replaced",
"reviewed",
"deployed",
"tested",
"documented",
"mentors",
"leads",
"contributes_to",
]
class PalaceDataGenerator:
"""Generate deterministic, realistic test data at configurable scale."""
def __init__(self, seed=42, scale="small"):
self.rng = random.Random(seed)
self.scale = scale
self.cfg = SCALE_CONFIGS[scale]
self.wings = WING_NAMES[: self.cfg["wings"]]
self.rooms_by_wing = {}
for wing in self.wings:
n = self.cfg["rooms_per_wing"]
rooms = self.rng.sample(ROOM_NAMES, min(n, len(ROOM_NAMES)))
self.rooms_by_wing[wing] = rooms
# Planted needles for recall measurement
self.needles = []
self._generate_needles()
def _generate_needles(self):
"""Create unique needle content for recall testing."""
topics = [
"Fibonacci sequence optimization uses memoization with O(n) space complexity",
"PostgreSQL vacuum autovacuum threshold set to 50 percent for table users",
"Redis cluster failover timeout configured at 30 seconds with sentinel monitoring",
"Kubernetes horizontal pod autoscaler targets 70 percent CPU utilization",
"GraphQL subscription uses WebSocket transport with heartbeat interval 25 seconds",
"JWT token rotation policy requires refresh every 15 minutes with sliding window",
"Elasticsearch index sharding strategy uses 5 primary shards with 1 replica each",
"Docker multi-stage build reduces image size from 1.2GB to 180MB for production",
"Apache Kafka consumer group rebalance timeout set to 45 seconds",
"MongoDB change streams resume token persisted every 100 operations",
"gRPC streaming uses bidirectional flow control with 64KB window size",
"Prometheus alerting rule fires when p99 latency exceeds 500ms for 5 minutes",
"Terraform state locking uses DynamoDB with consistent reads enabled",
"Nginx rate limiting configured at 100 requests per second with burst of 50",
"SQLAlchemy connection pool size set to 20 with max overflow of 10 connections",
"React concurrent mode uses startTransition for non-urgent state updates",
"AWS Lambda cold start mitigation uses provisioned concurrency of 10 instances",
"Git bisect automated with custom test script for regression hunting",
"OpenTelemetry trace sampling rate set to 10 percent in production environment",
"Celery worker prefetch multiplier set to 1 for fair task distribution",
]
for i in range(self.cfg["needles"]):
topic = topics[i % len(topics)]
wing = self.rng.choice(self.wings)
room = self.rng.choice(self.rooms_by_wing[wing])
needle_id = f"NEEDLE_{i:04d}"
content = f"{needle_id}: {topic}. This is a unique planted needle for recall benchmarking at scale."
self.needles.append(
{
"id": needle_id,
"content": content,
"wing": wing,
"room": room,
"query": topic.split(" uses ")[0]
if " uses " in topic
else topic.split(" set to ")[0]
if " set to " in topic
else topic[:60],
}
)
def _random_text(self, min_chars=600, max_chars=900):
"""Generate a random text block of realistic content."""
parts = []
total = 0
target = self.rng.randint(min_chars, max_chars)
while total < target:
choice = self.rng.random()
if choice < 0.3:
text = self.rng.choice(CODE_SNIPPETS)
elif choice < 0.7:
template = self.rng.choice(PROSE_TEMPLATES)
text = template.format(
component=self.rng.choice(ROOM_NAMES),
task=self.rng.choice(TECH_TERMS),
month=self.rng.choice(["January", "February", "March", "April", "May"]),
quality=self.rng.choice(
["performance", "readability", "test coverage", "latency"]
),
decision=self.rng.choice(TECH_TERMS),
condition=self.rng.choice(TECH_TERMS) + " is null",
cause=self.rng.choice(["race condition", "null pointer", "timeout", "OOM"]),
fix="adding " + self.rng.choice(TECH_TERMS),
test_file=f"test_{self.rng.choice(ROOM_NAMES)}.py",
old_tech=self.rng.choice(["MySQL", "Flask", "REST", "Jenkins"]),
new_tech=self.rng.choice(
["PostgreSQL", "FastAPI", "GraphQL", "GitHub Actions"]
),
reason=self.rng.choice(TECH_TERMS),
date=f"2025-{self.rng.randint(1, 12):02d}-{self.rng.randint(1, 28):02d}",
percent=self.rng.randint(10, 80),
topic=self.rng.choice(TECH_TERMS),
person=self.rng.choice(ENTITY_NAMES),
action=self.rng.choice(["refactor", "migrate", "optimize", "test"]),
deadline=f"2025-{self.rng.randint(1, 12):02d}-{self.rng.randint(1, 28):02d}",
followup=self.rng.choice(TECH_TERMS),
feature_name=self.rng.choice(TECH_TERMS),
capability=self.rng.choice(TECH_TERMS),
deps=", ".join(self.rng.sample(TECH_TERMS, 2)),
effort=self.rng.randint(1, 15),
)
else:
words = self.rng.sample(TECH_TERMS, min(5, len(TECH_TERMS)))
text = (
" ".join(words)
+ ". "
+ self.rng.choice(TECH_TERMS)
+ " implementation details follow.\n"
)
parts.append(text)
total += len(text)
return "\n".join(parts)[:max_chars]
# ── Project tree generation (for mine() tests) ───────────────────────
def generate_project_tree(self, base_path, wing=None, rooms=None, n_files=50):
"""
Write realistic project files + mempalace.yaml to base_path.
Returns the project path suitable for passing to mine().
"""
base = Path(base_path)
base.mkdir(parents=True, exist_ok=True)
wing = wing or self.rng.choice(self.wings)
rooms = rooms or self.rooms_by_wing.get(wing, ["general"])
# Write mempalace.yaml
room_defs = [{"name": r, "description": f"{r} code and docs"} for r in rooms]
with open(base / "mempalace.yaml", "w") as f:
yaml.dump({"wing": wing, "rooms": room_defs}, f)
# Write files distributed across room directories
files_written = 0
for i in range(n_files):
room = rooms[i % len(rooms)]
room_dir = base / room
room_dir.mkdir(parents=True, exist_ok=True)
ext = self.rng.choice([".py", ".js", ".md", ".ts", ".yaml"])
filename = f"file_{i:04d}{ext}"
content = self._random_text(400, 2000)
(room_dir / filename).write_text(content, encoding="utf-8")
files_written += 1
return str(base), wing, rooms, files_written
# ── Conversation file generation (for mine_convos() tests) ───────────
def generate_conversation_files(self, base_path, wing=None, n_files=20):
"""Write conversation transcript files for convo_miner tests."""
base = Path(base_path)
base.mkdir(parents=True, exist_ok=True)
wing = wing or self.rng.choice(self.wings)
for i in range(n_files):
lines = []
n_exchanges = self.rng.randint(5, 20)
for j in range(n_exchanges):
user_msg = f"> User: {self.rng.choice(TECH_TERMS)}? How does {self.rng.choice(TECH_TERMS)} work with {self.rng.choice(TECH_TERMS)}?"
ai_msg = self._random_text(200, 600)
lines.append(user_msg)
lines.append(ai_msg)
lines.append("")
(base / f"convo_{i:04d}.txt").write_text("\n".join(lines), encoding="utf-8")
return str(base), wing
# ── Direct palace population (bypasses mining for speed) ─────────────
def populate_palace_directly(self, palace_path, n_drawers=None, include_needles=True):
"""
Insert drawers directly into ChromaDB, bypassing the mining pipeline.
Much faster than mining for benchmarks that only care about
search/MCP behavior on a pre-populated palace.
Returns (client, collection, needle_info).
"""
n_drawers = n_drawers or self.cfg["drawers"]
os.makedirs(palace_path, exist_ok=True)
client = chromadb.PersistentClient(path=palace_path)
col = client.get_or_create_collection("mempalace_drawers")
batch_size = 500
docs = []
ids = []
metas = []
# Insert needles first
needle_info = []
if include_needles:
for needle in self.needles:
needle_id = f"drawer_{needle['wing']}_{needle['room']}_{hashlib.md5(needle['id'].encode()).hexdigest()[:16]}"
docs.append(needle["content"])
ids.append(needle_id)
metas.append(
{
"wing": needle["wing"],
"room": needle["room"],
"source_file": f"needle_{needle['id']}.txt",
"chunk_index": 0,
"added_by": "benchmark",
"filed_at": datetime.now().isoformat(),
}
)
needle_info.append(
{
"id": needle_id,
"query": needle["query"],
"wing": needle["wing"],
"room": needle["room"],
}
)
# Fill remaining drawers with realistic content
remaining = n_drawers - len(docs)
for i in range(remaining):
wing = self.wings[i % len(self.wings)]
rooms = self.rooms_by_wing[wing]
room = rooms[i % len(rooms)]
content = self._random_text(400, 800)
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(f'gen_{i}'.encode()).hexdigest()[:16]}"
docs.append(content)
ids.append(drawer_id)
metas.append(
{
"wing": wing,
"room": room,
"source_file": f"generated_{i:06d}.txt",
"chunk_index": i % 10,
"added_by": "benchmark",
"filed_at": datetime.now().isoformat(),
}
)
# Flush in batches
if len(docs) >= batch_size:
col.add(documents=docs, ids=ids, metadatas=metas)
docs, ids, metas = [], [], []
# Flush remainder
if docs:
col.add(documents=docs, ids=ids, metadatas=metas)
return client, col, needle_info
# ── KG triple generation ─────────────────────────────────────────────
def generate_kg_triples(self, n_entities=None, n_triples=None):
"""
Generate realistic entity-relationship triples.
Returns (entities, triples) where:
entities = [(name, type), ...]
triples = [(subject, predicate, object, valid_from, valid_to), ...]
"""
n_entities = n_entities or self.cfg["kg_entities"]
n_triples = n_triples or self.cfg["kg_triples"]
# Generate entities
entities = []
entity_names = []
for i in range(n_entities):
if i < len(ENTITY_NAMES):
name = ENTITY_NAMES[i]
else:
name = f"Entity_{i:04d}"
etype = self.rng.choice(ENTITY_TYPES)
entities.append((name, etype))
entity_names.append(name)
# Generate triples
triples = []
base_date = datetime(2024, 1, 1)
for i in range(n_triples):
subject = self.rng.choice(entity_names)
obj = self.rng.choice(entity_names)
while obj == subject:
obj = self.rng.choice(entity_names)
predicate = self.rng.choice(PREDICATES)
days_offset = self.rng.randint(0, 730)
valid_from = (base_date + timedelta(days=days_offset)).strftime("%Y-%m-%d")
# 30% chance of having a valid_to
valid_to = None
if self.rng.random() < 0.3:
end_offset = self.rng.randint(30, 365)
valid_to = (base_date + timedelta(days=days_offset + end_offset)).strftime(
"%Y-%m-%d"
)
triples.append((subject, predicate, obj, valid_from, valid_to))
return entities, triples
# ── Search query generation ──────────────────────────────────────────
def generate_search_queries(self, n_queries=None):
"""
Generate search queries with expected results.
Returns list of {"query": str, "expected_wing": str|None, "expected_room": str|None, "is_needle": bool}.
Needle queries have known-good answers for recall measurement.
"""
n_queries = n_queries or self.cfg["search_queries"]
queries = []
# Half are needle queries (known-good answers)
n_needle = min(n_queries // 2, len(self.needles))
for needle in self.needles[:n_needle]:
queries.append(
{
"query": needle["query"],
"expected_wing": needle["wing"],
"expected_room": needle["room"],
"needle_id": needle["id"],
"is_needle": True,
}
)
# Other half are generic queries (measure latency, not recall)
n_generic = n_queries - n_needle
for _ in range(n_generic):
queries.append(
{
"query": self.rng.choice(TECH_TERMS) + " " + self.rng.choice(TECH_TERMS),
"expected_wing": None,
"expected_room": None,
"needle_id": None,
"is_needle": False,
}
)
self.rng.shuffle(queries)
return queries
+117
View File
@@ -0,0 +1,117 @@
"""
Benchmark report utilities — JSON output and regression detection.
Each test records metrics via record_metric(). At session end, the
conftest.py pytest_terminal_summary hook writes the collected results.
"""
import json
import os
import tempfile
RESULTS_FILE = os.path.join(tempfile.gettempdir(), "mempalace_bench_results.json")
def record_metric(category: str, metric: str, value):
"""Append a metric to the session results file (JSON on disk)."""
results = {}
if os.path.exists(RESULTS_FILE):
try:
with open(RESULTS_FILE) as f:
results = json.load(f)
except (json.JSONDecodeError, OSError):
results = {}
if category not in results:
results[category] = {}
results[category][metric] = value
with open(RESULTS_FILE, "w") as f:
json.dump(results, f, indent=2)
def check_regression(current_report: str, baseline_report: str, threshold: float = 0.2):
"""
Compare current benchmark results against a baseline.
Returns a list of regression descriptions. Empty list = no regressions.
threshold: fractional degradation allowed (0.2 = 20% worse is OK).
"""
with open(current_report) as f:
current = json.load(f)
with open(baseline_report) as f:
baseline = json.load(f)
regressions = []
# Keywords for metric direction — checked in order, first match wins.
# "improvement" is checked before "latency" so that composite names
# like "latency_improvement_pct" are classified correctly.
_higher_is_better_kw = [
"improvement",
"recall",
"throughput",
"per_sec",
"files_per_sec",
"drawers_per_sec",
"triples_per_sec",
"speedup",
]
_higher_is_worse_kw = [
"latency",
"rss",
"memory",
"oom",
"lock_failures",
"elapsed",
"p50_ms",
"p95_ms",
"p99_ms",
"rss_delta_mb",
"peak_rss_mb",
"errors",
"failures",
]
def _metric_direction(name: str) -> str:
"""Return 'higher_better', 'higher_worse', or 'unknown'."""
low = name.lower()
for kw in _higher_is_better_kw:
if kw in low:
return "higher_better"
for kw in _higher_is_worse_kw:
if kw in low:
return "higher_worse"
return "unknown"
for category in baseline.get("results", {}):
if category not in current.get("results", {}):
continue
for metric, base_val in baseline["results"][category].items():
if metric not in current["results"][category]:
continue
curr_val = current["results"][category][metric]
if not isinstance(base_val, (int, float)) or not isinstance(curr_val, (int, float)):
continue
if base_val == 0:
continue
direction = _metric_direction(metric)
if direction == "higher_worse":
# Higher is worse — check if current exceeds baseline by threshold
if curr_val > base_val * (1 + threshold):
pct = ((curr_val - base_val) / base_val) * 100
regressions.append(
f"{category}/{metric}: {base_val:.2f} -> {curr_val:.2f} ({pct:+.1f}%, threshold {threshold * 100:.0f}%)"
)
elif direction == "higher_better":
# Lower is worse — check if current is below baseline by threshold
if curr_val < base_val * (1 - threshold):
pct = ((curr_val - base_val) / base_val) * 100
regressions.append(
f"{category}/{metric}: {base_val:.2f} -> {curr_val:.2f} ({pct:+.1f}%, threshold {threshold * 100:.0f}%)"
)
return regressions
+206
View File
@@ -0,0 +1,206 @@
"""
ChromaDB stress tests — find the breaking point.
Tests the raw ChromaDB patterns used by mempalace to determine:
- At what collection size does col.get(include=["metadatas"]) become dangerous?
- How does query latency degrade as collection grows?
- How much faster is batched insertion vs sequential?
"""
import os
import time
import chromadb
import pytest
from tests.benchmarks.data_generator import PalaceDataGenerator
from tests.benchmarks.report import record_metric
def _get_rss_mb():
try:
import psutil
return psutil.Process().memory_info().rss / (1024 * 1024)
except ImportError:
import resource
import platform
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
if platform.system() == "Darwin":
return usage / (1024 * 1024)
return usage / 1024
@pytest.mark.benchmark
class TestGetAllMetadatasOOM:
"""
The specific pattern causing finding #3:
col.get(include=["metadatas"]) with NO limit.
Measures RSS growth to find when this becomes dangerous.
"""
SIZES = [1_000, 2_500, 5_000, 10_000]
@pytest.mark.parametrize("n_drawers", SIZES)
def test_get_all_metadatas_rss(self, n_drawers, tmp_path, bench_scale):
"""RSS growth from fetching all metadata at once."""
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection("mempalace_drawers")
rss_before = _get_rss_mb()
start = time.perf_counter()
all_meta = col.get(include=["metadatas"])["metadatas"]
elapsed_ms = (time.perf_counter() - start) * 1000
rss_after = _get_rss_mb()
assert len(all_meta) == n_drawers
rss_delta = rss_after - rss_before
record_metric("chromadb_get_all", f"rss_delta_mb_at_{n_drawers}", round(rss_delta, 2))
record_metric("chromadb_get_all", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
@pytest.mark.benchmark
class TestQueryDegradation:
"""Measure query latency as collection grows."""
SIZES = [1_000, 2_500, 5_000, 10_000]
@pytest.mark.parametrize("n_drawers", SIZES)
def test_query_latency_at_size(self, n_drawers, tmp_path, bench_scale):
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection("mempalace_drawers")
queries = [
"authentication middleware optimization",
"database connection pooling strategy",
"error handling retry logic",
"deployment pipeline configuration",
"load balancer health check",
]
latencies = []
for q in queries:
start = time.perf_counter()
results = col.query(query_texts=[q], n_results=5, include=["documents", "distances"])
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
assert results["documents"][0] # got results
avg_ms = sum(latencies) / len(latencies)
p95_ms = sorted(latencies)[int(len(latencies) * 0.95)]
record_metric("chromadb_query", f"avg_latency_ms_at_{n_drawers}", round(avg_ms, 1))
record_metric("chromadb_query", f"p95_latency_ms_at_{n_drawers}", round(p95_ms, 1))
@pytest.mark.benchmark
class TestBulkInsertPerformance:
"""Compare batch insertion vs sequential add_drawer pattern."""
def test_sequential_vs_batched(self, tmp_path):
"""The current miner uses single-document add(). How much faster is batching?"""
n_docs = 500
gen = PalaceDataGenerator(seed=42)
# Generate content
contents = [gen._random_text(400, 800) for _ in range(n_docs)]
# Sequential insertion (mimics add_drawer pattern)
palace_seq = str(tmp_path / "seq")
os.makedirs(palace_seq)
client_seq = chromadb.PersistentClient(path=palace_seq)
col_seq = client_seq.get_or_create_collection("mempalace_drawers")
start = time.perf_counter()
for i, content in enumerate(contents):
col_seq.add(
documents=[content],
ids=[f"seq_{i}"],
metadatas=[{"wing": "test", "room": "bench", "chunk_index": i}],
)
sequential_ms = (time.perf_counter() - start) * 1000
# Batched insertion
palace_batch = str(tmp_path / "batch")
os.makedirs(palace_batch)
client_batch = chromadb.PersistentClient(path=palace_batch)
col_batch = client_batch.get_or_create_collection("mempalace_drawers")
batch_size = 100
start = time.perf_counter()
for batch_start in range(0, n_docs, batch_size):
batch_end = min(batch_start + batch_size, n_docs)
batch_docs = contents[batch_start:batch_end]
batch_ids = [f"batch_{i}" for i in range(batch_start, batch_end)]
batch_metas = [
{"wing": "test", "room": "bench", "chunk_index": i}
for i in range(batch_start, batch_end)
]
col_batch.add(documents=batch_docs, ids=batch_ids, metadatas=batch_metas)
batched_ms = (time.perf_counter() - start) * 1000
speedup = sequential_ms / max(batched_ms, 0.01)
assert col_seq.count() == n_docs
assert col_batch.count() == n_docs
record_metric("chromadb_insert", "sequential_ms", round(sequential_ms, 1))
record_metric("chromadb_insert", "batched_ms", round(batched_ms, 1))
record_metric("chromadb_insert", "speedup_ratio", round(speedup, 2))
record_metric("chromadb_insert", "n_docs", n_docs)
record_metric("chromadb_insert", "batch_size", batch_size)
@pytest.mark.benchmark
@pytest.mark.slow
class TestMaxCollectionSize:
"""Incrementally grow collection to find practical limits."""
def test_incremental_growth(self, tmp_path, bench_scale):
"""Add drawers in batches, measure latency per batch."""
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
cfg = gen.cfg
target = min(cfg["drawers"], 10_000) # cap at 10K for this test
palace_path = str(tmp_path / "palace")
os.makedirs(palace_path)
client = chromadb.PersistentClient(path=palace_path)
col = client.get_or_create_collection("mempalace_drawers")
batch_size = 500
batch_times = []
total_inserted = 0
for batch_num in range(0, target, batch_size):
n = min(batch_size, target - batch_num)
docs = [gen._random_text(400, 800) for _ in range(n)]
ids = [f"growth_{batch_num + i}" for i in range(n)]
metas = [
{"wing": gen.wings[i % len(gen.wings)], "room": "bench", "chunk_index": i}
for i in range(batch_num, batch_num + n)
]
start = time.perf_counter()
col.add(documents=docs, ids=ids, metadatas=metas)
batch_ms = (time.perf_counter() - start) * 1000
total_inserted += n
batch_times.append({"at_size": total_inserted, "batch_ms": round(batch_ms, 1)})
assert col.count() == total_inserted
# Record first and last batch times to show degradation
record_metric("chromadb_growth", "first_batch_ms", batch_times[0]["batch_ms"])
record_metric("chromadb_growth", "last_batch_ms", batch_times[-1]["batch_ms"])
record_metric("chromadb_growth", "total_inserted", total_inserted)
record_metric("chromadb_growth", "batch_times", batch_times)
+169
View File
@@ -0,0 +1,169 @@
"""
Ingestion throughput benchmarks.
Measures mining performance at scale:
- Files/sec and drawers/sec through the full mine() pipeline
- Peak RSS during mining
- Chunking throughput isolated from ChromaDB
- Re-ingest skip overhead (finding #11: file_already_mined check)
"""
import time
import chromadb
import pytest
from tests.benchmarks.data_generator import PalaceDataGenerator
from tests.benchmarks.report import record_metric
def _get_rss_mb():
try:
import psutil
return psutil.Process().memory_info().rss / (1024 * 1024)
except ImportError:
import resource
import platform
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
if platform.system() == "Darwin":
return usage / (1024 * 1024)
return usage / 1024
@pytest.mark.benchmark
class TestMineThroughput:
"""Measure the full mine() pipeline throughput."""
@pytest.mark.parametrize("n_files", [20, 50, 100])
def test_mine_files_per_second(self, n_files, tmp_path, bench_scale):
"""End-to-end mining throughput: generate files, mine, count drawers."""
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
project_path, wing, rooms, files_written = gen.generate_project_tree(
tmp_path / "project", n_files=n_files
)
palace_path = str(tmp_path / "palace")
from mempalace.miner import mine
start = time.perf_counter()
mine(project_path, palace_path)
elapsed = time.perf_counter() - start
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection("mempalace_drawers")
drawer_count = col.count()
files_per_sec = files_written / max(elapsed, 0.001)
drawers_per_sec = drawer_count / max(elapsed, 0.001)
record_metric("ingest", f"files_per_sec_at_{n_files}", round(files_per_sec, 1))
record_metric("ingest", f"drawers_per_sec_at_{n_files}", round(drawers_per_sec, 1))
record_metric("ingest", f"elapsed_sec_at_{n_files}", round(elapsed, 2))
record_metric("ingest", f"drawers_created_at_{n_files}", drawer_count)
def test_mine_peak_rss(self, tmp_path, bench_scale):
"""Track peak RSS during a mining run."""
import threading
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
project_path, wing, rooms, files_written = gen.generate_project_tree(
tmp_path / "project", n_files=100
)
palace_path = str(tmp_path / "palace")
from mempalace.miner import mine
rss_samples = []
stop_sampling = threading.Event()
def sample_rss():
while not stop_sampling.is_set():
rss_samples.append(_get_rss_mb())
stop_sampling.wait(0.1)
sampler = threading.Thread(target=sample_rss, daemon=True)
sampler.start()
rss_before = _get_rss_mb()
mine(project_path, palace_path)
stop_sampling.set()
sampler.join(timeout=1)
peak_rss = max(rss_samples) if rss_samples else _get_rss_mb()
rss_delta = peak_rss - rss_before
record_metric("ingest", "peak_rss_mb", round(peak_rss, 1))
record_metric("ingest", "rss_delta_mb", round(rss_delta, 1))
@pytest.mark.benchmark
class TestChunkThroughput:
"""Isolate chunking performance from ChromaDB insertion."""
@pytest.mark.parametrize("content_size_kb", [1, 10, 100])
def test_chunk_text_throughput(self, content_size_kb):
"""Measure chunk_text speed for different content sizes."""
from mempalace.miner import chunk_text
gen = PalaceDataGenerator(seed=42)
# Generate content of target size
content = gen._random_text(content_size_kb * 500, content_size_kb * 1200)
# Pad to approximate target KB
while len(content) < content_size_kb * 1024:
content += "\n" + gen._random_text(200, 500)
n_iterations = 50
start = time.perf_counter()
total_chunks = 0
for _ in range(n_iterations):
chunks = chunk_text(content, "bench_file.py")
total_chunks += len(chunks)
elapsed = time.perf_counter() - start
chunks_per_sec = total_chunks / max(elapsed, 0.001)
kb_per_sec = (len(content) * n_iterations / 1024) / max(elapsed, 0.001)
record_metric(
"chunking", f"chunks_per_sec_at_{content_size_kb}kb", round(chunks_per_sec, 1)
)
record_metric("chunking", f"kb_per_sec_at_{content_size_kb}kb", round(kb_per_sec, 1))
@pytest.mark.benchmark
class TestReingestSkipOverhead:
"""Finding #11: file_already_mined() check overhead at scale."""
def test_skip_check_cost(self, tmp_path):
"""Mine files, then re-mine — measure cost of skip checks."""
gen = PalaceDataGenerator(seed=42, scale="small")
project_path, wing, rooms, files_written = gen.generate_project_tree(
tmp_path / "project", n_files=50
)
palace_path = str(tmp_path / "palace")
from mempalace.miner import mine
# First mine
mine(project_path, palace_path)
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection("mempalace_drawers")
initial_count = col.count()
# Re-mine (all files should be skipped)
start = time.perf_counter()
mine(project_path, palace_path)
skip_elapsed = time.perf_counter() - start
# Verify no new drawers added
final_count = col.count()
assert final_count == initial_count, "Re-mine should not add new drawers"
record_metric("reingest", "skip_check_elapsed_sec", round(skip_elapsed, 2))
record_metric("reingest", "files_checked", files_written)
record_metric(
"reingest",
"skip_check_per_file_ms",
round(skip_elapsed * 1000 / max(files_written, 1), 1),
)
@@ -0,0 +1,290 @@
"""
Knowledge graph benchmarks — SQLite temporal KG at scale.
Tests triple insertion throughput, query latency, temporal accuracy,
and SQLite concurrent access behavior.
"""
import threading
import time
import pytest
from tests.benchmarks.data_generator import PalaceDataGenerator
from tests.benchmarks.report import record_metric
@pytest.mark.benchmark
class TestTripleInsertionRate:
"""Measure triples/sec at different scales."""
@pytest.mark.parametrize("n_triples", [200, 1_000, 5_000])
def test_insertion_throughput(self, n_triples, tmp_path):
gen = PalaceDataGenerator(seed=42, scale="small")
entities, triples = gen.generate_kg_triples(
n_entities=min(n_triples // 2, 200), n_triples=n_triples
)
from mempalace.knowledge_graph import KnowledgeGraph
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
# Insert entities first
for name, etype in entities:
kg.add_entity(name, etype)
# Measure triple insertion
start = time.perf_counter()
for subject, predicate, obj, valid_from, valid_to in triples:
kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to)
elapsed = time.perf_counter() - start
triples_per_sec = n_triples / max(elapsed, 0.001)
record_metric("kg_insert", f"triples_per_sec_at_{n_triples}", round(triples_per_sec, 1))
record_metric("kg_insert", f"elapsed_sec_at_{n_triples}", round(elapsed, 3))
@pytest.mark.benchmark
class TestQueryEntityLatency:
"""Query latency for entities with varying relationship counts."""
def test_query_latency_vs_relationships(self, tmp_path):
"""Create entities with 10, 50, 100 relationships and measure query time."""
from mempalace.knowledge_graph import KnowledgeGraph
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
# Create a hub entity connected to many others
kg.add_entity("Hub", "person")
target_counts = [10, 50, 100]
for target in target_counts:
for i in range(target):
entity_name = f"Node_{target}_{i}"
kg.add_entity(entity_name, "project")
kg.add_triple("Hub", "works_on", entity_name, valid_from="2025-01-01")
# Measure query for Hub (which has sum(target_counts) relationships)
latencies = []
for _ in range(20):
start = time.perf_counter()
kg.query_entity("Hub")
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
avg_ms = sum(latencies) / len(latencies)
total_rels = sum(target_counts)
record_metric("kg_query", f"avg_ms_with_{total_rels}_rels", round(avg_ms, 2))
record_metric("kg_query", "total_relationships", total_rels)
@pytest.mark.benchmark
class TestTimelinePerformance:
"""timeline() with no entity filter does a full table scan."""
@pytest.mark.parametrize("n_triples", [200, 1_000, 5_000])
def test_timeline_latency(self, n_triples, tmp_path):
from mempalace.knowledge_graph import KnowledgeGraph
gen = PalaceDataGenerator(seed=42)
entities, triples = gen.generate_kg_triples(
n_entities=min(n_triples // 2, 200), n_triples=n_triples
)
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
for name, etype in entities:
kg.add_entity(name, etype)
for subject, predicate, obj, valid_from, valid_to in triples:
kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to)
# Measure timeline (no filter = full scan with LIMIT 100)
latencies = []
for _ in range(10):
start = time.perf_counter()
kg.timeline()
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
avg_ms = sum(latencies) / len(latencies)
record_metric("kg_timeline", f"avg_ms_at_{n_triples}", round(avg_ms, 2))
@pytest.mark.benchmark
class TestTemporalQueryAccuracy:
"""Verify temporal filtering correctness at scale."""
def test_as_of_filtering(self, tmp_path):
"""Insert triples with known temporal ranges, verify as_of queries."""
from mempalace.knowledge_graph import KnowledgeGraph
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
kg.add_entity("Alice", "person")
kg.add_entity("ProjectA", "project")
kg.add_entity("ProjectB", "project")
# Alice worked on ProjectA from 2024-01 to 2024-06
kg.add_triple(
"Alice", "works_on", "ProjectA", valid_from="2024-01-01", valid_to="2024-06-30"
)
# Alice worked on ProjectB from 2024-07 onwards
kg.add_triple("Alice", "works_on", "ProjectB", valid_from="2024-07-01")
# Add noise triples
gen = PalaceDataGenerator(seed=42)
entities, triples = gen.generate_kg_triples(n_entities=50, n_triples=500)
for name, etype in entities:
kg.add_entity(name, etype)
for subject, predicate, obj, valid_from, valid_to in triples:
kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to)
# Query Alice as of March 2024 — should find ProjectA
result_march = kg.query_entity("Alice", as_of="2024-03-15")
# Query Alice as of September 2024 — should find ProjectB
result_sept = kg.query_entity("Alice", as_of="2024-09-15")
record_metric(
"kg_temporal",
"march_query_results",
len(result_march) if isinstance(result_march, list) else 0,
)
record_metric(
"kg_temporal",
"sept_query_results",
len(result_sept) if isinstance(result_sept, list) else 0,
)
@pytest.mark.benchmark
class TestSQLiteConcurrentAccess:
"""Test concurrent read/write behavior with SQLite (finding #8)."""
def test_concurrent_writers(self, tmp_path):
"""N threads writing triples simultaneously — count lock failures."""
from mempalace.knowledge_graph import KnowledgeGraph
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
# Pre-create entities
for i in range(100):
kg.add_entity(f"Entity_{i}", "concept")
n_threads = 4
triples_per_thread = 50
lock_failures = []
successes = []
def writer(thread_id):
fails = 0
ok = 0
for i in range(triples_per_thread):
try:
kg.add_triple(
f"Entity_{thread_id * 10}",
"relates_to",
f"Entity_{(thread_id * 10 + i) % 100}",
valid_from="2025-01-01",
)
ok += 1
except Exception:
fails += 1
lock_failures.append(fails)
successes.append(ok)
threads = [threading.Thread(target=writer, args=(t,)) for t in range(n_threads)]
start = time.perf_counter()
for t in threads:
t.start()
for t in threads:
t.join(timeout=30)
elapsed = time.perf_counter() - start
total_failures = sum(lock_failures)
total_successes = sum(successes)
record_metric("kg_concurrent", "total_failures", total_failures)
record_metric("kg_concurrent", "total_successes", total_successes)
record_metric("kg_concurrent", "elapsed_sec", round(elapsed, 2))
record_metric("kg_concurrent", "threads", n_threads)
record_metric("kg_concurrent", "triples_per_thread", triples_per_thread)
def test_concurrent_read_write(self, tmp_path):
"""Readers and writers running simultaneously."""
from mempalace.knowledge_graph import KnowledgeGraph
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
# Seed some data
for i in range(50):
kg.add_entity(f"E_{i}", "concept")
for i in range(200):
kg.add_triple(f"E_{i % 50}", "links", f"E_{(i + 1) % 50}", valid_from="2025-01-01")
read_errors = []
write_errors = []
def reader():
fails = 0
for i in range(50):
try:
kg.query_entity(f"E_{i % 50}")
except Exception:
fails += 1
read_errors.append(fails)
def writer():
fails = 0
for i in range(50):
try:
kg.add_triple(
f"E_{i % 50}", "new_rel", f"E_{(i + 7) % 50}", valid_from="2025-06-01"
)
except Exception:
fails += 1
write_errors.append(fails)
threads = [
threading.Thread(target=reader),
threading.Thread(target=reader),
threading.Thread(target=writer),
threading.Thread(target=writer),
]
for t in threads:
t.start()
for t in threads:
t.join(timeout=30)
record_metric("kg_concurrent_rw", "read_errors", sum(read_errors))
record_metric("kg_concurrent_rw", "write_errors", sum(write_errors))
@pytest.mark.benchmark
class TestKGStats:
"""Measure stats() performance as graph grows."""
@pytest.mark.parametrize("n_triples", [200, 1_000, 5_000])
def test_stats_latency(self, n_triples, tmp_path):
from mempalace.knowledge_graph import KnowledgeGraph
gen = PalaceDataGenerator(seed=42)
entities, triples = gen.generate_kg_triples(
n_entities=min(n_triples // 2, 200), n_triples=n_triples
)
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
for name, etype in entities:
kg.add_entity(name, etype)
for subject, predicate, obj, valid_from, valid_to in triples:
kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to)
latencies = []
for _ in range(10):
start = time.perf_counter()
kg.stats()
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
avg_ms = sum(latencies) / len(latencies)
record_metric("kg_stats", f"avg_ms_at_{n_triples}", round(avg_ms, 2))
+209
View File
@@ -0,0 +1,209 @@
"""
Memory stack (layers.py) benchmarks.
Tests MemoryStack.wake_up(), Layer1.generate(), and Layer2/L3
at scale. Layer1 has the same unbounded col.get() as tool_status.
"""
import time
import pytest
from tests.benchmarks.data_generator import PalaceDataGenerator
from tests.benchmarks.report import record_metric
def _get_rss_mb():
try:
import psutil
return psutil.Process().memory_info().rss / (1024 * 1024)
except ImportError:
import resource
import platform
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
if platform.system() == "Darwin":
return usage / (1024 * 1024)
return usage / 1024
@pytest.mark.benchmark
class TestWakeUpCost:
"""Measure wake_up() time (L0 + L1) at different palace sizes."""
SIZES = [500, 1_000, 2_500, 5_000]
@pytest.mark.parametrize("n_drawers", SIZES)
def test_wakeup_latency(self, n_drawers, tmp_path, bench_scale):
"""L0+L1 generation time grows with palace size because L1 fetches all."""
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
# Create identity file
identity_path = str(tmp_path / "identity.txt")
with open(identity_path, "w") as f:
f.write("I am a test AI. Traits: precise, fast.\n")
from mempalace.layers import MemoryStack
stack = MemoryStack(palace_path=palace_path, identity_path=identity_path)
latencies = []
for _ in range(5):
start = time.perf_counter()
text = stack.wake_up()
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
assert "L0" in text or "L1" in text or "IDENTITY" in text or "ESSENTIAL" in text
avg_ms = sum(latencies) / len(latencies)
record_metric("layers_wakeup", f"avg_ms_at_{n_drawers}", round(avg_ms, 1))
@pytest.mark.benchmark
class TestLayer1UnboundedFetch:
"""Layer1.generate() fetches ALL drawers — same pattern as tool_status."""
SIZES = [500, 1_000, 2_500, 5_000]
@pytest.mark.parametrize("n_drawers", SIZES)
def test_layer1_rss_growth(self, n_drawers, tmp_path):
"""Track RSS from Layer1 fetching all drawers at different sizes."""
gen = PalaceDataGenerator(seed=42, scale="small")
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
from mempalace.layers import Layer1
layer = Layer1(palace_path=palace_path)
rss_before = _get_rss_mb()
start = time.perf_counter()
text = layer.generate()
elapsed_ms = (time.perf_counter() - start) * 1000
rss_after = _get_rss_mb()
rss_delta = rss_after - rss_before
assert "L1" in text
record_metric("layer1", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
record_metric("layer1", f"rss_delta_mb_at_{n_drawers}", round(rss_delta, 2))
def test_layer1_wing_filtered(self, tmp_path):
"""Wing-filtered Layer1 should fetch fewer drawers."""
gen = PalaceDataGenerator(seed=42, scale="small")
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
from mempalace.layers import Layer1
wing = gen.wings[0]
# Unfiltered
layer_all = Layer1(palace_path=palace_path)
start = time.perf_counter()
layer_all.generate()
unfiltered_ms = (time.perf_counter() - start) * 1000
# Wing-filtered
layer_wing = Layer1(palace_path=palace_path, wing=wing)
start = time.perf_counter()
layer_wing.generate()
filtered_ms = (time.perf_counter() - start) * 1000
record_metric("layer1_filter", "unfiltered_ms", round(unfiltered_ms, 1))
record_metric("layer1_filter", "filtered_ms", round(filtered_ms, 1))
if unfiltered_ms > 0:
record_metric(
"layer1_filter", "speedup_pct", round((1 - filtered_ms / unfiltered_ms) * 100, 1)
)
@pytest.mark.benchmark
class TestWakeUpTokenBudget:
"""Verify L0+L1 stays within token budget even at large palace sizes."""
SIZES = [500, 1_000, 2_500, 5_000]
@pytest.mark.parametrize("n_drawers", SIZES)
def test_token_budget(self, n_drawers, tmp_path):
"""L1 has MAX_CHARS=3200 cap. Verify it holds at scale."""
gen = PalaceDataGenerator(seed=42, scale="small")
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
identity_path = str(tmp_path / "identity.txt")
with open(identity_path, "w") as f:
f.write("I am a benchmark AI.\n")
from mempalace.layers import MemoryStack
stack = MemoryStack(palace_path=palace_path, identity_path=identity_path)
text = stack.wake_up()
token_estimate = len(text) // 4
# Budget is ~600-900 tokens. Allow up to 1200 for safety margin.
record_metric("wakeup_budget", f"tokens_at_{n_drawers}", token_estimate)
record_metric("wakeup_budget", f"chars_at_{n_drawers}", len(text))
assert (
token_estimate < 1200
), f"Wake-up exceeded budget: ~{token_estimate} tokens at {n_drawers} drawers"
@pytest.mark.benchmark
class TestLayer2Retrieval:
"""Layer2 on-demand retrieval with filters."""
def test_layer2_latency(self, tmp_path, bench_scale):
"""L2 retrieval with wing filter at scale."""
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
from mempalace.layers import Layer2
layer = Layer2(palace_path=palace_path)
wing = gen.wings[0]
latencies = []
for _ in range(10):
start = time.perf_counter()
layer.retrieve(wing=wing, n_results=10)
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
avg_ms = sum(latencies) / len(latencies)
record_metric("layer2", "avg_retrieval_ms", round(avg_ms, 1))
@pytest.mark.benchmark
class TestLayer3Search:
"""Layer3 semantic search through the MemoryStack interface."""
def test_layer3_latency(self, tmp_path, bench_scale):
"""L3 search latency through MemoryStack."""
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
identity_path = str(tmp_path / "identity.txt")
with open(identity_path, "w") as f:
f.write("I am a benchmark AI.\n")
from mempalace.layers import MemoryStack
stack = MemoryStack(palace_path=palace_path, identity_path=identity_path)
queries = ["authentication", "database", "deployment", "testing", "monitoring"]
latencies = []
for q in queries:
start = time.perf_counter()
stack.search(q, n_results=5)
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
avg_ms = sum(latencies) / len(latencies)
record_metric("layer3", "avg_search_ms", round(avg_ms, 1))
+226
View File
@@ -0,0 +1,226 @@
"""
MCP server tool performance benchmarks.
Validates production readiness findings:
- Finding #3: tool_status() unbounded col.get(include=["metadatas"]) → OOM
- Finding #7: _get_collection() re-instantiates PersistentClient every call
- Finding #3 variants: tool_list_wings(), tool_get_taxonomy() same pattern
Calls MCP tool handler functions directly with monkeypatched _config.
"""
import time
import chromadb
import pytest
from tests.benchmarks.data_generator import PalaceDataGenerator
from tests.benchmarks.report import record_metric
# ── Helpers ──────────────────────────────────────────────────────────────
def _make_palace(tmp_path, n_drawers, scale="small"):
"""Create a palace with exactly n_drawers, return palace_path."""
gen = PalaceDataGenerator(seed=42, scale=scale)
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
return palace_path
def _patch_mcp_config(monkeypatch, palace_path, tmp_path):
"""Monkeypatch mcp_server._config and _kg to point at test dirs."""
from mempalace.config import MempalaceConfig
from mempalace.knowledge_graph import KnowledgeGraph
cfg = MempalaceConfig(config_dir=str(tmp_path / "cfg"))
# Override palace_path directly on the object
monkeypatch.setattr(cfg, "_file_config", {"palace_path": palace_path})
import mempalace.mcp_server as mcp_mod
monkeypatch.setattr(mcp_mod, "_config", cfg)
monkeypatch.setattr(mcp_mod, "_kg", KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")))
def _get_rss_mb():
"""Get current process RSS in MB."""
try:
import psutil
return psutil.Process().memory_info().rss / (1024 * 1024)
except ImportError:
import resource
# ru_maxrss is in KB on Linux, bytes on macOS
import platform
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
if platform.system() == "Darwin":
return usage / (1024 * 1024)
return usage / 1024
# ── Tests ────────────────────────────────────────────────────────────────
@pytest.mark.benchmark
class TestToolStatusOOM:
"""Finding #3: tool_status loads ALL metadata into memory."""
SIZES = [500, 1_000, 2_500, 5_000]
@pytest.mark.parametrize("n_drawers", SIZES)
def test_tool_status_rss_growth(self, n_drawers, tmp_path, monkeypatch):
"""Measure RSS growth from tool_status at different palace sizes."""
palace_path = _make_palace(tmp_path, n_drawers)
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
from mempalace.mcp_server import tool_status
rss_before = _get_rss_mb()
result = tool_status()
rss_after = _get_rss_mb()
rss_delta = rss_after - rss_before
assert "error" not in result, f"tool_status failed: {result}"
assert result["total_drawers"] == n_drawers
record_metric("mcp_status", f"rss_delta_mb_at_{n_drawers}", round(rss_delta, 2))
@pytest.mark.parametrize("n_drawers", SIZES)
def test_tool_status_latency(self, n_drawers, tmp_path, monkeypatch):
"""Measure tool_status response time at different palace sizes."""
palace_path = _make_palace(tmp_path, n_drawers)
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
from mempalace.mcp_server import tool_status
# Warm up
tool_status()
start = time.perf_counter()
result = tool_status()
elapsed_ms = (time.perf_counter() - start) * 1000
assert "error" not in result
record_metric("mcp_status", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
@pytest.mark.benchmark
class TestToolListWingsUnbounded:
"""Finding #3 variant: tool_list_wings also fetches ALL metadata."""
@pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500, 5_000])
def test_list_wings_latency(self, n_drawers, tmp_path, monkeypatch):
palace_path = _make_palace(tmp_path, n_drawers)
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
from mempalace.mcp_server import tool_list_wings
start = time.perf_counter()
result = tool_list_wings()
elapsed_ms = (time.perf_counter() - start) * 1000
assert "wings" in result
record_metric("mcp_list_wings", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
@pytest.mark.benchmark
class TestToolGetTaxonomyUnbounded:
"""Finding #3 variant: tool_get_taxonomy also fetches ALL metadata."""
@pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500, 5_000])
def test_get_taxonomy_latency(self, n_drawers, tmp_path, monkeypatch):
palace_path = _make_palace(tmp_path, n_drawers)
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
from mempalace.mcp_server import tool_get_taxonomy
start = time.perf_counter()
result = tool_get_taxonomy()
elapsed_ms = (time.perf_counter() - start) * 1000
assert "taxonomy" in result
record_metric("mcp_taxonomy", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
@pytest.mark.benchmark
class TestClientReinstantiation:
"""Finding #7: _get_collection() creates new PersistentClient every call."""
def test_reinstantiation_overhead(self, tmp_path, monkeypatch):
"""Measure cost of 50 _get_collection() calls vs a cached client."""
palace_path = _make_palace(tmp_path, 500)
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
from mempalace.mcp_server import _get_collection
n_calls = 50
# Measure re-instantiation (current behavior)
start = time.perf_counter()
for _ in range(n_calls):
col = _get_collection()
assert col is not None
uncached_ms = (time.perf_counter() - start) * 1000
# Measure cached client (what it should be)
client = chromadb.PersistentClient(path=palace_path)
cached_col = client.get_collection("mempalace_drawers")
start = time.perf_counter()
for _ in range(n_calls):
_ = cached_col.count()
cached_ms = (time.perf_counter() - start) * 1000
overhead_ratio = uncached_ms / max(cached_ms, 0.01)
record_metric("client_reinstantiation", "uncached_total_ms", round(uncached_ms, 1))
record_metric("client_reinstantiation", "cached_total_ms", round(cached_ms, 1))
record_metric("client_reinstantiation", "overhead_ratio", round(overhead_ratio, 2))
record_metric("client_reinstantiation", "n_calls", n_calls)
@pytest.mark.benchmark
class TestToolSearchLatency:
"""tool_search uses query() not get(), should scale better."""
@pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500, 5_000])
def test_search_latency(self, n_drawers, tmp_path, monkeypatch):
palace_path = _make_palace(tmp_path, n_drawers)
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
from mempalace.mcp_server import tool_search
queries = ["authentication middleware", "database migration", "error handling"]
latencies = []
for q in queries:
start = time.perf_counter()
result = tool_search(query=q, limit=5)
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
assert "error" not in result
avg_ms = sum(latencies) / len(latencies)
record_metric("mcp_search", f"avg_latency_ms_at_{n_drawers}", round(avg_ms, 1))
@pytest.mark.benchmark
class TestDuplicateCheckCost:
"""tool_add_drawer calls tool_check_duplicate first — measure overhead."""
@pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500])
def test_duplicate_check_latency(self, n_drawers, tmp_path, monkeypatch):
palace_path = _make_palace(tmp_path, n_drawers)
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
from mempalace.mcp_server import tool_check_duplicate
test_content = "This is unique test content for duplicate checking benchmark."
start = time.perf_counter()
result = tool_check_duplicate(content=test_content)
elapsed_ms = (time.perf_counter() - start) * 1000
assert "error" not in result
record_metric("mcp_duplicate_check", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
+181
View File
@@ -0,0 +1,181 @@
"""
Memory profiling benchmarks — detect leaks and measure RSS growth.
Uses tracemalloc for heap snapshots and psutil/resource for RSS.
Targets the highest-risk code paths:
- Repeated search() calls (PersistentClient re-instantiation)
- Repeated tool_status() calls (unbounded metadata fetch)
- Layer1.generate() (fetches all drawers)
"""
import tracemalloc
import pytest
from tests.benchmarks.data_generator import PalaceDataGenerator
from tests.benchmarks.report import record_metric
def _get_rss_mb():
try:
import psutil
return psutil.Process().memory_info().rss / (1024 * 1024)
except ImportError:
import resource
import platform
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
if platform.system() == "Darwin":
return usage / (1024 * 1024)
return usage / 1024
@pytest.mark.benchmark
class TestSearchMemoryProfile:
"""Track RSS growth over repeated search_memories() calls."""
def test_search_rss_growth(self, tmp_path):
"""Issue 200 searches and track RSS every 50 calls."""
gen = PalaceDataGenerator(seed=42, scale="small")
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=1_000, include_needles=False)
from mempalace.searcher import search_memories
n_calls = 200
check_interval = 50
queries = ["authentication", "database", "deployment", "error handling", "testing"]
rss_readings = []
rss_readings.append(("start", _get_rss_mb()))
for i in range(n_calls):
q = queries[i % len(queries)]
search_memories(q, palace_path=palace_path, n_results=5)
if (i + 1) % check_interval == 0:
rss_readings.append((f"after_{i + 1}", _get_rss_mb()))
start_rss = rss_readings[0][1]
end_rss = rss_readings[-1][1]
growth = end_rss - start_rss
record_metric("memory_search", "rss_start_mb", round(start_rss, 2))
record_metric("memory_search", "rss_end_mb", round(end_rss, 2))
record_metric("memory_search", "rss_growth_mb", round(growth, 2))
record_metric("memory_search", "n_calls", n_calls)
record_metric(
"memory_search", "growth_per_100_calls_mb", round(growth / (n_calls / 100), 2)
)
@pytest.mark.benchmark
class TestToolStatusMemoryProfile:
"""Track RSS growth from repeated tool_status() calls."""
def test_tool_status_repeated_calls(self, tmp_path, monkeypatch):
"""tool_status loads ALL metadata each call — does it leak?"""
gen = PalaceDataGenerator(seed=42, scale="small")
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
from mempalace.config import MempalaceConfig
from mempalace.knowledge_graph import KnowledgeGraph
import mempalace.mcp_server as mcp_mod
cfg = MempalaceConfig(config_dir=str(tmp_path / "cfg"))
monkeypatch.setattr(cfg, "_file_config", {"palace_path": palace_path})
monkeypatch.setattr(mcp_mod, "_config", cfg)
monkeypatch.setattr(mcp_mod, "_kg", KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")))
from mempalace.mcp_server import tool_status
n_calls = 50
rss_readings = []
rss_readings.append(("start", _get_rss_mb()))
for i in range(n_calls):
result = tool_status()
assert result["total_drawers"] == 2_000
if (i + 1) % 10 == 0:
rss_readings.append((f"after_{i + 1}", _get_rss_mb()))
start_rss = rss_readings[0][1]
end_rss = rss_readings[-1][1]
growth = end_rss - start_rss
record_metric("memory_tool_status", "rss_start_mb", round(start_rss, 2))
record_metric("memory_tool_status", "rss_end_mb", round(end_rss, 2))
record_metric("memory_tool_status", "rss_growth_mb", round(growth, 2))
record_metric("memory_tool_status", "n_calls", n_calls)
record_metric("memory_tool_status", "palace_size", 2_000)
@pytest.mark.benchmark
class TestLayer1MemoryProfile:
"""Layer1.generate() fetches ALL drawers — same risk as tool_status."""
def test_layer1_repeated_generate(self, tmp_path):
"""Layer1 fetches all drawers for scoring. Track memory over repeats."""
gen = PalaceDataGenerator(seed=42, scale="small")
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
from mempalace.layers import Layer1
layer = Layer1(palace_path=palace_path)
n_calls = 30
rss_readings = []
rss_readings.append(("start", _get_rss_mb()))
for i in range(n_calls):
text = layer.generate()
assert "L1" in text
if (i + 1) % 10 == 0:
rss_readings.append((f"after_{i + 1}", _get_rss_mb()))
start_rss = rss_readings[0][1]
end_rss = rss_readings[-1][1]
growth = end_rss - start_rss
record_metric("memory_layer1", "rss_start_mb", round(start_rss, 2))
record_metric("memory_layer1", "rss_end_mb", round(end_rss, 2))
record_metric("memory_layer1", "rss_growth_mb", round(growth, 2))
record_metric("memory_layer1", "n_calls", n_calls)
@pytest.mark.benchmark
class TestHeapSnapshot:
"""Use tracemalloc to identify top memory allocators during search."""
def test_search_heap_top_allocators(self, tmp_path):
"""Identify which code paths allocate the most memory during search."""
gen = PalaceDataGenerator(seed=42, scale="small")
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=1_000, include_needles=False)
from mempalace.searcher import search_memories
tracemalloc.start()
snap_before = tracemalloc.take_snapshot()
for i in range(100):
search_memories("test query", palace_path=palace_path, n_results=5)
snap_after = tracemalloc.take_snapshot()
tracemalloc.stop()
stats = snap_after.compare_to(snap_before, "lineno")
top_allocators = []
for stat in stats[:10]:
top_allocators.append(
{
"file": str(stat.traceback),
"size_kb": round(stat.size / 1024, 1),
"count": stat.count,
}
)
total_growth_kb = sum(s["size_kb"] for s in top_allocators)
record_metric("heap_search", "top_10_growth_kb", round(total_growth_kb, 1))
record_metric("heap_search", "n_searches", 100)
+176
View File
@@ -0,0 +1,176 @@
"""
Palace boost validation — does wing/room filtering actually help?
Quantifies the retrieval improvement from the palace spatial metaphor.
Uses planted needles to measure recall with and without filtering
at different scales.
"""
import time
import pytest
from tests.benchmarks.data_generator import PalaceDataGenerator
from tests.benchmarks.report import record_metric
@pytest.mark.benchmark
class TestFilteredVsUnfilteredRecall:
"""Quantify palace boost: recall improvement from wing/room filtering."""
SIZES = [1_000, 2_500, 5_000]
@pytest.mark.parametrize("n_drawers", SIZES)
def test_palace_boost_recall(self, n_drawers, tmp_path, bench_scale):
"""Compare recall@5 with/without wing filter at increasing scale."""
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
palace_path = str(tmp_path / "palace")
_, _, needle_info = gen.populate_palace_directly(
palace_path, n_drawers=n_drawers, include_needles=True
)
from mempalace.searcher import search_memories
n_queries = min(10, len(needle_info))
unfiltered_hits = 0
wing_filtered_hits = 0
room_filtered_hits = 0
for needle in needle_info[:n_queries]:
# Unfiltered search
result = search_memories(needle["query"], palace_path=palace_path, n_results=5)
texts = [h["text"] for h in result.get("results", [])]
if any("NEEDLE_" in t for t in texts[:5]):
unfiltered_hits += 1
# Wing-filtered search
result = search_memories(
needle["query"], palace_path=palace_path, wing=needle["wing"], n_results=5
)
texts = [h["text"] for h in result.get("results", [])]
if any("NEEDLE_" in t for t in texts[:5]):
wing_filtered_hits += 1
# Wing+room filtered search
result = search_memories(
needle["query"],
palace_path=palace_path,
wing=needle["wing"],
room=needle["room"],
n_results=5,
)
texts = [h["text"] for h in result.get("results", [])]
if any("NEEDLE_" in t for t in texts[:5]):
room_filtered_hits += 1
recall_none = unfiltered_hits / max(n_queries, 1)
recall_wing = wing_filtered_hits / max(n_queries, 1)
recall_room = room_filtered_hits / max(n_queries, 1)
boost_wing = recall_wing - recall_none
boost_room = recall_room - recall_none
record_metric("palace_boost", f"recall_unfiltered_at_{n_drawers}", round(recall_none, 3))
record_metric("palace_boost", f"recall_wing_filtered_at_{n_drawers}", round(recall_wing, 3))
record_metric("palace_boost", f"recall_room_filtered_at_{n_drawers}", round(recall_room, 3))
record_metric("palace_boost", f"wing_boost_at_{n_drawers}", round(boost_wing, 3))
record_metric("palace_boost", f"room_boost_at_{n_drawers}", round(boost_room, 3))
@pytest.mark.benchmark
class TestFilterLatencyBenefit:
"""Does filtering reduce query latency by narrowing the search space?"""
def test_filter_speedup(self, tmp_path, bench_scale):
"""Compare latency: no filter vs wing vs wing+room."""
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=5_000, include_needles=False)
from mempalace.searcher import search_memories
wing = gen.wings[0]
room = gen.rooms_by_wing[wing][0]
query = "authentication middleware optimization"
n_runs = 10
# No filter
latencies_none = []
for _ in range(n_runs):
start = time.perf_counter()
search_memories(query, palace_path=palace_path, n_results=5)
latencies_none.append((time.perf_counter() - start) * 1000)
# Wing filter
latencies_wing = []
for _ in range(n_runs):
start = time.perf_counter()
search_memories(query, palace_path=palace_path, wing=wing, n_results=5)
latencies_wing.append((time.perf_counter() - start) * 1000)
# Wing + room filter
latencies_room = []
for _ in range(n_runs):
start = time.perf_counter()
search_memories(query, palace_path=palace_path, wing=wing, room=room, n_results=5)
latencies_room.append((time.perf_counter() - start) * 1000)
avg_none = sum(latencies_none) / len(latencies_none)
avg_wing = sum(latencies_wing) / len(latencies_wing)
avg_room = sum(latencies_room) / len(latencies_room)
record_metric("filter_latency", "avg_unfiltered_ms", round(avg_none, 1))
record_metric("filter_latency", "avg_wing_filtered_ms", round(avg_wing, 1))
record_metric("filter_latency", "avg_room_filtered_ms", round(avg_room, 1))
if avg_none > 0:
record_metric(
"filter_latency", "wing_speedup_pct", round((1 - avg_wing / avg_none) * 100, 1)
)
record_metric(
"filter_latency", "room_speedup_pct", round((1 - avg_room / avg_none) * 100, 1)
)
@pytest.mark.benchmark
class TestBoostAtIncreasingScale:
"""Does the palace boost increase as the palace grows?"""
def test_boost_scaling(self, tmp_path, bench_scale):
"""Measure wing-filtered recall improvement at multiple sizes."""
sizes = [500, 1_000, 2_500]
boosts = []
for size in sizes:
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
palace_path = str(tmp_path / f"palace_{size}")
_, _, needle_info = gen.populate_palace_directly(
palace_path, n_drawers=size, include_needles=True
)
from mempalace.searcher import search_memories
n_queries = min(8, len(needle_info))
unfiltered_hits = 0
filtered_hits = 0
for needle in needle_info[:n_queries]:
result = search_memories(needle["query"], palace_path=palace_path, n_results=5)
if any("NEEDLE_" in h["text"] for h in result.get("results", [])[:5]):
unfiltered_hits += 1
result = search_memories(
needle["query"], palace_path=palace_path, wing=needle["wing"], n_results=5
)
if any("NEEDLE_" in h["text"] for h in result.get("results", [])[:5]):
filtered_hits += 1
recall_none = unfiltered_hits / max(n_queries, 1)
recall_filtered = filtered_hits / max(n_queries, 1)
boost = recall_filtered - recall_none
boosts.append({"size": size, "boost": boost})
record_metric("boost_scaling", "boosts_by_size", boosts)
# Check if boost increases with scale (the hypothesis)
if len(boosts) >= 2:
trend_positive = boosts[-1]["boost"] >= boosts[0]["boost"]
record_metric("boost_scaling", "trend_positive", trend_positive)
+182
View File
@@ -0,0 +1,182 @@
"""
Recall threshold test — find the per-bucket size where retrieval breaks.
The palace_boost tests showed room-filtered recall of 1.0, but only because
each room had ~333 drawers. This test concentrates ALL drawers into a single
wing+room to find the actual embedding model limit.
"""
import hashlib
import os
from datetime import datetime
import chromadb
import pytest
from tests.benchmarks.data_generator import PalaceDataGenerator
from tests.benchmarks.report import record_metric
NEEDLE_TOPICS = [
"Fibonacci sequence optimization uses memoization with O(n) space complexity",
"PostgreSQL vacuum autovacuum threshold set to 50 percent for table users",
"Redis cluster failover timeout configured at 30 seconds with sentinel monitoring",
"Kubernetes horizontal pod autoscaler targets 70 percent CPU utilization",
"GraphQL subscription uses WebSocket transport with heartbeat interval 25 seconds",
"JWT token rotation policy requires refresh every 15 minutes with sliding window",
"Elasticsearch index sharding strategy uses 5 primary shards with 1 replica each",
"Docker multi-stage build reduces image size from 1.2GB to 180MB for production",
"Apache Kafka consumer group rebalance timeout set to 45 seconds",
"MongoDB change streams resume token persisted every 100 operations",
]
NEEDLE_QUERIES = [
"Fibonacci sequence optimization memoization",
"PostgreSQL vacuum autovacuum threshold",
"Redis cluster failover timeout sentinel",
"Kubernetes horizontal pod autoscaler CPU",
"GraphQL subscription WebSocket heartbeat",
"JWT token rotation policy refresh",
"Elasticsearch index sharding primary replica",
"Docker multi-stage build image size production",
"Apache Kafka consumer group rebalance",
"MongoDB change streams resume token",
]
def _populate_single_room(palace_path, n_drawers, n_needles=10):
"""Pack all drawers into one wing+room, plant needles, return queries."""
gen = PalaceDataGenerator(seed=42, scale="small")
os.makedirs(palace_path, exist_ok=True)
client = chromadb.PersistentClient(path=palace_path)
col = client.get_or_create_collection("mempalace_drawers")
batch_size = 500
docs, ids, metas = [], [], []
# Plant needles
for i in range(n_needles):
needle_id = f"NEEDLE_{i:04d}"
content = f"{needle_id}: {NEEDLE_TOPICS[i]}. Unique planted needle for threshold test."
drawer_id = f"drawer_single_room_{hashlib.md5(needle_id.encode()).hexdigest()[:16]}"
docs.append(content)
ids.append(drawer_id)
metas.append(
{
"wing": "concentrated",
"room": "single_room",
"source_file": f"needle_{i}.txt",
"chunk_index": 0,
"added_by": "threshold_bench",
"filed_at": datetime.now().isoformat(),
}
)
# Fill with noise — all in the SAME room
remaining = n_drawers - len(docs)
for i in range(remaining):
content = gen._random_text(400, 800)
drawer_id = f"drawer_single_room_{hashlib.md5(f'noise_{i}'.encode()).hexdigest()[:16]}"
docs.append(content)
ids.append(drawer_id)
metas.append(
{
"wing": "concentrated",
"room": "single_room",
"source_file": f"noise_{i:06d}.txt",
"chunk_index": i % 10,
"added_by": "threshold_bench",
"filed_at": datetime.now().isoformat(),
}
)
if len(docs) >= batch_size:
col.add(documents=docs, ids=ids, metadatas=metas)
docs, ids, metas = [], [], []
if docs:
col.add(documents=docs, ids=ids, metadatas=metas)
return client, col
@pytest.mark.benchmark
class TestRecallThresholdSingleRoom:
"""
All drawers in one room — isolates the embedding model's retrieval limit.
Room filtering can't help here. This is the true ceiling.
"""
SIZES = [250, 500, 1_000, 2_000, 3_000, 5_000]
@pytest.mark.parametrize("n_drawers", SIZES)
def test_single_room_recall(self, n_drawers, tmp_path):
"""Recall@5 and @10 with all drawers in one bucket."""
palace_path = str(tmp_path / "palace")
_populate_single_room(palace_path, n_drawers, n_needles=10)
from mempalace.searcher import search_memories
hits_at_5 = 0
hits_at_10 = 0
n_queries = len(NEEDLE_QUERIES)
for i, query in enumerate(NEEDLE_QUERIES):
result = search_memories(
query,
palace_path=palace_path,
wing="concentrated",
room="single_room",
n_results=10,
)
if "error" in result:
continue
texts = [h["text"] for h in result.get("results", [])]
needle_id = f"NEEDLE_{i:04d}"
found_at_5 = any(needle_id in t for t in texts[:5])
found_at_10 = any(needle_id in t for t in texts[:10])
if found_at_5:
hits_at_5 += 1
if found_at_10:
hits_at_10 += 1
recall_5 = hits_at_5 / n_queries
recall_10 = hits_at_10 / n_queries
record_metric("single_room_recall", f"recall_at_5_at_{n_drawers}", round(recall_5, 3))
record_metric("single_room_recall", f"recall_at_10_at_{n_drawers}", round(recall_10, 3))
@pytest.mark.parametrize("n_drawers", SIZES)
def test_single_room_no_filter_recall(self, n_drawers, tmp_path):
"""Same test but WITHOUT wing/room filter — pure unfiltered search."""
palace_path = str(tmp_path / "palace")
_populate_single_room(palace_path, n_drawers, n_needles=10)
from mempalace.searcher import search_memories
hits_at_5 = 0
hits_at_10 = 0
n_queries = len(NEEDLE_QUERIES)
for i, query in enumerate(NEEDLE_QUERIES):
result = search_memories(query, palace_path=palace_path, n_results=10)
if "error" in result:
continue
texts = [h["text"] for h in result.get("results", [])]
needle_id = f"NEEDLE_{i:04d}"
if any(needle_id in t for t in texts[:5]):
hits_at_5 += 1
if any(needle_id in t for t in texts[:10]):
hits_at_10 += 1
recall_5 = hits_at_5 / n_queries
recall_10 = hits_at_10 / n_queries
record_metric("single_room_unfiltered", f"recall_at_5_at_{n_drawers}", round(recall_5, 3))
record_metric("single_room_unfiltered", f"recall_at_10_at_{n_drawers}", round(recall_10, 3))
+234
View File
@@ -0,0 +1,234 @@
"""
Search performance benchmarks.
Measures query latency, recall@k, and concurrent search behavior
as palace size grows. Uses planted needles for recall measurement.
"""
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from tests.benchmarks.data_generator import PalaceDataGenerator
from tests.benchmarks.report import record_metric
@pytest.mark.benchmark
class TestSearchLatencyVsSize:
"""Query latency scaling as palace grows."""
SIZES = [500, 1_000, 2_500, 5_000]
@pytest.mark.parametrize("n_drawers", SIZES)
def test_search_latency_curve(self, n_drawers, tmp_path, bench_scale):
"""Measure average search latency at different palace sizes."""
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
from mempalace.searcher import search_memories
queries = [
"authentication middleware",
"database optimization",
"error handling patterns",
"deployment configuration",
"testing strategy",
]
latencies = []
for q in queries:
start = time.perf_counter()
result = search_memories(q, palace_path=palace_path, n_results=5)
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
assert "error" not in result
avg_ms = sum(latencies) / len(latencies)
sorted_lat = sorted(latencies)
p50_ms = sorted_lat[len(sorted_lat) // 2]
p95_ms = sorted_lat[int(len(sorted_lat) * 0.95)]
record_metric("search", f"avg_latency_ms_at_{n_drawers}", round(avg_ms, 1))
record_metric("search", f"p50_ms_at_{n_drawers}", round(p50_ms, 1))
record_metric("search", f"p95_ms_at_{n_drawers}", round(p95_ms, 1))
@pytest.mark.benchmark
class TestSearchRecallAtScale:
"""Planted needle recall — does accuracy degrade as palace grows?"""
SIZES = [500, 1_000, 2_500, 5_000]
@pytest.mark.parametrize("n_drawers", SIZES)
def test_recall_at_k(self, n_drawers, tmp_path, bench_scale):
"""Recall@5 and Recall@10 using planted needles."""
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
palace_path = str(tmp_path / "palace")
_, _, needle_info = gen.populate_palace_directly(
palace_path, n_drawers=n_drawers, include_needles=True
)
from mempalace.searcher import search_memories
hits_at_5 = 0
hits_at_10 = 0
total_needle_queries = min(10, len(needle_info))
for needle in needle_info[:total_needle_queries]:
result = search_memories(needle["query"], palace_path=palace_path, n_results=10)
if "error" in result:
continue
texts = [h["text"] for h in result.get("results", [])]
# Check if needle content appears in top 5
found_at_5 = any("NEEDLE_" in t for t in texts[:5])
found_at_10 = any("NEEDLE_" in t for t in texts[:10])
if found_at_5:
hits_at_5 += 1
if found_at_10:
hits_at_10 += 1
recall_at_5 = hits_at_5 / max(total_needle_queries, 1)
recall_at_10 = hits_at_10 / max(total_needle_queries, 1)
record_metric("search_recall", f"recall_at_5_at_{n_drawers}", round(recall_at_5, 3))
record_metric("search_recall", f"recall_at_10_at_{n_drawers}", round(recall_at_10, 3))
@pytest.mark.benchmark
class TestSearchFilteredVsUnfiltered:
"""Compare search performance with and without wing/room filters."""
def test_filter_impact(self, tmp_path, bench_scale):
"""Measure latency and recall difference with wing filtering."""
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
palace_path = str(tmp_path / "palace")
_, _, needle_info = gen.populate_palace_directly(
palace_path, n_drawers=2_000, include_needles=True
)
from mempalace.searcher import search_memories
filtered_latencies = []
unfiltered_latencies = []
filtered_hits = 0
unfiltered_hits = 0
n_queries = min(10, len(needle_info))
for needle in needle_info[:n_queries]:
# Unfiltered
start = time.perf_counter()
result_unfiltered = search_memories(
needle["query"], palace_path=palace_path, n_results=5
)
unfiltered_latencies.append((time.perf_counter() - start) * 1000)
if any("NEEDLE_" in h["text"] for h in result_unfiltered.get("results", [])[:5]):
unfiltered_hits += 1
# Filtered by wing
start = time.perf_counter()
result_filtered = search_memories(
needle["query"],
palace_path=palace_path,
wing=needle["wing"],
n_results=5,
)
filtered_latencies.append((time.perf_counter() - start) * 1000)
if any("NEEDLE_" in h["text"] for h in result_filtered.get("results", [])[:5]):
filtered_hits += 1
avg_unfiltered = sum(unfiltered_latencies) / max(len(unfiltered_latencies), 1)
avg_filtered = sum(filtered_latencies) / max(len(filtered_latencies), 1)
latency_improvement = ((avg_unfiltered - avg_filtered) / max(avg_unfiltered, 0.01)) * 100
record_metric("search_filter", "avg_unfiltered_ms", round(avg_unfiltered, 1))
record_metric("search_filter", "avg_filtered_ms", round(avg_filtered, 1))
record_metric("search_filter", "latency_improvement_pct", round(latency_improvement, 1))
record_metric(
"search_filter", "unfiltered_recall_at_5", round(unfiltered_hits / max(n_queries, 1), 3)
)
record_metric(
"search_filter", "filtered_recall_at_5", round(filtered_hits / max(n_queries, 1), 3)
)
@pytest.mark.benchmark
class TestConcurrentSearch:
"""Concurrent query performance — tests PersistentClient contention."""
def test_concurrent_queries(self, tmp_path):
"""Issue N simultaneous queries and measure p50/p95/p99."""
gen = PalaceDataGenerator(seed=42, scale="small")
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
from mempalace.searcher import search_memories
queries = [
"authentication",
"database",
"deployment",
"error handling",
"testing",
"monitoring",
"caching",
"middleware",
"serialization",
"validation",
] * 3 # 30 total queries
def run_search(query):
start = time.perf_counter()
result = search_memories(query, palace_path=palace_path, n_results=5)
elapsed = (time.perf_counter() - start) * 1000
return elapsed, "error" not in result
# Concurrent execution
latencies = []
errors = 0
with ThreadPoolExecutor(max_workers=4) as executor:
futures = {executor.submit(run_search, q): q for q in queries}
for future in as_completed(futures):
elapsed, success = future.result()
latencies.append(elapsed)
if not success:
errors += 1
sorted_lat = sorted(latencies)
n = len(sorted_lat)
record_metric("concurrent_search", "p50_ms", round(sorted_lat[n // 2], 1))
record_metric("concurrent_search", "p95_ms", round(sorted_lat[int(n * 0.95)], 1))
record_metric("concurrent_search", "p99_ms", round(sorted_lat[int(n * 0.99)], 1))
record_metric("concurrent_search", "avg_ms", round(sum(sorted_lat) / n, 1))
record_metric("concurrent_search", "error_count", errors)
record_metric("concurrent_search", "total_queries", len(queries))
record_metric("concurrent_search", "workers", 4)
@pytest.mark.benchmark
class TestSearchNResultsScaling:
"""How does n_results affect query latency?"""
@pytest.mark.parametrize("n_results", [1, 5, 10, 25, 50])
def test_n_results_latency(self, n_results, tmp_path):
gen = PalaceDataGenerator(seed=42, scale="small")
palace_path = str(tmp_path / "palace")
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
from mempalace.searcher import search_memories
latencies = []
for _ in range(5):
start = time.perf_counter()
search_memories(
"authentication middleware", palace_path=palace_path, n_results=n_results
)
latencies.append((time.perf_counter() - start) * 1000)
avg_ms = sum(latencies) / len(latencies)
record_metric("search_n_results", f"avg_ms_at_n_{n_results}", round(avg_ms, 1))
+21 -1
View File
@@ -34,6 +34,24 @@ from mempalace.config import MempalaceConfig # noqa: E402
from mempalace.knowledge_graph import KnowledgeGraph # noqa: E402 from mempalace.knowledge_graph import KnowledgeGraph # noqa: E402
@pytest.fixture(autouse=True)
def _reset_mcp_cache():
"""Reset the MCP server's cached ChromaDB client/collection between tests."""
def _clear_cache():
try:
from mempalace import mcp_server
mcp_server._client_cache = None
mcp_server._collection_cache = None
except (ImportError, AttributeError):
pass
_clear_cache()
yield
_clear_cache()
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def _isolate_home(): def _isolate_home():
"""Ensure HOME points to a temp dir for the entire test session. """Ensure HOME points to a temp dir for the entire test session.
@@ -84,7 +102,9 @@ def collection(palace_path):
"""A ChromaDB collection pre-seeded in the temp palace.""" """A ChromaDB collection pre-seeded in the temp palace."""
client = chromadb.PersistentClient(path=palace_path) client = chromadb.PersistentClient(path=palace_path)
col = client.get_or_create_collection("mempalace_drawers") col = client.get_or_create_collection("mempalace_drawers")
return col yield col
client.delete_collection("mempalace_drawers")
del client
@pytest.fixture @pytest.fixture
+609
View File
@@ -0,0 +1,609 @@
"""Tests for mempalace.cli — the main CLI dispatcher."""
import argparse
import sys
from unittest.mock import MagicMock, patch
import pytest
from mempalace.cli import (
cmd_compress,
cmd_hook,
cmd_init,
cmd_instructions,
cmd_mine,
cmd_repair,
cmd_search,
cmd_split,
cmd_status,
cmd_wakeup,
main,
)
# ── cmd_status ─────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_status_default_palace(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None)
mock_miner = MagicMock()
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
cmd_status(args)
mock_miner.status.assert_called_once_with(palace_path="/fake/palace")
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_status_custom_palace(mock_config_cls):
args = argparse.Namespace(palace="~/my_palace")
mock_miner = MagicMock()
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
cmd_status(args)
import os
expected = os.path.expanduser("~/my_palace")
mock_miner.status.assert_called_once_with(palace_path=expected)
# ── cmd_search ─────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_search_calls_search(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(
palace=None, query="test query", wing="mywing", room="myroom", results=3
)
with patch("mempalace.searcher.search") as mock_search:
cmd_search(args)
mock_search.assert_called_once_with(
query="test query",
palace_path="/fake/palace",
wing="mywing",
room="myroom",
n_results=3,
)
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_search_error_exits(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, query="q", wing=None, room=None, results=5)
from mempalace.searcher import SearchError
with patch("mempalace.searcher.search", side_effect=SearchError("fail")):
with pytest.raises(SystemExit) as exc_info:
cmd_search(args)
assert exc_info.value.code == 1
# ── cmd_instructions ───────────────────────────────────────────────────
def test_cmd_instructions_calls_run_instructions():
args = argparse.Namespace(name="help")
with patch("mempalace.instructions_cli.run_instructions") as mock_run:
cmd_instructions(args)
mock_run.assert_called_once_with(name="help")
# ── cmd_hook ───────────────────────────────────────────────────────────
def test_cmd_hook_calls_run_hook():
args = argparse.Namespace(hook="session-start", harness="claude-code")
with patch("mempalace.hooks_cli.run_hook") as mock_run:
cmd_hook(args)
mock_run.assert_called_once_with(hook_name="session-start", harness="claude-code")
# ── cmd_init ───────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_init_no_entities(mock_config_cls, tmp_path):
args = argparse.Namespace(dir=str(tmp_path), yes=True)
with (
patch("mempalace.entity_detector.scan_for_detection", return_value=[]),
patch("mempalace.room_detector_local.detect_rooms_local") as mock_rooms,
):
cmd_init(args)
mock_rooms.assert_called_once_with(project_dir=str(tmp_path), yes=True)
mock_config_cls.return_value.init.assert_called_once()
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_init_with_entities(mock_config_cls, tmp_path):
fake_files = [tmp_path / "a.txt"]
detected = {"people": [{"name": "Alice"}], "projects": [], "uncertain": []}
confirmed = {"people": ["Alice"], "projects": []}
args = argparse.Namespace(dir=str(tmp_path), yes=True)
with (
patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files),
patch("mempalace.entity_detector.detect_entities", return_value=detected),
patch("mempalace.entity_detector.confirm_entities", return_value=confirmed),
patch("mempalace.room_detector_local.detect_rooms_local"),
patch("builtins.open", MagicMock()),
):
cmd_init(args)
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_init_with_entities_zero_total(mock_config_cls, tmp_path, capsys):
"""When entities detected but total is 0, prints 'No entities' message."""
fake_files = [tmp_path / "a.txt"]
detected = {"people": [], "projects": [], "uncertain": []}
args = argparse.Namespace(dir=str(tmp_path), yes=False)
with (
patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files),
patch("mempalace.entity_detector.detect_entities", return_value=detected),
patch("mempalace.room_detector_local.detect_rooms_local"),
):
cmd_init(args)
out = capsys.readouterr().out
assert "No entities detected" in out
# ── cmd_mine ───────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_mine_projects_mode(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(
dir="/src",
palace=None,
mode="projects",
wing=None,
agent="mempalace",
limit=0,
dry_run=False,
no_gitignore=False,
include_ignored=[],
extract="exchange",
)
with patch("mempalace.miner.mine") as mock_mine:
cmd_mine(args)
mock_mine.assert_called_once_with(
project_dir="/src",
palace_path="/fake/palace",
wing_override=None,
agent="mempalace",
limit=0,
dry_run=False,
respect_gitignore=True,
include_ignored=[],
)
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_mine_convos_mode(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(
dir="/chats",
palace=None,
mode="convos",
wing="mywing",
agent="me",
limit=10,
dry_run=True,
no_gitignore=False,
include_ignored=[],
extract="general",
)
with patch("mempalace.convo_miner.mine_convos") as mock_mine:
cmd_mine(args)
mock_mine.assert_called_once_with(
convo_dir="/chats",
palace_path="/fake/palace",
wing="mywing",
agent="me",
limit=10,
dry_run=True,
extract_mode="general",
)
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_mine_include_ignored_comma_split(mock_config_cls):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(
dir="/src",
palace=None,
mode="projects",
wing=None,
agent="mempalace",
limit=0,
dry_run=False,
no_gitignore=False,
include_ignored=["a.txt,b.txt", "c.txt"],
extract="exchange",
)
with patch("mempalace.miner.mine") as mock_mine:
cmd_mine(args)
mock_mine.assert_called_once()
call_kwargs = mock_mine.call_args[1]
assert call_kwargs["include_ignored"] == ["a.txt", "b.txt", "c.txt"]
# ── cmd_wakeup ─────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_wakeup(mock_config_cls, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing=None)
mock_stack = MagicMock()
mock_stack.wake_up.return_value = "Hello world context"
with patch("mempalace.layers.MemoryStack", return_value=mock_stack):
cmd_wakeup(args)
out = capsys.readouterr().out
assert "Hello world context" in out
assert "tokens" in out
# ── cmd_split ──────────────────────────────────────────────────────────
def test_cmd_split_basic():
args = argparse.Namespace(dir="/chats", output_dir=None, dry_run=False, min_sessions=2)
with patch("mempalace.split_mega_files.main") as mock_main:
cmd_split(args)
mock_main.assert_called_once()
def test_cmd_split_all_options():
args = argparse.Namespace(dir="/chats", output_dir="/out", dry_run=True, min_sessions=5)
with patch("mempalace.split_mega_files.main") as mock_main:
cmd_split(args)
mock_main.assert_called_once()
# sys.argv should be restored
assert sys.argv[0] != "mempalace split"
# ── main() argparse dispatch ──────────────────────────────────────────
def test_main_no_args_prints_help(capsys):
with patch("sys.argv", ["mempalace"]):
main()
out = capsys.readouterr().out
assert "MemPalace" in out
def test_main_status_dispatches():
with (
patch("sys.argv", ["mempalace", "status"]),
patch("mempalace.cli.cmd_status") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_search_dispatches():
with (
patch("sys.argv", ["mempalace", "search", "my query"]),
patch("mempalace.cli.cmd_search") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_init_dispatches():
with (
patch("sys.argv", ["mempalace", "init", "/some/dir"]),
patch("mempalace.cli.cmd_init") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_mine_dispatches():
with (
patch("sys.argv", ["mempalace", "mine", "/some/dir"]),
patch("mempalace.cli.cmd_mine") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_wakeup_dispatches():
with (
patch("sys.argv", ["mempalace", "wake-up"]),
patch("mempalace.cli.cmd_wakeup") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_split_dispatches():
with (
patch("sys.argv", ["mempalace", "split", "/chats"]),
patch("mempalace.cli.cmd_split") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_hook_no_subcommand_prints_help(capsys):
with patch("sys.argv", ["mempalace", "hook"]):
main()
out = capsys.readouterr().out
assert "hook" in out.lower() or "run" in out.lower()
def test_main_hook_run_dispatches():
with (
patch(
"sys.argv",
["mempalace", "hook", "run", "--hook", "session-start", "--harness", "claude-code"],
),
patch("mempalace.cli.cmd_hook") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_instructions_no_subcommand_prints_help(capsys):
with patch("sys.argv", ["mempalace", "instructions"]):
main()
out = capsys.readouterr().out
assert "instructions" in out.lower() or "init" in out.lower()
def test_main_instructions_dispatches():
with (
patch("sys.argv", ["mempalace", "instructions", "help"]),
patch("mempalace.cli.cmd_instructions") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_repair_dispatches():
with (
patch("sys.argv", ["mempalace", "repair"]),
patch("mempalace.cli.cmd_repair") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
def test_main_compress_dispatches():
with (
patch("sys.argv", ["mempalace", "compress"]),
patch("mempalace.cli.cmd_compress") as mock_cmd,
):
main()
mock_cmd.assert_called_once()
# ── cmd_repair ─────────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_no_palace(mock_config_cls, tmp_path, capsys):
mock_config_cls.return_value.palace_path = str(tmp_path / "nonexistent")
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args)
out = capsys.readouterr().out
assert "No palace found" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
mock_client = MagicMock()
mock_client.get_collection.side_effect = Exception("corrupt db")
mock_chromadb.PersistentClient.return_value = mock_client
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args)
out = capsys.readouterr().out
assert "Error reading palace" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.count.return_value = 0
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args)
out = capsys.readouterr().out
assert "Nothing to repair" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_success(mock_config_cls, tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.count.return_value = 2
mock_col.get.return_value = {
"ids": ["id1", "id2"],
"documents": ["doc1", "doc2"],
"metadatas": [{"wing": "a"}, {"wing": "b"}],
}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_new_col = MagicMock()
mock_client.create_collection.return_value = mock_new_col
mock_chromadb.PersistentClient.return_value = mock_client
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args)
out = capsys.readouterr().out
assert "Repair complete" in out
assert "2 drawers rebuilt" in out
# ── cmd_compress ───────────────────────────────────────────────────────
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_compress_no_palace(mock_config_cls, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None)
mock_chromadb = MagicMock()
mock_chromadb.PersistentClient.side_effect = Exception("no palace")
with (
patch.dict("sys.modules", {"chromadb": mock_chromadb}),
pytest.raises(SystemExit),
):
cmd_compress(args)
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_compress_no_drawers(mock_config_cls, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing="mywing", dry_run=False, config=None)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_compress(args)
out = capsys.readouterr().out
assert "No drawers found" in out
def _make_mock_dialect_module(dialect_instance):
"""Create a mock dialect module with a Dialect class that returns the given instance."""
mock_mod = MagicMock()
mock_mod.Dialect.return_value = dialect_instance
mock_mod.Dialect.from_config.return_value = dialect_instance
mock_mod.Dialect.count_tokens = MagicMock(side_effect=lambda x: len(x) // 4)
return mock_mod
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_compress_dry_run(mock_config_cls, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=None)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.get.side_effect = [
{
"documents": ["some long text here for testing"],
"metadatas": [{"wing": "test", "room": "general", "source_file": "test.txt"}],
"ids": ["id1"],
},
{"documents": [], "metadatas": [], "ids": []},
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_dialect = MagicMock()
mock_dialect.compress.return_value = "compressed"
mock_dialect.compression_stats.return_value = {
"original_chars": 100,
"compressed_chars": 30,
"original_tokens": 25,
"compressed_tokens": 8,
"ratio": 3.3,
}
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
with patch.dict(
"sys.modules",
{
"chromadb": mock_chromadb,
"mempalace.dialect": mock_dialect_mod,
},
):
cmd_compress(args)
out = capsys.readouterr().out
assert "dry run" in out.lower()
assert "Compressing" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_compress_with_config(mock_config_cls, tmp_path, capsys):
mock_config_cls.return_value.palace_path = "/fake/palace"
config_file = tmp_path / "entities.json"
config_file.write_text('{"people": [], "projects": []}')
args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=str(config_file))
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_dialect = MagicMock()
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
with patch.dict(
"sys.modules",
{
"chromadb": mock_chromadb,
"mempalace.dialect": mock_dialect_mod,
},
):
cmd_compress(args)
out = capsys.readouterr().out
assert "Loaded entity config" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_compress_stores_results(mock_config_cls, capsys):
"""Non-dry-run compress stores to mempalace_compressed collection."""
mock_config_cls.return_value.palace_path = "/fake/palace"
args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.get.side_effect = [
{
"documents": ["text"],
"metadatas": [{"wing": "w", "room": "r", "source_file": "f.txt"}],
"ids": ["id1"],
},
{"documents": [], "metadatas": [], "ids": []},
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_comp_col = MagicMock()
mock_client.get_or_create_collection.return_value = mock_comp_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_dialect = MagicMock()
mock_dialect.compress.return_value = "compressed"
mock_dialect.compression_stats.return_value = {
"original_chars": 100,
"compressed_chars": 30,
"original_tokens": 25,
"compressed_tokens": 8,
"ratio": 3.3,
}
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
with patch.dict(
"sys.modules",
{
"chromadb": mock_chromadb,
"mempalace.dialect": mock_dialect_mod,
},
):
cmd_compress(args)
out = capsys.readouterr().out
assert "Stored" in out
mock_comp_col.upsert.assert_called_once()
+79
View File
@@ -0,0 +1,79 @@
"""Extra tests for mempalace.config to cover remaining gaps."""
import json
import os
from mempalace.config import MempalaceConfig
def test_config_bad_json(tmp_path):
"""Bad JSON in config file falls back to empty."""
(tmp_path / "config.json").write_text("not json", encoding="utf-8")
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.palace_path # still returns default
def test_people_map_from_file(tmp_path):
(tmp_path / "people_map.json").write_text(json.dumps({"bob": "Robert"}), encoding="utf-8")
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.people_map == {"bob": "Robert"}
def test_people_map_bad_json(tmp_path):
(tmp_path / "people_map.json").write_text("bad", encoding="utf-8")
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.people_map == {}
def test_people_map_missing(tmp_path):
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.people_map == {}
def test_topic_wings_default(tmp_path):
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert isinstance(cfg.topic_wings, list)
assert "emotions" in cfg.topic_wings
def test_hall_keywords_default(tmp_path):
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert isinstance(cfg.hall_keywords, dict)
assert "technical" in cfg.hall_keywords
def test_init_idempotent(tmp_path):
cfg = MempalaceConfig(config_dir=str(tmp_path))
cfg.init()
cfg.init() # second call should not overwrite
with open(tmp_path / "config.json") as f:
data = json.load(f)
assert "palace_path" in data
def test_save_people_map(tmp_path):
cfg = MempalaceConfig(config_dir=str(tmp_path))
result = cfg.save_people_map({"alice": "Alice Smith"})
assert result.exists()
with open(result) as f:
data = json.load(f)
assert data["alice"] == "Alice Smith"
def test_env_mempal_palace_path(tmp_path):
"""MEMPAL_PALACE_PATH (legacy) should also work."""
os.environ.pop("MEMPALACE_PALACE_PATH", None)
os.environ["MEMPAL_PALACE_PATH"] = "/legacy/path"
try:
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.palace_path == "/legacy/path"
finally:
del os.environ["MEMPAL_PALACE_PATH"]
def test_collection_name_from_config(tmp_path):
(tmp_path / "config.json").write_text(
json.dumps({"collection_name": "custom_col"}), encoding="utf-8"
)
cfg = MempalaceConfig(config_dir=str(tmp_path))
assert cfg.collection_name == "custom_col"
+1 -1
View File
@@ -23,4 +23,4 @@ def test_convo_mining():
results = col.query(query_texts=["memory persistence"], n_results=1) results = col.query(query_texts=["memory persistence"], n_results=1)
assert len(results["documents"][0]) > 0 assert len(results["documents"][0]) > 0
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir, ignore_errors=True)
+102
View File
@@ -0,0 +1,102 @@
"""Unit tests for convo_miner pure functions (no chromadb needed)."""
from mempalace.convo_miner import (
chunk_exchanges,
detect_convo_room,
scan_convos,
)
class TestChunkExchanges:
def test_exchange_chunking(self):
content = (
"> What is memory?\n"
"Memory is persistence of information over time.\n\n"
"> Why does it matter?\n"
"It enables continuity across sessions and conversations.\n\n"
"> How do we build it?\n"
"With structured storage and retrieval mechanisms.\n"
)
chunks = chunk_exchanges(content)
assert len(chunks) >= 2
assert all("content" in c and "chunk_index" in c for c in chunks)
def test_paragraph_fallback(self):
"""Content without '>' lines falls back to paragraph chunking."""
content = (
"This is a long paragraph about memory systems. " * 10 + "\n\n"
"This is another paragraph about storage. " * 10 + "\n\n"
"And a third paragraph about retrieval. " * 10
)
chunks = chunk_exchanges(content)
assert len(chunks) >= 2
def test_paragraph_line_group_fallback(self):
"""Long content with no paragraph breaks chunks by line groups."""
lines = [f"Line {i}: some content that is meaningful" for i in range(60)]
content = "\n".join(lines)
chunks = chunk_exchanges(content)
assert len(chunks) >= 1
def test_empty_content(self):
chunks = chunk_exchanges("")
assert chunks == []
def test_short_content_skipped(self):
chunks = chunk_exchanges("> hi\nbye")
# Too short to produce chunks (below MIN_CHUNK_SIZE)
assert isinstance(chunks, list)
class TestDetectConvoRoom:
def test_technical_room(self):
content = "Let me debug this python function and fix the code error in the api"
assert detect_convo_room(content) == "technical"
def test_planning_room(self):
content = "We need to plan the roadmap for the next sprint and set milestone deadlines"
assert detect_convo_room(content) == "planning"
def test_architecture_room(self):
content = "The architecture uses a service layer with component interface and module design"
assert detect_convo_room(content) == "architecture"
def test_decisions_room(self):
content = "We decided to switch and migrated to the new framework after we chose it"
assert detect_convo_room(content) == "decisions"
def test_general_fallback(self):
content = "Hello, how are you doing today? The weather is nice."
assert detect_convo_room(content) == "general"
class TestScanConvos:
def test_scan_finds_txt_and_md(self, tmp_path):
(tmp_path / "chat.txt").write_text("hello", encoding="utf-8")
(tmp_path / "notes.md").write_text("world", encoding="utf-8")
(tmp_path / "image.png").write_bytes(b"fake")
files = scan_convos(str(tmp_path))
extensions = {f.suffix for f in files}
assert ".txt" in extensions
assert ".md" in extensions
assert ".png" not in extensions
def test_scan_skips_git_dir(self, tmp_path):
git_dir = tmp_path / ".git"
git_dir.mkdir()
(git_dir / "config.txt").write_text("git stuff", encoding="utf-8")
(tmp_path / "chat.txt").write_text("hello", encoding="utf-8")
files = scan_convos(str(tmp_path))
assert len(files) == 1
def test_scan_skips_meta_json(self, tmp_path):
(tmp_path / "chat.meta.json").write_text("{}", encoding="utf-8")
(tmp_path / "chat.json").write_text("{}", encoding="utf-8")
files = scan_convos(str(tmp_path))
names = [f.name for f in files]
assert "chat.json" in names
assert "chat.meta.json" not in names
def test_scan_empty_dir(self, tmp_path):
files = scan_convos(str(tmp_path))
assert files == []
+380
View File
@@ -0,0 +1,380 @@
"""Tests for mempalace.entity_detector."""
import os
from unittest.mock import patch
from mempalace.entity_detector import (
PROSE_EXTENSIONS,
STOPWORDS,
_print_entity_list,
classify_entity,
confirm_entities,
detect_entities,
extract_candidates,
scan_for_detection,
score_entity,
)
# ── extract_candidates ──────────────────────────────────────────────────
def test_extract_candidates_finds_frequent_names():
text = "Riley said hello. Riley laughed. Riley smiled. Riley waved."
result = extract_candidates(text)
assert "Riley" in result
assert result["Riley"] >= 3
def test_extract_candidates_ignores_stopwords():
# "The" appears many times but is a stopword
text = "The The The The The The"
result = extract_candidates(text)
assert "The" not in result
def test_extract_candidates_requires_min_frequency():
text = "Riley said hi. Devon waved."
result = extract_candidates(text)
# Each name appears only once, below the threshold of 3
assert "Riley" not in result
assert "Devon" not in result
def test_extract_candidates_finds_multi_word_names():
# Multi-word names need 3+ occurrences and no stopwords
text = "Claude Code is great. Claude Code rocks. Claude Code works. Claude Code rules."
result = extract_candidates(text)
assert "Claude Code" in result
def test_extract_candidates_empty_text():
result = extract_candidates("")
assert result == {}
# ── score_entity ────────────────────────────────────────────────────────
def test_score_entity_person_verbs():
text = "Riley said hello. Riley asked why. Riley told me."
lines = text.splitlines()
result = score_entity("Riley", text, lines)
assert result["person_score"] > 0
assert len(result["person_signals"]) > 0
def test_score_entity_project_verbs():
text = "We are building ChromaDB. We deployed ChromaDB. Install ChromaDB."
lines = text.splitlines()
result = score_entity("ChromaDB", text, lines)
assert result["project_score"] > 0
assert len(result["project_signals"]) > 0
def test_score_entity_dialogue_markers():
text = "Riley: Hey, how are you?\nRiley: I'm fine."
lines = text.splitlines()
result = score_entity("Riley", text, lines)
assert result["person_score"] > 0
def test_score_entity_code_ref():
text = "Check out ChromaDB.py for details. Also ChromaDB.js is good."
lines = text.splitlines()
result = score_entity("ChromaDB", text, lines)
assert result["project_score"] > 0
def test_score_entity_no_signals():
text = "Nothing interesting here at all."
lines = text.splitlines()
result = score_entity("Riley", text, lines)
assert result["person_score"] == 0
assert result["project_score"] == 0
# ── classify_entity ─────────────────────────────────────────────────────
def test_classify_entity_no_signals_gives_uncertain():
scores = {
"person_score": 0,
"project_score": 0,
"person_signals": [],
"project_signals": [],
}
result = classify_entity("Foo", 10, scores)
assert result["type"] == "uncertain"
assert result["name"] == "Foo"
def test_classify_entity_strong_project():
scores = {
"person_score": 0,
"project_score": 10,
"person_signals": [],
"project_signals": ["project verb (5x)", "code file reference (2x)"],
}
result = classify_entity("ChromaDB", 5, scores)
assert result["type"] == "project"
def test_classify_entity_strong_person_needs_two_signal_types():
scores = {
"person_score": 10,
"project_score": 0,
"person_signals": [
"dialogue marker (3x)",
"'Riley ...' action (4x)",
],
"project_signals": [],
}
result = classify_entity("Riley", 8, scores)
assert result["type"] == "person"
def test_classify_entity_pronoun_only_is_uncertain():
scores = {
"person_score": 8,
"project_score": 0,
"person_signals": ["pronoun nearby (4x)"],
"project_signals": [],
}
result = classify_entity("Riley", 5, scores)
assert result["type"] == "uncertain"
def test_classify_entity_mixed_signals():
scores = {
"person_score": 5,
"project_score": 5,
"person_signals": ["pronoun nearby (2x)"],
"project_signals": ["project verb (2x)"],
}
result = classify_entity("Lantern", 5, scores)
assert result["type"] == "uncertain"
assert "mixed signals" in result["signals"][-1]
# ── detect_entities (integration) ───────────────────────────────────────
def test_detect_entities_with_person_file(tmp_path):
f = tmp_path / "notes.txt"
content = "\n".join(
[
"Riley said hello today.",
"Riley asked about the project.",
"Riley told me she was happy.",
"Riley: I think we should go.",
"Hey Riley, thanks for the help.",
"Riley laughed and smiled.",
"Riley decided to join.",
"Riley pushed the change.",
]
)
f.write_text(content)
result = detect_entities([f])
all_names = [e["name"] for cat in result.values() for e in cat]
assert "Riley" in all_names
def test_detect_entities_with_project_file(tmp_path):
f = tmp_path / "readme.txt"
# "ChromaDB" has uppercase+lowercase mix but extract_candidates looks
# for /[A-Z][a-z]{1,19}/ — so we need a name that matches that regex.
# Use "Lantern" which matches the capitalized-word pattern.
content = "\n".join(
[
"The Lantern project is great.",
"Building Lantern was fun.",
"We deployed Lantern today.",
"Install Lantern with pip install Lantern.",
"Check Lantern.py for the source.",
"Lantern v2 is faster.",
]
)
f.write_text(content)
result = detect_entities([f])
all_names = [e["name"] for cat in result.values() for e in cat]
assert "Lantern" in all_names
def test_detect_entities_empty_files(tmp_path):
f = tmp_path / "empty.txt"
f.write_text("")
result = detect_entities([f])
assert result == {"people": [], "projects": [], "uncertain": []}
def test_detect_entities_handles_missing_file(tmp_path):
missing = tmp_path / "nonexistent.txt"
result = detect_entities([missing])
assert result == {"people": [], "projects": [], "uncertain": []}
def test_detect_entities_respects_max_files(tmp_path):
files = []
for i in range(5):
f = tmp_path / f"file{i}.txt"
f.write_text("Riley said hello. " * 10)
files.append(f)
# max_files=2 should only read 2 files
result = detect_entities(files, max_files=2)
# Should still work without error
assert isinstance(result, dict)
# ── scan_for_detection ──────────────────────────────────────────────────
def test_scan_for_detection_finds_prose(tmp_path):
(tmp_path / "notes.md").write_text("hello")
(tmp_path / "data.txt").write_text("world")
(tmp_path / "code.py").write_text("import os")
files = scan_for_detection(str(tmp_path))
extensions = {os.path.splitext(str(f))[1] for f in files}
# Prose files should be found
assert ".md" in extensions or ".txt" in extensions
def test_scan_for_detection_skips_git_dir(tmp_path):
git_dir = tmp_path / ".git"
git_dir.mkdir()
(git_dir / "config.txt").write_text("git config")
(tmp_path / "readme.md").write_text("hello")
files = scan_for_detection(str(tmp_path))
file_strs = [str(f) for f in files]
assert not any(".git" in f for f in file_strs)
# ── module-level constants ──────────────────────────────────────────────
def test_stopwords_contains_common_words():
assert "the" in STOPWORDS
assert "import" in STOPWORDS
assert "class" in STOPWORDS
def test_prose_extensions():
assert ".txt" in PROSE_EXTENSIONS
assert ".md" in PROSE_EXTENSIONS
# ── _print_entity_list ─────────────────────────────────────────────────
def test_print_entity_list_with_entities(capsys):
entities = [
{"name": "Alice", "confidence": 0.9, "signals": ["dialogue marker (3x)"]},
{"name": "Bob", "confidence": 0.5, "signals": []},
]
_print_entity_list(entities, "PEOPLE")
out = capsys.readouterr().out
assert "PEOPLE" in out
assert "Alice" in out
assert "Bob" in out
def test_print_entity_list_empty(capsys):
_print_entity_list([], "PEOPLE")
out = capsys.readouterr().out
assert "none detected" in out
# ── confirm_entities ───────────────────────────────────────────────────
def test_confirm_entities_yes_mode():
detected = {
"people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}],
"projects": [{"name": "Acme", "confidence": 0.8, "signals": ["test"]}],
"uncertain": [{"name": "Foo", "confidence": 0.4, "signals": ["test"]}],
}
result = confirm_entities(detected, yes=True)
assert result["people"] == ["Alice"]
assert result["projects"] == ["Acme"]
def test_confirm_entities_accept_all():
detected = {
"people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}],
"projects": [],
"uncertain": [],
}
with patch("builtins.input", side_effect=["", "n"]):
result = confirm_entities(detected, yes=False)
assert "Alice" in result["people"]
def test_confirm_entities_edit_reclassify_uncertain():
detected = {
"people": [],
"projects": [],
"uncertain": [
{"name": "Foo", "confidence": 0.4, "signals": ["test"]},
{"name": "Bar", "confidence": 0.4, "signals": ["test"]},
],
}
with patch(
"builtins.input",
side_effect=[
"edit", # choice
"p", # Foo -> person
"s", # Bar -> skip
"", # no removals from people
"", # no removals from projects
"n", # don't add missing
],
):
result = confirm_entities(detected, yes=False)
assert "Foo" in result["people"]
assert "Bar" not in result["people"]
assert "Bar" not in result["projects"]
def test_confirm_entities_add_mode():
detected = {
"people": [],
"projects": [],
"uncertain": [],
}
with patch(
"builtins.input",
side_effect=[
"add", # choice = add
"NewPerson", # name
"p", # person
"NewProj", # name
"r", # project
"", # stop adding
],
):
result = confirm_entities(detected, yes=False)
assert "NewPerson" in result["people"]
assert "NewProj" in result["projects"]
# ── scan_for_detection fallback ────────────────────────────────────────
def test_scan_for_detection_fallback_to_all_readable(tmp_path):
"""When fewer than 3 prose files, falls back to include all readable files."""
(tmp_path / "one.md").write_text("hello")
(tmp_path / "two.txt").write_text("world")
# Only 2 prose files, so it should also include code files
(tmp_path / "code.py").write_text("import os")
(tmp_path / "app.js").write_text("console.log()")
files = scan_for_detection(str(tmp_path))
extensions = {os.path.splitext(str(f))[1] for f in files}
assert ".py" in extensions or ".js" in extensions
def test_scan_for_detection_max_files(tmp_path):
"""Caps to max_files."""
for i in range(20):
(tmp_path / f"note{i}.md").write_text(f"content {i}")
files = scan_for_detection(str(tmp_path), max_files=5)
assert len(files) <= 5
+313
View File
@@ -0,0 +1,313 @@
"""Tests for mempalace.entity_registry."""
from unittest.mock import patch
from mempalace.entity_registry import (
COMMON_ENGLISH_WORDS,
PERSON_CONTEXT_PATTERNS,
EntityRegistry,
)
# ── COMMON_ENGLISH_WORDS ────────────────────────────────────────────────
def test_common_english_words_has_expected_entries():
assert "ever" in COMMON_ENGLISH_WORDS
assert "grace" in COMMON_ENGLISH_WORDS
assert "will" in COMMON_ENGLISH_WORDS
assert "may" in COMMON_ENGLISH_WORDS
assert "monday" in COMMON_ENGLISH_WORDS
def test_common_english_words_is_lowercase():
for word in COMMON_ENGLISH_WORDS:
assert word == word.lower(), f"{word} should be lowercase"
# ── PERSON_CONTEXT_PATTERNS ─────────────────────────────────────────────
def test_person_context_patterns_is_nonempty():
assert len(PERSON_CONTEXT_PATTERNS) > 0
# ── EntityRegistry creation and empty state ─────────────────────────────
def test_load_from_nonexistent_dir(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
assert registry.people == {}
assert registry.projects == []
assert registry.mode == "personal"
assert registry.ambiguous_flags == []
def test_save_and_load_roundtrip(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="work",
people=[{"name": "Alice", "relationship": "colleague", "context": "work"}],
projects=["MemPalace"],
)
# Load again from same dir
loaded = EntityRegistry.load(config_dir=tmp_path)
assert loaded.mode == "work"
assert "Alice" in loaded.people
assert "MemPalace" in loaded.projects
def test_save_creates_file(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.save()
assert (tmp_path / "entity_registry.json").exists()
# ── seed ────────────────────────────────────────────────────────────────
def test_seed_registers_people(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[
{"name": "Riley", "relationship": "daughter", "context": "personal"},
{"name": "Devon", "relationship": "friend", "context": "personal"},
],
projects=["MemPalace"],
)
assert "Riley" in registry.people
assert "Devon" in registry.people
assert registry.people["Riley"]["relationship"] == "daughter"
assert registry.people["Riley"]["source"] == "onboarding"
assert registry.people["Riley"]["confidence"] == 1.0
def test_seed_registers_projects(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="work", people=[], projects=["Acme", "Widget"])
assert registry.projects == ["Acme", "Widget"]
def test_seed_sets_mode(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="combo", people=[], projects=[])
assert registry.mode == "combo"
def test_seed_flags_ambiguous_names(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[
{"name": "Grace", "relationship": "friend", "context": "personal"},
{"name": "Riley", "relationship": "daughter", "context": "personal"},
],
projects=[],
)
assert "grace" in registry.ambiguous_flags
# Riley is not a common English word
assert "riley" not in registry.ambiguous_flags
def test_seed_with_aliases(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}],
projects=[],
aliases={"Max": "Maxwell"},
)
assert "Maxwell" in registry.people
assert "Max" in registry.people
assert registry.people["Max"].get("canonical") == "Maxwell"
def test_seed_skips_empty_names(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "", "relationship": "", "context": "personal"}],
projects=[],
)
assert len(registry.people) == 0
# ── lookup ──────────────────────────────────────────────────────────────
def test_lookup_known_person(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
projects=[],
)
result = registry.lookup("Riley")
assert result["type"] == "person"
assert result["confidence"] == 1.0
assert result["name"] == "Riley"
def test_lookup_known_project(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="work", people=[], projects=["MemPalace"])
result = registry.lookup("MemPalace")
assert result["type"] == "project"
assert result["confidence"] == 1.0
def test_lookup_unknown_word(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="personal", people=[], projects=[])
result = registry.lookup("Xyzzy")
assert result["type"] == "unknown"
assert result["confidence"] == 0.0
def test_lookup_case_insensitive(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
projects=[],
)
result = registry.lookup("riley")
assert result["type"] == "person"
def test_lookup_alias(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}],
projects=[],
aliases={"Max": "Maxwell"},
)
result = registry.lookup("Max")
assert result["type"] == "person"
# ── disambiguation ──────────────────────────────────────────────────────
def test_lookup_ambiguous_word_as_person(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Grace", "relationship": "friend", "context": "personal"}],
projects=[],
)
result = registry.lookup("Grace", context="I went with Grace today")
assert result["type"] == "person"
def test_lookup_ambiguous_word_as_concept(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Ever", "relationship": "friend", "context": "personal"}],
projects=[],
)
result = registry.lookup("Ever", context="have you ever tried this")
assert result["type"] == "concept"
# ── research (Wikipedia) — mocked ──────────────────────────────────────
def test_research_caches_result(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="personal", people=[], projects=[])
mock_result = {
"inferred_type": "person",
"confidence": 0.80,
"wiki_summary": "Saoirse is an Irish given name.",
"wiki_title": "Saoirse",
}
with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result):
result = registry.research("Saoirse", auto_confirm=True)
assert result["inferred_type"] == "person"
# Second call should use cache, not call Wikipedia again
with patch(
"mempalace.entity_registry._wikipedia_lookup",
side_effect=AssertionError("should not be called"),
):
cached = registry.research("Saoirse")
assert cached["inferred_type"] == "person"
def test_confirm_research_adds_to_people(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="personal", people=[], projects=[])
mock_result = {
"inferred_type": "person",
"confidence": 0.80,
"wiki_summary": "Saoirse is a name",
"wiki_title": "Saoirse",
}
with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result):
registry.research("Saoirse", auto_confirm=False)
registry.confirm_research("Saoirse", entity_type="person", relationship="friend")
assert "Saoirse" in registry.people
assert registry.people["Saoirse"]["source"] == "wiki"
# ── extract_people_from_query ───────────────────────────────────────────
def test_extract_people_from_query(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[
{"name": "Riley", "relationship": "daughter", "context": "personal"},
{"name": "Devon", "relationship": "friend", "context": "personal"},
],
projects=[],
)
found = registry.extract_people_from_query("What did Riley say about the weather?")
assert "Riley" in found
assert "Devon" not in found
# ── extract_unknown_candidates ──────────────────────────────────────────
def test_extract_unknown_candidates(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(mode="personal", people=[], projects=[])
unknowns = registry.extract_unknown_candidates("Saoirse went to the store")
assert "Saoirse" in unknowns
def test_extract_unknown_candidates_skips_known(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
projects=[],
)
unknowns = registry.extract_unknown_candidates("Riley went to the store")
assert "Riley" not in unknowns
# ── summary ─────────────────────────────────────────────────────────────
def test_summary(tmp_path):
registry = EntityRegistry.load(config_dir=tmp_path)
registry.seed(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
projects=["MemPalace"],
)
s = registry.summary()
assert "personal" in s
assert "Riley" in s
assert "MemPalace" in s
+248
View File
@@ -0,0 +1,248 @@
"""Tests for mempalace.general_extractor."""
from mempalace.general_extractor import (
ALL_MARKERS,
NEGATIVE_WORDS,
POSITIVE_WORDS,
_extract_prose,
_get_sentiment,
_has_resolution,
_is_code_line,
_score_markers,
_split_into_segments,
extract_memories,
)
# ── extract_memories — empty / no markers ───────────────────────────────
def test_extract_memories_empty_text():
result = extract_memories("")
assert result == []
def test_extract_memories_no_markers():
result = extract_memories("The quick brown fox jumped over the lazy dog.")
assert result == []
def test_extract_memories_short_text_skipped():
# Paragraphs shorter than 20 chars are skipped
result = extract_memories("ok sure")
assert result == []
# ── extract_memories — decision markers ─────────────────────────────────
def test_extract_memories_decision():
text = (
"We decided to go with PostgreSQL instead of MySQL "
"because the performance was better for our use case. "
"The trade-off was more complexity in setup."
)
result = extract_memories(text)
assert len(result) >= 1
assert any(m["memory_type"] == "decision" for m in result)
# ── extract_memories — preference markers ───────────────────────────────
def test_extract_memories_preference():
text = (
"I prefer using snake_case in Python code. "
"Please always use type hints. "
"Never use wildcard imports."
)
result = extract_memories(text)
assert len(result) >= 1
assert any(m["memory_type"] == "preference" for m in result)
# ── extract_memories — milestone markers ────────────────────────────────
def test_extract_memories_milestone():
text = (
"It finally works! After three days of debugging, "
"I figured out the issue. The breakthrough was realizing "
"the config file was cached. Got it working at 2am."
)
result = extract_memories(text)
assert len(result) >= 1
assert any(m["memory_type"] == "milestone" for m in result)
# ── extract_memories — problem markers ──────────────────────────────────
def test_extract_memories_problem():
text = (
"There's a critical bug in the auth module. "
"The error keeps crashing the server. "
"The root cause was a missing null check. "
"The problem is that tokens expire silently."
)
result = extract_memories(text)
assert len(result) >= 1
types = {m["memory_type"] for m in result}
assert "problem" in types or "milestone" in types # resolved problems become milestones
# ── extract_memories — emotional markers ────────────────────────────────
def test_extract_memories_emotional():
text = (
"I feel so proud of what we built together. "
"I love working on this project, it makes me happy. "
"I'm grateful for the team and the beautiful code we wrote."
)
result = extract_memories(text)
assert len(result) >= 1
assert any(m["memory_type"] == "emotional" for m in result)
# ── extract_memories — chunk_index ──────────────────────────────────────
def test_extract_memories_chunk_index_increments():
text = (
"We decided to use React because it fits our team.\n\n"
"I prefer functional components always.\n\n"
"It works! We finally shipped the v1.0 release."
)
result = extract_memories(text)
if len(result) >= 2:
indices = [m["chunk_index"] for m in result]
assert indices == list(range(len(result)))
# ── _score_markers ──────────────────────────────────────────────────────
def test_score_markers_with_matches():
score, keywords = _score_markers(
"we decided to go with postgres because it is faster",
ALL_MARKERS["decision"],
)
assert score > 0
assert len(keywords) > 0
def test_score_markers_no_matches():
score, keywords = _score_markers("nothing relevant here", ALL_MARKERS["decision"])
assert score == 0.0
# ── _get_sentiment ──────────────────────────────────────────────────────
def test_get_sentiment_positive():
assert _get_sentiment("I am so happy and proud of this breakthrough") == "positive"
def test_get_sentiment_negative():
assert _get_sentiment("This bug caused a crash and total failure") == "negative"
def test_get_sentiment_neutral():
assert _get_sentiment("The meeting is at three") == "neutral"
# ── _has_resolution ─────────────────────────────────────────────────────
def test_has_resolution_true():
assert _has_resolution("I fixed the auth bug and it works now") is True
def test_has_resolution_false():
assert _has_resolution("The server keeps crashing") is False
# ── _is_code_line ───────────────────────────────────────────────────────
def test_is_code_line_detects_code():
assert _is_code_line(" import os") is True
assert _is_code_line(" $ pip install flask") is True
assert _is_code_line(" ```python") is True
def test_is_code_line_allows_prose():
assert _is_code_line("This is a regular sentence about coding.") is False
assert _is_code_line("") is False
# ── _extract_prose ──────────────────────────────────────────────────────
def test_extract_prose_strips_code_blocks():
text = "Hello world\n```\nimport os\nprint('hi')\n```\nGoodbye"
result = _extract_prose(text)
assert "import os" not in result
assert "Hello world" in result
assert "Goodbye" in result
def test_extract_prose_returns_original_if_all_code():
text = "import os\nfrom sys import argv"
result = _extract_prose(text)
# Falls back to original text if nothing left
assert len(result) > 0
# ── _split_into_segments ───────────────────────────────────────────────
def test_split_into_segments_by_paragraph():
text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph."
result = _split_into_segments(text)
assert len(result) == 3
def test_split_into_segments_by_turns():
lines = []
for i in range(5):
lines.append(f"Human: Question {i}")
lines.append(f"Assistant: Answer {i}")
text = "\n".join(lines)
result = _split_into_segments(text)
assert len(result) >= 3 # turn-based splitting should fire
def test_split_into_segments_single_block():
# Many lines without double-newline produces chunked segments
lines = [f"Line {i} of the document" for i in range(30)]
text = "\n".join(lines)
result = _split_into_segments(text)
assert len(result) >= 1
# ── ALL_MARKERS constant ───────────────────────────────────────────────
def test_all_markers_has_five_types():
assert set(ALL_MARKERS.keys()) == {
"decision",
"preference",
"milestone",
"problem",
"emotional",
}
# ── POSITIVE_WORDS / NEGATIVE_WORDS ────────────────────────────────────
def test_positive_words():
assert "happy" in POSITIVE_WORDS
assert "proud" in POSITIVE_WORDS
def test_negative_words():
assert "bug" in NEGATIVE_WORDS
assert "crash" in NEGATIVE_WORDS
+420
View File
@@ -0,0 +1,420 @@
import contextlib
import io
import json
from pathlib import Path
from unittest.mock import patch
import pytest
from mempalace.hooks_cli import (
SAVE_INTERVAL,
STOP_BLOCK_REASON,
PRECOMPACT_BLOCK_REASON,
_count_human_messages,
_log,
_maybe_auto_ingest,
_parse_harness_input,
_sanitize_session_id,
hook_stop,
hook_session_start,
hook_precompact,
run_hook,
)
# --- _sanitize_session_id ---
def test_sanitize_normal_id():
assert _sanitize_session_id("abc-123_XYZ") == "abc-123_XYZ"
def test_sanitize_strips_dangerous_chars():
assert _sanitize_session_id("../../etc/passwd") == "etcpasswd"
def test_sanitize_empty_returns_unknown():
assert _sanitize_session_id("") == "unknown"
assert _sanitize_session_id("!!!") == "unknown"
# --- _count_human_messages ---
def _write_transcript(path: Path, entries: list[dict]):
with open(path, "w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry) + "\n")
def test_count_human_messages_basic(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[
{"message": {"role": "user", "content": "hello"}},
{"message": {"role": "assistant", "content": "hi"}},
{"message": {"role": "user", "content": "bye"}},
],
)
assert _count_human_messages(str(transcript)) == 2
def test_count_skips_command_messages(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[
{"message": {"role": "user", "content": "<command-message>status</command-message>"}},
{"message": {"role": "user", "content": "real question"}},
],
)
assert _count_human_messages(str(transcript)) == 1
def test_count_handles_list_content(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[
{"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}},
{
"message": {
"role": "user",
"content": [{"type": "text", "text": "<command-message>x</command-message>"}],
}
},
],
)
assert _count_human_messages(str(transcript)) == 1
def test_count_missing_file():
assert _count_human_messages("/nonexistent/path.jsonl") == 0
def test_count_empty_file(tmp_path):
transcript = tmp_path / "t.jsonl"
transcript.write_text("")
assert _count_human_messages(str(transcript)) == 0
def test_count_malformed_json_lines(tmp_path):
transcript = tmp_path / "t.jsonl"
transcript.write_text('not json\n{"message": {"role": "user", "content": "ok"}}\n')
assert _count_human_messages(str(transcript)) == 1
# --- hook_stop ---
def _capture_hook_output(hook_fn, data, harness="claude-code", state_dir=None):
"""Run a hook and capture its JSON stdout output."""
import io
buf = io.StringIO()
patches = [patch("mempalace.hooks_cli._output", side_effect=lambda d: buf.write(json.dumps(d)))]
if state_dir:
patches.append(patch("mempalace.hooks_cli.STATE_DIR", state_dir))
with contextlib.ExitStack() as stack:
for p in patches:
stack.enter_context(p)
hook_fn(data, harness)
return json.loads(buf.getvalue())
def test_stop_hook_passthrough_when_active(tmp_path):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": True, "transcript_path": ""},
state_dir=tmp_path,
)
assert result == {}
def test_stop_hook_passthrough_when_active_string(tmp_path):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": "true", "transcript_path": ""},
state_dir=tmp_path,
)
assert result == {}
def test_stop_hook_passthrough_below_interval(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL - 1)],
)
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
state_dir=tmp_path,
)
assert result == {}
def test_stop_hook_blocks_at_interval(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
)
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
state_dir=tmp_path,
)
assert result["decision"] == "block"
assert result["reason"] == STOP_BLOCK_REASON
def test_stop_hook_tracks_save_point(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
)
data = {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}
# First call blocks
result = _capture_hook_output(hook_stop, data, state_dir=tmp_path)
assert result["decision"] == "block"
# Second call with same count passes through (already saved)
result = _capture_hook_output(hook_stop, data, state_dir=tmp_path)
assert result == {}
# --- hook_session_start ---
def test_session_start_passes_through(tmp_path):
result = _capture_hook_output(
hook_session_start,
{"session_id": "test"},
state_dir=tmp_path,
)
assert result == {}
# --- hook_precompact ---
def test_precompact_always_blocks(tmp_path):
result = _capture_hook_output(
hook_precompact,
{"session_id": "test"},
state_dir=tmp_path,
)
assert result["decision"] == "block"
assert result["reason"] == PRECOMPACT_BLOCK_REASON
# --- _log ---
def test_log_writes_to_hook_log(tmp_path):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
_log("test message")
log_path = tmp_path / "hook.log"
assert log_path.is_file()
content = log_path.read_text()
assert "test message" in content
def test_log_oserror_is_silenced(tmp_path):
"""_log should not raise if the directory cannot be created."""
with patch("mempalace.hooks_cli.STATE_DIR", Path("/nonexistent/deeply/nested/dir")):
# Should not raise
_log("this will fail silently")
# --- _maybe_auto_ingest ---
def test_maybe_auto_ingest_no_env(tmp_path):
"""Without MEMPAL_DIR set, does nothing."""
with patch.dict("os.environ", {}, clear=True):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
_maybe_auto_ingest() # should not raise
def test_maybe_auto_ingest_with_env(tmp_path):
"""With MEMPAL_DIR set to a valid directory, spawns subprocess."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli.subprocess.Popen") as mock_popen:
_maybe_auto_ingest()
mock_popen.assert_called_once()
def test_maybe_auto_ingest_oserror(tmp_path):
"""OSError during subprocess spawn is silenced."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli.subprocess.Popen", side_effect=OSError("fail")):
_maybe_auto_ingest() # should not raise
# --- _parse_harness_input ---
def test_parse_harness_input_unknown():
"""Unknown harness should sys.exit(1)."""
with pytest.raises(SystemExit) as exc_info:
_parse_harness_input({"session_id": "test"}, "unknown-harness")
assert exc_info.value.code == 1
def test_parse_harness_input_valid():
result = _parse_harness_input(
{"session_id": "abc-123", "stop_hook_active": True, "transcript_path": "/tmp/t.jsonl"},
"claude-code",
)
assert result["session_id"] == "abc-123"
assert result["stop_hook_active"] is True
# --- hook_stop with OSError on write ---
def test_stop_hook_oserror_on_last_save_read(tmp_path):
"""When last_save_file has invalid content, falls back to 0."""
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
)
# Write invalid content to last save file
(tmp_path / "test_last_save").write_text("not_a_number")
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
state_dir=tmp_path,
)
assert result["decision"] == "block"
def test_stop_hook_oserror_on_write(tmp_path):
"""When write to last_save_file fails, hook still outputs correctly."""
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript,
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
)
def bad_write_text(*args, **kwargs):
raise OSError("disk full")
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch.object(Path, "write_text", bad_write_text):
result = _capture_hook_output(
hook_stop,
{
"session_id": "test",
"stop_hook_active": False,
"transcript_path": str(transcript),
},
state_dir=tmp_path,
)
assert result["decision"] == "block"
# --- hook_precompact with MEMPAL_DIR ---
def test_precompact_with_mempal_dir(tmp_path):
"""Precompact runs subprocess.run when MEMPAL_DIR is set."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.subprocess.run") as mock_run:
result = _capture_hook_output(
hook_precompact,
{"session_id": "test"},
state_dir=tmp_path,
)
assert result["decision"] == "block"
mock_run.assert_called_once()
def test_precompact_with_mempal_dir_oserror(tmp_path):
"""Precompact handles OSError from subprocess gracefully."""
mempal_dir = tmp_path / "project"
mempal_dir.mkdir()
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
with patch("mempalace.hooks_cli.subprocess.run", side_effect=OSError("fail")):
result = _capture_hook_output(
hook_precompact,
{"session_id": "test"},
state_dir=tmp_path,
)
assert result["decision"] == "block"
# --- run_hook ---
def test_run_hook_dispatches_session_start(tmp_path):
"""run_hook reads stdin JSON and dispatches to correct handler."""
stdin_data = json.dumps({"session_id": "run-test"})
with patch("sys.stdin", io.StringIO(stdin_data)):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("session-start", "claude-code")
mock_output.assert_called_once_with({})
def test_run_hook_dispatches_stop(tmp_path):
transcript = tmp_path / "t.jsonl"
_write_transcript(
transcript, [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(3)]
)
stdin_data = json.dumps(
{
"session_id": "run-test",
"stop_hook_active": False,
"transcript_path": str(transcript),
}
)
with patch("sys.stdin", io.StringIO(stdin_data)):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("stop", "claude-code")
mock_output.assert_called_once_with({})
def test_run_hook_dispatches_precompact(tmp_path):
stdin_data = json.dumps({"session_id": "run-test"})
with patch("sys.stdin", io.StringIO(stdin_data)):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("precompact", "claude-code")
mock_output.assert_called_once()
call_args = mock_output.call_args[0][0]
assert call_args["decision"] == "block"
def test_run_hook_unknown_hook():
stdin_data = json.dumps({"session_id": "test"})
with patch("sys.stdin", io.StringIO(stdin_data)):
with pytest.raises(SystemExit) as exc_info:
run_hook("nonexistent", "claude-code")
assert exc_info.value.code == 1
def test_run_hook_invalid_json(tmp_path):
"""Invalid stdin JSON should not crash — falls back to empty dict."""
with patch("sys.stdin", io.StringIO("not valid json")):
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
with patch("mempalace.hooks_cli._output") as mock_output:
run_hook("session-start", "claude-code")
mock_output.assert_called_once_with({})
+45
View File
@@ -0,0 +1,45 @@
"""Tests for mempalace.instructions_cli — instruction text output."""
from unittest.mock import patch
import pytest
from mempalace.instructions_cli import AVAILABLE, INSTRUCTIONS_DIR, run_instructions
def test_run_instructions_valid_name(capsys):
"""Valid name prints the .md file content."""
name = "init"
expected = (INSTRUCTIONS_DIR / f"{name}.md").read_text()
run_instructions(name)
captured = capsys.readouterr()
assert captured.out.strip() == expected.strip()
def test_run_instructions_all_available(capsys):
"""Every name in AVAILABLE should succeed without error."""
for name in AVAILABLE:
run_instructions(name)
out = capsys.readouterr().out
assert len(out) > 0
def test_run_instructions_invalid_name(capsys):
"""Invalid name should sys.exit(1) and print error to stderr."""
with pytest.raises(SystemExit) as exc_info:
run_instructions("nonexistent")
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "Unknown instructions: nonexistent" in captured.err
assert "Available:" in captured.err
def test_run_instructions_missing_md_file(capsys, tmp_path):
"""If the .md file is missing on disk, should sys.exit(1)."""
with patch("mempalace.instructions_cli.INSTRUCTIONS_DIR", tmp_path):
with patch("mempalace.instructions_cli.AVAILABLE", ["fakecmd"]):
with pytest.raises(SystemExit) as exc_info:
run_instructions("fakecmd")
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "Instructions file not found" in captured.err
+105
View File
@@ -0,0 +1,105 @@
"""Extra knowledge graph tests for seed_from_entity_facts and query_relationship."""
import pytest
from mempalace.knowledge_graph import KnowledgeGraph
@pytest.fixture
def kg(tmp_path):
return KnowledgeGraph(db_path=str(tmp_path / "kg.db"))
class TestSeedFromEntityFacts:
def test_seed_person_with_partner(self, kg):
facts = {
"alice": {
"full_name": "Alice Smith",
"type": "person",
"gender": "female",
"partner": "bob",
"relationship": "husband",
}
}
kg.seed_from_entity_facts(facts)
stats = kg.stats()
assert stats["entities"] >= 1
results = kg.query_entity("Alice Smith", direction="outgoing")
predicates = {r["predicate"] for r in results}
assert "married_to" in predicates
assert "is_partner_of" in predicates
def test_seed_child(self, kg):
facts = {
"max": {
"full_name": "Max",
"type": "person",
"birthday": "2015-04-01",
"parent": "alice",
"relationship": "daughter",
}
}
kg.seed_from_entity_facts(facts)
results = kg.query_entity("Max", direction="outgoing")
predicates = {r["predicate"] for r in results}
assert "child_of" in predicates
assert "is_child_of" in predicates
def test_seed_sibling(self, kg):
facts = {
"emma": {
"full_name": "Emma",
"type": "person",
"relationship": "brother",
"sibling": "max",
}
}
kg.seed_from_entity_facts(facts)
results = kg.query_entity("Emma", direction="outgoing")
predicates = {r["predicate"] for r in results}
assert "is_sibling_of" in predicates
def test_seed_dog(self, kg):
facts = {
"rex": {
"full_name": "Rex",
"type": "animal",
"relationship": "dog",
"owner": "alice",
}
}
kg.seed_from_entity_facts(facts)
results = kg.query_entity("Rex", direction="outgoing")
predicates = {r["predicate"] for r in results}
assert "is_pet_of" in predicates
def test_seed_with_interests(self, kg):
facts = {
"max": {
"full_name": "Max",
"type": "person",
"interests": ["swimming", "chess"],
}
}
kg.seed_from_entity_facts(facts)
results = kg.query_entity("Max", direction="outgoing")
objects = {r["object"] for r in results if r["predicate"] == "loves"}
assert "Swimming" in objects
assert "Chess" in objects
def test_seed_minimal_facts(self, kg):
"""Facts with no relationships just create entities."""
facts = {"bob": {"full_name": "Bob"}}
kg.seed_from_entity_facts(facts)
stats = kg.stats()
assert stats["entities"] >= 1
class TestQueryRelationshipWithTime:
def test_query_relationship_with_as_of(self, kg):
kg.add_triple("Alice", "works_at", "Acme", valid_from="2020-01-01", valid_to="2024-12-31")
kg.add_triple("Alice", "works_at", "NewCo", valid_from="2025-01-01")
results = kg.query_relationship("works_at", as_of="2023-06-01")
objects = [r["object"] for r in results]
assert "Acme" in objects
assert "NewCo" not in objects
+719
View File
@@ -0,0 +1,719 @@
"""Tests for mempalace.layers — Layer0, Layer1, Layer2, Layer3, MemoryStack."""
import os
from unittest.mock import MagicMock, patch
from mempalace.layers import Layer0, Layer1, Layer2, Layer3, MemoryStack
# ── Layer0 — with identity file ─────────────────────────────────────────
def test_layer0_reads_identity_file(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas, a personal AI assistant for Alice.")
layer = Layer0(identity_path=str(identity_file))
text = layer.render()
assert "Atlas" in text
assert "Alice" in text
def test_layer0_caches_text(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("Hello world")
layer = Layer0(identity_path=str(identity_file))
first = layer.render()
identity_file.write_text("Changed content")
second = layer.render()
assert first == second
assert second == "Hello world"
def test_layer0_missing_file_returns_default(tmp_path):
missing = str(tmp_path / "nonexistent.txt")
layer = Layer0(identity_path=missing)
text = layer.render()
assert "No identity configured" in text
assert "identity.txt" in text
def test_layer0_token_estimate(tmp_path):
identity_file = tmp_path / "identity.txt"
content = "A" * 400
identity_file.write_text(content)
layer = Layer0(identity_path=str(identity_file))
estimate = layer.token_estimate()
assert estimate == 100
def test_layer0_token_estimate_empty(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("")
layer = Layer0(identity_path=str(identity_file))
assert layer.token_estimate() == 0
def test_layer0_strips_whitespace(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text(" Hello world \n\n")
layer = Layer0(identity_path=str(identity_file))
text = layer.render()
assert text == "Hello world"
def test_layer0_default_path():
layer = Layer0()
expected = os.path.expanduser("~/.mempalace/identity.txt")
assert layer.path == expected
# ── Layer1 — mocked chromadb ────────────────────────────────────────────
def _mock_chromadb_for_layer(docs, metas, monkeypatch=None):
"""Return a mock PersistentClient whose collection.get returns docs/metas."""
mock_col = MagicMock()
# First batch returns data, second batch returns empty (end of pagination)
mock_col.get.side_effect = [
{"documents": docs, "metadatas": metas},
{"documents": [], "metadatas": []},
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
return mock_client
def test_layer1_no_palace():
"""Layer1 returns helpful message when no palace exists."""
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent/palace"
layer = Layer1(palace_path="/nonexistent/palace")
result = layer.generate()
assert "No palace found" in result or "No memories" in result
def test_layer1_generates_essential_story():
docs = [
"Important memory about project decisions",
"Key architectural choice for the backend",
]
metas = [
{"room": "decisions", "source_file": "meeting.txt", "importance": 5},
{"room": "architecture", "source_file": "design.txt", "importance": 4},
]
mock_client = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
result = layer.generate()
assert "ESSENTIAL STORY" in result
assert "project decisions" in result
def test_layer1_empty_palace():
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
result = layer.generate()
assert "No memories" in result
def test_layer1_with_wing_filter():
docs = ["Memory about project X"]
metas = [{"room": "general", "source_file": "x.txt", "importance": 3}]
mock_client = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake", wing="project_x")
result = layer.generate()
assert "ESSENTIAL STORY" in result
# Verify wing filter was passed
call_kwargs = mock_client.get_collection.return_value.get.call_args_list[0][1]
assert call_kwargs.get("where") == {"wing": "project_x"}
def test_layer1_truncates_long_snippets():
docs = ["A" * 300]
metas = [{"room": "general", "source_file": "long.txt"}]
mock_client = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
result = layer.generate()
assert "..." in result
def test_layer1_respects_max_chars():
"""L1 stops adding entries once MAX_CHARS is reached."""
docs = [f"Memory number {i} with substantial content padding here" for i in range(30)]
metas = [{"room": "general", "source_file": f"f{i}.txt", "importance": 5} for i in range(30)]
mock_client = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
layer.MAX_CHARS = 200 # Very low cap to trigger truncation
result = layer.generate()
assert "more in L3 search" in result
def test_layer1_importance_from_various_keys():
"""Layer1 tries importance, emotional_weight, weight keys."""
docs = ["mem1", "mem2", "mem3"]
metas = [
{"room": "r", "emotional_weight": 5},
{"room": "r", "weight": 1},
{"room": "r"}, # no weight key, defaults to 3
]
mock_client = _mock_chromadb_for_layer(docs, metas)
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
result = layer.generate()
assert "ESSENTIAL STORY" in result
def test_layer1_batch_exception_breaks():
"""If col.get raises on a batch, loop breaks gracefully."""
mock_col = MagicMock()
mock_col.get.side_effect = [
{"documents": ["doc1"], "metadatas": [{"room": "r"}]},
RuntimeError("batch error"),
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer1(palace_path="/fake")
result = layer.generate()
assert "ESSENTIAL STORY" in result
# ── Layer2 — mocked chromadb ────────────────────────────────────────────
def test_layer2_no_palace():
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent/palace"
layer = Layer2(palace_path="/nonexistent/palace")
result = layer.retrieve(wing="test")
assert "No palace found" in result
def test_layer2_retrieve_with_wing():
mock_col = MagicMock()
mock_col.get.return_value = {
"documents": ["Some memory about the project"],
"metadatas": [{"room": "backend", "source_file": "notes.txt"}],
}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(wing="project")
assert "ON-DEMAND" in result
assert "memory about the project" in result
def test_layer2_retrieve_with_room():
mock_col = MagicMock()
mock_col.get.return_value = {
"documents": ["Backend architecture notes"],
"metadatas": [{"room": "architecture", "source_file": "arch.txt"}],
}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(room="architecture")
assert "ON-DEMAND" in result
def test_layer2_retrieve_wing_and_room():
mock_col = MagicMock()
mock_col.get.return_value = {
"documents": ["Filtered result"],
"metadatas": [{"room": "backend", "source_file": "x.txt"}],
}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(wing="proj", room="backend")
assert "ON-DEMAND" in result
call_kwargs = mock_col.get.call_args[1]
assert "$and" in call_kwargs.get("where", {})
def test_layer2_retrieve_empty():
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(wing="missing")
assert "No drawers found" in result
def test_layer2_retrieve_no_filter():
mock_col = MagicMock()
mock_col.get.return_value = {"documents": [], "metadatas": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
layer.retrieve()
# No where filter should be passed
call_kwargs = mock_col.get.call_args[1]
assert "where" not in call_kwargs
def test_layer2_retrieve_error():
mock_col = MagicMock()
mock_col.get.side_effect = RuntimeError("db error")
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(wing="test")
assert "Retrieval error" in result
def test_layer2_truncates_long_snippets():
mock_col = MagicMock()
mock_col.get.return_value = {
"documents": ["B" * 400],
"metadatas": [{"room": "r", "source_file": "s.txt"}],
}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer2(palace_path="/fake")
result = layer.retrieve(wing="test")
assert "..." in result
# ── Layer3 — mocked chromadb ────────────────────────────────────────────
def _mock_query_results(docs, metas, dists):
return {
"documents": [docs],
"metadatas": [metas],
"distances": [dists],
}
def test_layer3_no_palace():
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent/palace"
layer = Layer3(palace_path="/nonexistent/palace")
result = layer.search("test query")
assert "No palace found" in result
def test_layer3_search_raw_no_palace():
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent/palace"
layer = Layer3(palace_path="/nonexistent/palace")
result = layer.search_raw("test query")
assert result == []
def test_layer3_search_with_results():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["Found this important memory"],
[{"wing": "project", "room": "backend", "source_file": "notes.txt"}],
[0.2],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
result = layer.search("important")
assert "SEARCH RESULTS" in result
assert "important memory" in result
assert "sim=0.8" in result
def test_layer3_search_no_results():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results([], [], [])
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
result = layer.search("nothing")
assert "No results found" in result
def test_layer3_search_with_wing_filter():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["result"],
[{"wing": "proj", "room": "r"}],
[0.1],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
layer.search("q", wing="proj")
call_kwargs = mock_col.query.call_args[1]
assert call_kwargs["where"] == {"wing": "proj"}
def test_layer3_search_with_room_filter():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["result"],
[{"wing": "w", "room": "backend"}],
[0.1],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
layer.search("q", room="backend")
call_kwargs = mock_col.query.call_args[1]
assert call_kwargs["where"] == {"room": "backend"}
def test_layer3_search_with_wing_and_room():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["result"],
[{"wing": "proj", "room": "backend"}],
[0.1],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
layer.search("q", wing="proj", room="backend")
call_kwargs = mock_col.query.call_args[1]
assert "$and" in call_kwargs["where"]
def test_layer3_search_error():
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("search failed")
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
result = layer.search("q")
assert "Search error" in result
def test_layer3_search_truncates_long_docs():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["C" * 400],
[{"wing": "w", "room": "r", "source_file": "s.txt"}],
[0.1],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
result = layer.search("q")
assert "..." in result
def test_layer3_search_raw_returns_dicts():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["doc text"],
[{"wing": "proj", "room": "backend", "source_file": "f.txt"}],
[0.3],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
hits = layer.search_raw("q")
assert len(hits) == 1
assert hits[0]["text"] == "doc text"
assert hits[0]["wing"] == "proj"
assert hits[0]["similarity"] == 0.7
assert "metadata" in hits[0]
def test_layer3_search_raw_with_filters():
mock_col = MagicMock()
mock_col.query.return_value = _mock_query_results(
["doc"],
[{"wing": "w", "room": "r"}],
[0.1],
)
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
layer.search_raw("q", wing="w", room="r")
call_kwargs = mock_col.query.call_args[1]
assert "$and" in call_kwargs["where"]
def test_layer3_search_raw_error():
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("fail")
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
layer = Layer3(palace_path="/fake")
result = layer.search_raw("q")
assert result == []
# ── MemoryStack ─────────────────────────────────────────────────────────
def test_memory_stack_wake_up(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent"
stack = MemoryStack(
palace_path="/nonexistent",
identity_path=str(identity_file),
)
result = stack.wake_up()
assert "Atlas" in result
# L1 will say no palace found
assert "No palace" in result or "No memories" in result
def test_memory_stack_wake_up_with_wing(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent"
stack = MemoryStack(
palace_path="/nonexistent",
identity_path=str(identity_file),
)
result = stack.wake_up(wing="my_project")
assert stack.l1.wing == "my_project"
assert "Atlas" in result
def test_memory_stack_recall(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent"
stack = MemoryStack(
palace_path="/nonexistent",
identity_path=str(identity_file),
)
result = stack.recall(wing="test")
assert "No palace found" in result
def test_memory_stack_search(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent"
stack = MemoryStack(
palace_path="/nonexistent",
identity_path=str(identity_file),
)
result = stack.search("test query")
assert "No palace found" in result
def test_memory_stack_status(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
mock_cfg.return_value.palace_path = "/nonexistent"
stack = MemoryStack(
palace_path="/nonexistent",
identity_path=str(identity_file),
)
result = stack.status()
assert result["palace_path"] == "/nonexistent"
assert result["total_drawers"] == 0
assert "L0_identity" in result
assert "L1_essential" in result
assert "L2_on_demand" in result
assert "L3_deep_search" in result
def test_memory_stack_status_with_palace(tmp_path):
identity_file = tmp_path / "identity.txt"
identity_file.write_text("I am Atlas.")
mock_col = MagicMock()
mock_col.count.return_value = 42
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with (
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
):
mock_cfg.return_value.palace_path = "/fake"
stack = MemoryStack(
palace_path="/fake",
identity_path=str(identity_file),
)
result = stack.status()
assert result["total_drawers"] == 42
assert result["L0_identity"]["exists"] is True
+45 -39
View File
@@ -9,25 +9,26 @@ via monkeypatch to avoid touching real data.
import json import json
def _patch_mcp_server(monkeypatch, config, palace_path, kg): def _patch_mcp_server(monkeypatch, config, kg):
"""Patch the mcp_server module globals to use test fixtures.""" """Patch the mcp_server module globals to use test fixtures."""
from mempalace import mcp_server from mempalace import mcp_server
assert getattr(config, "palace_path", None) == palace_path, (
f"config.palace_path ({getattr(config, 'palace_path', None)!r}) does not match palace_path fixture ({palace_path!r})"
)
monkeypatch.setattr(mcp_server, "_config", config) monkeypatch.setattr(mcp_server, "_config", config)
monkeypatch.setattr(mcp_server, "_kg", kg) monkeypatch.setattr(mcp_server, "_kg", kg)
def _get_collection(palace_path, create=False): def _get_collection(palace_path, create=False):
"""Helper to get collection from test palace.""" """Helper to get collection from test palace.
Returns (client, collection) so callers can clean up the client
when they are done.
"""
import chromadb import chromadb
client = chromadb.PersistentClient(path=palace_path) client = chromadb.PersistentClient(path=palace_path)
if create: if create:
return client.get_or_create_collection("mempalace_drawers") return client, client.get_or_create_collection("mempalace_drawers")
return client.get_collection("mempalace_drawers") return client, client.get_collection("mempalace_drawers")
# ── Protocol Layer ────────────────────────────────────────────────────── # ── Protocol Layer ──────────────────────────────────────────────────────
@@ -77,11 +78,12 @@ class TestHandleRequest:
assert resp["error"]["code"] == -32601 assert resp["error"]["code"] == -32601
def test_tools_call_dispatches(self, monkeypatch, config, palace_path, seeded_kg): def test_tools_call_dispatches(self, monkeypatch, config, palace_path, seeded_kg):
_patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) _patch_mcp_server(monkeypatch, config, seeded_kg)
from mempalace.mcp_server import handle_request from mempalace.mcp_server import handle_request
# Create a collection so status works # Create a collection so status works
_get_collection(palace_path, create=True) _client, _col = _get_collection(palace_path, create=True)
del _client
resp = handle_request( resp = handle_request(
{ {
@@ -100,8 +102,9 @@ class TestHandleRequest:
class TestReadTools: class TestReadTools:
def test_status_empty_palace(self, monkeypatch, config, palace_path, kg): def test_status_empty_palace(self, monkeypatch, config, palace_path, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
_get_collection(palace_path, create=True) _client, _col = _get_collection(palace_path, create=True)
del _client
from mempalace.mcp_server import tool_status from mempalace.mcp_server import tool_status
result = tool_status() result = tool_status()
@@ -109,7 +112,7 @@ class TestReadTools:
assert result["wings"] == {} assert result["wings"] == {}
def test_status_with_data(self, monkeypatch, config, palace_path, seeded_collection, kg): def test_status_with_data(self, monkeypatch, config, palace_path, seeded_collection, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_status from mempalace.mcp_server import tool_status
result = tool_status() result = tool_status()
@@ -118,7 +121,7 @@ class TestReadTools:
assert "notes" in result["wings"] assert "notes" in result["wings"]
def test_list_wings(self, monkeypatch, config, palace_path, seeded_collection, kg): def test_list_wings(self, monkeypatch, config, palace_path, seeded_collection, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_list_wings from mempalace.mcp_server import tool_list_wings
result = tool_list_wings() result = tool_list_wings()
@@ -126,7 +129,7 @@ class TestReadTools:
assert result["wings"]["notes"] == 1 assert result["wings"]["notes"] == 1
def test_list_rooms_all(self, monkeypatch, config, palace_path, seeded_collection, kg): def test_list_rooms_all(self, monkeypatch, config, palace_path, seeded_collection, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_list_rooms from mempalace.mcp_server import tool_list_rooms
result = tool_list_rooms() result = tool_list_rooms()
@@ -135,7 +138,7 @@ class TestReadTools:
assert "planning" in result["rooms"] assert "planning" in result["rooms"]
def test_list_rooms_filtered(self, monkeypatch, config, palace_path, seeded_collection, kg): def test_list_rooms_filtered(self, monkeypatch, config, palace_path, seeded_collection, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_list_rooms from mempalace.mcp_server import tool_list_rooms
result = tool_list_rooms(wing="project") result = tool_list_rooms(wing="project")
@@ -143,7 +146,7 @@ class TestReadTools:
assert "planning" not in result["rooms"] assert "planning" not in result["rooms"]
def test_get_taxonomy(self, monkeypatch, config, palace_path, seeded_collection, kg): def test_get_taxonomy(self, monkeypatch, config, palace_path, seeded_collection, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_get_taxonomy from mempalace.mcp_server import tool_get_taxonomy
result = tool_get_taxonomy() result = tool_get_taxonomy()
@@ -152,8 +155,7 @@ class TestReadTools:
assert result["taxonomy"]["notes"]["planning"] == 1 assert result["taxonomy"]["notes"]["planning"] == 1
def test_no_palace_returns_error(self, monkeypatch, config, kg): def test_no_palace_returns_error(self, monkeypatch, config, kg):
config._file_config["palace_path"] = "/nonexistent/path" _patch_mcp_server(monkeypatch, config, kg)
_patch_mcp_server(monkeypatch, config, "/nonexistent/path", kg)
from mempalace.mcp_server import tool_status from mempalace.mcp_server import tool_status
result = tool_status() result = tool_status()
@@ -165,7 +167,7 @@ class TestReadTools:
class TestSearchTool: class TestSearchTool:
def test_search_basic(self, monkeypatch, config, palace_path, seeded_collection, kg): def test_search_basic(self, monkeypatch, config, palace_path, seeded_collection, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_search from mempalace.mcp_server import tool_search
result = tool_search(query="JWT authentication tokens") result = tool_search(query="JWT authentication tokens")
@@ -176,14 +178,14 @@ class TestSearchTool:
assert "JWT" in top["text"] or "authentication" in top["text"].lower() assert "JWT" in top["text"] or "authentication" in top["text"].lower()
def test_search_with_wing_filter(self, monkeypatch, config, palace_path, seeded_collection, kg): def test_search_with_wing_filter(self, monkeypatch, config, palace_path, seeded_collection, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_search from mempalace.mcp_server import tool_search
result = tool_search(query="planning", wing="notes") result = tool_search(query="planning", wing="notes")
assert all(r["wing"] == "notes" for r in result["results"]) assert all(r["wing"] == "notes" for r in result["results"])
def test_search_with_room_filter(self, monkeypatch, config, palace_path, seeded_collection, kg): def test_search_with_room_filter(self, monkeypatch, config, palace_path, seeded_collection, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_search from mempalace.mcp_server import tool_search
result = tool_search(query="database", room="backend") result = tool_search(query="database", room="backend")
@@ -195,8 +197,9 @@ class TestSearchTool:
class TestWriteTools: class TestWriteTools:
def test_add_drawer(self, monkeypatch, config, palace_path, kg): def test_add_drawer(self, monkeypatch, config, palace_path, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
_get_collection(palace_path, create=True) _client, _col = _get_collection(palace_path, create=True)
del _client
from mempalace.mcp_server import tool_add_drawer from mempalace.mcp_server import tool_add_drawer
result = tool_add_drawer( result = tool_add_drawer(
@@ -210,8 +213,9 @@ class TestWriteTools:
assert result["drawer_id"].startswith("drawer_test_wing_test_room_") assert result["drawer_id"].startswith("drawer_test_wing_test_room_")
def test_add_drawer_duplicate_detection(self, monkeypatch, config, palace_path, kg): def test_add_drawer_duplicate_detection(self, monkeypatch, config, palace_path, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
_get_collection(palace_path, create=True) _client, _col = _get_collection(palace_path, create=True)
del _client
from mempalace.mcp_server import tool_add_drawer from mempalace.mcp_server import tool_add_drawer
content = "This is a unique test memory about Rust ownership and borrowing." content = "This is a unique test memory about Rust ownership and borrowing."
@@ -219,11 +223,11 @@ class TestWriteTools:
assert result1["success"] is True assert result1["success"] is True
result2 = tool_add_drawer(wing="w", room="r", content=content) result2 = tool_add_drawer(wing="w", room="r", content=content)
assert result2["success"] is False assert result2["success"] is True
assert result2["reason"] == "duplicate" assert result2["reason"] == "already_exists"
def test_delete_drawer(self, monkeypatch, config, palace_path, seeded_collection, kg): def test_delete_drawer(self, monkeypatch, config, palace_path, seeded_collection, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_delete_drawer from mempalace.mcp_server import tool_delete_drawer
result = tool_delete_drawer("drawer_proj_backend_aaa") result = tool_delete_drawer("drawer_proj_backend_aaa")
@@ -231,14 +235,14 @@ class TestWriteTools:
assert seeded_collection.count() == 3 assert seeded_collection.count() == 3
def test_delete_drawer_not_found(self, monkeypatch, config, palace_path, seeded_collection, kg): def test_delete_drawer_not_found(self, monkeypatch, config, palace_path, seeded_collection, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_delete_drawer from mempalace.mcp_server import tool_delete_drawer
result = tool_delete_drawer("nonexistent_drawer") result = tool_delete_drawer("nonexistent_drawer")
assert result["success"] is False assert result["success"] is False
def test_check_duplicate(self, monkeypatch, config, palace_path, seeded_collection, kg): def test_check_duplicate(self, monkeypatch, config, palace_path, seeded_collection, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_check_duplicate from mempalace.mcp_server import tool_check_duplicate
# Exact match text from seeded_collection should be flagged # Exact match text from seeded_collection should be flagged
@@ -262,7 +266,7 @@ class TestWriteTools:
class TestKGTools: class TestKGTools:
def test_kg_add(self, monkeypatch, config, palace_path, kg): def test_kg_add(self, monkeypatch, config, palace_path, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
from mempalace.mcp_server import tool_kg_add from mempalace.mcp_server import tool_kg_add
result = tool_kg_add( result = tool_kg_add(
@@ -274,14 +278,14 @@ class TestKGTools:
assert result["success"] is True assert result["success"] is True
def test_kg_query(self, monkeypatch, config, palace_path, seeded_kg): def test_kg_query(self, monkeypatch, config, palace_path, seeded_kg):
_patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) _patch_mcp_server(monkeypatch, config, seeded_kg)
from mempalace.mcp_server import tool_kg_query from mempalace.mcp_server import tool_kg_query
result = tool_kg_query(entity="Max") result = tool_kg_query(entity="Max")
assert result["count"] > 0 assert result["count"] > 0
def test_kg_invalidate(self, monkeypatch, config, palace_path, seeded_kg): def test_kg_invalidate(self, monkeypatch, config, palace_path, seeded_kg):
_patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) _patch_mcp_server(monkeypatch, config, seeded_kg)
from mempalace.mcp_server import tool_kg_invalidate from mempalace.mcp_server import tool_kg_invalidate
result = tool_kg_invalidate( result = tool_kg_invalidate(
@@ -293,14 +297,14 @@ class TestKGTools:
assert result["success"] is True assert result["success"] is True
def test_kg_timeline(self, monkeypatch, config, palace_path, seeded_kg): def test_kg_timeline(self, monkeypatch, config, palace_path, seeded_kg):
_patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) _patch_mcp_server(monkeypatch, config, seeded_kg)
from mempalace.mcp_server import tool_kg_timeline from mempalace.mcp_server import tool_kg_timeline
result = tool_kg_timeline(entity="Alice") result = tool_kg_timeline(entity="Alice")
assert result["count"] > 0 assert result["count"] > 0
def test_kg_stats(self, monkeypatch, config, palace_path, seeded_kg): def test_kg_stats(self, monkeypatch, config, palace_path, seeded_kg):
_patch_mcp_server(monkeypatch, config, palace_path, seeded_kg) _patch_mcp_server(monkeypatch, config, seeded_kg)
from mempalace.mcp_server import tool_kg_stats from mempalace.mcp_server import tool_kg_stats
result = tool_kg_stats() result = tool_kg_stats()
@@ -312,8 +316,9 @@ class TestKGTools:
class TestDiaryTools: class TestDiaryTools:
def test_diary_write_and_read(self, monkeypatch, config, palace_path, kg): def test_diary_write_and_read(self, monkeypatch, config, palace_path, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
_get_collection(palace_path, create=True) _client, _col = _get_collection(palace_path, create=True)
del _client
from mempalace.mcp_server import tool_diary_write, tool_diary_read from mempalace.mcp_server import tool_diary_write, tool_diary_read
w = tool_diary_write( w = tool_diary_write(
@@ -330,8 +335,9 @@ class TestDiaryTools:
assert "authentication" in r["entries"][0]["content"] assert "authentication" in r["entries"][0]["content"]
def test_diary_read_empty(self, monkeypatch, config, palace_path, kg): def test_diary_read_empty(self, monkeypatch, config, palace_path, kg):
_patch_mcp_server(monkeypatch, config, palace_path, kg) _patch_mcp_server(monkeypatch, config, kg)
_get_collection(palace_path, create=True) _client, _col = _get_collection(palace_path, create=True)
del _client
from mempalace.mcp_server import tool_diary_read from mempalace.mcp_server import tool_diary_read
r = tool_diary_read(agent_name="Nobody") r = tool_diary_read(agent_name="Nobody")
+1 -1
View File
@@ -47,7 +47,7 @@ def test_project_mining():
col = client.get_collection("mempalace_drawers") col = client.get_collection("mempalace_drawers")
assert col.count() > 0 assert col.count() > 0
finally: finally:
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir, ignore_errors=True)
def test_scan_project_respects_gitignore(): def test_scan_project_respects_gitignore():
+490 -20
View File
@@ -1,31 +1,501 @@
import os
import json import json
import tempfile from unittest.mock import patch
from mempalace.normalize import normalize
from mempalace.normalize import (
_extract_content,
_messages_to_transcript,
_try_chatgpt_json,
_try_claude_ai_json,
_try_claude_code_jsonl,
_try_codex_jsonl,
_try_normalize_json,
_try_slack_json,
normalize,
)
def test_plain_text(): # ── normalize() top-level ──────────────────────────────────────────────
f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False)
f.write("Hello world\nSecond line\n")
f.close() def test_plain_text(tmp_path):
result = normalize(f.name) f = tmp_path / "plain.txt"
f.write_text("Hello world\nSecond line\n")
result = normalize(str(f))
assert "Hello world" in result assert "Hello world" in result
os.unlink(f.name)
def test_claude_json(): def test_claude_json(tmp_path):
data = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}] data = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}]
f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) f = tmp_path / "claude.json"
json.dump(data, f) f.write_text(json.dumps(data))
f.close() result = normalize(str(f))
result = normalize(f.name)
assert "Hi" in result assert "Hi" in result
os.unlink(f.name)
def test_empty(): def test_empty(tmp_path):
f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) f = tmp_path / "empty.txt"
f.close() f.write_text("")
result = normalize(f.name) result = normalize(str(f))
assert result.strip() == "" assert result.strip() == ""
os.unlink(f.name)
def test_normalize_io_error():
"""normalize raises IOError for unreadable file."""
try:
normalize("/nonexistent/path/file.txt")
assert False, "Should have raised"
except IOError as e:
assert "Could not read" in str(e)
def test_normalize_already_has_markers(tmp_path):
"""Files with >= 3 '>' lines pass through unchanged."""
content = "> question 1\nanswer 1\n> question 2\nanswer 2\n> question 3\nanswer 3\n"
f = tmp_path / "markers.txt"
f.write_text(content)
result = normalize(str(f))
assert result == content
def test_normalize_json_content_detected_by_brace(tmp_path):
"""A .txt file starting with [ triggers JSON parsing."""
data = [{"role": "user", "content": "Hey"}, {"role": "assistant", "content": "Hi there"}]
f = tmp_path / "chat.txt"
f.write_text(json.dumps(data))
result = normalize(str(f))
assert "Hey" in result
def test_normalize_whitespace_only(tmp_path):
f = tmp_path / "ws.txt"
f.write_text(" \n \n ")
result = normalize(str(f))
assert result.strip() == ""
# ── _extract_content ───────────────────────────────────────────────────
def test_extract_content_string():
assert _extract_content("hello") == "hello"
def test_extract_content_list_of_strings():
assert _extract_content(["hello", "world"]) == "hello world"
def test_extract_content_list_of_blocks():
blocks = [{"type": "text", "text": "hello"}, {"type": "image", "url": "x"}]
assert _extract_content(blocks) == "hello"
def test_extract_content_dict():
assert _extract_content({"text": "hello"}) == "hello"
def test_extract_content_none():
assert _extract_content(None) == ""
def test_extract_content_mixed_list():
blocks = ["plain", {"type": "text", "text": "block"}]
assert _extract_content(blocks) == "plain block"
# ── _try_claude_code_jsonl ─────────────────────────────────────────────
def test_claude_code_jsonl_valid():
lines = [
json.dumps({"type": "human", "message": {"content": "What is X?"}}),
json.dumps({"type": "assistant", "message": {"content": "X is Y."}}),
]
result = _try_claude_code_jsonl("\n".join(lines))
assert result is not None
assert "> What is X?" in result
assert "X is Y." in result
def test_claude_code_jsonl_user_type():
lines = [
json.dumps({"type": "user", "message": {"content": "Q"}}),
json.dumps({"type": "assistant", "message": {"content": "A"}}),
]
result = _try_claude_code_jsonl("\n".join(lines))
assert result is not None
assert "> Q" in result
def test_claude_code_jsonl_too_few_messages():
lines = [json.dumps({"type": "human", "message": {"content": "only one"}})]
result = _try_claude_code_jsonl("\n".join(lines))
assert result is None
def test_claude_code_jsonl_invalid_json_lines():
lines = [
"not json",
json.dumps({"type": "human", "message": {"content": "Q"}}),
json.dumps({"type": "assistant", "message": {"content": "A"}}),
]
result = _try_claude_code_jsonl("\n".join(lines))
assert result is not None
def test_claude_code_jsonl_non_dict_entries():
lines = [
json.dumps([1, 2, 3]),
json.dumps({"type": "human", "message": {"content": "Q"}}),
json.dumps({"type": "assistant", "message": {"content": "A"}}),
]
result = _try_claude_code_jsonl("\n".join(lines))
assert result is not None
# ── _try_codex_jsonl ───────────────────────────────────────────────────
def test_codex_jsonl_valid():
lines = [
json.dumps({"type": "session_meta", "payload": {}}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is not None
assert "> Q" in result
def test_codex_jsonl_no_session_meta():
"""Without session_meta, codex parser returns None."""
lines = [
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is None
def test_codex_jsonl_skips_non_event_msg():
lines = [
json.dumps({"type": "session_meta"}),
json.dumps({"type": "response_item", "payload": {"type": "user_message", "message": "X"}}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is not None
assert "X" not in result.split("> Q")[0]
def test_codex_jsonl_non_string_message():
lines = [
json.dumps({"type": "session_meta"}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": 123}}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is not None
def test_codex_jsonl_empty_text_skipped():
lines = [
json.dumps({"type": "session_meta"}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": " "}}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is not None
def test_codex_jsonl_payload_not_dict():
lines = [
json.dumps({"type": "session_meta"}),
json.dumps({"type": "event_msg", "payload": "not a dict"}),
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
]
result = _try_codex_jsonl("\n".join(lines))
assert result is not None
# ── _try_claude_ai_json ───────────────────────────────────────────────
def test_claude_ai_flat_messages():
data = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
result = _try_claude_ai_json(data)
assert result is not None
assert "> Hello" in result
def test_claude_ai_dict_with_messages_key():
data = {
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
}
result = _try_claude_ai_json(data)
assert result is not None
def test_claude_ai_privacy_export():
data = [
{
"chat_messages": [
{"role": "human", "content": "Q1"},
{"role": "ai", "content": "A1"},
]
}
]
result = _try_claude_ai_json(data)
assert result is not None
assert "> Q1" in result
def test_claude_ai_not_a_list():
result = _try_claude_ai_json("not a list")
assert result is None
def test_claude_ai_too_few_messages():
data = [{"role": "user", "content": "Hello"}]
result = _try_claude_ai_json(data)
assert result is None
def test_claude_ai_dict_with_chat_messages_key():
data = {
"chat_messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "World"},
]
}
result = _try_claude_ai_json(data)
assert result is not None
def test_claude_ai_privacy_export_non_dict_items():
"""Non-dict items in privacy export are skipped."""
data = [
{
"chat_messages": [
"not a dict",
{"role": "user", "content": "Q"},
{"role": "assistant", "content": "A"},
]
},
"not a convo",
]
result = _try_claude_ai_json(data)
assert result is not None
# ── _try_chatgpt_json ─────────────────────────────────────────────────
def test_chatgpt_json_valid():
data = {
"mapping": {
"root": {
"parent": None,
"message": None,
"children": ["msg1"],
},
"msg1": {
"parent": "root",
"message": {
"author": {"role": "user"},
"content": {"parts": ["Hello ChatGPT"]},
},
"children": ["msg2"],
},
"msg2": {
"parent": "msg1",
"message": {
"author": {"role": "assistant"},
"content": {"parts": ["Hello! How can I help?"]},
},
"children": [],
},
}
}
result = _try_chatgpt_json(data)
assert result is not None
assert "> Hello ChatGPT" in result
def test_chatgpt_json_no_mapping():
result = _try_chatgpt_json({"data": []})
assert result is None
def test_chatgpt_json_not_dict():
result = _try_chatgpt_json([1, 2, 3])
assert result is None
def test_chatgpt_json_fallback_root():
"""Root node has a message (no synthetic root), uses fallback."""
data = {
"mapping": {
"root": {
"parent": None,
"message": {
"author": {"role": "system"},
"content": {"parts": ["system prompt"]},
},
"children": ["msg1"],
},
"msg1": {
"parent": "root",
"message": {
"author": {"role": "user"},
"content": {"parts": ["Hello"]},
},
"children": ["msg2"],
},
"msg2": {
"parent": "msg1",
"message": {
"author": {"role": "assistant"},
"content": {"parts": ["Hi there"]},
},
"children": [],
},
}
}
result = _try_chatgpt_json(data)
assert result is not None
def test_chatgpt_json_too_few_messages():
data = {
"mapping": {
"root": {
"parent": None,
"message": None,
"children": ["msg1"],
},
"msg1": {
"parent": "root",
"message": {
"author": {"role": "user"},
"content": {"parts": ["Only one"]},
},
"children": [],
},
}
}
result = _try_chatgpt_json(data)
assert result is None
# ── _try_slack_json ────────────────────────────────────────────────────
def test_slack_json_valid():
data = [
{"type": "message", "user": "U1", "text": "Hello"},
{"type": "message", "user": "U2", "text": "Hi there"},
]
result = _try_slack_json(data)
assert result is not None
assert "Hello" in result
def test_slack_json_not_a_list():
result = _try_slack_json({"type": "message"})
assert result is None
def test_slack_json_too_few_messages():
data = [{"type": "message", "user": "U1", "text": "Hello"}]
result = _try_slack_json(data)
assert result is None
def test_slack_json_skips_non_message_types():
data = [
{"type": "channel_join", "user": "U1", "text": "joined"},
{"type": "message", "user": "U1", "text": "Hello"},
{"type": "message", "user": "U2", "text": "Hi"},
]
result = _try_slack_json(data)
assert result is not None
def test_slack_json_three_users():
"""Three speakers get alternating roles."""
data = [
{"type": "message", "user": "U1", "text": "Hello"},
{"type": "message", "user": "U2", "text": "Hi"},
{"type": "message", "user": "U3", "text": "Hey"},
]
result = _try_slack_json(data)
assert result is not None
def test_slack_json_empty_text_skipped():
data = [
{"type": "message", "user": "U1", "text": ""},
{"type": "message", "user": "U1", "text": "Hello"},
{"type": "message", "user": "U2", "text": "Hi"},
]
result = _try_slack_json(data)
assert result is not None
def test_slack_json_username_fallback():
data = [
{"type": "message", "username": "bot1", "text": "Hello"},
{"type": "message", "username": "bot2", "text": "Hi"},
]
result = _try_slack_json(data)
assert result is not None
# ── _try_normalize_json ────────────────────────────────────────────────
def test_try_normalize_json_invalid_json():
result = _try_normalize_json("not json at all {{{")
assert result is None
def test_try_normalize_json_valid_but_unknown_schema():
result = _try_normalize_json(json.dumps({"random": "data"}))
assert result is None
# ── _messages_to_transcript ────────────────────────────────────────────
def test_messages_to_transcript_basic():
msgs = [("user", "Q"), ("assistant", "A")]
with patch("mempalace.normalize.spellcheck_user_text", side_effect=lambda x: x, create=True):
result = _messages_to_transcript(msgs, spellcheck=False)
assert "> Q" in result
assert "A" in result
def test_messages_to_transcript_consecutive_users():
"""Two user messages in a row (no assistant between)."""
msgs = [("user", "Q1"), ("user", "Q2"), ("assistant", "A")]
result = _messages_to_transcript(msgs, spellcheck=False)
assert "> Q1" in result
assert "> Q2" in result
def test_messages_to_transcript_assistant_first():
"""Leading assistant message (no user before it)."""
msgs = [("assistant", "preamble"), ("user", "Q"), ("assistant", "A")]
result = _messages_to_transcript(msgs, spellcheck=False)
assert "preamble" in result
assert "> Q" in result
+452
View File
@@ -0,0 +1,452 @@
"""Tests for mempalace.onboarding."""
import os
from unittest.mock import patch
from mempalace.onboarding import (
DEFAULT_WINGS,
_ask,
_ask_mode,
_ask_people,
_ask_projects,
_ask_wings,
_auto_detect,
_generate_aaak_bootstrap,
_header,
_hr,
_warn_ambiguous,
_yn,
quick_setup,
run_onboarding,
)
# Force UTF-8 for Windows (source file contains Unicode symbols like hearts/stars)
os.environ["PYTHONUTF8"] = "1"
# ── DEFAULT_WINGS ───────────────────────────────────────────────────────
def test_default_wings_has_expected_keys():
assert "work" in DEFAULT_WINGS
assert "personal" in DEFAULT_WINGS
assert "combo" in DEFAULT_WINGS
def test_default_wings_work_has_projects():
assert "projects" in DEFAULT_WINGS["work"]
def test_default_wings_personal_has_family():
assert "family" in DEFAULT_WINGS["personal"]
def test_default_wings_combo_has_both():
wings = DEFAULT_WINGS["combo"]
assert "family" in wings
assert "work" in wings
def test_default_wings_values_are_lists():
for mode, wings in DEFAULT_WINGS.items():
assert isinstance(wings, list), f"{mode} wings should be a list"
assert len(wings) >= 3, f"{mode} should have at least 3 wings"
# ── _warn_ambiguous ─────────────────────────────────────────────────────
def test_warn_ambiguous_flags_common_words():
people = [
{"name": "Grace", "relationship": "friend"},
{"name": "Riley", "relationship": "daughter"},
]
result = _warn_ambiguous(people)
assert "Grace" in result
# Riley is not a common English word
assert "Riley" not in result
def test_warn_ambiguous_empty_list():
result = _warn_ambiguous([])
assert result == []
def test_warn_ambiguous_no_ambiguous_names():
people = [
{"name": "Riley", "relationship": "daughter"},
{"name": "Devon", "relationship": "friend"},
]
result = _warn_ambiguous(people)
assert result == []
def test_warn_ambiguous_multiple_hits():
people = [
{"name": "Grace", "relationship": "friend"},
{"name": "May", "relationship": "aunt"},
{"name": "Joy", "relationship": "sister"},
]
result = _warn_ambiguous(people)
assert "Grace" in result
assert "May" in result
assert "Joy" in result
# ── quick_setup ─────────────────────────────────────────────────────────
def test_quick_setup_creates_registry(tmp_path):
registry = quick_setup(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
projects=["MemPalace"],
config_dir=tmp_path,
)
assert "Riley" in registry.people
assert "MemPalace" in registry.projects
assert registry.mode == "personal"
def test_quick_setup_work_mode(tmp_path):
registry = quick_setup(
mode="work",
people=[{"name": "Alice", "relationship": "colleague", "context": "work"}],
projects=["Acme"],
config_dir=tmp_path,
)
assert registry.mode == "work"
assert "Alice" in registry.people
assert "Acme" in registry.projects
def test_quick_setup_empty(tmp_path):
registry = quick_setup(mode="personal", people=[], config_dir=tmp_path)
assert len(registry.people) == 0
assert len(registry.projects) == 0
def test_quick_setup_saves_to_disk(tmp_path):
quick_setup(
mode="personal",
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
config_dir=tmp_path,
)
assert (tmp_path / "entity_registry.json").exists()
# ── _generate_aaak_bootstrap ───────────────────────────────────────────
def test_generate_aaak_bootstrap_creates_files(tmp_path):
people = [
{"name": "Riley", "relationship": "daughter", "context": "personal"},
{"name": "Devon", "relationship": "friend", "context": "personal"},
]
projects = ["MemPalace"]
wings = ["family", "creative"]
_generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path)
assert (tmp_path / "aaak_entities.md").exists()
assert (tmp_path / "critical_facts.md").exists()
def test_generate_aaak_bootstrap_entities_content(tmp_path):
people = [{"name": "Riley", "relationship": "daughter", "context": "personal"}]
projects = ["MemPalace"]
wings = ["family"]
_generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path)
content = (tmp_path / "aaak_entities.md").read_text()
assert "Riley" in content
assert "RIL" in content # entity code
assert "MemPalace" in content
def test_generate_aaak_bootstrap_facts_content(tmp_path):
people = [
{"name": "Alice", "relationship": "colleague", "context": "work"},
]
projects = ["Acme"]
wings = ["projects"]
_generate_aaak_bootstrap(people, projects, wings, "work", config_dir=tmp_path)
content = (tmp_path / "critical_facts.md").read_text()
assert "Alice" in content
assert "Acme" in content
assert "work" in content.lower()
def test_generate_aaak_bootstrap_empty_people(tmp_path):
_generate_aaak_bootstrap([], [], ["general"], "personal", config_dir=tmp_path)
assert (tmp_path / "aaak_entities.md").exists()
assert (tmp_path / "critical_facts.md").exists()
def test_generate_aaak_bootstrap_collision(tmp_path):
"""Two people with same 3-letter code get different codes."""
people = [
{"name": "Alice", "relationship": "friend", "context": "work"},
{"name": "Alison", "relationship": "coworker", "context": "work"},
]
_generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path)
content = (tmp_path / "aaak_entities.md").read_text()
assert "ALI" in content
assert "ALIS" in content
def test_generate_aaak_bootstrap_no_relationship(tmp_path):
"""Person without relationship string still generates entry."""
people = [{"name": "Bob", "context": "work"}]
_generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path)
content = (tmp_path / "aaak_entities.md").read_text()
assert "BOB=Bob" in content
# ── _hr, _header ──────────────────────────────────────────────────────
def test_hr_prints_line(capsys):
_hr()
out = capsys.readouterr().out
assert "" in out
def test_header_prints_banner(capsys):
_header("Test Title")
out = capsys.readouterr().out
assert "Test Title" in out
assert "=" in out
# ── _ask ──────────────────────────────────────────────────────────────
def test_ask_with_default_uses_default():
with patch("builtins.input", return_value=""):
result = _ask("prompt", default="fallback")
assert result == "fallback"
def test_ask_with_default_uses_input():
with patch("builtins.input", return_value="custom"):
result = _ask("prompt", default="fallback")
assert result == "custom"
def test_ask_no_default():
with patch("builtins.input", return_value="answer"):
result = _ask("prompt")
assert result == "answer"
# ── _yn ───────────────────────────────────────────────────────────────
def test_yn_default_yes_empty_input():
with patch("builtins.input", return_value=""):
assert _yn("continue?") is True
def test_yn_default_no_empty_input():
with patch("builtins.input", return_value=""):
assert _yn("continue?", default="n") is False
def test_yn_explicit_yes():
with patch("builtins.input", return_value="yes"):
assert _yn("continue?", default="n") is True
def test_yn_explicit_no():
with patch("builtins.input", return_value="no"):
assert _yn("continue?") is False
# ── _ask_mode ─────────────────────────────────────────────────────────
def test_ask_mode_work():
with patch("builtins.input", return_value="1"):
assert _ask_mode() == "work"
def test_ask_mode_personal():
with patch("builtins.input", return_value="2"):
assert _ask_mode() == "personal"
def test_ask_mode_combo():
with patch("builtins.input", return_value="3"):
assert _ask_mode() == "combo"
def test_ask_mode_retries_on_bad_input():
with patch("builtins.input", side_effect=["x", "bad", "1"]):
assert _ask_mode() == "work"
# ── _ask_people ───────────────────────────────────────────────────────
def test_ask_people_personal_mode():
with patch("builtins.input", side_effect=["Alice, daughter", "", "done"]):
people, aliases = _ask_people("personal")
assert len(people) == 1
assert people[0]["name"] == "Alice"
assert people[0]["relationship"] == "daughter"
def test_ask_people_work_mode():
with patch("builtins.input", side_effect=["Bob, manager", "", "done"]):
people, aliases = _ask_people("work")
assert len(people) == 1
assert people[0]["name"] == "Bob"
assert people[0]["context"] == "work"
def test_ask_people_combo_mode():
with patch(
"builtins.input",
side_effect=[
"Alice, daughter",
"",
"done", # personal
"Bob, boss",
"done", # work
],
):
people, aliases = _ask_people("combo")
assert len(people) == 2
def test_ask_people_with_nickname():
with patch("builtins.input", side_effect=["Alice, daughter", "Ali", "done"]):
people, aliases = _ask_people("personal")
assert aliases == {"Ali": "Alice"}
def test_ask_people_empty_name_skipped():
with patch("builtins.input", side_effect=["", "done"]):
people, aliases = _ask_people("personal")
assert len(people) == 0
# ── _ask_projects ─────────────────────────────────────────────────────
def test_ask_projects_personal_returns_empty():
result = _ask_projects("personal")
assert result == []
def test_ask_projects_work_mode():
with patch("builtins.input", side_effect=["Acme", "BigCo", "done"]):
result = _ask_projects("work")
assert result == ["Acme", "BigCo"]
def test_ask_projects_empty_entry_stops():
with patch("builtins.input", side_effect=["Acme", ""]):
result = _ask_projects("work")
assert result == ["Acme"]
# ── _ask_wings ────────────────────────────────────────────────────────
def test_ask_wings_accept_defaults():
with patch("builtins.input", return_value=""):
result = _ask_wings("work")
assert result == DEFAULT_WINGS["work"]
def test_ask_wings_custom():
with patch("builtins.input", return_value="alpha, beta, gamma"):
result = _ask_wings("personal")
assert result == ["alpha", "beta", "gamma"]
# ── _auto_detect ──────────────────────────────────────────────────────
def test_auto_detect_no_files(tmp_path):
result = _auto_detect(str(tmp_path), [])
assert result == []
def test_auto_detect_filters_known(tmp_path):
known = [{"name": "Alice"}]
fake_detected = {
"people": [
{"name": "Alice", "confidence": 0.9, "signals": ["test"]},
{"name": "Bob", "confidence": 0.8, "signals": ["test"]},
],
"projects": [],
"uncertain": [],
}
with (
patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]),
patch("mempalace.onboarding.detect_entities", return_value=fake_detected),
):
result = _auto_detect(str(tmp_path), known)
names = [p["name"] for p in result]
assert "Alice" not in names
assert "Bob" in names
def test_auto_detect_filters_low_confidence(tmp_path):
fake_detected = {
"people": [{"name": "Bob", "confidence": 0.5, "signals": ["test"]}],
"projects": [],
"uncertain": [],
}
with (
patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]),
patch("mempalace.onboarding.detect_entities", return_value=fake_detected),
):
result = _auto_detect(str(tmp_path), [])
assert len(result) == 0
def test_auto_detect_handles_exception(tmp_path):
with patch("mempalace.onboarding.scan_for_detection", side_effect=Exception("boom")):
result = _auto_detect(str(tmp_path), [])
assert result == []
# ── run_onboarding ────────────────────────────────────────────────────
def test_run_onboarding_basic_flow(tmp_path):
"""Test the full onboarding flow with minimal mocking."""
with (
patch("mempalace.onboarding._ask_mode", return_value="work"),
patch(
"mempalace.onboarding._ask_people",
return_value=([{"name": "Bob", "relationship": "boss", "context": "work"}], {}),
),
patch("mempalace.onboarding._ask_projects", return_value=["Acme"]),
patch("mempalace.onboarding._ask_wings", return_value=["projects", "team"]),
patch("mempalace.onboarding._yn", return_value=False),
patch("mempalace.onboarding._warn_ambiguous", return_value=[]),
):
registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False)
assert "Bob" in registry.people
assert "Acme" in registry.projects
def test_run_onboarding_with_ambiguous_names(tmp_path):
"""Onboarding prints a warning for ambiguous names."""
with (
patch("mempalace.onboarding._ask_mode", return_value="personal"),
patch(
"mempalace.onboarding._ask_people",
return_value=([{"name": "Grace", "relationship": "friend", "context": "personal"}], {}),
),
patch("mempalace.onboarding._ask_projects", return_value=[]),
patch("mempalace.onboarding._ask_wings", return_value=["family"]),
patch("mempalace.onboarding._yn", return_value=False),
):
registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False)
assert "Grace" in registry.people
+244
View File
@@ -0,0 +1,244 @@
"""Tests for mempalace.palace_graph — graph traversal layer.
All ChromaDB access is mocked — no real database needed.
"""
from unittest.mock import MagicMock, patch
def _make_fake_collection(metadatas, ids=None):
"""Create a mock collection that returns the given metadata in batches."""
if ids is None:
ids = [f"id_{i}" for i in range(len(metadatas))]
col = MagicMock()
col.count.return_value = len(metadatas)
def fake_get(limit=1000, offset=0, include=None):
batch_meta = metadatas[offset : offset + limit]
batch_ids = ids[offset : offset + limit]
return {"ids": batch_ids, "metadatas": batch_meta}
col.get.side_effect = fake_get
return col
# Patch chromadb at import time so palace_graph can be imported
with patch.dict("sys.modules", {"chromadb": MagicMock()}):
from mempalace.palace_graph import (
_fuzzy_match,
build_graph,
find_tunnels,
graph_stats,
traverse,
)
# --- build_graph ---
class TestBuildGraph:
def test_empty_collection(self):
col = _make_fake_collection([])
nodes, edges = build_graph(col=col)
assert nodes == {}
assert edges == []
def test_falsy_collection(self):
"""When col is explicitly falsy, build_graph returns empty."""
nodes, edges = build_graph(col=0)
assert nodes == {}
assert edges == []
def test_single_wing_no_edges(self):
col = _make_fake_collection(
[
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-02"},
]
)
nodes, edges = build_graph(col=col)
assert "auth" in nodes
assert nodes["auth"]["count"] == 2
assert edges == []
def test_multi_wing_creates_edges(self):
col = _make_fake_collection(
[
{
"room": "chromadb",
"wing": "wing_code",
"hall": "databases",
"date": "2026-01-01",
},
{
"room": "chromadb",
"wing": "wing_project",
"hall": "databases",
"date": "2026-01-02",
},
]
)
nodes, edges = build_graph(col=col)
assert "chromadb" in nodes
assert len(edges) == 1
assert edges[0]["wing_a"] == "wing_code"
assert edges[0]["wing_b"] == "wing_project"
assert edges[0]["hall"] == "databases"
def test_general_room_excluded(self):
col = _make_fake_collection(
[
{"room": "general", "wing": "wing_code", "hall": "misc", "date": ""},
]
)
nodes, edges = build_graph(col=col)
assert "general" not in nodes
def test_missing_wing_excluded(self):
col = _make_fake_collection(
[
{"room": "orphan", "wing": "", "hall": "misc", "date": ""},
]
)
nodes, edges = build_graph(col=col)
assert "orphan" not in nodes
def test_dates_capped_at_five(self):
col = _make_fake_collection(
[
{"room": "busy", "wing": "w", "hall": "h", "date": f"2026-01-{i:02d}"}
for i in range(1, 10)
]
)
nodes, _ = build_graph(col=col)
assert len(nodes["busy"]["dates"]) <= 5
# --- traverse ---
class TestTraverse:
def _build_col(self):
return _make_fake_collection(
[
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
{"room": "login", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
{"room": "deploy", "wing": "wing_ops", "hall": "infra", "date": "2026-01-01"},
]
)
def test_traverse_known_room(self):
col = self._build_col()
result = traverse("auth", col=col)
assert isinstance(result, list)
rooms = [r["room"] for r in result]
assert "auth" in rooms
# login shares wing_code with auth
assert "login" in rooms
def test_traverse_unknown_room(self):
col = self._build_col()
result = traverse("nonexistent", col=col)
assert isinstance(result, dict)
assert "error" in result
assert "suggestions" in result
def test_traverse_max_hops(self):
col = self._build_col()
result = traverse("auth", col=col, max_hops=0)
# Only the start room itself at hop 0
assert len(result) == 1
assert result[0]["room"] == "auth"
# --- find_tunnels ---
class TestFindTunnels:
def _build_tunnel_col(self):
return _make_fake_collection(
[
{"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"},
{"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"},
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
]
)
def test_find_all_tunnels(self):
col = self._build_tunnel_col()
tunnels = find_tunnels(col=col)
assert len(tunnels) == 1
assert tunnels[0]["room"] == "chromadb"
def test_find_tunnels_with_wing_filter(self):
col = self._build_tunnel_col()
tunnels = find_tunnels(wing_a="wing_code", col=col)
assert len(tunnels) == 1
def test_find_tunnels_no_match(self):
col = self._build_tunnel_col()
tunnels = find_tunnels(wing_a="wing_nonexistent", col=col)
assert tunnels == []
def test_find_tunnels_both_wings(self):
col = self._build_tunnel_col()
tunnels = find_tunnels(wing_a="wing_code", wing_b="wing_project", col=col)
assert len(tunnels) == 1
assert tunnels[0]["room"] == "chromadb"
# --- graph_stats ---
class TestGraphStats:
def test_empty_graph(self):
col = _make_fake_collection([])
stats = graph_stats(col=col)
assert stats["total_rooms"] == 0
assert stats["tunnel_rooms"] == 0
assert stats["total_edges"] == 0
def test_stats_with_data(self):
col = _make_fake_collection(
[
{"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"},
{"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"},
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
]
)
stats = graph_stats(col=col)
assert stats["total_rooms"] == 2
assert stats["tunnel_rooms"] == 1
assert stats["total_edges"] == 1
assert "wing_code" in stats["rooms_per_wing"]
# --- _fuzzy_match ---
class TestFuzzyMatch:
def test_exact_substring(self):
nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}}
result = _fuzzy_match("chromadb", nodes)
assert "chromadb-setup" in result
def test_partial_word_match(self):
nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}}
result = _fuzzy_match("auth", nodes)
assert "auth-module" in result
def test_no_match(self):
nodes = {"chromadb-setup": {}, "auth-module": {}}
result = _fuzzy_match("zzzzz", nodes)
assert result == []
def test_hyphenated_query(self):
nodes = {"riley-college-apps": {}, "college-prep": {}}
result = _fuzzy_match("riley-college", nodes)
assert "riley-college-apps" in result
def test_max_results(self):
nodes = {f"room-{i}": {} for i in range(20)}
result = _fuzzy_match("room", nodes, n=3)
assert len(result) <= 3
+264
View File
@@ -0,0 +1,264 @@
"""Tests for mempalace.room_detector_local."""
from unittest.mock import MagicMock, patch
from mempalace.room_detector_local import (
FOLDER_ROOM_MAP,
detect_rooms_from_files,
detect_rooms_from_folders,
detect_rooms_local,
get_user_approval,
print_proposed_structure,
save_config,
)
# ── FOLDER_ROOM_MAP ────────────────────────────────────────────────────
def test_folder_room_map_has_expected_mappings():
assert FOLDER_ROOM_MAP["frontend"] == "frontend"
assert FOLDER_ROOM_MAP["backend"] == "backend"
assert FOLDER_ROOM_MAP["docs"] == "documentation"
assert FOLDER_ROOM_MAP["tests"] == "testing"
assert FOLDER_ROOM_MAP["config"] == "configuration"
def test_folder_room_map_alternative_names():
assert FOLDER_ROOM_MAP["front-end"] == "frontend"
assert FOLDER_ROOM_MAP["back-end"] == "backend"
assert FOLDER_ROOM_MAP["server"] == "backend"
assert FOLDER_ROOM_MAP["client"] == "frontend"
assert FOLDER_ROOM_MAP["api"] == "backend"
# ── detect_rooms_from_folders ───────────────────────────────────────────
def test_detect_rooms_from_folders_standard_layout(tmp_path):
(tmp_path / "frontend").mkdir()
(tmp_path / "backend").mkdir()
(tmp_path / "docs").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
room_names = {r["name"] for r in rooms}
assert "frontend" in room_names
assert "backend" in room_names
assert "documentation" in room_names
def test_detect_rooms_from_folders_always_has_general(tmp_path):
rooms = detect_rooms_from_folders(str(tmp_path))
room_names = {r["name"] for r in rooms}
assert "general" in room_names
def test_detect_rooms_from_folders_empty_dir(tmp_path):
rooms = detect_rooms_from_folders(str(tmp_path))
# Should at least have "general"
assert len(rooms) >= 1
assert any(r["name"] == "general" for r in rooms)
def test_detect_rooms_from_folders_skips_git(tmp_path):
(tmp_path / ".git").mkdir()
(tmp_path / "node_modules").mkdir()
(tmp_path / "frontend").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
room_names = {r["name"] for r in rooms}
assert ".git" not in room_names
assert "node_modules" not in room_names
def test_detect_rooms_from_folders_nested_dirs(tmp_path):
src = tmp_path / "src"
src.mkdir()
(src / "components").mkdir()
(src / "routes").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
room_names = {r["name"] for r in rooms}
# Nested dirs should be detected at one level deep
assert "frontend" in room_names or "backend" in room_names
def test_detect_rooms_from_folders_room_has_description(tmp_path):
(tmp_path / "docs").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
doc_room = next((r for r in rooms if r["name"] == "documentation"), None)
assert doc_room is not None
assert "description" in doc_room
assert "docs" in doc_room["description"]
def test_detect_rooms_from_folders_room_has_keywords(tmp_path):
(tmp_path / "frontend").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
fe_room = next((r for r in rooms if r["name"] == "frontend"), None)
assert fe_room is not None
assert "keywords" in fe_room
assert len(fe_room["keywords"]) > 0
def test_detect_rooms_from_folders_custom_named_dirs(tmp_path):
(tmp_path / "mylib").mkdir()
rooms = detect_rooms_from_folders(str(tmp_path))
room_names = {r["name"] for r in rooms}
# Custom dir names that don't match FOLDER_ROOM_MAP get added as-is
assert "mylib" in room_names or "general" in room_names
# ── detect_rooms_from_files ─────────────────────────────────────────────
def test_detect_rooms_from_files_with_matching_filenames(tmp_path):
# Create files whose names contain room keywords
for name in ["test_auth.py", "test_login.py", "test_api.py"]:
(tmp_path / name).write_text("content")
rooms = detect_rooms_from_files(str(tmp_path))
room_names = {r["name"] for r in rooms}
assert "testing" in room_names or "general" in room_names
def test_detect_rooms_from_files_empty_dir(tmp_path):
rooms = detect_rooms_from_files(str(tmp_path))
assert len(rooms) >= 1
assert any(r["name"] == "general" for r in rooms)
def test_detect_rooms_from_files_caps_at_six(tmp_path):
# Create many files with different keywords to hit the cap
for keyword in ["test", "doc", "api", "config", "frontend", "backend", "design", "meeting"]:
for i in range(3):
(tmp_path / f"{keyword}_file_{i}.txt").write_text("content")
rooms = detect_rooms_from_files(str(tmp_path))
assert len(rooms) <= 6
# ── save_config ─────────────────────────────────────────────────────────
def test_save_config_creates_yaml(tmp_path):
rooms = [
{"name": "frontend", "description": "UI files", "keywords": ["frontend"]},
{"name": "backend", "description": "Server files", "keywords": ["backend"]},
]
save_config(str(tmp_path), "myproject", rooms)
config_file = tmp_path / "mempalace.yaml"
assert config_file.exists()
content = config_file.read_text()
assert "myproject" in content
assert "frontend" in content
assert "backend" in content
def test_save_config_valid_yaml(tmp_path):
import yaml
rooms = [{"name": "general", "description": "All files", "keywords": []}]
save_config(str(tmp_path), "test_proj", rooms)
config_file = tmp_path / "mempalace.yaml"
data = yaml.safe_load(config_file.read_text())
assert data["wing"] == "test_proj"
assert len(data["rooms"]) == 1
assert data["rooms"][0]["name"] == "general"
# ── print_proposed_structure ──────────────────────────────────────────
def test_print_proposed_structure(capsys):
rooms = [
{"name": "frontend", "description": "UI files"},
{"name": "general", "description": "Everything else"},
]
print_proposed_structure("myapp", rooms, 42, "folder structure")
out = capsys.readouterr().out
assert "myapp" in out
assert "frontend" in out
assert "42 files" in out
assert "folder structure" in out
# ── get_user_approval ─────────────────────────────────────────────────
def test_get_user_approval_accept_all():
rooms = [{"name": "frontend", "description": "UI"}]
with patch("builtins.input", return_value=""):
result = get_user_approval(rooms)
assert result == rooms
def test_get_user_approval_edit_remove():
rooms = [
{"name": "frontend", "description": "UI"},
{"name": "backend", "description": "Server"},
]
with patch("builtins.input", side_effect=["edit", "1", "n"]):
result = get_user_approval(rooms)
# Room 1 (frontend) removed
assert len(result) == 1
assert result[0]["name"] == "backend"
def test_get_user_approval_add_room():
rooms = [{"name": "general", "description": "All files"}]
with patch(
"builtins.input",
side_effect=[
"add",
"custom_room",
"My custom room",
"",
],
):
result = get_user_approval(rooms)
names = [r["name"] for r in result]
assert "custom_room" in names
# ── detect_rooms_local ────────────────────────────────────────────────
def test_detect_rooms_local_yes_mode(tmp_path):
(tmp_path / "docs").mkdir()
(tmp_path / "docs" / "readme.md").write_text("hello")
mock_miner = MagicMock()
mock_miner.scan_project.return_value = ["file1.py"]
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
detect_rooms_local(str(tmp_path), yes=True)
assert (tmp_path / "mempalace.yaml").exists()
def test_detect_rooms_local_fallback_to_files(tmp_path):
"""When folder detection gives only 'general', falls back to file patterns."""
for i in range(3):
(tmp_path / f"test_file_{i}.py").write_text("content")
mock_miner = MagicMock()
mock_miner.scan_project.return_value = ["f1", "f2"]
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
detect_rooms_local(str(tmp_path), yes=True)
assert (tmp_path / "mempalace.yaml").exists()
def test_detect_rooms_local_missing_dir():
"""Non-existent directory causes sys.exit."""
import pytest
with pytest.raises(SystemExit):
detect_rooms_local("/nonexistent/path/that/does/not/exist", yes=True)
def test_detect_rooms_local_interactive(tmp_path):
(tmp_path / "src").mkdir()
(tmp_path / "src" / "main.py").write_text("code")
mock_miner = MagicMock()
mock_miner.scan_project.return_value = ["f1"]
with (
patch.dict("sys.modules", {"mempalace.miner": mock_miner}),
patch(
"mempalace.room_detector_local.get_user_approval",
return_value=[{"name": "general", "description": "All files", "keywords": []}],
),
):
detect_rooms_local(str(tmp_path), yes=False)
assert (tmp_path / "mempalace.yaml").exists()
+85 -5
View File
@@ -1,10 +1,18 @@
""" """
test_searcher.py Tests for the programmatic search_memories API. test_searcher.py -- Tests for both search() (CLI) and search_memories() (API).
Tests the library-facing search interface (not the CLI print variant). Uses the real ChromaDB fixtures from conftest.py for integration tests,
plus mock-based tests for error paths.
""" """
from mempalace.searcher import search_memories from unittest.mock import MagicMock, patch
import pytest
from mempalace.searcher import SearchError, search, search_memories
# ── search_memories (API) ──────────────────────────────────────────────
class TestSearchMemories: class TestSearchMemories:
@@ -30,8 +38,8 @@ class TestSearchMemories:
result = search_memories("code", palace_path, n_results=2) result = search_memories("code", palace_path, n_results=2)
assert len(result["results"]) <= 2 assert len(result["results"]) <= 2
def test_no_palace_returns_error(self): def test_no_palace_returns_error(self, tmp_path):
result = search_memories("anything", "/nonexistent/path") result = search_memories("anything", str(tmp_path / "missing"))
assert "error" in result assert "error" in result
def test_result_fields(self, palace_path, seeded_collection): def test_result_fields(self, palace_path, seeded_collection):
@@ -43,3 +51,75 @@ class TestSearchMemories:
assert "source_file" in hit assert "source_file" in hit
assert "similarity" in hit assert "similarity" in hit
assert isinstance(hit["similarity"], float) assert isinstance(hit["similarity"], float)
def test_search_memories_query_error(self):
"""search_memories returns error dict when query raises."""
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("query failed")
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client):
result = search_memories("test", "/fake/path")
assert "error" in result
assert "query failed" in result["error"]
def test_search_memories_filters_in_result(self, palace_path, seeded_collection):
result = search_memories("test", palace_path, wing="project", room="backend")
assert result["filters"]["wing"] == "project"
assert result["filters"]["room"] == "backend"
# ── search() (CLI print function) ─────────────────────────────────────
class TestSearchCLI:
def test_search_prints_results(self, palace_path, seeded_collection, capsys):
search("JWT authentication", palace_path)
captured = capsys.readouterr()
assert "JWT" in captured.out or "authentication" in captured.out
def test_search_with_wing_filter(self, palace_path, seeded_collection, capsys):
search("planning", palace_path, wing="notes")
captured = capsys.readouterr()
assert "Results for" in captured.out
def test_search_with_room_filter(self, palace_path, seeded_collection, capsys):
search("database", palace_path, room="backend")
captured = capsys.readouterr()
assert "Room:" in captured.out
def test_search_with_wing_and_room(self, palace_path, seeded_collection, capsys):
search("code", palace_path, wing="project", room="frontend")
captured = capsys.readouterr()
assert "Wing:" in captured.out
assert "Room:" in captured.out
def test_search_no_palace_raises(self, tmp_path):
with pytest.raises(SearchError, match="No palace found"):
search("anything", str(tmp_path / "missing"))
def test_search_no_results(self, palace_path, collection, capsys):
"""Empty collection returns no results message."""
# collection is empty (no seeded data)
result = search("xyzzy_nonexistent_query", palace_path, n_results=1)
captured = capsys.readouterr()
# Either prints "No results" or returns None
assert result is None or "No results" in captured.out
def test_search_query_error_raises(self):
"""search raises SearchError when query fails."""
mock_col = MagicMock()
mock_col.query.side_effect = RuntimeError("boom")
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client):
with pytest.raises(SearchError, match="Search error"):
search("test", "/fake/path")
def test_search_n_results(self, palace_path, seeded_collection, capsys):
search("code", palace_path, n_results=1)
captured = capsys.readouterr()
# Should have output with at least one result block
assert "[1]" in captured.out
+160
View File
@@ -0,0 +1,160 @@
"""Tests for mempalace.spellcheck — spell-correction utilities."""
from unittest.mock import patch
from mempalace.spellcheck import (
_edit_distance,
_get_system_words,
_should_skip,
spellcheck_transcript,
spellcheck_transcript_line,
spellcheck_user_text,
)
# --- _should_skip ---
class TestShouldSkip:
"""Token-level skip logic."""
def test_short_tokens_skipped(self):
assert _should_skip("hi", set()) is True
assert _should_skip("ok", set()) is True
assert _should_skip("I", set()) is True
def test_digits_skipped(self):
assert _should_skip("3am", set()) is True
assert _should_skip("top10", set()) is True
assert _should_skip("bge-large-v1.5", set()) is True
def test_camelcase_skipped(self):
assert _should_skip("ChromaDB", set()) is True
assert _should_skip("MemPalace", set()) is True
def test_allcaps_skipped(self):
assert _should_skip("NDCG", set()) is True
assert _should_skip("MAX_RESULTS", set()) is True
def test_technical_skipped(self):
assert _should_skip("bge-large", set()) is True
assert _should_skip("train_test", set()) is True
def test_url_skipped(self):
assert _should_skip("https://example.com", set()) is True
assert _should_skip("www.google.com", set()) is True
def test_code_or_emoji_skipped(self):
assert _should_skip("`code`", set()) is True
assert _should_skip("**bold**", set()) is True
def test_known_name_skipped(self):
assert _should_skip("mempalace", {"mempalace"}) is True
def test_normal_word_not_skipped(self):
assert _should_skip("hello", set()) is False
assert _should_skip("question", set()) is False
# --- _edit_distance ---
class TestEditDistance:
def test_identical(self):
assert _edit_distance("hello", "hello") == 0
def test_empty_strings(self):
assert _edit_distance("", "abc") == 3
assert _edit_distance("abc", "") == 3
assert _edit_distance("", "") == 0
def test_single_edit(self):
assert _edit_distance("cat", "bat") == 1 # substitution
assert _edit_distance("cat", "cats") == 1 # insertion
assert _edit_distance("cats", "cat") == 1 # deletion
def test_known_distance(self):
assert _edit_distance("kitten", "sitting") == 3
# --- _get_system_words ---
def test_get_system_words_returns_set():
result = _get_system_words()
assert isinstance(result, set)
# --- spellcheck_user_text ---
def test_spellcheck_user_text_passthrough_no_autocorrect():
"""When autocorrect is not installed, text passes through unchanged."""
with patch("mempalace.spellcheck._get_speller", return_value=None):
text = "somee misspeledd textt"
assert spellcheck_user_text(text) == text
def test_spellcheck_user_text_with_speller():
"""When a speller is available, it corrects words."""
def fake_speller(word):
corrections = {"knoe": "know", "befor": "before"}
return corrections.get(word, word)
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
result = spellcheck_user_text("knoe the question befor")
assert "know" in result
assert "before" in result
def test_spellcheck_preserves_technical_terms():
"""Technical terms should never be touched even with a speller."""
def fake_speller(word):
return "WRONG"
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
result = spellcheck_user_text("ChromaDB bge-large", known_names=set())
assert "ChromaDB" in result
assert "bge-large" in result
assert "WRONG" not in result
# --- spellcheck_transcript_line ---
def test_transcript_line_user_turn():
"""Lines starting with '>' should be processed."""
with patch("mempalace.spellcheck.spellcheck_user_text", return_value="corrected"):
result = spellcheck_transcript_line("> hello world")
assert "corrected" in result
def test_transcript_line_assistant_turn():
"""Lines not starting with '>' should pass through unchanged."""
line = "This is an assistant response"
assert spellcheck_transcript_line(line) == line
def test_transcript_line_empty_user_turn():
"""A '> ' line with no message content should pass through."""
line = "> "
assert spellcheck_transcript_line(line) == line
# --- spellcheck_transcript ---
def test_spellcheck_transcript_processes_content():
"""Full transcript: only '>' lines are touched."""
content = "Assistant line\n> user line\nAnother assistant line"
with patch("mempalace.spellcheck.spellcheck_user_text", return_value="fixed"):
result = spellcheck_transcript(content)
lines = result.split("\n")
assert lines[0] == "Assistant line"
assert "fixed" in lines[1]
assert lines[2] == "Another assistant line"
+72
View File
@@ -0,0 +1,72 @@
"""Extra spellcheck tests covering _load_known_names and speller edge cases."""
from unittest.mock import patch, MagicMock
from mempalace.spellcheck import (
_load_known_names,
spellcheck_user_text,
)
class TestLoadKnownNames:
def test_returns_names_from_registry(self):
mock_reg = MagicMock()
mock_reg._data = {
"entities": {
"e1": {"canonical": "Alice", "aliases": ["ali"]},
"e2": {"canonical": "Bob", "aliases": []},
}
}
with patch("mempalace.entity_registry.EntityRegistry") as MockER:
MockER.load.return_value = mock_reg
names = _load_known_names()
assert "alice" in names
assert "ali" in names
assert "bob" in names
def test_returns_empty_on_exception(self):
with patch(
"mempalace.entity_registry.EntityRegistry.load",
side_effect=Exception("no registry"),
):
names = _load_known_names()
assert names == set()
class TestSpellerEdgeCases:
def test_capitalized_word_skipped(self):
"""Capitalized words (likely proper nouns) are not corrected."""
def fake_speller(word):
return "WRONG"
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
result = spellcheck_user_text("Alice went home")
assert "Alice" in result
assert "WRONG" not in result
def test_system_word_not_corrected(self):
"""Words in system dict should not be corrected."""
def fake_speller(word):
return "WRONG"
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
with patch("mempalace.spellcheck._get_system_words", return_value={"coherently"}):
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
result = spellcheck_user_text("coherently")
assert "coherently" in result
def test_high_edit_distance_rejected(self):
"""Corrections with too many edits are rejected."""
def fake_speller(word):
return "completely_different_word"
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
result = spellcheck_user_text("hello")
assert "hello" in result
+244
View File
@@ -3,6 +3,9 @@ import json
from mempalace import split_mega_files as smf from mempalace import split_mega_files as smf
# ── Config loading ─────────────────────────────────────────────────────
def test_load_known_people_falls_back_when_config_missing(monkeypatch, tmp_path): def test_load_known_people_falls_back_when_config_missing(monkeypatch, tmp_path):
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", tmp_path / "missing.json") monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", tmp_path / "missing.json")
smf._KNOWN_NAMES_CACHE = None smf._KNOWN_NAMES_CACHE = None
@@ -46,3 +49,244 @@ def test_extract_people_detects_names_from_content(monkeypatch):
monkeypatch.setattr(smf, "KNOWN_PEOPLE", ["Alice", "Ben"]) monkeypatch.setattr(smf, "KNOWN_PEOPLE", ["Alice", "Ben"])
people = smf.extract_people(["> Alice reviewed the change with Ben\n"]) people = smf.extract_people(["> Alice reviewed the change with Ben\n"])
assert people == ["Alice", "Ben"] assert people == ["Alice", "Ben"]
# ── Config: force_reload and invalid JSON ──────────────────────────────
def test_load_known_names_force_reload(monkeypatch, tmp_path):
config_path = tmp_path / "known_names.json"
config_path.write_text(json.dumps(["Alice"]))
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
smf._KNOWN_NAMES_CACHE = None
smf._load_known_names_config()
assert smf._KNOWN_NAMES_CACHE == ["Alice"]
config_path.write_text(json.dumps(["Bob"]))
smf._load_known_names_config(force_reload=True)
assert smf._KNOWN_NAMES_CACHE == ["Bob"]
def test_load_known_names_invalid_json(monkeypatch, tmp_path):
config_path = tmp_path / "known_names.json"
config_path.write_text("not json {{{")
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
smf._KNOWN_NAMES_CACHE = None
result = smf._load_known_names_config()
assert result is None
def test_load_known_names_caching(monkeypatch, tmp_path):
config_path = tmp_path / "known_names.json"
config_path.write_text(json.dumps(["Alice"]))
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
smf._KNOWN_NAMES_CACHE = None
smf._load_known_names_config()
# Second call returns cached value without re-reading
config_path.write_text(json.dumps(["Changed"]))
result = smf._load_known_names_config()
assert result == ["Alice"]
# ── is_true_session_start ──────────────────────────────────────────────
def test_is_true_session_start_yes():
lines = ["Claude Code v1.0", "Some content", "More content", "", "", ""]
assert smf.is_true_session_start(lines, 0) is True
def test_is_true_session_start_no_ctrl_e():
lines = [
"Claude Code v1.0",
"Ctrl+E to show 5 previous messages",
"",
"",
"",
"",
]
assert smf.is_true_session_start(lines, 0) is False
def test_is_true_session_start_no_previous_messages():
lines = [
"Claude Code v1.0",
"Some text",
"previous messages here",
"",
"",
"",
]
assert smf.is_true_session_start(lines, 0) is False
# ── find_session_boundaries ────────────────────────────────────────────
def test_find_session_boundaries_two_sessions():
lines = [
"Claude Code v1.0",
"content 1",
"",
"",
"",
"",
"",
"Claude Code v1.0",
"content 2",
"",
"",
"",
"",
"",
]
boundaries = smf.find_session_boundaries(lines)
assert boundaries == [0, 7]
def test_find_session_boundaries_none():
lines = ["Just some text", "No sessions here"]
assert smf.find_session_boundaries(lines) == []
def test_find_session_boundaries_context_restore_skipped():
lines = [
"Claude Code v1.0",
"content",
"",
"",
"",
"",
"",
"Claude Code v1.0",
"Ctrl+E to show 5 previous messages",
"",
"",
"",
"",
]
boundaries = smf.find_session_boundaries(lines)
assert len(boundaries) == 1
# ── extract_timestamp ──────────────────────────────────────────────────
def test_extract_timestamp_found():
lines = ["⏺ 2:30 PM Wednesday, March 25, 2026"]
human, iso = smf.extract_timestamp(lines)
assert human == "2026-03-25_230PM"
assert iso == "2026-03-25"
def test_extract_timestamp_not_found():
lines = ["No timestamp here"]
human, iso = smf.extract_timestamp(lines)
assert human is None
assert iso is None
def test_extract_timestamp_only_checks_first_50():
lines = ["filler\n"] * 51 + ["⏺ 1:00 AM Monday, January 01, 2026"]
human, iso = smf.extract_timestamp(lines)
assert human is None
# ── extract_subject ────────────────────────────────────────────────────
def test_extract_subject_found():
lines = ["> How do we handle authentication?"]
subject = smf.extract_subject(lines)
assert "authentication" in subject.lower()
def test_extract_subject_skips_commands():
lines = ["> cd /some/dir", "> git status", "> What is the plan?"]
subject = smf.extract_subject(lines)
assert "plan" in subject.lower()
def test_extract_subject_fallback():
lines = ["No prompts at all", "Just text"]
subject = smf.extract_subject(lines)
assert subject == "session"
def test_extract_subject_short_prompt_skipped():
lines = ["> ok", "> yes", "> What about the deployment strategy?"]
subject = smf.extract_subject(lines)
assert "deployment" in subject.lower()
def test_extract_subject_truncated():
lines = ["> " + "a" * 100]
subject = smf.extract_subject(lines)
assert len(subject) <= 60
# ── split_file ─────────────────────────────────────────────────────────
def _make_mega_file(tmp_path, n_sessions=3, lines_per_session=15):
"""Create a mega-file with N sessions."""
content = ""
for i in range(n_sessions):
content += f"Claude Code v1.{i}\n"
content += f"> What about topic {i} and how it works?\n"
for j in range(lines_per_session - 2):
content += f"Line {j} of session {i}\n"
path = tmp_path / "mega.txt"
path.write_text(content)
return path
def test_split_file_creates_output(tmp_path):
mega = _make_mega_file(tmp_path)
out_dir = tmp_path / "output"
out_dir.mkdir()
written = smf.split_file(str(mega), str(out_dir))
assert len(written) >= 2
for p in written:
assert p.exists()
def test_split_file_dry_run(tmp_path):
mega = _make_mega_file(tmp_path)
out_dir = tmp_path / "output"
out_dir.mkdir()
written = smf.split_file(str(mega), str(out_dir), dry_run=True)
assert len(written) >= 2
for p in written:
assert not p.exists()
def test_split_file_not_mega(tmp_path):
"""File with fewer than 2 sessions is not split."""
path = tmp_path / "single.txt"
path.write_text("Claude Code v1.0\nJust one session\n" + "line\n" * 20)
written = smf.split_file(str(path), str(tmp_path))
assert written == []
def test_split_file_output_dir_none(tmp_path):
"""When output_dir is None, writes to same dir as source."""
mega = _make_mega_file(tmp_path)
written = smf.split_file(str(mega), None)
assert len(written) >= 2
for p in written:
assert str(p.parent) == str(tmp_path)
def test_split_file_tiny_fragments_skipped(tmp_path):
"""Tiny chunks (< 10 lines) are skipped."""
content = "Claude Code v1.0\nline\n" * 2 + "Claude Code v1.0\n" + "line\n" * 20
path = tmp_path / "tiny.txt"
path.write_text(content)
written = smf.split_file(str(path), str(tmp_path))
# The first chunk is very small, should be skipped
for p in written:
assert p.stat().st_size > 0