Merge branch 'main' into fix/query-sanitizer-prompt-contamination
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"name": "mempalace",
|
||||
"interface": {
|
||||
"displayName": "MemPalace"
|
||||
},
|
||||
"plugins": [
|
||||
{
|
||||
"name": "mempalace",
|
||||
"source": {
|
||||
"source": "local",
|
||||
"path": "./.codex-plugin"
|
||||
},
|
||||
"policy": {
|
||||
"installation": "AVAILABLE",
|
||||
"authentication": "NONE"
|
||||
},
|
||||
"category": "Coding"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"mempalace": {
|
||||
"command": "python3",
|
||||
"args": [
|
||||
"-m",
|
||||
"mempalace.mcp_server"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
# MemPalace Claude Code Plugin
|
||||
|
||||
A Claude Code plugin that gives your AI a persistent memory system. Mine projects and conversations into a searchable palace backed by ChromaDB, with 19 MCP tools, auto-save hooks, and 5 guided skills.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.9+
|
||||
|
||||
## Installation
|
||||
|
||||
### Claude Code Marketplace
|
||||
|
||||
```bash
|
||||
claude plugin marketplace add milla-jovovich/mempalace
|
||||
claude plugin install --scope user mempalace
|
||||
```
|
||||
|
||||
### Local Clone
|
||||
|
||||
```bash
|
||||
claude plugin add /path/to/mempalace
|
||||
```
|
||||
|
||||
## Post-Install Setup
|
||||
|
||||
After installing the plugin, run the init command to complete setup (pip install, MCP configuration, etc.):
|
||||
|
||||
```
|
||||
/mempalace:init
|
||||
```
|
||||
|
||||
## Available Slash Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/mempalace:help` | Show available tools, skills, and architecture |
|
||||
| `/mempalace:init` | Set up MemPalace -- install, configure MCP, onboard |
|
||||
| `/mempalace:search` | Search your memories across the palace |
|
||||
| `/mempalace:mine` | Mine projects and conversations into the palace |
|
||||
| `/mempalace:status` | Show palace overview -- wings, rooms, drawer counts |
|
||||
|
||||
## Hooks
|
||||
|
||||
MemPalace registers two hooks that run automatically:
|
||||
|
||||
- **Stop** -- Saves conversation context every 15 messages.
|
||||
- **PreCompact** -- Preserves important memories before context compaction.
|
||||
|
||||
Set the `MEMPAL_DIR` environment variable to a directory path to automatically run `mempalace mine` on that directory during each save trigger.
|
||||
|
||||
## MCP Server
|
||||
|
||||
The plugin automatically configures a local MCP server with 19 tools for storing, searching, and managing memories. No manual MCP setup is required -- `/mempalace:init` handles everything.
|
||||
|
||||
## Full Documentation
|
||||
|
||||
See the main [README](../README.md) for complete documentation, architecture details, and advanced usage.
|
||||
@@ -0,0 +1,6 @@
|
||||
---
|
||||
description: Show comprehensive MemPalace help — available skills, MCP tools, CLI commands, hooks, and architecture.
|
||||
allowed-tools: Bash, Read
|
||||
---
|
||||
|
||||
Invoke the generic mempalace skill (using the Skill tool) with the `help` command, then follow its instructions.
|
||||
@@ -0,0 +1,6 @@
|
||||
---
|
||||
description: Set up MemPalace — install the package, initialize a palace, configure MCP server, and verify everything works.
|
||||
allowed-tools: Bash, Read, Write, Edit, Glob, Grep
|
||||
---
|
||||
|
||||
Invoke the generic mempalace skill (using the Skill tool) with the `init` command, then follow its instructions.
|
||||
@@ -0,0 +1,7 @@
|
||||
---
|
||||
description: Mine projects and conversations into the MemPalace. Supports project files, conversation exports, and auto-classification.
|
||||
argument-hint: Path to project or conversation export to mine.
|
||||
allowed-tools: Bash, Read, Write, Edit, Glob, Grep
|
||||
---
|
||||
|
||||
Invoke the generic mempalace skill (using the Skill tool) with the `mine` command, then follow its instructions.
|
||||
@@ -0,0 +1,7 @@
|
||||
---
|
||||
description: Search your memories across the MemPalace using semantic search with wing/room filtering.
|
||||
argument-hint: Search query, optionally with wing/room filters.
|
||||
allowed-tools: Bash, Read
|
||||
---
|
||||
|
||||
Invoke the generic mempalace skill (using the Skill tool) with the `search` command, then follow its instructions.
|
||||
@@ -0,0 +1,6 @@
|
||||
---
|
||||
description: Show the current state of your memory palace — wings, rooms, drawer counts, and suggestions.
|
||||
allowed-tools: Bash, Read
|
||||
---
|
||||
|
||||
Invoke the generic mempalace skill (using the Skill tool) with the `status` command, then follow its instructions.
|
||||
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"description": "MemPalace auto-save and pre-compact hooks",
|
||||
"hooks": {
|
||||
"Stop": [
|
||||
{
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "bash ${CLAUDE_PLUGIN_ROOT}/hooks/mempal-stop-hook.sh"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"PreCompact": [
|
||||
{
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "bash ${CLAUDE_PLUGIN_ROOT}/hooks/mempal-precompact-hook.sh"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
# MemPalace PreCompact Hook — thin wrapper calling Python CLI
|
||||
# All logic lives in mempalace.hooks_cli for cross-harness extensibility
|
||||
INPUT=$(cat)
|
||||
echo "$INPUT" | python3 -m mempalace hook run --hook precompact --harness claude-code
|
||||
@@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
# MemPalace Stop Hook — thin wrapper calling Python CLI
|
||||
# All logic lives in mempalace.hooks_cli for cross-harness extensibility
|
||||
INPUT=$(cat)
|
||||
echo "$INPUT" | python3 -m mempalace hook run --hook stop --harness claude-code
|
||||
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"name": "mempalace",
|
||||
"owner": {
|
||||
"name": "milla-jovovich",
|
||||
"url": "https://github.com/milla-jovovich"
|
||||
},
|
||||
"plugins": [
|
||||
{
|
||||
"name": "mempalace",
|
||||
"source": "./.claude-plugin",
|
||||
"description": "AI memory system — mine projects and conversations into a searchable palace. 19 MCP tools, auto-save hooks, guided setup.",
|
||||
"version": "3.0.14",
|
||||
"author": {
|
||||
"name": "milla-jovovich"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
{
|
||||
"name": "mempalace",
|
||||
"version": "3.0.14",
|
||||
"description": "Give your AI a memory — mine projects and conversations into a searchable palace. 19 MCP tools, auto-save hooks, and guided setup.",
|
||||
"author": {
|
||||
"name": "milla-jovovich"
|
||||
},
|
||||
"license": "MIT",
|
||||
"commands": [],
|
||||
"mcpServers": {
|
||||
"mempalace": {
|
||||
"command": "python3",
|
||||
"args": [
|
||||
"-m",
|
||||
"mempalace.mcp_server"
|
||||
]
|
||||
}
|
||||
},
|
||||
"keywords": [
|
||||
"memory",
|
||||
"ai",
|
||||
"rag",
|
||||
"mcp",
|
||||
"chromadb",
|
||||
"palace",
|
||||
"search"
|
||||
],
|
||||
"repository": "https://github.com/milla-jovovich/mempalace"
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
---
|
||||
name: mempalace
|
||||
description: MemPalace — mine projects and conversations into a searchable memory palace. Use when asked about mempalace, memory palace, mining memories, searching memories, or palace setup.
|
||||
allowed-tools: Bash, Read, Write, Edit, Glob, Grep
|
||||
---
|
||||
|
||||
# MemPalace
|
||||
|
||||
A searchable memory palace for AI — mine projects and conversations, then search them semantically.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Ensure `mempalace` is installed:
|
||||
|
||||
```bash
|
||||
mempalace --version
|
||||
```
|
||||
|
||||
If not installed:
|
||||
|
||||
```bash
|
||||
pip install mempalace
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
MemPalace provides dynamic instructions via the CLI. To get instructions for any operation:
|
||||
|
||||
```bash
|
||||
mempalace instructions <command>
|
||||
```
|
||||
|
||||
Where `<command>` is one of: `help`, `init`, `mine`, `search`, `status`.
|
||||
|
||||
Run the appropriate instructions command, then follow the returned instructions step by step.
|
||||
@@ -0,0 +1,75 @@
|
||||
# MemPalace - Codex CLI Plugin
|
||||
|
||||
Give your AI a persistent memory -- mine projects and conversations into a searchable palace backed by ChromaDB, with 19 MCP tools, auto-save hooks, and guided skills.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.9+
|
||||
- Codex CLI installed and configured
|
||||
- `pip install mempalace`
|
||||
|
||||
## Installation
|
||||
|
||||
### Local Install
|
||||
|
||||
1. Copy or symlink the `.codex-plugin` directory into your project root:
|
||||
|
||||
```bash
|
||||
cp -r .codex-plugin /path/to/your/project/.codex-plugin
|
||||
```
|
||||
|
||||
2. Verify the plugin is detected:
|
||||
|
||||
```bash
|
||||
codex --plugins
|
||||
```
|
||||
|
||||
3. Initialize your palace:
|
||||
|
||||
```bash
|
||||
codex /init
|
||||
```
|
||||
|
||||
### Git Install
|
||||
|
||||
1. Clone the MemPalace repository:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/milla-jovovich/mempalace.git
|
||||
cd mempalace
|
||||
```
|
||||
|
||||
2. Install the Python package:
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
3. The `.codex-plugin` directory is already in the repo root. Codex CLI will detect it automatically when you run Codex from inside the repository.
|
||||
|
||||
4. Initialize your palace:
|
||||
|
||||
```bash
|
||||
codex /init
|
||||
```
|
||||
|
||||
## Available Skills
|
||||
|
||||
| Skill | Description |
|
||||
|-------|-------------|
|
||||
| `/help` | Show available commands and usage tips |
|
||||
| `/init` | Initialize a new memory palace |
|
||||
| `/search` | Semantic search across all mined memories |
|
||||
| `/mine` | Mine a project or conversation into your palace |
|
||||
| `/status` | Show palace status, room counts, and health |
|
||||
|
||||
## Hooks
|
||||
|
||||
The plugin includes auto-save hooks that run on session stop (every 15 messages) and before context compaction, automatically preserving conversation context into your palace.
|
||||
|
||||
Set the `MEMPAL_DIR` environment variable to a directory path to automatically run `mempalace mine` on that directory during each save trigger.
|
||||
|
||||
## Support
|
||||
|
||||
- Repository: https://github.com/milla-jovovich/mempalace
|
||||
- Issues: https://github.com/milla-jovovich/mempalace/issues
|
||||
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"hooks": {
|
||||
"SessionStart": [
|
||||
{
|
||||
"matcher": "*",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "${CODEX_PLUGIN_ROOT}/hooks/mempal-hook.sh session-start"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"Stop": [
|
||||
{
|
||||
"matcher": "*",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "${CODEX_PLUGIN_ROOT}/hooks/mempal-hook.sh stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"PreCompact": [
|
||||
{
|
||||
"matcher": "*",
|
||||
"hooks": [
|
||||
{
|
||||
"type": "command",
|
||||
"command": "${CODEX_PLUGIN_ROOT}/hooks/mempal-hook.sh precompact"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
HOOK_NAME="${1:?Usage: mempal-hook.sh <hook-name>}"
|
||||
INPUT_FILE=$(mktemp) || { echo "Failed to create temp file" >&2; exit 1; }
|
||||
cat > "$INPUT_FILE"
|
||||
cat "$INPUT_FILE" | python3 -m mempalace hook run --hook "$HOOK_NAME" --harness codex
|
||||
EXIT_CODE=$?
|
||||
rm -f "$INPUT_FILE" 2>/dev/null
|
||||
exit $EXIT_CODE
|
||||
@@ -0,0 +1,52 @@
|
||||
{
|
||||
"name": "mempalace",
|
||||
"version": "3.0.14",
|
||||
"description": "Give your AI a memory — mine projects and conversations into a searchable palace. 19 MCP tools, auto-save hooks, and guided setup.",
|
||||
"author": {
|
||||
"name": "milla-jovovich"
|
||||
},
|
||||
"homepage": "https://github.com/milla-jovovich/mempalace",
|
||||
"repository": "https://github.com/milla-jovovich/mempalace",
|
||||
"license": "MIT",
|
||||
"keywords": [
|
||||
"memory",
|
||||
"ai",
|
||||
"rag",
|
||||
"mcp",
|
||||
"chromadb",
|
||||
"palace",
|
||||
"search"
|
||||
],
|
||||
"skills": "./skills/",
|
||||
"hooks": "./hooks.json",
|
||||
"mcpServers": {
|
||||
"mempalace": {
|
||||
"command": "python3",
|
||||
"args": [
|
||||
"-m",
|
||||
"mempalace.mcp_server"
|
||||
]
|
||||
}
|
||||
},
|
||||
"interface": {
|
||||
"displayName": "MemPalace",
|
||||
"shortDescription": "AI memory system for Codex",
|
||||
"longDescription": "Give your AI a persistent memory — mine projects and conversations into a searchable palace backed by ChromaDB, with 19 MCP tools, auto-save hooks, and guided skills.",
|
||||
"developerName": "milla-jovovich",
|
||||
"category": "Coding",
|
||||
"capabilities": [
|
||||
"Interactive",
|
||||
"Read",
|
||||
"Write"
|
||||
],
|
||||
"websiteURL": "https://github.com/milla-jovovich/mempalace",
|
||||
"privacyPolicyURL": "https://github.com/milla-jovovich/mempalace",
|
||||
"termsOfServiceURL": "https://github.com/milla-jovovich/mempalace",
|
||||
"defaultPrompt": [
|
||||
"Search my memories for recent decisions",
|
||||
"Mine this project into my memory palace",
|
||||
"Show my palace status and room counts"
|
||||
],
|
||||
"brandColor": "#7C3AED"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
---
|
||||
name: help
|
||||
description: Show MemPalace help — available commands, usage tips, and getting started guidance.
|
||||
allowed-tools: Bash, Read
|
||||
---
|
||||
|
||||
# MemPalace Help
|
||||
|
||||
Run the following command and follow the returned instructions step by step:
|
||||
|
||||
```bash
|
||||
mempalace instructions help
|
||||
```
|
||||
@@ -0,0 +1,13 @@
|
||||
---
|
||||
name: init
|
||||
description: Initialize a new MemPalace — guided setup for your AI memory palace with ChromaDB backend.
|
||||
allowed-tools: Bash, Read, Write, Edit
|
||||
---
|
||||
|
||||
# MemPalace Init
|
||||
|
||||
Run the following command and follow the returned instructions step by step:
|
||||
|
||||
```bash
|
||||
mempalace instructions init
|
||||
```
|
||||
@@ -0,0 +1,13 @@
|
||||
---
|
||||
name: mine
|
||||
description: Mine a project or conversation into your MemPalace — extract and store memories for later retrieval.
|
||||
allowed-tools: Bash, Read, Glob, Grep
|
||||
---
|
||||
|
||||
# MemPalace Mine
|
||||
|
||||
Run the following command and follow the returned instructions step by step:
|
||||
|
||||
```bash
|
||||
mempalace instructions mine
|
||||
```
|
||||
@@ -0,0 +1,13 @@
|
||||
---
|
||||
name: search
|
||||
description: Search your MemPalace — semantic search across all mined memories, projects, and conversations.
|
||||
allowed-tools: Bash, Read
|
||||
---
|
||||
|
||||
# MemPalace Search
|
||||
|
||||
Run the following command and follow the returned instructions step by step:
|
||||
|
||||
```bash
|
||||
mempalace instructions search
|
||||
```
|
||||
@@ -0,0 +1,13 @@
|
||||
---
|
||||
name: status
|
||||
description: Show MemPalace status — room counts, storage usage, and palace health.
|
||||
allowed-tools: Bash, Read
|
||||
---
|
||||
|
||||
# MemPalace Status
|
||||
|
||||
Run the following command and follow the returned instructions step by step:
|
||||
|
||||
```bash
|
||||
mempalace instructions status
|
||||
```
|
||||
@@ -0,0 +1,51 @@
|
||||
name: Bump Version
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
bump-version:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Bump patch version
|
||||
run: |
|
||||
CURRENT=$(python3 -c "exec(open('mempalace/version.py').read()); print(__version__)")
|
||||
IFS='.' read -r MAJOR MINOR PATCH <<< "$CURRENT"
|
||||
PATCH=$((PATCH + 1))
|
||||
NEW="${MAJOR}.${MINOR}.${PATCH}"
|
||||
echo "__version__ = \"${NEW}\"" > mempalace/version.py
|
||||
# Prepend docstring
|
||||
sed -i '1i"""Single source of truth for the MemPalace package version."""\n' mempalace/version.py
|
||||
echo "version=$NEW" >> "$GITHUB_OUTPUT"
|
||||
id: version
|
||||
|
||||
- name: Sync plugin.json
|
||||
run: |
|
||||
jq --arg v "${{ steps.version.outputs.version }}" '.version = $v' .claude-plugin/plugin.json > tmp.json && mv tmp.json .claude-plugin/plugin.json
|
||||
|
||||
- name: Sync marketplace.json
|
||||
run: |
|
||||
jq --arg v "${{ steps.version.outputs.version }}" '.plugins[0].version = $v' .claude-plugin/marketplace.json > tmp.json && mv tmp.json .claude-plugin/marketplace.json
|
||||
|
||||
- name: Sync codex plugin.json
|
||||
run: |
|
||||
jq --arg v "${{ steps.version.outputs.version }}" '.version = $v' .codex-plugin/plugin.json > tmp.json && mv tmp.json .codex-plugin/plugin.json
|
||||
|
||||
- name: Sync pyproject.toml
|
||||
run: |
|
||||
sed -i "s/^version = \".*\"/version = \"${{ steps.version.outputs.version }}\"/" pyproject.toml
|
||||
|
||||
- name: Commit and push
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git add mempalace/version.py .claude-plugin/plugin.json .claude-plugin/marketplace.json .codex-plugin/plugin.json pyproject.toml
|
||||
if ! git diff --staged --quiet; then
|
||||
git commit -m "chore: bump version to ${{ steps.version.outputs.version }}"
|
||||
git push
|
||||
fi
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
test-linux:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -18,8 +18,27 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- run: pip install -e ".[dev]"
|
||||
- run: python -m pytest tests/ -v
|
||||
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85
|
||||
|
||||
test-windows:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.9"
|
||||
- run: pip install -e ".[dev]"
|
||||
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85
|
||||
|
||||
test-macos:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.9"
|
||||
- run: pip install -e ".[dev]"
|
||||
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=85
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
@@ -27,6 +46,6 @@ jobs:
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- run: pip install ruff
|
||||
- run: pip install "ruff>=0.4.0,<0.5"
|
||||
- run: ruff check .
|
||||
- run: ruff format --check .
|
||||
|
||||
@@ -5,3 +5,4 @@ __pycache__/
|
||||
*.pyc
|
||||
.pytest_cache/
|
||||
mempal.yaml
|
||||
.a5c/
|
||||
|
||||
@@ -29,7 +29,7 @@ Other memory systems try to fix this by letting AI decide what's worth rememberi
|
||||
|
||||
<br>
|
||||
|
||||
[Quick Start](#quick-start) · [The Palace](#the-palace) · [AAAK Dialect](#aaak-compression) · [Benchmarks](#benchmarks) · [MCP Tools](#mcp-server)
|
||||
[Quick Start](#quick-start) · [The Palace](#the-palace) · [AAAK Dialect](#aaak-dialect-experimental) · [Benchmarks](#benchmarks) · [MCP Tools](#mcp-server)
|
||||
|
||||
<br>
|
||||
|
||||
@@ -112,6 +112,17 @@ Three mining modes: **projects** (code and docs), **convos** (conversation expor
|
||||
|
||||
After the one-time setup (install → init → mine), you don't run MemPalace commands manually. Your AI uses it for you. There are two ways, depending on which AI you use.
|
||||
|
||||
### With Claude Code (recommended)
|
||||
|
||||
Native marketplace install:
|
||||
|
||||
```bash
|
||||
claude plugin marketplace add milla-jovovich/mempalace
|
||||
claude plugin install --scope user mempalace
|
||||
```
|
||||
|
||||
Restart Claude Code, then type `/skills` to verify "mempalace" appears.
|
||||
|
||||
### With Claude, ChatGPT, Cursor, Gemini (MCP-compatible tools)
|
||||
|
||||
```bash
|
||||
@@ -439,6 +450,11 @@ Letta charges $20–200/mo for agent-managed memory. MemPalace does it with a wi
|
||||
## MCP Server
|
||||
|
||||
```bash
|
||||
# Via plugin (recommended)
|
||||
claude plugin marketplace add milla-jovovich/mempalace
|
||||
claude plugin install --scope user mempalace
|
||||
|
||||
# Or manually
|
||||
claude mcp add mempalace -- python -m mempalace.mcp_server
|
||||
```
|
||||
|
||||
@@ -509,6 +525,8 @@ Two hooks for Claude Code that automatically save memories during work:
|
||||
}
|
||||
```
|
||||
|
||||
**Optional auto-ingest:** Set the `MEMPAL_DIR` environment variable to a directory path and the hooks will automatically run `mempalace mine` on that directory during each save trigger (background on stop, synchronous on precompact).
|
||||
|
||||
---
|
||||
|
||||
## Benchmarks
|
||||
|
||||
+17
-2
@@ -1,6 +1,21 @@
|
||||
"""MemPalace — Give your AI a memory. No API key required."""
|
||||
|
||||
from .cli import main
|
||||
from .version import __version__
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
|
||||
from .cli import main # noqa: E402
|
||||
from .version import __version__ # noqa: E402
|
||||
|
||||
# ChromaDB 0.6.x ships a Posthog telemetry client whose capture() signature is
|
||||
# incompatible with the bundled posthog library, producing noisy stderr warnings
|
||||
# on every client operation ("Failed to send telemetry event … capture() takes
|
||||
# 1 positional argument but 3 were given"). Silence just that logger.
|
||||
logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.CRITICAL)
|
||||
|
||||
# ONNX Runtime's CoreML provider segfaults during vector queries on Apple Silicon.
|
||||
# Force CPU execution unless the user has explicitly set a preference.
|
||||
if platform.machine() == "arm64" and platform.system() == "Darwin":
|
||||
os.environ.setdefault("ORT_DISABLE_COREML", "1")
|
||||
|
||||
__all__ = ["main", "__version__"]
|
||||
|
||||
@@ -226,6 +226,20 @@ def cmd_repair(args):
|
||||
print(f"\n{'=' * 55}\n")
|
||||
|
||||
|
||||
def cmd_hook(args):
|
||||
"""Run hook logic: reads JSON from stdin, outputs JSON to stdout."""
|
||||
from .hooks_cli import run_hook
|
||||
|
||||
run_hook(hook_name=args.hook, harness=args.harness)
|
||||
|
||||
|
||||
def cmd_instructions(args):
|
||||
"""Output skill instructions to stdout."""
|
||||
from .instructions_cli import run_instructions
|
||||
|
||||
run_instructions(name=args.name)
|
||||
|
||||
|
||||
def cmd_compress(args):
|
||||
"""Compress drawers in a wing using AAAK Dialect."""
|
||||
import chromadb
|
||||
@@ -451,6 +465,35 @@ def main():
|
||||
help="Only split files containing at least N sessions (default: 2)",
|
||||
)
|
||||
|
||||
# hook
|
||||
p_hook = sub.add_parser(
|
||||
"hook",
|
||||
help="Run hook logic (reads JSON from stdin, outputs JSON to stdout)",
|
||||
)
|
||||
hook_sub = p_hook.add_subparsers(dest="hook_action")
|
||||
p_hook_run = hook_sub.add_parser("run", help="Execute a hook")
|
||||
p_hook_run.add_argument(
|
||||
"--hook",
|
||||
required=True,
|
||||
choices=["session-start", "stop", "precompact"],
|
||||
help="Hook name to run",
|
||||
)
|
||||
p_hook_run.add_argument(
|
||||
"--harness",
|
||||
required=True,
|
||||
choices=["claude-code", "codex"],
|
||||
help="Harness type (determines stdin JSON format)",
|
||||
)
|
||||
|
||||
# instructions
|
||||
p_instructions = sub.add_parser(
|
||||
"instructions",
|
||||
help="Output skill instructions to stdout",
|
||||
)
|
||||
instructions_sub = p_instructions.add_subparsers(dest="instructions_name")
|
||||
for instr_name in ["init", "search", "mine", "help", "status"]:
|
||||
instructions_sub.add_parser(instr_name, help=f"Output {instr_name} instructions")
|
||||
|
||||
# repair
|
||||
sub.add_parser(
|
||||
"repair",
|
||||
@@ -466,6 +509,23 @@ def main():
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
# Handle two-level subcommands
|
||||
if args.command == "hook":
|
||||
if not getattr(args, "hook_action", None):
|
||||
p_hook.print_help()
|
||||
return
|
||||
cmd_hook(args)
|
||||
return
|
||||
|
||||
if args.command == "instructions":
|
||||
name = getattr(args, "instructions_name", None)
|
||||
if not name:
|
||||
p_instructions.print_help()
|
||||
return
|
||||
args.name = name
|
||||
cmd_instructions(args)
|
||||
return
|
||||
|
||||
dispatch = {
|
||||
"init": cmd_init,
|
||||
"mine": cmd_mine,
|
||||
|
||||
@@ -309,7 +309,7 @@ class EntityRegistry:
|
||||
|
||||
def save(self):
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._path.write_text(json.dumps(self._data, indent=2))
|
||||
self._path.write_text(json.dumps(self._data, indent=2), encoding="utf-8")
|
||||
|
||||
@staticmethod
|
||||
def _empty() -> dict:
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Hook logic for MemPalace — Python implementation of session-start, stop, and precompact hooks.
|
||||
|
||||
Reads JSON from stdin, outputs JSON to stdout.
|
||||
Supported hooks: session-start, stop, precompact
|
||||
Supported harnesses: claude-code, codex (extensible to cursor, gemini, etc.)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
SAVE_INTERVAL = 15
|
||||
STATE_DIR = Path.home() / ".mempalace" / "hook_state"
|
||||
|
||||
STOP_BLOCK_REASON = (
|
||||
"AUTO-SAVE checkpoint. Save key topics, decisions, quotes, and code "
|
||||
"from this session to your memory system. Organize into appropriate "
|
||||
"categories. Use verbatim quotes where possible. Continue conversation "
|
||||
"after saving."
|
||||
)
|
||||
|
||||
PRECOMPACT_BLOCK_REASON = (
|
||||
"COMPACTION IMMINENT. Save ALL topics, decisions, quotes, code, and "
|
||||
"important context from this session to your memory system. Be thorough "
|
||||
"\u2014 after compaction, detailed context will be lost. Organize into "
|
||||
"appropriate categories. Use verbatim quotes where possible. Save "
|
||||
"everything, then allow compaction to proceed."
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_session_id(session_id: str) -> str:
|
||||
"""Only allow alnum, dash, underscore to prevent path traversal."""
|
||||
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "", session_id)
|
||||
return sanitized or "unknown"
|
||||
|
||||
|
||||
def _count_human_messages(transcript_path: str) -> int:
|
||||
"""Count human messages in a JSONL transcript, skipping command-messages."""
|
||||
path = Path(transcript_path).expanduser()
|
||||
if not path.is_file():
|
||||
return 0
|
||||
count = 0
|
||||
try:
|
||||
with open(path, encoding="utf-8", errors="replace") as f:
|
||||
for line in f:
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
msg = entry.get("message", {})
|
||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
if "<command-message>" in content:
|
||||
continue
|
||||
elif isinstance(content, list):
|
||||
text = " ".join(
|
||||
b.get("text", "") for b in content if isinstance(b, dict)
|
||||
)
|
||||
if "<command-message>" in text:
|
||||
continue
|
||||
count += 1
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
except OSError:
|
||||
return 0
|
||||
return count
|
||||
|
||||
|
||||
def _log(message: str):
|
||||
"""Append to hook state log file."""
|
||||
try:
|
||||
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
log_path = STATE_DIR / "hook.log"
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
with open(log_path, "a") as f:
|
||||
f.write(f"[{timestamp}] {message}\n")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _output(data: dict):
|
||||
"""Print JSON to stdout with consistent formatting (pretty-printed)."""
|
||||
print(json.dumps(data, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def _maybe_auto_ingest():
|
||||
"""If MEMPAL_DIR is set and exists, run mempalace mine in background."""
|
||||
mempal_dir = os.environ.get("MEMPAL_DIR", "")
|
||||
if mempal_dir and os.path.isdir(mempal_dir):
|
||||
try:
|
||||
log_path = STATE_DIR / "hook.log"
|
||||
with open(log_path, "a") as log_f:
|
||||
subprocess.Popen(
|
||||
[sys.executable, "-m", "mempalace", "mine", mempal_dir],
|
||||
stdout=log_f,
|
||||
stderr=log_f,
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
SUPPORTED_HARNESSES = {"claude-code", "codex"}
|
||||
|
||||
|
||||
def _parse_harness_input(data: dict, harness: str) -> dict:
|
||||
"""Parse stdin JSON according to the harness type."""
|
||||
if harness not in SUPPORTED_HARNESSES:
|
||||
print(f"Unknown harness: {harness}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return {
|
||||
"session_id": _sanitize_session_id(str(data.get("session_id", "unknown"))),
|
||||
"stop_hook_active": data.get("stop_hook_active", False),
|
||||
"transcript_path": str(data.get("transcript_path", "")),
|
||||
}
|
||||
|
||||
|
||||
def hook_stop(data: dict, harness: str):
|
||||
"""Stop hook: block every N messages for auto-save."""
|
||||
parsed = _parse_harness_input(data, harness)
|
||||
session_id = parsed["session_id"]
|
||||
stop_hook_active = parsed["stop_hook_active"]
|
||||
transcript_path = parsed["transcript_path"]
|
||||
|
||||
# If already in a save cycle, let through (infinite-loop prevention)
|
||||
if str(stop_hook_active).lower() in ("true", "1", "yes"):
|
||||
_output({})
|
||||
return
|
||||
|
||||
# Count human messages
|
||||
exchange_count = _count_human_messages(transcript_path)
|
||||
|
||||
# Track last save point
|
||||
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
last_save_file = STATE_DIR / f"{session_id}_last_save"
|
||||
last_save = 0
|
||||
if last_save_file.is_file():
|
||||
try:
|
||||
last_save = int(last_save_file.read_text().strip())
|
||||
except (ValueError, OSError):
|
||||
last_save = 0
|
||||
|
||||
since_last = exchange_count - last_save
|
||||
|
||||
_log(f"Session {session_id}: {exchange_count} exchanges, {since_last} since last save")
|
||||
|
||||
if since_last >= SAVE_INTERVAL and exchange_count > 0:
|
||||
# Update last save point
|
||||
try:
|
||||
last_save_file.write_text(str(exchange_count), encoding="utf-8")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
_log(f"TRIGGERING SAVE at exchange {exchange_count}")
|
||||
|
||||
# Optional: auto-ingest if MEMPAL_DIR is set
|
||||
_maybe_auto_ingest()
|
||||
|
||||
_output({"decision": "block", "reason": STOP_BLOCK_REASON})
|
||||
else:
|
||||
_output({})
|
||||
|
||||
|
||||
def hook_session_start(data: dict, harness: str):
|
||||
"""Session start hook: initialize session tracking state."""
|
||||
parsed = _parse_harness_input(data, harness)
|
||||
session_id = parsed["session_id"]
|
||||
|
||||
_log(f"SESSION START for session {session_id}")
|
||||
|
||||
# Initialize session state directory
|
||||
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Pass through — no blocking on session start
|
||||
_output({})
|
||||
|
||||
|
||||
def hook_precompact(data: dict, harness: str):
|
||||
"""Precompact hook: always block with comprehensive save instruction."""
|
||||
parsed = _parse_harness_input(data, harness)
|
||||
session_id = parsed["session_id"]
|
||||
|
||||
_log(f"PRE-COMPACT triggered for session {session_id}")
|
||||
|
||||
# Optional: auto-ingest synchronously before compaction (so memories land first)
|
||||
mempal_dir = os.environ.get("MEMPAL_DIR", "")
|
||||
if mempal_dir and os.path.isdir(mempal_dir):
|
||||
try:
|
||||
log_path = STATE_DIR / "hook.log"
|
||||
with open(log_path, "a") as log_f:
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "mempalace", "mine", mempal_dir],
|
||||
stdout=log_f,
|
||||
stderr=log_f,
|
||||
timeout=60,
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Always block -- compaction = save everything
|
||||
_output({"decision": "block", "reason": PRECOMPACT_BLOCK_REASON})
|
||||
|
||||
|
||||
def run_hook(hook_name: str, harness: str):
|
||||
"""Main entry point: read stdin JSON, dispatch to hook handler."""
|
||||
try:
|
||||
data = json.load(sys.stdin)
|
||||
except (json.JSONDecodeError, EOFError):
|
||||
_log("WARNING: Failed to parse stdin JSON, proceeding with empty data")
|
||||
data = {}
|
||||
|
||||
hooks = {
|
||||
"session-start": hook_session_start,
|
||||
"stop": hook_stop,
|
||||
"precompact": hook_precompact,
|
||||
}
|
||||
|
||||
handler = hooks.get(hook_name)
|
||||
if handler is None:
|
||||
print(f"Unknown hook: {hook_name}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
handler(data, harness)
|
||||
@@ -0,0 +1,105 @@
|
||||
# MemPalace
|
||||
|
||||
AI memory system. Store everything, find anything. Local, free, no API key.
|
||||
|
||||
---
|
||||
|
||||
## Slash Commands
|
||||
|
||||
| Command | Description |
|
||||
|----------------------|--------------------------------|
|
||||
| /mempalace:init | Install and set up MemPalace |
|
||||
| /mempalace:search | Search your memories |
|
||||
| /mempalace:mine | Mine projects and conversations|
|
||||
| /mempalace:status | Palace overview and stats |
|
||||
| /mempalace:help | This help message |
|
||||
|
||||
---
|
||||
|
||||
## MCP Tools (19)
|
||||
|
||||
### Palace (read)
|
||||
- mempalace_status -- Palace status and stats
|
||||
- mempalace_list_wings -- List all wings
|
||||
- mempalace_list_rooms -- List rooms in a wing
|
||||
- mempalace_get_taxonomy -- Get the full taxonomy tree
|
||||
- mempalace_search -- Search memories by query
|
||||
- mempalace_check_duplicate -- Check if a memory already exists
|
||||
- mempalace_get_aaak_spec -- Get the AAAK specification
|
||||
|
||||
### Palace (write)
|
||||
- mempalace_add_drawer -- Add a new memory (drawer)
|
||||
- mempalace_delete_drawer -- Delete a memory (drawer)
|
||||
|
||||
### Knowledge Graph
|
||||
- mempalace_kg_query -- Query the knowledge graph
|
||||
- mempalace_kg_add -- Add a knowledge graph entry
|
||||
- mempalace_kg_invalidate -- Invalidate a knowledge graph entry
|
||||
- mempalace_kg_timeline -- View knowledge graph timeline
|
||||
- mempalace_kg_stats -- Knowledge graph statistics
|
||||
|
||||
### Navigation
|
||||
- mempalace_traverse -- Traverse the palace structure
|
||||
- mempalace_find_tunnels -- Find cross-wing connections
|
||||
- mempalace_graph_stats -- Graph connectivity statistics
|
||||
|
||||
### Agent Diary
|
||||
- mempalace_diary_write -- Write a diary entry
|
||||
- mempalace_diary_read -- Read diary entries
|
||||
|
||||
---
|
||||
|
||||
## CLI Commands
|
||||
|
||||
mempalace init <dir> Initialize a new palace
|
||||
mempalace mine <dir> Mine a project (default mode)
|
||||
mempalace mine <dir> --mode convos Mine conversation exports
|
||||
mempalace search "query" Search your memories
|
||||
mempalace split <dir> Split large transcript files
|
||||
mempalace wake-up Load palace into context
|
||||
mempalace compress Compress palace storage
|
||||
mempalace status Show palace status
|
||||
mempalace repair Rebuild vector index
|
||||
mempalace hook run Run hook logic (for harness integration)
|
||||
mempalace instructions <name> Output skill instructions
|
||||
|
||||
---
|
||||
|
||||
## Auto-Save Hooks
|
||||
|
||||
- Stop hook -- Automatically saves memories every 15 messages. Counts human
|
||||
messages in the session transcript (skipping command-messages). When the
|
||||
threshold is reached, blocks the AI with a save instruction. Uses
|
||||
~/.mempalace/hook_state/ to track save points per session. If
|
||||
stop_hook_active is true, passes through to prevent infinite loops.
|
||||
|
||||
- PreCompact hook -- Emergency save before context compaction. Always blocks
|
||||
with a comprehensive save instruction because compaction means the AI is
|
||||
about to lose detailed context.
|
||||
|
||||
Hooks read JSON from stdin and output JSON to stdout. They can be invoked via:
|
||||
|
||||
echo '{"session_id":"abc","stop_hook_active":false,"transcript_path":"..."}' | mempalace hook run --hook stop --harness claude-code
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
Wings (projects/people)
|
||||
+-- Rooms (topics)
|
||||
+-- Closets (summaries)
|
||||
+-- Drawers (verbatim memories)
|
||||
|
||||
Halls connect rooms within a wing.
|
||||
Tunnels connect rooms across wings.
|
||||
|
||||
The palace is stored locally using ChromaDB for vector search and SQLite for
|
||||
metadata. No cloud services or API keys required.
|
||||
|
||||
---
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. /mempalace:init -- Set up your palace
|
||||
2. /mempalace:mine -- Mine a project or conversation
|
||||
3. /mempalace:search -- Find what you stored
|
||||
@@ -0,0 +1,69 @@
|
||||
# MemPalace Init
|
||||
|
||||
Guide the user through a complete MemPalace setup. Follow each step in order,
|
||||
stopping to report errors and attempt remediation before proceeding.
|
||||
|
||||
## Step 1: Check Python version
|
||||
|
||||
Run `python3 --version` (or `python --version` on Windows) and confirm the
|
||||
version is 3.9 or higher. If Python is not found or the version is too old,
|
||||
tell the user they need Python 3.9+ installed and stop.
|
||||
|
||||
## Step 2: Check if mempalace is already installed
|
||||
|
||||
Run `pip show mempalace` to see if the package is already present. If it is,
|
||||
report the installed version and skip to Step 4.
|
||||
|
||||
## Step 3: Install mempalace
|
||||
|
||||
Run `pip install mempalace`.
|
||||
|
||||
### Error handling -- pip failures
|
||||
|
||||
If `pip install mempalace` fails, try these fallbacks in order:
|
||||
|
||||
1. Try `pip3 install mempalace`
|
||||
2. Try `python -m pip install mempalace` (or `python3 -m pip install mempalace`)
|
||||
3. If the error mentions missing build tools or compilation failures (commonly
|
||||
from chromadb or its native dependencies):
|
||||
- On Linux/macOS: suggest `sudo apt-get install build-essential python3-dev`
|
||||
(Debian/Ubuntu) or `xcode-select --install` (macOS)
|
||||
- On Windows: suggest installing Microsoft C++ Build Tools from
|
||||
https://visualstudio.microsoft.com/visual-cpp-build-tools/
|
||||
- Then retry the install command
|
||||
4. If all attempts fail, report the error clearly and stop.
|
||||
|
||||
## Step 4: Ask for project directory
|
||||
|
||||
Ask the user which project directory they want to initialize with MemPalace.
|
||||
Offer the current working directory as the default. Wait for their response
|
||||
before continuing.
|
||||
|
||||
## Step 5: Initialize the palace
|
||||
|
||||
Run `mempalace init <dir>` where `<dir>` is the directory from Step 4.
|
||||
|
||||
If this fails, report the error and stop.
|
||||
|
||||
## Step 6: Configure MCP server
|
||||
|
||||
Run the following command to register the MemPalace MCP server with Claude:
|
||||
|
||||
claude mcp add mempalace -- python -m mempalace.mcp_server
|
||||
|
||||
If this fails, report the error but continue to the next step (MCP
|
||||
configuration can be done manually later).
|
||||
|
||||
## Step 7: Verify installation
|
||||
|
||||
Run `mempalace status` and confirm the output shows a healthy palace.
|
||||
|
||||
If the command fails or reports errors, walk the user through troubleshooting
|
||||
based on the output.
|
||||
|
||||
## Step 8: Show next steps
|
||||
|
||||
Tell the user setup is complete and suggest these next actions:
|
||||
|
||||
- Use /mempalace:mine to start adding data to their palace
|
||||
- Use /mempalace:search to query their palace and retrieve stored knowledge
|
||||
@@ -0,0 +1,64 @@
|
||||
# MemPalace Mine
|
||||
|
||||
When the user invokes this skill, follow these steps:
|
||||
|
||||
## 1. Ask what to mine
|
||||
|
||||
Ask the user what they want to mine and where the source data is located.
|
||||
Clarify:
|
||||
- Is it a project directory (code, docs, notes)?
|
||||
- Is it conversation exports (Claude, ChatGPT, Slack)?
|
||||
- Do they want auto-classification (decisions, milestones, problems)?
|
||||
|
||||
## 2. Choose the mining mode
|
||||
|
||||
There are three mining modes:
|
||||
|
||||
### Project mining
|
||||
|
||||
mempalace mine <dir>
|
||||
|
||||
Mines code files, documentation, and notes from a project directory.
|
||||
|
||||
### Conversation mining
|
||||
|
||||
mempalace mine <dir> --mode convos
|
||||
|
||||
Mines conversation exports from Claude, ChatGPT, or Slack into the palace.
|
||||
|
||||
### General extraction (auto-classify)
|
||||
|
||||
mempalace mine <dir> --mode convos --extract general
|
||||
|
||||
Auto-classifies mined content into decisions, milestones, and problems.
|
||||
|
||||
## 3. Optionally split mega-files first
|
||||
|
||||
If the source directory contains very large files, suggest splitting them
|
||||
before mining:
|
||||
|
||||
mempalace split <dir> [--dry-run]
|
||||
|
||||
Use --dry-run first to preview what will be split without making changes.
|
||||
|
||||
## 4. Optionally tag with a wing
|
||||
|
||||
If the user wants to organize mined content under a specific wing, add the
|
||||
--wing flag:
|
||||
|
||||
mempalace mine <dir> --wing <name>
|
||||
|
||||
## 5. Show progress and results
|
||||
|
||||
Run the selected mining command and display progress as it executes. After
|
||||
completion, summarize the results including:
|
||||
- Number of items mined
|
||||
- Categories or classifications applied
|
||||
- Any warnings or skipped files
|
||||
|
||||
## 6. Suggest next steps
|
||||
|
||||
After mining completes, suggest the user try:
|
||||
- /mempalace:search -- search the newly mined content
|
||||
- /mempalace:status -- check the current state of their palace
|
||||
- Mine more data from additional sources
|
||||
@@ -0,0 +1,57 @@
|
||||
# MemPalace Search
|
||||
|
||||
When the user wants to search their MemPalace memories, follow these steps:
|
||||
|
||||
## 1. Parse the Search Query
|
||||
|
||||
Extract the core search intent from the user's message. Identify any explicit
|
||||
or implicit filters:
|
||||
- Wing -- a top-level category (e.g., "work", "personal", "research")
|
||||
- Room -- a sub-category within a wing
|
||||
- Keywords / semantic query -- the actual search terms
|
||||
|
||||
## 2. Determine Wing/Room Filters
|
||||
|
||||
If the user mentions a specific domain, topic area, or context, map it to the
|
||||
appropriate wing and/or room. If unsure, omit filters to search globally. You
|
||||
can discover the taxonomy first if needed.
|
||||
|
||||
## 3. Use MCP Tools (Preferred)
|
||||
|
||||
If MCP tools are available, use them in this priority order:
|
||||
|
||||
- mempalace_search(query, wing, room) -- Primary search tool. Pass the semantic
|
||||
query and any wing/room filters.
|
||||
- mempalace_list_wings -- Discover all available wings. Use when the user asks
|
||||
what categories exist or you need to resolve a wing name.
|
||||
- mempalace_list_rooms(wing) -- List rooms within a specific wing. Use to help
|
||||
the user navigate or to resolve a room name.
|
||||
- mempalace_get_taxonomy -- Retrieve the full wing/room/drawer tree. Use when
|
||||
the user wants an overview of their entire memory structure.
|
||||
- mempalace_traverse(room) -- Walk the knowledge graph starting from a room.
|
||||
Use when the user wants to explore connections and related memories.
|
||||
- mempalace_find_tunnels(wing1, wing2) -- Find cross-wing connections (tunnels)
|
||||
between two wings. Use when the user asks about relationships between
|
||||
different knowledge domains.
|
||||
|
||||
## 4. CLI Fallback
|
||||
|
||||
If MCP tools are not available, fall back to the CLI:
|
||||
|
||||
mempalace search "query" [--wing X] [--room Y]
|
||||
|
||||
## 5. Present Results
|
||||
|
||||
When presenting search results:
|
||||
- Always include source attribution: wing, room, and drawer for each result
|
||||
- Show relevance or similarity scores if available
|
||||
- Group results by wing/room when returning multiple hits
|
||||
- Quote or summarize the memory content clearly
|
||||
|
||||
## 6. Offer Next Steps
|
||||
|
||||
After presenting results, offer the user options to go deeper:
|
||||
- Drill deeper -- search within a specific room or narrow the query
|
||||
- Traverse -- explore the knowledge graph from a related room
|
||||
- Check tunnels -- look for cross-wing connections if the topic spans domains
|
||||
- Browse taxonomy -- show the full structure for manual exploration
|
||||
@@ -0,0 +1,49 @@
|
||||
# MemPalace Status
|
||||
|
||||
Display the current state of the user's memory palace.
|
||||
|
||||
## Step 1: Gather Palace Status
|
||||
|
||||
Check if MCP tools are available (look for mempalace_status in available tools).
|
||||
|
||||
- If MCP is available: Call the mempalace_status tool to retrieve palace state.
|
||||
- If MCP is not available: Run the CLI command: mempalace status
|
||||
|
||||
## Step 2: Display Wing/Room/Drawer Counts
|
||||
|
||||
Present the palace structure counts clearly:
|
||||
- Number of wings
|
||||
- Number of rooms
|
||||
- Number of drawers
|
||||
- Total memories stored
|
||||
|
||||
Keep the output concise -- use a brief summary format, not verbose tables.
|
||||
|
||||
## Step 3: Knowledge Graph Stats (MCP only)
|
||||
|
||||
If MCP tools are available, also call:
|
||||
- mempalace_kg_stats -- for a knowledge graph overview (triple count, entity
|
||||
count, relationship types)
|
||||
- mempalace_graph_stats -- for connectivity information (connected components,
|
||||
average connections per entity)
|
||||
|
||||
Present these alongside the palace counts in a unified summary.
|
||||
|
||||
## Step 4: Suggest Next Actions
|
||||
|
||||
Based on the current state, suggest one relevant action:
|
||||
|
||||
- Empty palace (zero memories): Suggest "Try /mempalace:mine to add data from
|
||||
files, URLs, or text."
|
||||
- Has data but no knowledge graph (memories exist but KG stats show zero
|
||||
triples): Suggest "Consider adding knowledge graph triples for richer
|
||||
queries."
|
||||
- Healthy palace (has memories and KG data): Suggest "Use /mempalace:search to
|
||||
query your memories."
|
||||
|
||||
## Output Style
|
||||
|
||||
- Be concise and informative -- aim for a quick glance, not a report.
|
||||
- Use short labels and numbers, not prose paragraphs.
|
||||
- If any step fails or a tool is unavailable, note it briefly and continue
|
||||
with what is available.
|
||||
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Instruction text output for MemPalace CLI commands.
|
||||
|
||||
Each instruction lives as a .md file in the instructions/ directory
|
||||
inside the package. The CLI reads and prints the file content.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
INSTRUCTIONS_DIR = Path(__file__).parent / "instructions"
|
||||
|
||||
AVAILABLE = ["init", "search", "mine", "help", "status"]
|
||||
|
||||
|
||||
def run_instructions(name: str):
|
||||
"""Read and print the instruction .md file for the given name."""
|
||||
if name not in AVAILABLE:
|
||||
print(f"Unknown instructions: {name}", file=sys.stderr)
|
||||
print(f"Available: {', '.join(sorted(AVAILABLE))}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
md_path = INSTRUCTIONS_DIR / f"{name}.md"
|
||||
if not md_path.is_file():
|
||||
print(f"Instructions file not found: {md_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(md_path.read_text())
|
||||
+47
-17
@@ -2,7 +2,7 @@
|
||||
"""
|
||||
MemPalace MCP Server — read/write palace access for Claude Code
|
||||
================================================================
|
||||
Install: claude mcp add mempalace -- python -m mempalace.mcp_server
|
||||
Install: claude mcp add mempalace -- python -m mempalace.mcp_server [--palace /path/to/palace]
|
||||
|
||||
Tools (read):
|
||||
mempalace_status — total drawers, wing/room breakdown
|
||||
@@ -17,6 +17,8 @@ Tools (write):
|
||||
mempalace_delete_drawer — remove a drawer by ID
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
@@ -32,21 +34,50 @@ import chromadb
|
||||
|
||||
from .knowledge_graph import KnowledgeGraph
|
||||
|
||||
_kg = KnowledgeGraph()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s", stream=sys.stderr)
|
||||
logger = logging.getLogger("mempalace_mcp")
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(description="MemPalace MCP Server")
|
||||
parser.add_argument(
|
||||
"--palace",
|
||||
metavar="PATH",
|
||||
help="Path to the palace directory (overrides config file and env var)",
|
||||
)
|
||||
args, unknown = parser.parse_known_args()
|
||||
if unknown:
|
||||
logger.debug("Ignoring unknown args: %s", unknown)
|
||||
return args
|
||||
|
||||
|
||||
_args = _parse_args()
|
||||
|
||||
if _args.palace:
|
||||
os.environ["MEMPALACE_PALACE_PATH"] = os.path.abspath(_args.palace)
|
||||
|
||||
_config = MempalaceConfig()
|
||||
if _args.palace:
|
||||
_kg = KnowledgeGraph(db_path=os.path.join(_config.palace_path, "knowledge_graph.sqlite3"))
|
||||
else:
|
||||
_kg = KnowledgeGraph()
|
||||
|
||||
|
||||
_client_cache = None
|
||||
_collection_cache = None
|
||||
|
||||
|
||||
def _get_collection(create=False):
|
||||
"""Return the ChromaDB collection, or None on failure."""
|
||||
"""Return the ChromaDB collection, caching the client between calls."""
|
||||
global _client_cache, _collection_cache
|
||||
try:
|
||||
client = chromadb.PersistentClient(path=_config.palace_path)
|
||||
if _client_cache is None:
|
||||
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
|
||||
if create:
|
||||
return client.get_or_create_collection(_config.collection_name)
|
||||
return client.get_collection(_config.collection_name)
|
||||
_collection_cache = _client_cache.get_or_create_collection(_config.collection_name)
|
||||
elif _collection_cache is None:
|
||||
_collection_cache = _client_cache.get_collection(_config.collection_name)
|
||||
return _collection_cache
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -270,19 +301,18 @@ def tool_add_drawer(
|
||||
if not col:
|
||||
return _no_palace()
|
||||
|
||||
# Duplicate check
|
||||
dup = tool_check_duplicate(content, threshold=0.9)
|
||||
if dup.get("is_duplicate"):
|
||||
return {
|
||||
"success": False,
|
||||
"reason": "duplicate",
|
||||
"matches": dup["matches"],
|
||||
}
|
||||
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}"
|
||||
|
||||
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((content[:100] + datetime.now().isoformat()).encode()).hexdigest()[:16]}"
|
||||
# Idempotency: if the deterministic ID already exists, return success as a no-op.
|
||||
try:
|
||||
existing = col.get(ids=[drawer_id])
|
||||
if existing and existing["ids"]:
|
||||
return {"success": True, "reason": "already_exists", "drawer_id": drawer_id}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
col.add(
|
||||
col.upsert(
|
||||
ids=[drawer_id],
|
||||
documents=[content],
|
||||
metadatas=[
|
||||
|
||||
+38
-25
@@ -403,10 +403,22 @@ def get_collection(palace_path: str):
|
||||
|
||||
|
||||
def file_already_mined(collection, source_file: str) -> bool:
|
||||
"""Fast check: has this file been filed before?"""
|
||||
"""Fast check: has this file been filed before and is unchanged?
|
||||
|
||||
Compares the stored mtime in drawer metadata against the file's current
|
||||
mtime. Returns False (needs re-mining) when the file has been modified
|
||||
since it was last mined, or when no mtime was stored.
|
||||
"""
|
||||
try:
|
||||
results = collection.get(where={"source_file": source_file}, limit=1)
|
||||
return len(results.get("ids", [])) > 0
|
||||
if not results.get("ids"):
|
||||
return False
|
||||
stored_meta = results["metadatas"][0] if results.get("metadatas") else {}
|
||||
stored_mtime = stored_meta.get("source_mtime")
|
||||
if stored_mtime is None:
|
||||
return False
|
||||
current_mtime = os.path.getmtime(source_file)
|
||||
return float(stored_mtime) == current_mtime
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -417,24 +429,26 @@ def add_drawer(
|
||||
"""Add one drawer to the palace."""
|
||||
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}"
|
||||
try:
|
||||
collection.add(
|
||||
metadata = {
|
||||
"wing": wing,
|
||||
"room": room,
|
||||
"source_file": source_file,
|
||||
"chunk_index": chunk_index,
|
||||
"added_by": agent,
|
||||
"filed_at": datetime.now().isoformat(),
|
||||
}
|
||||
# Store file mtime so we can detect modifications later.
|
||||
try:
|
||||
metadata["source_mtime"] = os.path.getmtime(source_file)
|
||||
except OSError:
|
||||
pass
|
||||
collection.upsert(
|
||||
documents=[content],
|
||||
ids=[drawer_id],
|
||||
metadatas=[
|
||||
{
|
||||
"wing": wing,
|
||||
"room": room,
|
||||
"source_file": source_file,
|
||||
"chunk_index": chunk_index,
|
||||
"added_by": agent,
|
||||
"filed_at": datetime.now().isoformat(),
|
||||
}
|
||||
],
|
||||
metadatas=[metadata],
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower() or "duplicate" in str(e).lower():
|
||||
return False
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
|
||||
@@ -451,29 +465,29 @@ def process_file(
|
||||
rooms: list,
|
||||
agent: str,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Read, chunk, route, and file one file. Returns drawer count."""
|
||||
) -> tuple:
|
||||
"""Read, chunk, route, and file one file. Returns (drawer_count, room_name)."""
|
||||
|
||||
# Skip if already filed
|
||||
source_file = str(filepath)
|
||||
if not dry_run and file_already_mined(collection, source_file):
|
||||
return 0
|
||||
return 0, None
|
||||
|
||||
try:
|
||||
content = filepath.read_text(encoding="utf-8", errors="replace")
|
||||
except OSError:
|
||||
return 0
|
||||
return 0, None
|
||||
|
||||
content = content.strip()
|
||||
if len(content) < MIN_CHUNK_SIZE:
|
||||
return 0
|
||||
return 0, None
|
||||
|
||||
room = detect_room(filepath, content, rooms, project_path)
|
||||
chunks = chunk_text(content, source_file)
|
||||
|
||||
if dry_run:
|
||||
print(f" [DRY RUN] {filepath.name} → room:{room} ({len(chunks)} drawers)")
|
||||
return len(chunks)
|
||||
return len(chunks), room
|
||||
|
||||
drawers_added = 0
|
||||
for chunk in chunks:
|
||||
@@ -489,7 +503,7 @@ def process_file(
|
||||
if added:
|
||||
drawers_added += 1
|
||||
|
||||
return drawers_added
|
||||
return drawers_added, room
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -608,7 +622,7 @@ def mine(
|
||||
room_counts = defaultdict(int)
|
||||
|
||||
for i, filepath in enumerate(files, 1):
|
||||
drawers = process_file(
|
||||
drawers, room = process_file(
|
||||
filepath=filepath,
|
||||
project_path=project_path,
|
||||
collection=collection,
|
||||
@@ -621,7 +635,6 @@ def mine(
|
||||
files_skipped += 1
|
||||
else:
|
||||
total_drawers += drawers
|
||||
room = detect_room(filepath, "", rooms, project_path)
|
||||
room_counts[room] += 1
|
||||
if not dry_run:
|
||||
print(f" ✓ [{i:4}/{len(files)}] {filepath.name[:50]:50} +{drawers}")
|
||||
|
||||
@@ -312,7 +312,7 @@ def _generate_aaak_bootstrap(
|
||||
]
|
||||
)
|
||||
|
||||
(mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines))
|
||||
(mempalace_dir / "aaak_entities.md").write_text("\n".join(registry_lines), encoding="utf-8")
|
||||
|
||||
# Critical facts bootstrap (pre-palace — before any mining)
|
||||
facts_lines = [
|
||||
@@ -359,7 +359,7 @@ def _generate_aaak_bootstrap(
|
||||
]
|
||||
)
|
||||
|
||||
(mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines))
|
||||
(mempalace_dir / "critical_facts.md").write_text("\n".join(facts_lines), encoding="utf-8")
|
||||
|
||||
|
||||
def run_onboarding(
|
||||
|
||||
@@ -219,7 +219,7 @@ def split_file(filepath, output_dir, dry_run=False):
|
||||
if dry_run:
|
||||
print(f" [{i + 1}/{len(boundaries) - 1}] {name} ({len(chunk)} lines)")
|
||||
else:
|
||||
out_path.write_text("".join(chunk))
|
||||
out_path.write_text("".join(chunk), encoding="utf-8")
|
||||
print(f" ✓ {name} ({len(chunk)} lines)")
|
||||
|
||||
written.append(out_path)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""Single source of truth for the MemPalace package version."""
|
||||
|
||||
__version__ = "3.0.0"
|
||||
__version__ = "3.0.14"
|
||||
|
||||
+21
-3
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "mempalace"
|
||||
version = "3.0.0"
|
||||
version = "3.0.14"
|
||||
description = "Give your AI a memory — mine projects and conversations into a searchable palace. No API key required."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
@@ -38,11 +38,11 @@ Repository = "https://github.com/milla-jovovich/mempalace"
|
||||
mempalace = "mempalace:main"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest>=7.0", "ruff>=0.4.0"]
|
||||
dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"]
|
||||
spellcheck = ["autocorrect>=2.0"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["pytest>=7.0", "ruff>=0.4.0"]
|
||||
dev = ["pytest>=7.0", "pytest-cov>=4.0", "ruff>=0.4.0", "psutil>=5.9"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
@@ -64,3 +64,21 @@ quote-style = "double"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
addopts = "-m 'not benchmark and not slow and not stress'"
|
||||
markers = [
|
||||
"benchmark: scale/performance benchmark tests",
|
||||
"slow: tests that take more than 30 seconds",
|
||||
"stress: destructive scale tests (100K+ drawers)",
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["mempalace"]
|
||||
|
||||
[tool.coverage.report]
|
||||
fail_under = 85
|
||||
show_missing = true
|
||||
exclude_lines = [
|
||||
"if __name__",
|
||||
"pragma: no cover",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
# MemPalace Scale Benchmark Suite
|
||||
|
||||
106 tests that benchmark mempalace at scale to validate real-world performance limits.
|
||||
|
||||
## Why
|
||||
|
||||
MemPalace has strong academic scores (96.6% R@5 on LongMemEval) but no empirical data on how it behaves at scale. Key unknowns:
|
||||
|
||||
- `tool_status()` loads ALL metadata into memory — at what palace size does this OOM?
|
||||
- `PersistentClient` is re-instantiated on every MCP call — what's the overhead?
|
||||
- Modified files are never re-ingested — what's the skip-check cost at scale?
|
||||
- How does query latency degrade as the palace grows from 1K to 100K drawers?
|
||||
- Does wing/room filtering actually improve retrieval, and by how much?
|
||||
- At what per-room drawer count does recall break regardless of filtering?
|
||||
|
||||
This suite finds those answers.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Fast smoke test (~2 min)
|
||||
uv run pytest tests/benchmarks/ -v --bench-scale=small -m "benchmark and not slow"
|
||||
|
||||
# Full small scale (~35 min)
|
||||
uv run pytest tests/benchmarks/ -v --bench-scale=small
|
||||
|
||||
# Medium scale with JSON report
|
||||
uv run pytest tests/benchmarks/ -v --bench-scale=medium --bench-report=results.json
|
||||
|
||||
# Stress test (local only, very slow)
|
||||
uv run pytest tests/benchmarks/ -v --bench-scale=stress -m stress
|
||||
```
|
||||
|
||||
## Scale Levels
|
||||
|
||||
| Level | Drawers | Wings | Rooms/Wing | KG Triples | Use case |
|
||||
|---------|---------|-------|------------|------------|---------------------|
|
||||
| small | 1,000 | 3 | 5 | 200 | CI, quick checks |
|
||||
| medium | 10,000 | 8 | 12 | 2,000 | Pre-release testing |
|
||||
| large | 50,000 | 15 | 20 | 10,000 | Scale limit finding |
|
||||
| stress | 100,000 | 25 | 30 | 50,000 | Breaking point |
|
||||
|
||||
## Test Modules
|
||||
|
||||
### Critical Path
|
||||
|
||||
| File | What it tests |
|
||||
|------|--------------|
|
||||
| `test_mcp_bench.py` | MCP tool response times, unbounded metadata fetch, client re-instantiation overhead |
|
||||
| `test_chromadb_stress.py` | ChromaDB breaking point, query degradation curve, batch vs sequential insert |
|
||||
| `test_memory_profile.py` | RSS/heap growth over repeated operations, leak detection |
|
||||
|
||||
### Performance Baselines
|
||||
|
||||
| File | What it tests |
|
||||
|------|--------------|
|
||||
| `test_ingest_bench.py` | Mining throughput (files/sec, drawers/sec), peak RSS, chunking speed, re-ingest skip overhead |
|
||||
| `test_search_bench.py` | Query latency vs palace size, recall@k with planted needles, concurrent queries, n_results scaling |
|
||||
|
||||
### Architectural Validation
|
||||
|
||||
| File | What it tests |
|
||||
|------|--------------|
|
||||
| `test_palace_boost.py` | Retrieval improvement from wing/room filtering at different scales |
|
||||
| `test_recall_threshold.py` | Per-room recall ceiling — isolates embedding model limit with all drawers in one bucket |
|
||||
| `test_knowledge_graph_bench.py` | Triple insertion rate, temporal query accuracy, SQLite concurrent access |
|
||||
| `test_layers_bench.py` | MemoryStack wake-up cost, Layer1 unbounded fetch, token budget compliance |
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
tests/benchmarks/
|
||||
conftest.py # --bench-scale / --bench-report CLI options, fixtures, markers
|
||||
data_generator.py # Deterministic data factory (seeded RNG, planted needles)
|
||||
report.py # JSON report writer + regression checker
|
||||
test_*.py # 9 test modules (106 tests total)
|
||||
```
|
||||
|
||||
### Data Generator
|
||||
|
||||
`PalaceDataGenerator(seed=42, scale="small")` produces deterministic, realistic test data:
|
||||
|
||||
- **`generate_project_tree()`** — writes real files + `mempalace.yaml` for `mine()` to ingest
|
||||
- **`populate_palace_directly()`** — bypasses mining, inserts directly into ChromaDB (10-100x faster for search/MCP benchmarks)
|
||||
- **`generate_kg_triples()`** — entity-relationship triples with temporal validity
|
||||
- **`generate_search_queries()`** — queries with known-good answers for recall measurement
|
||||
|
||||
**Planted needles**: Unique identifiable content (e.g., `NEEDLE_0042: PostgreSQL vacuum autovacuum threshold...`) seeded into specific wings/rooms. Search queries target these needles, enabling recall@k measurement without an LLM judge.
|
||||
|
||||
### JSON Reports
|
||||
|
||||
When run with `--bench-report=path.json`, produces machine-readable output:
|
||||
|
||||
```json
|
||||
{
|
||||
"timestamp": "2026-04-07T...",
|
||||
"git_sha": "abc123",
|
||||
"scale": "small",
|
||||
"system": {"os": "linux", "cpu_count": 8},
|
||||
"results": {
|
||||
"mcp_status": {"latency_ms_at_1000": 45.2, "rss_delta_mb_at_5000": 12.3},
|
||||
"search": {"avg_latency_ms_at_5000": 23.1, "recall_at_5": 0.92},
|
||||
"chromadb_insert": {"sequential_ms": 8500, "batched_ms": 1200, "speedup_ratio": 7.1}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Regression Detection
|
||||
|
||||
```python
|
||||
from tests.benchmarks.report import check_regression
|
||||
|
||||
regressions = check_regression("current.json", "baseline.json", threshold=0.2)
|
||||
# Returns list of metric descriptions that degraded beyond 20%
|
||||
```
|
||||
|
||||
## CI Integration
|
||||
|
||||
The GitHub Actions workflow runs benchmarks on PRs at small scale:
|
||||
|
||||
```yaml
|
||||
benchmark:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event_name == 'pull_request'
|
||||
# Runs: pytest tests/benchmarks/ -m "benchmark and not stress and not slow" --bench-scale=small
|
||||
```
|
||||
|
||||
Existing unit tests are isolated with `--ignore=tests/benchmarks`.
|
||||
|
||||
## Markers
|
||||
|
||||
- `@pytest.mark.benchmark` — all benchmark tests
|
||||
- `@pytest.mark.slow` — tests taking >30s even at small scale
|
||||
- `@pytest.mark.stress` — tests that should only run at large/stress scale
|
||||
|
||||
## Dependencies
|
||||
|
||||
Only one new dependency beyond the existing dev stack: `psutil` (for cross-platform RSS measurement). `tracemalloc` and `resource` are stdlib.
|
||||
@@ -0,0 +1 @@
|
||||
# MemPalace scale benchmark suite
|
||||
@@ -0,0 +1,144 @@
|
||||
"""Benchmark-specific pytest configuration, fixtures, and CLI options."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
SCALE_OPTIONS = ["small", "medium", "large", "stress"]
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--bench-scale",
|
||||
default="small",
|
||||
choices=SCALE_OPTIONS,
|
||||
help="Scale level for benchmark tests: small (1K), medium (10K), large (50K), stress (100K)",
|
||||
)
|
||||
parser.addoption(
|
||||
"--bench-report",
|
||||
default=None,
|
||||
help="Path for JSON benchmark report output",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def bench_scale(request):
|
||||
"""The configured benchmark scale level."""
|
||||
return request.config.getoption("--bench-scale")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def bench_report_path(request):
|
||||
"""Path for JSON report output, or None."""
|
||||
return request.config.getoption("--bench-report")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def palace_dir(tmp_path):
|
||||
"""Isolated palace directory for a single test."""
|
||||
p = tmp_path / "palace"
|
||||
p.mkdir()
|
||||
return str(p)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kg_db(tmp_path):
|
||||
"""Isolated KG SQLite path for a single test."""
|
||||
return str(tmp_path / "test_kg.sqlite3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(tmp_path):
|
||||
"""Isolated config directory for monkeypatching MempalaceConfig."""
|
||||
d = tmp_path / "config"
|
||||
d.mkdir()
|
||||
config = {"palace_path": str(tmp_path / "palace"), "collection_name": "mempalace_drawers"}
|
||||
with open(d / "config.json", "w") as f:
|
||||
json.dump(config, f)
|
||||
return str(d)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_dir(tmp_path):
|
||||
"""Temporary project directory for mining tests."""
|
||||
d = tmp_path / "project"
|
||||
d.mkdir()
|
||||
return d
|
||||
|
||||
|
||||
# ── Session-scoped result collector ──────────────────────────────────────
|
||||
|
||||
|
||||
class BenchmarkResults:
|
||||
"""Collect benchmark metrics across all tests in a session."""
|
||||
|
||||
def __init__(self):
|
||||
self.results = {}
|
||||
|
||||
def record(self, category: str, metric: str, value):
|
||||
if category not in self.results:
|
||||
self.results[category] = {}
|
||||
self.results[category][metric] = value
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def bench_results():
|
||||
"""Session-scoped results collector shared by all benchmark tests."""
|
||||
return BenchmarkResults()
|
||||
|
||||
|
||||
def pytest_terminal_summary(terminalreporter, config):
|
||||
"""Write JSON benchmark report after all tests complete."""
|
||||
report_path = config.getoption("--bench-report", default=None)
|
||||
if not report_path:
|
||||
return
|
||||
|
||||
# Collect results written by individual tests via record_metric()
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
git_sha = subprocess.check_output(
|
||||
["git", "rev-parse", "--short", "HEAD"], text=True, stderr=subprocess.DEVNULL
|
||||
).strip()
|
||||
except Exception:
|
||||
git_sha = "unknown"
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
|
||||
chromadb_version = chromadb.__version__
|
||||
except Exception:
|
||||
chromadb_version = "unknown"
|
||||
|
||||
report = {
|
||||
"timestamp": __import__("datetime").datetime.now().isoformat(),
|
||||
"git_sha": git_sha,
|
||||
"python_version": platform.python_version(),
|
||||
"chromadb_version": chromadb_version,
|
||||
"scale": config.getoption("--bench-scale", default="small"),
|
||||
"system": {
|
||||
"os": platform.system().lower(),
|
||||
"cpu_count": os.cpu_count(),
|
||||
"platform": platform.platform(),
|
||||
},
|
||||
"results": {},
|
||||
}
|
||||
|
||||
# Read results from the temp file written by record_metric()
|
||||
results_file = os.path.join(tempfile.gettempdir(), "mempalace_bench_results.json")
|
||||
if os.path.exists(results_file):
|
||||
try:
|
||||
with open(results_file) as f:
|
||||
report["results"] = json.load(f)
|
||||
os.unlink(results_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
os.makedirs(os.path.dirname(os.path.abspath(report_path)), exist_ok=True)
|
||||
with open(report_path, "w") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
terminalreporter.write_line(f"\nBenchmark report written to: {report_path}")
|
||||
@@ -0,0 +1,568 @@
|
||||
"""
|
||||
Deterministic data factory for MemPalace scale benchmarks.
|
||||
|
||||
Generates realistic project files, conversations, and KG triples at
|
||||
configurable scale levels. All randomness uses seeded RNG for reproducibility.
|
||||
|
||||
Planted "needle" drawers enable recall measurement without an LLM judge.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import chromadb
|
||||
import yaml
|
||||
|
||||
|
||||
# ── Scale configurations ─────────────────────────────────────────────────
|
||||
|
||||
SCALE_CONFIGS = {
|
||||
"small": {
|
||||
"drawers": 1_000,
|
||||
"wings": 3,
|
||||
"rooms_per_wing": 5,
|
||||
"kg_entities": 50,
|
||||
"kg_triples": 200,
|
||||
"needles": 20,
|
||||
"search_queries": 20,
|
||||
},
|
||||
"medium": {
|
||||
"drawers": 10_000,
|
||||
"wings": 8,
|
||||
"rooms_per_wing": 12,
|
||||
"kg_entities": 200,
|
||||
"kg_triples": 2_000,
|
||||
"needles": 50,
|
||||
"search_queries": 50,
|
||||
},
|
||||
"large": {
|
||||
"drawers": 50_000,
|
||||
"wings": 15,
|
||||
"rooms_per_wing": 20,
|
||||
"kg_entities": 500,
|
||||
"kg_triples": 10_000,
|
||||
"needles": 100,
|
||||
"search_queries": 100,
|
||||
},
|
||||
"stress": {
|
||||
"drawers": 100_000,
|
||||
"wings": 25,
|
||||
"rooms_per_wing": 30,
|
||||
"kg_entities": 1_000,
|
||||
"kg_triples": 50_000,
|
||||
"needles": 200,
|
||||
"search_queries": 200,
|
||||
},
|
||||
}
|
||||
|
||||
# ── Vocabulary banks for realistic content ───────────────────────────────
|
||||
|
||||
WING_NAMES = [
|
||||
"webapp",
|
||||
"backend_api",
|
||||
"mobile_app",
|
||||
"data_pipeline",
|
||||
"ml_platform",
|
||||
"devops",
|
||||
"auth_service",
|
||||
"payments",
|
||||
"analytics",
|
||||
"docs_site",
|
||||
"cli_tool",
|
||||
"dashboard",
|
||||
"notification_service",
|
||||
"search_engine",
|
||||
"user_mgmt",
|
||||
"inventory",
|
||||
"reporting",
|
||||
"testing_infra",
|
||||
"monitoring",
|
||||
"email_service",
|
||||
"chat_bot",
|
||||
"file_storage",
|
||||
"scheduler",
|
||||
"gateway",
|
||||
"marketplace",
|
||||
]
|
||||
|
||||
ROOM_NAMES = [
|
||||
"backend",
|
||||
"frontend",
|
||||
"api",
|
||||
"database",
|
||||
"auth",
|
||||
"tests",
|
||||
"docs",
|
||||
"config",
|
||||
"deployment",
|
||||
"models",
|
||||
"views",
|
||||
"controllers",
|
||||
"middleware",
|
||||
"utils",
|
||||
"schemas",
|
||||
"migrations",
|
||||
"fixtures",
|
||||
"scripts",
|
||||
"styles",
|
||||
"components",
|
||||
"hooks",
|
||||
"services",
|
||||
"routes",
|
||||
"templates",
|
||||
"static",
|
||||
"media",
|
||||
"logging",
|
||||
"cache",
|
||||
"queue",
|
||||
"workers",
|
||||
]
|
||||
|
||||
TECH_TERMS = [
|
||||
"authentication",
|
||||
"authorization",
|
||||
"middleware",
|
||||
"endpoint",
|
||||
"REST API",
|
||||
"GraphQL",
|
||||
"WebSocket",
|
||||
"database migration",
|
||||
"ORM",
|
||||
"query optimization",
|
||||
"caching strategy",
|
||||
"load balancer",
|
||||
"rate limiting",
|
||||
"pagination",
|
||||
"serialization",
|
||||
"validation",
|
||||
"error handling",
|
||||
"logging framework",
|
||||
"monitoring",
|
||||
"deployment pipeline",
|
||||
"CI/CD",
|
||||
"containerization",
|
||||
"microservice",
|
||||
"event sourcing",
|
||||
"message queue",
|
||||
"pub/sub",
|
||||
"connection pooling",
|
||||
"session management",
|
||||
"token refresh",
|
||||
"CORS",
|
||||
"SSL termination",
|
||||
"health check",
|
||||
"circuit breaker",
|
||||
"retry logic",
|
||||
"batch processing",
|
||||
"stream processing",
|
||||
"data pipeline",
|
||||
"ETL",
|
||||
"feature flag",
|
||||
"A/B testing",
|
||||
"blue-green deployment",
|
||||
"canary release",
|
||||
]
|
||||
|
||||
CODE_SNIPPETS = [
|
||||
"def process_request(data):\n validated = schema.validate(data)\n result = handler.execute(validated)\n return Response(result, status=200)\n",
|
||||
"class UserRepository:\n def __init__(self, db):\n self.db = db\n def find_by_id(self, user_id):\n return self.db.query(User).filter(User.id == user_id).first()\n",
|
||||
"async def fetch_data(url, timeout=30):\n async with aiohttp.ClientSession() as session:\n async with session.get(url, timeout=timeout) as resp:\n return await resp.json()\n",
|
||||
"const handleSubmit = async (formData) => {\n try {\n const response = await api.post('/users', formData);\n dispatch({ type: 'USER_CREATED', payload: response.data });\n } catch (error) {\n setError(error.message);\n }\n};\n",
|
||||
"SELECT u.name, COUNT(o.id) as order_count\nFROM users u\nLEFT JOIN orders o ON u.id = o.user_id\nWHERE u.created_at > '2025-01-01'\nGROUP BY u.name\nHAVING COUNT(o.id) > 5\nORDER BY order_count DESC;\n",
|
||||
]
|
||||
|
||||
PROSE_TEMPLATES = [
|
||||
"The {component} module handles {task}. It was refactored in {month} to improve {quality}. Key design decision: {decision}.",
|
||||
"Bug report: {component} fails when {condition}. Root cause: {cause}. Fixed by {fix}. Regression test added in {test_file}.",
|
||||
"Architecture decision: switched from {old_tech} to {new_tech} for {reason}. Migration completed {date}. Performance improved by {percent}%.",
|
||||
"Meeting notes: discussed {topic} with {person}. Agreed to {action}. Deadline: {deadline}. Follow-up: {followup}.",
|
||||
"Feature spec: {feature_name} allows users to {capability}. Dependencies: {deps}. Estimated effort: {effort} days.",
|
||||
]
|
||||
|
||||
ENTITY_NAMES = [
|
||||
"Alice",
|
||||
"Bob",
|
||||
"Carol",
|
||||
"Dave",
|
||||
"Eve",
|
||||
"Frank",
|
||||
"Grace",
|
||||
"Heidi",
|
||||
"Ivan",
|
||||
"Judy",
|
||||
"Karl",
|
||||
"Linda",
|
||||
"Mike",
|
||||
"Nina",
|
||||
"Oscar",
|
||||
"Pat",
|
||||
"Quinn",
|
||||
"Rita",
|
||||
"Steve",
|
||||
"Tina",
|
||||
"Ursula",
|
||||
"Victor",
|
||||
"Wendy",
|
||||
"Xander",
|
||||
]
|
||||
|
||||
ENTITY_TYPES = ["person", "project", "tool", "concept", "team", "service"]
|
||||
|
||||
PREDICATES = [
|
||||
"works_on",
|
||||
"manages",
|
||||
"reports_to",
|
||||
"collaborates_with",
|
||||
"created",
|
||||
"maintains",
|
||||
"uses",
|
||||
"depends_on",
|
||||
"replaced",
|
||||
"reviewed",
|
||||
"deployed",
|
||||
"tested",
|
||||
"documented",
|
||||
"mentors",
|
||||
"leads",
|
||||
"contributes_to",
|
||||
]
|
||||
|
||||
|
||||
class PalaceDataGenerator:
|
||||
"""Generate deterministic, realistic test data at configurable scale."""
|
||||
|
||||
def __init__(self, seed=42, scale="small"):
|
||||
self.rng = random.Random(seed)
|
||||
self.scale = scale
|
||||
self.cfg = SCALE_CONFIGS[scale]
|
||||
self.wings = WING_NAMES[: self.cfg["wings"]]
|
||||
self.rooms_by_wing = {}
|
||||
for wing in self.wings:
|
||||
n = self.cfg["rooms_per_wing"]
|
||||
rooms = self.rng.sample(ROOM_NAMES, min(n, len(ROOM_NAMES)))
|
||||
self.rooms_by_wing[wing] = rooms
|
||||
# Planted needles for recall measurement
|
||||
self.needles = []
|
||||
self._generate_needles()
|
||||
|
||||
def _generate_needles(self):
|
||||
"""Create unique needle content for recall testing."""
|
||||
topics = [
|
||||
"Fibonacci sequence optimization uses memoization with O(n) space complexity",
|
||||
"PostgreSQL vacuum autovacuum threshold set to 50 percent for table users",
|
||||
"Redis cluster failover timeout configured at 30 seconds with sentinel monitoring",
|
||||
"Kubernetes horizontal pod autoscaler targets 70 percent CPU utilization",
|
||||
"GraphQL subscription uses WebSocket transport with heartbeat interval 25 seconds",
|
||||
"JWT token rotation policy requires refresh every 15 minutes with sliding window",
|
||||
"Elasticsearch index sharding strategy uses 5 primary shards with 1 replica each",
|
||||
"Docker multi-stage build reduces image size from 1.2GB to 180MB for production",
|
||||
"Apache Kafka consumer group rebalance timeout set to 45 seconds",
|
||||
"MongoDB change streams resume token persisted every 100 operations",
|
||||
"gRPC streaming uses bidirectional flow control with 64KB window size",
|
||||
"Prometheus alerting rule fires when p99 latency exceeds 500ms for 5 minutes",
|
||||
"Terraform state locking uses DynamoDB with consistent reads enabled",
|
||||
"Nginx rate limiting configured at 100 requests per second with burst of 50",
|
||||
"SQLAlchemy connection pool size set to 20 with max overflow of 10 connections",
|
||||
"React concurrent mode uses startTransition for non-urgent state updates",
|
||||
"AWS Lambda cold start mitigation uses provisioned concurrency of 10 instances",
|
||||
"Git bisect automated with custom test script for regression hunting",
|
||||
"OpenTelemetry trace sampling rate set to 10 percent in production environment",
|
||||
"Celery worker prefetch multiplier set to 1 for fair task distribution",
|
||||
]
|
||||
for i in range(self.cfg["needles"]):
|
||||
topic = topics[i % len(topics)]
|
||||
wing = self.rng.choice(self.wings)
|
||||
room = self.rng.choice(self.rooms_by_wing[wing])
|
||||
needle_id = f"NEEDLE_{i:04d}"
|
||||
content = f"{needle_id}: {topic}. This is a unique planted needle for recall benchmarking at scale."
|
||||
self.needles.append(
|
||||
{
|
||||
"id": needle_id,
|
||||
"content": content,
|
||||
"wing": wing,
|
||||
"room": room,
|
||||
"query": topic.split(" uses ")[0]
|
||||
if " uses " in topic
|
||||
else topic.split(" set to ")[0]
|
||||
if " set to " in topic
|
||||
else topic[:60],
|
||||
}
|
||||
)
|
||||
|
||||
def _random_text(self, min_chars=600, max_chars=900):
|
||||
"""Generate a random text block of realistic content."""
|
||||
parts = []
|
||||
total = 0
|
||||
target = self.rng.randint(min_chars, max_chars)
|
||||
while total < target:
|
||||
choice = self.rng.random()
|
||||
if choice < 0.3:
|
||||
text = self.rng.choice(CODE_SNIPPETS)
|
||||
elif choice < 0.7:
|
||||
template = self.rng.choice(PROSE_TEMPLATES)
|
||||
text = template.format(
|
||||
component=self.rng.choice(ROOM_NAMES),
|
||||
task=self.rng.choice(TECH_TERMS),
|
||||
month=self.rng.choice(["January", "February", "March", "April", "May"]),
|
||||
quality=self.rng.choice(
|
||||
["performance", "readability", "test coverage", "latency"]
|
||||
),
|
||||
decision=self.rng.choice(TECH_TERMS),
|
||||
condition=self.rng.choice(TECH_TERMS) + " is null",
|
||||
cause=self.rng.choice(["race condition", "null pointer", "timeout", "OOM"]),
|
||||
fix="adding " + self.rng.choice(TECH_TERMS),
|
||||
test_file=f"test_{self.rng.choice(ROOM_NAMES)}.py",
|
||||
old_tech=self.rng.choice(["MySQL", "Flask", "REST", "Jenkins"]),
|
||||
new_tech=self.rng.choice(
|
||||
["PostgreSQL", "FastAPI", "GraphQL", "GitHub Actions"]
|
||||
),
|
||||
reason=self.rng.choice(TECH_TERMS),
|
||||
date=f"2025-{self.rng.randint(1, 12):02d}-{self.rng.randint(1, 28):02d}",
|
||||
percent=self.rng.randint(10, 80),
|
||||
topic=self.rng.choice(TECH_TERMS),
|
||||
person=self.rng.choice(ENTITY_NAMES),
|
||||
action=self.rng.choice(["refactor", "migrate", "optimize", "test"]),
|
||||
deadline=f"2025-{self.rng.randint(1, 12):02d}-{self.rng.randint(1, 28):02d}",
|
||||
followup=self.rng.choice(TECH_TERMS),
|
||||
feature_name=self.rng.choice(TECH_TERMS),
|
||||
capability=self.rng.choice(TECH_TERMS),
|
||||
deps=", ".join(self.rng.sample(TECH_TERMS, 2)),
|
||||
effort=self.rng.randint(1, 15),
|
||||
)
|
||||
else:
|
||||
words = self.rng.sample(TECH_TERMS, min(5, len(TECH_TERMS)))
|
||||
text = (
|
||||
" ".join(words)
|
||||
+ ". "
|
||||
+ self.rng.choice(TECH_TERMS)
|
||||
+ " implementation details follow.\n"
|
||||
)
|
||||
parts.append(text)
|
||||
total += len(text)
|
||||
return "\n".join(parts)[:max_chars]
|
||||
|
||||
# ── Project tree generation (for mine() tests) ───────────────────────
|
||||
|
||||
def generate_project_tree(self, base_path, wing=None, rooms=None, n_files=50):
|
||||
"""
|
||||
Write realistic project files + mempalace.yaml to base_path.
|
||||
|
||||
Returns the project path suitable for passing to mine().
|
||||
"""
|
||||
base = Path(base_path)
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
wing = wing or self.rng.choice(self.wings)
|
||||
rooms = rooms or self.rooms_by_wing.get(wing, ["general"])
|
||||
|
||||
# Write mempalace.yaml
|
||||
room_defs = [{"name": r, "description": f"{r} code and docs"} for r in rooms]
|
||||
with open(base / "mempalace.yaml", "w") as f:
|
||||
yaml.dump({"wing": wing, "rooms": room_defs}, f)
|
||||
|
||||
# Write files distributed across room directories
|
||||
files_written = 0
|
||||
for i in range(n_files):
|
||||
room = rooms[i % len(rooms)]
|
||||
room_dir = base / room
|
||||
room_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ext = self.rng.choice([".py", ".js", ".md", ".ts", ".yaml"])
|
||||
filename = f"file_{i:04d}{ext}"
|
||||
content = self._random_text(400, 2000)
|
||||
(room_dir / filename).write_text(content, encoding="utf-8")
|
||||
files_written += 1
|
||||
|
||||
return str(base), wing, rooms, files_written
|
||||
|
||||
# ── Conversation file generation (for mine_convos() tests) ───────────
|
||||
|
||||
def generate_conversation_files(self, base_path, wing=None, n_files=20):
|
||||
"""Write conversation transcript files for convo_miner tests."""
|
||||
base = Path(base_path)
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
wing = wing or self.rng.choice(self.wings)
|
||||
|
||||
for i in range(n_files):
|
||||
lines = []
|
||||
n_exchanges = self.rng.randint(5, 20)
|
||||
for j in range(n_exchanges):
|
||||
user_msg = f"> User: {self.rng.choice(TECH_TERMS)}? How does {self.rng.choice(TECH_TERMS)} work with {self.rng.choice(TECH_TERMS)}?"
|
||||
ai_msg = self._random_text(200, 600)
|
||||
lines.append(user_msg)
|
||||
lines.append(ai_msg)
|
||||
lines.append("")
|
||||
|
||||
(base / f"convo_{i:04d}.txt").write_text("\n".join(lines), encoding="utf-8")
|
||||
|
||||
return str(base), wing
|
||||
|
||||
# ── Direct palace population (bypasses mining for speed) ─────────────
|
||||
|
||||
def populate_palace_directly(self, palace_path, n_drawers=None, include_needles=True):
|
||||
"""
|
||||
Insert drawers directly into ChromaDB, bypassing the mining pipeline.
|
||||
|
||||
Much faster than mining for benchmarks that only care about
|
||||
search/MCP behavior on a pre-populated palace.
|
||||
|
||||
Returns (client, collection, needle_info).
|
||||
"""
|
||||
n_drawers = n_drawers or self.cfg["drawers"]
|
||||
os.makedirs(palace_path, exist_ok=True)
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
|
||||
batch_size = 500
|
||||
docs = []
|
||||
ids = []
|
||||
metas = []
|
||||
|
||||
# Insert needles first
|
||||
needle_info = []
|
||||
if include_needles:
|
||||
for needle in self.needles:
|
||||
needle_id = f"drawer_{needle['wing']}_{needle['room']}_{hashlib.md5(needle['id'].encode()).hexdigest()[:16]}"
|
||||
docs.append(needle["content"])
|
||||
ids.append(needle_id)
|
||||
metas.append(
|
||||
{
|
||||
"wing": needle["wing"],
|
||||
"room": needle["room"],
|
||||
"source_file": f"needle_{needle['id']}.txt",
|
||||
"chunk_index": 0,
|
||||
"added_by": "benchmark",
|
||||
"filed_at": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
needle_info.append(
|
||||
{
|
||||
"id": needle_id,
|
||||
"query": needle["query"],
|
||||
"wing": needle["wing"],
|
||||
"room": needle["room"],
|
||||
}
|
||||
)
|
||||
|
||||
# Fill remaining drawers with realistic content
|
||||
remaining = n_drawers - len(docs)
|
||||
for i in range(remaining):
|
||||
wing = self.wings[i % len(self.wings)]
|
||||
rooms = self.rooms_by_wing[wing]
|
||||
room = rooms[i % len(rooms)]
|
||||
content = self._random_text(400, 800)
|
||||
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(f'gen_{i}'.encode()).hexdigest()[:16]}"
|
||||
|
||||
docs.append(content)
|
||||
ids.append(drawer_id)
|
||||
metas.append(
|
||||
{
|
||||
"wing": wing,
|
||||
"room": room,
|
||||
"source_file": f"generated_{i:06d}.txt",
|
||||
"chunk_index": i % 10,
|
||||
"added_by": "benchmark",
|
||||
"filed_at": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
# Flush in batches
|
||||
if len(docs) >= batch_size:
|
||||
col.add(documents=docs, ids=ids, metadatas=metas)
|
||||
docs, ids, metas = [], [], []
|
||||
|
||||
# Flush remainder
|
||||
if docs:
|
||||
col.add(documents=docs, ids=ids, metadatas=metas)
|
||||
|
||||
return client, col, needle_info
|
||||
|
||||
# ── KG triple generation ─────────────────────────────────────────────
|
||||
|
||||
def generate_kg_triples(self, n_entities=None, n_triples=None):
|
||||
"""
|
||||
Generate realistic entity-relationship triples.
|
||||
|
||||
Returns (entities, triples) where:
|
||||
entities = [(name, type), ...]
|
||||
triples = [(subject, predicate, object, valid_from, valid_to), ...]
|
||||
"""
|
||||
n_entities = n_entities or self.cfg["kg_entities"]
|
||||
n_triples = n_triples or self.cfg["kg_triples"]
|
||||
|
||||
# Generate entities
|
||||
entities = []
|
||||
entity_names = []
|
||||
for i in range(n_entities):
|
||||
if i < len(ENTITY_NAMES):
|
||||
name = ENTITY_NAMES[i]
|
||||
else:
|
||||
name = f"Entity_{i:04d}"
|
||||
etype = self.rng.choice(ENTITY_TYPES)
|
||||
entities.append((name, etype))
|
||||
entity_names.append(name)
|
||||
|
||||
# Generate triples
|
||||
triples = []
|
||||
base_date = datetime(2024, 1, 1)
|
||||
for i in range(n_triples):
|
||||
subject = self.rng.choice(entity_names)
|
||||
obj = self.rng.choice(entity_names)
|
||||
while obj == subject:
|
||||
obj = self.rng.choice(entity_names)
|
||||
predicate = self.rng.choice(PREDICATES)
|
||||
days_offset = self.rng.randint(0, 730)
|
||||
valid_from = (base_date + timedelta(days=days_offset)).strftime("%Y-%m-%d")
|
||||
# 30% chance of having a valid_to
|
||||
valid_to = None
|
||||
if self.rng.random() < 0.3:
|
||||
end_offset = self.rng.randint(30, 365)
|
||||
valid_to = (base_date + timedelta(days=days_offset + end_offset)).strftime(
|
||||
"%Y-%m-%d"
|
||||
)
|
||||
triples.append((subject, predicate, obj, valid_from, valid_to))
|
||||
|
||||
return entities, triples
|
||||
|
||||
# ── Search query generation ──────────────────────────────────────────
|
||||
|
||||
def generate_search_queries(self, n_queries=None):
|
||||
"""
|
||||
Generate search queries with expected results.
|
||||
|
||||
Returns list of {"query": str, "expected_wing": str|None, "expected_room": str|None, "is_needle": bool}.
|
||||
Needle queries have known-good answers for recall measurement.
|
||||
"""
|
||||
n_queries = n_queries or self.cfg["search_queries"]
|
||||
queries = []
|
||||
|
||||
# Half are needle queries (known-good answers)
|
||||
n_needle = min(n_queries // 2, len(self.needles))
|
||||
for needle in self.needles[:n_needle]:
|
||||
queries.append(
|
||||
{
|
||||
"query": needle["query"],
|
||||
"expected_wing": needle["wing"],
|
||||
"expected_room": needle["room"],
|
||||
"needle_id": needle["id"],
|
||||
"is_needle": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Other half are generic queries (measure latency, not recall)
|
||||
n_generic = n_queries - n_needle
|
||||
for _ in range(n_generic):
|
||||
queries.append(
|
||||
{
|
||||
"query": self.rng.choice(TECH_TERMS) + " " + self.rng.choice(TECH_TERMS),
|
||||
"expected_wing": None,
|
||||
"expected_room": None,
|
||||
"needle_id": None,
|
||||
"is_needle": False,
|
||||
}
|
||||
)
|
||||
|
||||
self.rng.shuffle(queries)
|
||||
return queries
|
||||
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Benchmark report utilities — JSON output and regression detection.
|
||||
|
||||
Each test records metrics via record_metric(). At session end, the
|
||||
conftest.py pytest_terminal_summary hook writes the collected results.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
||||
RESULTS_FILE = os.path.join(tempfile.gettempdir(), "mempalace_bench_results.json")
|
||||
|
||||
|
||||
def record_metric(category: str, metric: str, value):
|
||||
"""Append a metric to the session results file (JSON on disk)."""
|
||||
results = {}
|
||||
if os.path.exists(RESULTS_FILE):
|
||||
try:
|
||||
with open(RESULTS_FILE) as f:
|
||||
results = json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
results = {}
|
||||
|
||||
if category not in results:
|
||||
results[category] = {}
|
||||
results[category][metric] = value
|
||||
|
||||
with open(RESULTS_FILE, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
|
||||
def check_regression(current_report: str, baseline_report: str, threshold: float = 0.2):
|
||||
"""
|
||||
Compare current benchmark results against a baseline.
|
||||
|
||||
Returns a list of regression descriptions. Empty list = no regressions.
|
||||
|
||||
threshold: fractional degradation allowed (0.2 = 20% worse is OK).
|
||||
"""
|
||||
with open(current_report) as f:
|
||||
current = json.load(f)
|
||||
with open(baseline_report) as f:
|
||||
baseline = json.load(f)
|
||||
|
||||
regressions = []
|
||||
# Keywords for metric direction — checked in order, first match wins.
|
||||
# "improvement" is checked before "latency" so that composite names
|
||||
# like "latency_improvement_pct" are classified correctly.
|
||||
_higher_is_better_kw = [
|
||||
"improvement",
|
||||
"recall",
|
||||
"throughput",
|
||||
"per_sec",
|
||||
"files_per_sec",
|
||||
"drawers_per_sec",
|
||||
"triples_per_sec",
|
||||
"speedup",
|
||||
]
|
||||
_higher_is_worse_kw = [
|
||||
"latency",
|
||||
"rss",
|
||||
"memory",
|
||||
"oom",
|
||||
"lock_failures",
|
||||
"elapsed",
|
||||
"p50_ms",
|
||||
"p95_ms",
|
||||
"p99_ms",
|
||||
"rss_delta_mb",
|
||||
"peak_rss_mb",
|
||||
"errors",
|
||||
"failures",
|
||||
]
|
||||
|
||||
def _metric_direction(name: str) -> str:
|
||||
"""Return 'higher_better', 'higher_worse', or 'unknown'."""
|
||||
low = name.lower()
|
||||
for kw in _higher_is_better_kw:
|
||||
if kw in low:
|
||||
return "higher_better"
|
||||
for kw in _higher_is_worse_kw:
|
||||
if kw in low:
|
||||
return "higher_worse"
|
||||
return "unknown"
|
||||
|
||||
for category in baseline.get("results", {}):
|
||||
if category not in current.get("results", {}):
|
||||
continue
|
||||
for metric, base_val in baseline["results"][category].items():
|
||||
if metric not in current["results"][category]:
|
||||
continue
|
||||
curr_val = current["results"][category][metric]
|
||||
if not isinstance(base_val, (int, float)) or not isinstance(curr_val, (int, float)):
|
||||
continue
|
||||
if base_val == 0:
|
||||
continue
|
||||
|
||||
direction = _metric_direction(metric)
|
||||
|
||||
if direction == "higher_worse":
|
||||
# Higher is worse — check if current exceeds baseline by threshold
|
||||
if curr_val > base_val * (1 + threshold):
|
||||
pct = ((curr_val - base_val) / base_val) * 100
|
||||
regressions.append(
|
||||
f"{category}/{metric}: {base_val:.2f} -> {curr_val:.2f} ({pct:+.1f}%, threshold {threshold * 100:.0f}%)"
|
||||
)
|
||||
elif direction == "higher_better":
|
||||
# Lower is worse — check if current is below baseline by threshold
|
||||
if curr_val < base_val * (1 - threshold):
|
||||
pct = ((curr_val - base_val) / base_val) * 100
|
||||
regressions.append(
|
||||
f"{category}/{metric}: {base_val:.2f} -> {curr_val:.2f} ({pct:+.1f}%, threshold {threshold * 100:.0f}%)"
|
||||
)
|
||||
|
||||
return regressions
|
||||
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
ChromaDB stress tests — find the breaking point.
|
||||
|
||||
Tests the raw ChromaDB patterns used by mempalace to determine:
|
||||
- At what collection size does col.get(include=["metadatas"]) become dangerous?
|
||||
- How does query latency degrade as collection grows?
|
||||
- How much faster is batched insertion vs sequential?
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import chromadb
|
||||
import pytest
|
||||
|
||||
from tests.benchmarks.data_generator import PalaceDataGenerator
|
||||
from tests.benchmarks.report import record_metric
|
||||
|
||||
|
||||
def _get_rss_mb():
|
||||
try:
|
||||
import psutil
|
||||
|
||||
return psutil.Process().memory_info().rss / (1024 * 1024)
|
||||
except ImportError:
|
||||
import resource
|
||||
import platform
|
||||
|
||||
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
if platform.system() == "Darwin":
|
||||
return usage / (1024 * 1024)
|
||||
return usage / 1024
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestGetAllMetadatasOOM:
|
||||
"""
|
||||
The specific pattern causing finding #3:
|
||||
col.get(include=["metadatas"]) with NO limit.
|
||||
|
||||
Measures RSS growth to find when this becomes dangerous.
|
||||
"""
|
||||
|
||||
SIZES = [1_000, 2_500, 5_000, 10_000]
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_get_all_metadatas_rss(self, n_drawers, tmp_path, bench_scale):
|
||||
"""RSS growth from fetching all metadata at once."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
|
||||
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_collection("mempalace_drawers")
|
||||
|
||||
rss_before = _get_rss_mb()
|
||||
start = time.perf_counter()
|
||||
all_meta = col.get(include=["metadatas"])["metadatas"]
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
rss_after = _get_rss_mb()
|
||||
|
||||
assert len(all_meta) == n_drawers
|
||||
rss_delta = rss_after - rss_before
|
||||
|
||||
record_metric("chromadb_get_all", f"rss_delta_mb_at_{n_drawers}", round(rss_delta, 2))
|
||||
record_metric("chromadb_get_all", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestQueryDegradation:
|
||||
"""Measure query latency as collection grows."""
|
||||
|
||||
SIZES = [1_000, 2_500, 5_000, 10_000]
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_query_latency_at_size(self, n_drawers, tmp_path, bench_scale):
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
|
||||
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_collection("mempalace_drawers")
|
||||
|
||||
queries = [
|
||||
"authentication middleware optimization",
|
||||
"database connection pooling strategy",
|
||||
"error handling retry logic",
|
||||
"deployment pipeline configuration",
|
||||
"load balancer health check",
|
||||
]
|
||||
|
||||
latencies = []
|
||||
for q in queries:
|
||||
start = time.perf_counter()
|
||||
results = col.query(query_texts=[q], n_results=5, include=["documents", "distances"])
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
assert results["documents"][0] # got results
|
||||
|
||||
avg_ms = sum(latencies) / len(latencies)
|
||||
p95_ms = sorted(latencies)[int(len(latencies) * 0.95)]
|
||||
|
||||
record_metric("chromadb_query", f"avg_latency_ms_at_{n_drawers}", round(avg_ms, 1))
|
||||
record_metric("chromadb_query", f"p95_latency_ms_at_{n_drawers}", round(p95_ms, 1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestBulkInsertPerformance:
|
||||
"""Compare batch insertion vs sequential add_drawer pattern."""
|
||||
|
||||
def test_sequential_vs_batched(self, tmp_path):
|
||||
"""The current miner uses single-document add(). How much faster is batching?"""
|
||||
n_docs = 500
|
||||
gen = PalaceDataGenerator(seed=42)
|
||||
|
||||
# Generate content
|
||||
contents = [gen._random_text(400, 800) for _ in range(n_docs)]
|
||||
|
||||
# Sequential insertion (mimics add_drawer pattern)
|
||||
palace_seq = str(tmp_path / "seq")
|
||||
os.makedirs(palace_seq)
|
||||
client_seq = chromadb.PersistentClient(path=palace_seq)
|
||||
col_seq = client_seq.get_or_create_collection("mempalace_drawers")
|
||||
|
||||
start = time.perf_counter()
|
||||
for i, content in enumerate(contents):
|
||||
col_seq.add(
|
||||
documents=[content],
|
||||
ids=[f"seq_{i}"],
|
||||
metadatas=[{"wing": "test", "room": "bench", "chunk_index": i}],
|
||||
)
|
||||
sequential_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Batched insertion
|
||||
palace_batch = str(tmp_path / "batch")
|
||||
os.makedirs(palace_batch)
|
||||
client_batch = chromadb.PersistentClient(path=palace_batch)
|
||||
col_batch = client_batch.get_or_create_collection("mempalace_drawers")
|
||||
|
||||
batch_size = 100
|
||||
start = time.perf_counter()
|
||||
for batch_start in range(0, n_docs, batch_size):
|
||||
batch_end = min(batch_start + batch_size, n_docs)
|
||||
batch_docs = contents[batch_start:batch_end]
|
||||
batch_ids = [f"batch_{i}" for i in range(batch_start, batch_end)]
|
||||
batch_metas = [
|
||||
{"wing": "test", "room": "bench", "chunk_index": i}
|
||||
for i in range(batch_start, batch_end)
|
||||
]
|
||||
col_batch.add(documents=batch_docs, ids=batch_ids, metadatas=batch_metas)
|
||||
batched_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
speedup = sequential_ms / max(batched_ms, 0.01)
|
||||
|
||||
assert col_seq.count() == n_docs
|
||||
assert col_batch.count() == n_docs
|
||||
|
||||
record_metric("chromadb_insert", "sequential_ms", round(sequential_ms, 1))
|
||||
record_metric("chromadb_insert", "batched_ms", round(batched_ms, 1))
|
||||
record_metric("chromadb_insert", "speedup_ratio", round(speedup, 2))
|
||||
record_metric("chromadb_insert", "n_docs", n_docs)
|
||||
record_metric("chromadb_insert", "batch_size", batch_size)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
@pytest.mark.slow
|
||||
class TestMaxCollectionSize:
|
||||
"""Incrementally grow collection to find practical limits."""
|
||||
|
||||
def test_incremental_growth(self, tmp_path, bench_scale):
|
||||
"""Add drawers in batches, measure latency per batch."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
cfg = gen.cfg
|
||||
target = min(cfg["drawers"], 10_000) # cap at 10K for this test
|
||||
|
||||
palace_path = str(tmp_path / "palace")
|
||||
os.makedirs(palace_path)
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
|
||||
batch_size = 500
|
||||
batch_times = []
|
||||
total_inserted = 0
|
||||
|
||||
for batch_num in range(0, target, batch_size):
|
||||
n = min(batch_size, target - batch_num)
|
||||
docs = [gen._random_text(400, 800) for _ in range(n)]
|
||||
ids = [f"growth_{batch_num + i}" for i in range(n)]
|
||||
metas = [
|
||||
{"wing": gen.wings[i % len(gen.wings)], "room": "bench", "chunk_index": i}
|
||||
for i in range(batch_num, batch_num + n)
|
||||
]
|
||||
|
||||
start = time.perf_counter()
|
||||
col.add(documents=docs, ids=ids, metadatas=metas)
|
||||
batch_ms = (time.perf_counter() - start) * 1000
|
||||
total_inserted += n
|
||||
batch_times.append({"at_size": total_inserted, "batch_ms": round(batch_ms, 1)})
|
||||
|
||||
assert col.count() == total_inserted
|
||||
|
||||
# Record first and last batch times to show degradation
|
||||
record_metric("chromadb_growth", "first_batch_ms", batch_times[0]["batch_ms"])
|
||||
record_metric("chromadb_growth", "last_batch_ms", batch_times[-1]["batch_ms"])
|
||||
record_metric("chromadb_growth", "total_inserted", total_inserted)
|
||||
record_metric("chromadb_growth", "batch_times", batch_times)
|
||||
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Ingestion throughput benchmarks.
|
||||
|
||||
Measures mining performance at scale:
|
||||
- Files/sec and drawers/sec through the full mine() pipeline
|
||||
- Peak RSS during mining
|
||||
- Chunking throughput isolated from ChromaDB
|
||||
- Re-ingest skip overhead (finding #11: file_already_mined check)
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import chromadb
|
||||
import pytest
|
||||
|
||||
from tests.benchmarks.data_generator import PalaceDataGenerator
|
||||
from tests.benchmarks.report import record_metric
|
||||
|
||||
|
||||
def _get_rss_mb():
|
||||
try:
|
||||
import psutil
|
||||
|
||||
return psutil.Process().memory_info().rss / (1024 * 1024)
|
||||
except ImportError:
|
||||
import resource
|
||||
import platform
|
||||
|
||||
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
if platform.system() == "Darwin":
|
||||
return usage / (1024 * 1024)
|
||||
return usage / 1024
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestMineThroughput:
|
||||
"""Measure the full mine() pipeline throughput."""
|
||||
|
||||
@pytest.mark.parametrize("n_files", [20, 50, 100])
|
||||
def test_mine_files_per_second(self, n_files, tmp_path, bench_scale):
|
||||
"""End-to-end mining throughput: generate files, mine, count drawers."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
project_path, wing, rooms, files_written = gen.generate_project_tree(
|
||||
tmp_path / "project", n_files=n_files
|
||||
)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
|
||||
from mempalace.miner import mine
|
||||
|
||||
start = time.perf_counter()
|
||||
mine(project_path, palace_path)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_collection("mempalace_drawers")
|
||||
drawer_count = col.count()
|
||||
|
||||
files_per_sec = files_written / max(elapsed, 0.001)
|
||||
drawers_per_sec = drawer_count / max(elapsed, 0.001)
|
||||
|
||||
record_metric("ingest", f"files_per_sec_at_{n_files}", round(files_per_sec, 1))
|
||||
record_metric("ingest", f"drawers_per_sec_at_{n_files}", round(drawers_per_sec, 1))
|
||||
record_metric("ingest", f"elapsed_sec_at_{n_files}", round(elapsed, 2))
|
||||
record_metric("ingest", f"drawers_created_at_{n_files}", drawer_count)
|
||||
|
||||
def test_mine_peak_rss(self, tmp_path, bench_scale):
|
||||
"""Track peak RSS during a mining run."""
|
||||
import threading
|
||||
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
project_path, wing, rooms, files_written = gen.generate_project_tree(
|
||||
tmp_path / "project", n_files=100
|
||||
)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
|
||||
from mempalace.miner import mine
|
||||
|
||||
rss_samples = []
|
||||
stop_sampling = threading.Event()
|
||||
|
||||
def sample_rss():
|
||||
while not stop_sampling.is_set():
|
||||
rss_samples.append(_get_rss_mb())
|
||||
stop_sampling.wait(0.1)
|
||||
|
||||
sampler = threading.Thread(target=sample_rss, daemon=True)
|
||||
sampler.start()
|
||||
|
||||
rss_before = _get_rss_mb()
|
||||
mine(project_path, palace_path)
|
||||
stop_sampling.set()
|
||||
sampler.join(timeout=1)
|
||||
|
||||
peak_rss = max(rss_samples) if rss_samples else _get_rss_mb()
|
||||
rss_delta = peak_rss - rss_before
|
||||
|
||||
record_metric("ingest", "peak_rss_mb", round(peak_rss, 1))
|
||||
record_metric("ingest", "rss_delta_mb", round(rss_delta, 1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestChunkThroughput:
|
||||
"""Isolate chunking performance from ChromaDB insertion."""
|
||||
|
||||
@pytest.mark.parametrize("content_size_kb", [1, 10, 100])
|
||||
def test_chunk_text_throughput(self, content_size_kb):
|
||||
"""Measure chunk_text speed for different content sizes."""
|
||||
from mempalace.miner import chunk_text
|
||||
|
||||
gen = PalaceDataGenerator(seed=42)
|
||||
# Generate content of target size
|
||||
content = gen._random_text(content_size_kb * 500, content_size_kb * 1200)
|
||||
# Pad to approximate target KB
|
||||
while len(content) < content_size_kb * 1024:
|
||||
content += "\n" + gen._random_text(200, 500)
|
||||
|
||||
n_iterations = 50
|
||||
start = time.perf_counter()
|
||||
total_chunks = 0
|
||||
for _ in range(n_iterations):
|
||||
chunks = chunk_text(content, "bench_file.py")
|
||||
total_chunks += len(chunks)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
chunks_per_sec = total_chunks / max(elapsed, 0.001)
|
||||
kb_per_sec = (len(content) * n_iterations / 1024) / max(elapsed, 0.001)
|
||||
|
||||
record_metric(
|
||||
"chunking", f"chunks_per_sec_at_{content_size_kb}kb", round(chunks_per_sec, 1)
|
||||
)
|
||||
record_metric("chunking", f"kb_per_sec_at_{content_size_kb}kb", round(kb_per_sec, 1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestReingestSkipOverhead:
|
||||
"""Finding #11: file_already_mined() check overhead at scale."""
|
||||
|
||||
def test_skip_check_cost(self, tmp_path):
|
||||
"""Mine files, then re-mine — measure cost of skip checks."""
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
project_path, wing, rooms, files_written = gen.generate_project_tree(
|
||||
tmp_path / "project", n_files=50
|
||||
)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
|
||||
from mempalace.miner import mine
|
||||
|
||||
# First mine
|
||||
mine(project_path, palace_path)
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_collection("mempalace_drawers")
|
||||
initial_count = col.count()
|
||||
|
||||
# Re-mine (all files should be skipped)
|
||||
start = time.perf_counter()
|
||||
mine(project_path, palace_path)
|
||||
skip_elapsed = time.perf_counter() - start
|
||||
|
||||
# Verify no new drawers added
|
||||
final_count = col.count()
|
||||
assert final_count == initial_count, "Re-mine should not add new drawers"
|
||||
|
||||
record_metric("reingest", "skip_check_elapsed_sec", round(skip_elapsed, 2))
|
||||
record_metric("reingest", "files_checked", files_written)
|
||||
record_metric(
|
||||
"reingest",
|
||||
"skip_check_per_file_ms",
|
||||
round(skip_elapsed * 1000 / max(files_written, 1), 1),
|
||||
)
|
||||
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Knowledge graph benchmarks — SQLite temporal KG at scale.
|
||||
|
||||
Tests triple insertion throughput, query latency, temporal accuracy,
|
||||
and SQLite concurrent access behavior.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.benchmarks.data_generator import PalaceDataGenerator
|
||||
from tests.benchmarks.report import record_metric
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestTripleInsertionRate:
|
||||
"""Measure triples/sec at different scales."""
|
||||
|
||||
@pytest.mark.parametrize("n_triples", [200, 1_000, 5_000])
|
||||
def test_insertion_throughput(self, n_triples, tmp_path):
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
entities, triples = gen.generate_kg_triples(
|
||||
n_entities=min(n_triples // 2, 200), n_triples=n_triples
|
||||
)
|
||||
|
||||
from mempalace.knowledge_graph import KnowledgeGraph
|
||||
|
||||
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
|
||||
|
||||
# Insert entities first
|
||||
for name, etype in entities:
|
||||
kg.add_entity(name, etype)
|
||||
|
||||
# Measure triple insertion
|
||||
start = time.perf_counter()
|
||||
for subject, predicate, obj, valid_from, valid_to in triples:
|
||||
kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
triples_per_sec = n_triples / max(elapsed, 0.001)
|
||||
|
||||
record_metric("kg_insert", f"triples_per_sec_at_{n_triples}", round(triples_per_sec, 1))
|
||||
record_metric("kg_insert", f"elapsed_sec_at_{n_triples}", round(elapsed, 3))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestQueryEntityLatency:
|
||||
"""Query latency for entities with varying relationship counts."""
|
||||
|
||||
def test_query_latency_vs_relationships(self, tmp_path):
|
||||
"""Create entities with 10, 50, 100 relationships and measure query time."""
|
||||
from mempalace.knowledge_graph import KnowledgeGraph
|
||||
|
||||
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
|
||||
|
||||
# Create a hub entity connected to many others
|
||||
kg.add_entity("Hub", "person")
|
||||
target_counts = [10, 50, 100]
|
||||
|
||||
for target in target_counts:
|
||||
for i in range(target):
|
||||
entity_name = f"Node_{target}_{i}"
|
||||
kg.add_entity(entity_name, "project")
|
||||
kg.add_triple("Hub", "works_on", entity_name, valid_from="2025-01-01")
|
||||
|
||||
# Measure query for Hub (which has sum(target_counts) relationships)
|
||||
latencies = []
|
||||
for _ in range(20):
|
||||
start = time.perf_counter()
|
||||
kg.query_entity("Hub")
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
|
||||
avg_ms = sum(latencies) / len(latencies)
|
||||
total_rels = sum(target_counts)
|
||||
|
||||
record_metric("kg_query", f"avg_ms_with_{total_rels}_rels", round(avg_ms, 2))
|
||||
record_metric("kg_query", "total_relationships", total_rels)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestTimelinePerformance:
|
||||
"""timeline() with no entity filter does a full table scan."""
|
||||
|
||||
@pytest.mark.parametrize("n_triples", [200, 1_000, 5_000])
|
||||
def test_timeline_latency(self, n_triples, tmp_path):
|
||||
from mempalace.knowledge_graph import KnowledgeGraph
|
||||
|
||||
gen = PalaceDataGenerator(seed=42)
|
||||
entities, triples = gen.generate_kg_triples(
|
||||
n_entities=min(n_triples // 2, 200), n_triples=n_triples
|
||||
)
|
||||
|
||||
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
|
||||
for name, etype in entities:
|
||||
kg.add_entity(name, etype)
|
||||
for subject, predicate, obj, valid_from, valid_to in triples:
|
||||
kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to)
|
||||
|
||||
# Measure timeline (no filter = full scan with LIMIT 100)
|
||||
latencies = []
|
||||
for _ in range(10):
|
||||
start = time.perf_counter()
|
||||
kg.timeline()
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
|
||||
avg_ms = sum(latencies) / len(latencies)
|
||||
record_metric("kg_timeline", f"avg_ms_at_{n_triples}", round(avg_ms, 2))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestTemporalQueryAccuracy:
|
||||
"""Verify temporal filtering correctness at scale."""
|
||||
|
||||
def test_as_of_filtering(self, tmp_path):
|
||||
"""Insert triples with known temporal ranges, verify as_of queries."""
|
||||
from mempalace.knowledge_graph import KnowledgeGraph
|
||||
|
||||
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
|
||||
|
||||
kg.add_entity("Alice", "person")
|
||||
kg.add_entity("ProjectA", "project")
|
||||
kg.add_entity("ProjectB", "project")
|
||||
|
||||
# Alice worked on ProjectA from 2024-01 to 2024-06
|
||||
kg.add_triple(
|
||||
"Alice", "works_on", "ProjectA", valid_from="2024-01-01", valid_to="2024-06-30"
|
||||
)
|
||||
# Alice worked on ProjectB from 2024-07 onwards
|
||||
kg.add_triple("Alice", "works_on", "ProjectB", valid_from="2024-07-01")
|
||||
|
||||
# Add noise triples
|
||||
gen = PalaceDataGenerator(seed=42)
|
||||
entities, triples = gen.generate_kg_triples(n_entities=50, n_triples=500)
|
||||
for name, etype in entities:
|
||||
kg.add_entity(name, etype)
|
||||
for subject, predicate, obj, valid_from, valid_to in triples:
|
||||
kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to)
|
||||
|
||||
# Query Alice as of March 2024 — should find ProjectA
|
||||
result_march = kg.query_entity("Alice", as_of="2024-03-15")
|
||||
# Query Alice as of September 2024 — should find ProjectB
|
||||
result_sept = kg.query_entity("Alice", as_of="2024-09-15")
|
||||
|
||||
record_metric(
|
||||
"kg_temporal",
|
||||
"march_query_results",
|
||||
len(result_march) if isinstance(result_march, list) else 0,
|
||||
)
|
||||
record_metric(
|
||||
"kg_temporal",
|
||||
"sept_query_results",
|
||||
len(result_sept) if isinstance(result_sept, list) else 0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestSQLiteConcurrentAccess:
|
||||
"""Test concurrent read/write behavior with SQLite (finding #8)."""
|
||||
|
||||
def test_concurrent_writers(self, tmp_path):
|
||||
"""N threads writing triples simultaneously — count lock failures."""
|
||||
from mempalace.knowledge_graph import KnowledgeGraph
|
||||
|
||||
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
|
||||
|
||||
# Pre-create entities
|
||||
for i in range(100):
|
||||
kg.add_entity(f"Entity_{i}", "concept")
|
||||
|
||||
n_threads = 4
|
||||
triples_per_thread = 50
|
||||
lock_failures = []
|
||||
successes = []
|
||||
|
||||
def writer(thread_id):
|
||||
fails = 0
|
||||
ok = 0
|
||||
for i in range(triples_per_thread):
|
||||
try:
|
||||
kg.add_triple(
|
||||
f"Entity_{thread_id * 10}",
|
||||
"relates_to",
|
||||
f"Entity_{(thread_id * 10 + i) % 100}",
|
||||
valid_from="2025-01-01",
|
||||
)
|
||||
ok += 1
|
||||
except Exception:
|
||||
fails += 1
|
||||
lock_failures.append(fails)
|
||||
successes.append(ok)
|
||||
|
||||
threads = [threading.Thread(target=writer, args=(t,)) for t in range(n_threads)]
|
||||
start = time.perf_counter()
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=30)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
total_failures = sum(lock_failures)
|
||||
total_successes = sum(successes)
|
||||
|
||||
record_metric("kg_concurrent", "total_failures", total_failures)
|
||||
record_metric("kg_concurrent", "total_successes", total_successes)
|
||||
record_metric("kg_concurrent", "elapsed_sec", round(elapsed, 2))
|
||||
record_metric("kg_concurrent", "threads", n_threads)
|
||||
record_metric("kg_concurrent", "triples_per_thread", triples_per_thread)
|
||||
|
||||
def test_concurrent_read_write(self, tmp_path):
|
||||
"""Readers and writers running simultaneously."""
|
||||
from mempalace.knowledge_graph import KnowledgeGraph
|
||||
|
||||
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
|
||||
|
||||
# Seed some data
|
||||
for i in range(50):
|
||||
kg.add_entity(f"E_{i}", "concept")
|
||||
for i in range(200):
|
||||
kg.add_triple(f"E_{i % 50}", "links", f"E_{(i + 1) % 50}", valid_from="2025-01-01")
|
||||
|
||||
read_errors = []
|
||||
write_errors = []
|
||||
|
||||
def reader():
|
||||
fails = 0
|
||||
for i in range(50):
|
||||
try:
|
||||
kg.query_entity(f"E_{i % 50}")
|
||||
except Exception:
|
||||
fails += 1
|
||||
read_errors.append(fails)
|
||||
|
||||
def writer():
|
||||
fails = 0
|
||||
for i in range(50):
|
||||
try:
|
||||
kg.add_triple(
|
||||
f"E_{i % 50}", "new_rel", f"E_{(i + 7) % 50}", valid_from="2025-06-01"
|
||||
)
|
||||
except Exception:
|
||||
fails += 1
|
||||
write_errors.append(fails)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=reader),
|
||||
threading.Thread(target=reader),
|
||||
threading.Thread(target=writer),
|
||||
threading.Thread(target=writer),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=30)
|
||||
|
||||
record_metric("kg_concurrent_rw", "read_errors", sum(read_errors))
|
||||
record_metric("kg_concurrent_rw", "write_errors", sum(write_errors))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestKGStats:
|
||||
"""Measure stats() performance as graph grows."""
|
||||
|
||||
@pytest.mark.parametrize("n_triples", [200, 1_000, 5_000])
|
||||
def test_stats_latency(self, n_triples, tmp_path):
|
||||
from mempalace.knowledge_graph import KnowledgeGraph
|
||||
|
||||
gen = PalaceDataGenerator(seed=42)
|
||||
entities, triples = gen.generate_kg_triples(
|
||||
n_entities=min(n_triples // 2, 200), n_triples=n_triples
|
||||
)
|
||||
|
||||
kg = KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3"))
|
||||
for name, etype in entities:
|
||||
kg.add_entity(name, etype)
|
||||
for subject, predicate, obj, valid_from, valid_to in triples:
|
||||
kg.add_triple(subject, predicate, obj, valid_from=valid_from, valid_to=valid_to)
|
||||
|
||||
latencies = []
|
||||
for _ in range(10):
|
||||
start = time.perf_counter()
|
||||
kg.stats()
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
|
||||
avg_ms = sum(latencies) / len(latencies)
|
||||
record_metric("kg_stats", f"avg_ms_at_{n_triples}", round(avg_ms, 2))
|
||||
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
Memory stack (layers.py) benchmarks.
|
||||
|
||||
Tests MemoryStack.wake_up(), Layer1.generate(), and Layer2/L3
|
||||
at scale. Layer1 has the same unbounded col.get() as tool_status.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.benchmarks.data_generator import PalaceDataGenerator
|
||||
from tests.benchmarks.report import record_metric
|
||||
|
||||
|
||||
def _get_rss_mb():
|
||||
try:
|
||||
import psutil
|
||||
|
||||
return psutil.Process().memory_info().rss / (1024 * 1024)
|
||||
except ImportError:
|
||||
import resource
|
||||
import platform
|
||||
|
||||
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
if platform.system() == "Darwin":
|
||||
return usage / (1024 * 1024)
|
||||
return usage / 1024
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestWakeUpCost:
|
||||
"""Measure wake_up() time (L0 + L1) at different palace sizes."""
|
||||
|
||||
SIZES = [500, 1_000, 2_500, 5_000]
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_wakeup_latency(self, n_drawers, tmp_path, bench_scale):
|
||||
"""L0+L1 generation time grows with palace size because L1 fetches all."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
|
||||
|
||||
# Create identity file
|
||||
identity_path = str(tmp_path / "identity.txt")
|
||||
with open(identity_path, "w") as f:
|
||||
f.write("I am a test AI. Traits: precise, fast.\n")
|
||||
|
||||
from mempalace.layers import MemoryStack
|
||||
|
||||
stack = MemoryStack(palace_path=palace_path, identity_path=identity_path)
|
||||
|
||||
latencies = []
|
||||
for _ in range(5):
|
||||
start = time.perf_counter()
|
||||
text = stack.wake_up()
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
assert "L0" in text or "L1" in text or "IDENTITY" in text or "ESSENTIAL" in text
|
||||
|
||||
avg_ms = sum(latencies) / len(latencies)
|
||||
record_metric("layers_wakeup", f"avg_ms_at_{n_drawers}", round(avg_ms, 1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestLayer1UnboundedFetch:
|
||||
"""Layer1.generate() fetches ALL drawers — same pattern as tool_status."""
|
||||
|
||||
SIZES = [500, 1_000, 2_500, 5_000]
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_layer1_rss_growth(self, n_drawers, tmp_path):
|
||||
"""Track RSS from Layer1 fetching all drawers at different sizes."""
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
|
||||
|
||||
from mempalace.layers import Layer1
|
||||
|
||||
layer = Layer1(palace_path=palace_path)
|
||||
|
||||
rss_before = _get_rss_mb()
|
||||
start = time.perf_counter()
|
||||
text = layer.generate()
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
rss_after = _get_rss_mb()
|
||||
|
||||
rss_delta = rss_after - rss_before
|
||||
assert "L1" in text
|
||||
|
||||
record_metric("layer1", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
|
||||
record_metric("layer1", f"rss_delta_mb_at_{n_drawers}", round(rss_delta, 2))
|
||||
|
||||
def test_layer1_wing_filtered(self, tmp_path):
|
||||
"""Wing-filtered Layer1 should fetch fewer drawers."""
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
|
||||
|
||||
from mempalace.layers import Layer1
|
||||
|
||||
wing = gen.wings[0]
|
||||
|
||||
# Unfiltered
|
||||
layer_all = Layer1(palace_path=palace_path)
|
||||
start = time.perf_counter()
|
||||
layer_all.generate()
|
||||
unfiltered_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Wing-filtered
|
||||
layer_wing = Layer1(palace_path=palace_path, wing=wing)
|
||||
start = time.perf_counter()
|
||||
layer_wing.generate()
|
||||
filtered_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
record_metric("layer1_filter", "unfiltered_ms", round(unfiltered_ms, 1))
|
||||
record_metric("layer1_filter", "filtered_ms", round(filtered_ms, 1))
|
||||
if unfiltered_ms > 0:
|
||||
record_metric(
|
||||
"layer1_filter", "speedup_pct", round((1 - filtered_ms / unfiltered_ms) * 100, 1)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestWakeUpTokenBudget:
|
||||
"""Verify L0+L1 stays within token budget even at large palace sizes."""
|
||||
|
||||
SIZES = [500, 1_000, 2_500, 5_000]
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_token_budget(self, n_drawers, tmp_path):
|
||||
"""L1 has MAX_CHARS=3200 cap. Verify it holds at scale."""
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
|
||||
|
||||
identity_path = str(tmp_path / "identity.txt")
|
||||
with open(identity_path, "w") as f:
|
||||
f.write("I am a benchmark AI.\n")
|
||||
|
||||
from mempalace.layers import MemoryStack
|
||||
|
||||
stack = MemoryStack(palace_path=palace_path, identity_path=identity_path)
|
||||
text = stack.wake_up()
|
||||
token_estimate = len(text) // 4
|
||||
|
||||
# Budget is ~600-900 tokens. Allow up to 1200 for safety margin.
|
||||
record_metric("wakeup_budget", f"tokens_at_{n_drawers}", token_estimate)
|
||||
record_metric("wakeup_budget", f"chars_at_{n_drawers}", len(text))
|
||||
|
||||
assert (
|
||||
token_estimate < 1200
|
||||
), f"Wake-up exceeded budget: ~{token_estimate} tokens at {n_drawers} drawers"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestLayer2Retrieval:
|
||||
"""Layer2 on-demand retrieval with filters."""
|
||||
|
||||
def test_layer2_latency(self, tmp_path, bench_scale):
|
||||
"""L2 retrieval with wing filter at scale."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
|
||||
|
||||
from mempalace.layers import Layer2
|
||||
|
||||
layer = Layer2(palace_path=palace_path)
|
||||
wing = gen.wings[0]
|
||||
|
||||
latencies = []
|
||||
for _ in range(10):
|
||||
start = time.perf_counter()
|
||||
layer.retrieve(wing=wing, n_results=10)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
|
||||
avg_ms = sum(latencies) / len(latencies)
|
||||
record_metric("layer2", "avg_retrieval_ms", round(avg_ms, 1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestLayer3Search:
|
||||
"""Layer3 semantic search through the MemoryStack interface."""
|
||||
|
||||
def test_layer3_latency(self, tmp_path, bench_scale):
|
||||
"""L3 search latency through MemoryStack."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
|
||||
|
||||
identity_path = str(tmp_path / "identity.txt")
|
||||
with open(identity_path, "w") as f:
|
||||
f.write("I am a benchmark AI.\n")
|
||||
|
||||
from mempalace.layers import MemoryStack
|
||||
|
||||
stack = MemoryStack(palace_path=palace_path, identity_path=identity_path)
|
||||
|
||||
queries = ["authentication", "database", "deployment", "testing", "monitoring"]
|
||||
latencies = []
|
||||
for q in queries:
|
||||
start = time.perf_counter()
|
||||
stack.search(q, n_results=5)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
|
||||
avg_ms = sum(latencies) / len(latencies)
|
||||
record_metric("layer3", "avg_search_ms", round(avg_ms, 1))
|
||||
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
MCP server tool performance benchmarks.
|
||||
|
||||
Validates production readiness findings:
|
||||
- Finding #3: tool_status() unbounded col.get(include=["metadatas"]) → OOM
|
||||
- Finding #7: _get_collection() re-instantiates PersistentClient every call
|
||||
- Finding #3 variants: tool_list_wings(), tool_get_taxonomy() same pattern
|
||||
|
||||
Calls MCP tool handler functions directly with monkeypatched _config.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import chromadb
|
||||
import pytest
|
||||
|
||||
from tests.benchmarks.data_generator import PalaceDataGenerator
|
||||
from tests.benchmarks.report import record_metric
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_palace(tmp_path, n_drawers, scale="small"):
|
||||
"""Create a palace with exactly n_drawers, return palace_path."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=scale)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
|
||||
return palace_path
|
||||
|
||||
|
||||
def _patch_mcp_config(monkeypatch, palace_path, tmp_path):
|
||||
"""Monkeypatch mcp_server._config and _kg to point at test dirs."""
|
||||
from mempalace.config import MempalaceConfig
|
||||
from mempalace.knowledge_graph import KnowledgeGraph
|
||||
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path / "cfg"))
|
||||
# Override palace_path directly on the object
|
||||
monkeypatch.setattr(cfg, "_file_config", {"palace_path": palace_path})
|
||||
|
||||
import mempalace.mcp_server as mcp_mod
|
||||
|
||||
monkeypatch.setattr(mcp_mod, "_config", cfg)
|
||||
monkeypatch.setattr(mcp_mod, "_kg", KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")))
|
||||
|
||||
|
||||
def _get_rss_mb():
|
||||
"""Get current process RSS in MB."""
|
||||
try:
|
||||
import psutil
|
||||
|
||||
return psutil.Process().memory_info().rss / (1024 * 1024)
|
||||
except ImportError:
|
||||
import resource
|
||||
|
||||
# ru_maxrss is in KB on Linux, bytes on macOS
|
||||
import platform
|
||||
|
||||
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
if platform.system() == "Darwin":
|
||||
return usage / (1024 * 1024)
|
||||
return usage / 1024
|
||||
|
||||
|
||||
# ── Tests ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestToolStatusOOM:
|
||||
"""Finding #3: tool_status loads ALL metadata into memory."""
|
||||
|
||||
SIZES = [500, 1_000, 2_500, 5_000]
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_tool_status_rss_growth(self, n_drawers, tmp_path, monkeypatch):
|
||||
"""Measure RSS growth from tool_status at different palace sizes."""
|
||||
palace_path = _make_palace(tmp_path, n_drawers)
|
||||
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
|
||||
|
||||
from mempalace.mcp_server import tool_status
|
||||
|
||||
rss_before = _get_rss_mb()
|
||||
result = tool_status()
|
||||
rss_after = _get_rss_mb()
|
||||
|
||||
rss_delta = rss_after - rss_before
|
||||
assert "error" not in result, f"tool_status failed: {result}"
|
||||
assert result["total_drawers"] == n_drawers
|
||||
|
||||
record_metric("mcp_status", f"rss_delta_mb_at_{n_drawers}", round(rss_delta, 2))
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_tool_status_latency(self, n_drawers, tmp_path, monkeypatch):
|
||||
"""Measure tool_status response time at different palace sizes."""
|
||||
palace_path = _make_palace(tmp_path, n_drawers)
|
||||
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
|
||||
|
||||
from mempalace.mcp_server import tool_status
|
||||
|
||||
# Warm up
|
||||
tool_status()
|
||||
|
||||
start = time.perf_counter()
|
||||
result = tool_status()
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
assert "error" not in result
|
||||
record_metric("mcp_status", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestToolListWingsUnbounded:
|
||||
"""Finding #3 variant: tool_list_wings also fetches ALL metadata."""
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500, 5_000])
|
||||
def test_list_wings_latency(self, n_drawers, tmp_path, monkeypatch):
|
||||
palace_path = _make_palace(tmp_path, n_drawers)
|
||||
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
|
||||
|
||||
from mempalace.mcp_server import tool_list_wings
|
||||
|
||||
start = time.perf_counter()
|
||||
result = tool_list_wings()
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
assert "wings" in result
|
||||
record_metric("mcp_list_wings", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestToolGetTaxonomyUnbounded:
|
||||
"""Finding #3 variant: tool_get_taxonomy also fetches ALL metadata."""
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500, 5_000])
|
||||
def test_get_taxonomy_latency(self, n_drawers, tmp_path, monkeypatch):
|
||||
palace_path = _make_palace(tmp_path, n_drawers)
|
||||
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
|
||||
|
||||
from mempalace.mcp_server import tool_get_taxonomy
|
||||
|
||||
start = time.perf_counter()
|
||||
result = tool_get_taxonomy()
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
assert "taxonomy" in result
|
||||
record_metric("mcp_taxonomy", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestClientReinstantiation:
|
||||
"""Finding #7: _get_collection() creates new PersistentClient every call."""
|
||||
|
||||
def test_reinstantiation_overhead(self, tmp_path, monkeypatch):
|
||||
"""Measure cost of 50 _get_collection() calls vs a cached client."""
|
||||
palace_path = _make_palace(tmp_path, 500)
|
||||
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
|
||||
|
||||
from mempalace.mcp_server import _get_collection
|
||||
|
||||
n_calls = 50
|
||||
|
||||
# Measure re-instantiation (current behavior)
|
||||
start = time.perf_counter()
|
||||
for _ in range(n_calls):
|
||||
col = _get_collection()
|
||||
assert col is not None
|
||||
uncached_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Measure cached client (what it should be)
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
cached_col = client.get_collection("mempalace_drawers")
|
||||
start = time.perf_counter()
|
||||
for _ in range(n_calls):
|
||||
_ = cached_col.count()
|
||||
cached_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
overhead_ratio = uncached_ms / max(cached_ms, 0.01)
|
||||
|
||||
record_metric("client_reinstantiation", "uncached_total_ms", round(uncached_ms, 1))
|
||||
record_metric("client_reinstantiation", "cached_total_ms", round(cached_ms, 1))
|
||||
record_metric("client_reinstantiation", "overhead_ratio", round(overhead_ratio, 2))
|
||||
record_metric("client_reinstantiation", "n_calls", n_calls)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestToolSearchLatency:
|
||||
"""tool_search uses query() not get(), should scale better."""
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500, 5_000])
|
||||
def test_search_latency(self, n_drawers, tmp_path, monkeypatch):
|
||||
palace_path = _make_palace(tmp_path, n_drawers)
|
||||
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
|
||||
|
||||
from mempalace.mcp_server import tool_search
|
||||
|
||||
queries = ["authentication middleware", "database migration", "error handling"]
|
||||
latencies = []
|
||||
for q in queries:
|
||||
start = time.perf_counter()
|
||||
result = tool_search(query=q, limit=5)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
assert "error" not in result
|
||||
|
||||
avg_ms = sum(latencies) / len(latencies)
|
||||
record_metric("mcp_search", f"avg_latency_ms_at_{n_drawers}", round(avg_ms, 1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestDuplicateCheckCost:
|
||||
"""tool_add_drawer calls tool_check_duplicate first — measure overhead."""
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", [500, 1_000, 2_500])
|
||||
def test_duplicate_check_latency(self, n_drawers, tmp_path, monkeypatch):
|
||||
palace_path = _make_palace(tmp_path, n_drawers)
|
||||
_patch_mcp_config(monkeypatch, palace_path, tmp_path)
|
||||
|
||||
from mempalace.mcp_server import tool_check_duplicate
|
||||
|
||||
test_content = "This is unique test content for duplicate checking benchmark."
|
||||
start = time.perf_counter()
|
||||
result = tool_check_duplicate(content=test_content)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
assert "error" not in result
|
||||
record_metric("mcp_duplicate_check", f"latency_ms_at_{n_drawers}", round(elapsed_ms, 1))
|
||||
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Memory profiling benchmarks — detect leaks and measure RSS growth.
|
||||
|
||||
Uses tracemalloc for heap snapshots and psutil/resource for RSS.
|
||||
Targets the highest-risk code paths:
|
||||
- Repeated search() calls (PersistentClient re-instantiation)
|
||||
- Repeated tool_status() calls (unbounded metadata fetch)
|
||||
- Layer1.generate() (fetches all drawers)
|
||||
"""
|
||||
|
||||
import tracemalloc
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.benchmarks.data_generator import PalaceDataGenerator
|
||||
from tests.benchmarks.report import record_metric
|
||||
|
||||
|
||||
def _get_rss_mb():
|
||||
try:
|
||||
import psutil
|
||||
|
||||
return psutil.Process().memory_info().rss / (1024 * 1024)
|
||||
except ImportError:
|
||||
import resource
|
||||
import platform
|
||||
|
||||
usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
if platform.system() == "Darwin":
|
||||
return usage / (1024 * 1024)
|
||||
return usage / 1024
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestSearchMemoryProfile:
|
||||
"""Track RSS growth over repeated search_memories() calls."""
|
||||
|
||||
def test_search_rss_growth(self, tmp_path):
|
||||
"""Issue 200 searches and track RSS every 50 calls."""
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=1_000, include_needles=False)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
n_calls = 200
|
||||
check_interval = 50
|
||||
queries = ["authentication", "database", "deployment", "error handling", "testing"]
|
||||
rss_readings = []
|
||||
rss_readings.append(("start", _get_rss_mb()))
|
||||
|
||||
for i in range(n_calls):
|
||||
q = queries[i % len(queries)]
|
||||
search_memories(q, palace_path=palace_path, n_results=5)
|
||||
if (i + 1) % check_interval == 0:
|
||||
rss_readings.append((f"after_{i + 1}", _get_rss_mb()))
|
||||
|
||||
start_rss = rss_readings[0][1]
|
||||
end_rss = rss_readings[-1][1]
|
||||
growth = end_rss - start_rss
|
||||
|
||||
record_metric("memory_search", "rss_start_mb", round(start_rss, 2))
|
||||
record_metric("memory_search", "rss_end_mb", round(end_rss, 2))
|
||||
record_metric("memory_search", "rss_growth_mb", round(growth, 2))
|
||||
record_metric("memory_search", "n_calls", n_calls)
|
||||
record_metric(
|
||||
"memory_search", "growth_per_100_calls_mb", round(growth / (n_calls / 100), 2)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestToolStatusMemoryProfile:
|
||||
"""Track RSS growth from repeated tool_status() calls."""
|
||||
|
||||
def test_tool_status_repeated_calls(self, tmp_path, monkeypatch):
|
||||
"""tool_status loads ALL metadata each call — does it leak?"""
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
|
||||
|
||||
from mempalace.config import MempalaceConfig
|
||||
from mempalace.knowledge_graph import KnowledgeGraph
|
||||
import mempalace.mcp_server as mcp_mod
|
||||
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path / "cfg"))
|
||||
monkeypatch.setattr(cfg, "_file_config", {"palace_path": palace_path})
|
||||
monkeypatch.setattr(mcp_mod, "_config", cfg)
|
||||
monkeypatch.setattr(mcp_mod, "_kg", KnowledgeGraph(db_path=str(tmp_path / "kg.sqlite3")))
|
||||
|
||||
from mempalace.mcp_server import tool_status
|
||||
|
||||
n_calls = 50
|
||||
rss_readings = []
|
||||
rss_readings.append(("start", _get_rss_mb()))
|
||||
|
||||
for i in range(n_calls):
|
||||
result = tool_status()
|
||||
assert result["total_drawers"] == 2_000
|
||||
if (i + 1) % 10 == 0:
|
||||
rss_readings.append((f"after_{i + 1}", _get_rss_mb()))
|
||||
|
||||
start_rss = rss_readings[0][1]
|
||||
end_rss = rss_readings[-1][1]
|
||||
growth = end_rss - start_rss
|
||||
|
||||
record_metric("memory_tool_status", "rss_start_mb", round(start_rss, 2))
|
||||
record_metric("memory_tool_status", "rss_end_mb", round(end_rss, 2))
|
||||
record_metric("memory_tool_status", "rss_growth_mb", round(growth, 2))
|
||||
record_metric("memory_tool_status", "n_calls", n_calls)
|
||||
record_metric("memory_tool_status", "palace_size", 2_000)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestLayer1MemoryProfile:
|
||||
"""Layer1.generate() fetches ALL drawers — same risk as tool_status."""
|
||||
|
||||
def test_layer1_repeated_generate(self, tmp_path):
|
||||
"""Layer1 fetches all drawers for scoring. Track memory over repeats."""
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
|
||||
|
||||
from mempalace.layers import Layer1
|
||||
|
||||
layer = Layer1(palace_path=palace_path)
|
||||
|
||||
n_calls = 30
|
||||
rss_readings = []
|
||||
rss_readings.append(("start", _get_rss_mb()))
|
||||
|
||||
for i in range(n_calls):
|
||||
text = layer.generate()
|
||||
assert "L1" in text
|
||||
if (i + 1) % 10 == 0:
|
||||
rss_readings.append((f"after_{i + 1}", _get_rss_mb()))
|
||||
|
||||
start_rss = rss_readings[0][1]
|
||||
end_rss = rss_readings[-1][1]
|
||||
growth = end_rss - start_rss
|
||||
|
||||
record_metric("memory_layer1", "rss_start_mb", round(start_rss, 2))
|
||||
record_metric("memory_layer1", "rss_end_mb", round(end_rss, 2))
|
||||
record_metric("memory_layer1", "rss_growth_mb", round(growth, 2))
|
||||
record_metric("memory_layer1", "n_calls", n_calls)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestHeapSnapshot:
|
||||
"""Use tracemalloc to identify top memory allocators during search."""
|
||||
|
||||
def test_search_heap_top_allocators(self, tmp_path):
|
||||
"""Identify which code paths allocate the most memory during search."""
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=1_000, include_needles=False)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
tracemalloc.start()
|
||||
snap_before = tracemalloc.take_snapshot()
|
||||
|
||||
for i in range(100):
|
||||
search_memories("test query", palace_path=palace_path, n_results=5)
|
||||
|
||||
snap_after = tracemalloc.take_snapshot()
|
||||
tracemalloc.stop()
|
||||
|
||||
stats = snap_after.compare_to(snap_before, "lineno")
|
||||
top_allocators = []
|
||||
for stat in stats[:10]:
|
||||
top_allocators.append(
|
||||
{
|
||||
"file": str(stat.traceback),
|
||||
"size_kb": round(stat.size / 1024, 1),
|
||||
"count": stat.count,
|
||||
}
|
||||
)
|
||||
|
||||
total_growth_kb = sum(s["size_kb"] for s in top_allocators)
|
||||
record_metric("heap_search", "top_10_growth_kb", round(total_growth_kb, 1))
|
||||
record_metric("heap_search", "n_searches", 100)
|
||||
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Palace boost validation — does wing/room filtering actually help?
|
||||
|
||||
Quantifies the retrieval improvement from the palace spatial metaphor.
|
||||
Uses planted needles to measure recall with and without filtering
|
||||
at different scales.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.benchmarks.data_generator import PalaceDataGenerator
|
||||
from tests.benchmarks.report import record_metric
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestFilteredVsUnfilteredRecall:
|
||||
"""Quantify palace boost: recall improvement from wing/room filtering."""
|
||||
|
||||
SIZES = [1_000, 2_500, 5_000]
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_palace_boost_recall(self, n_drawers, tmp_path, bench_scale):
|
||||
"""Compare recall@5 with/without wing filter at increasing scale."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
_, _, needle_info = gen.populate_palace_directly(
|
||||
palace_path, n_drawers=n_drawers, include_needles=True
|
||||
)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
n_queries = min(10, len(needle_info))
|
||||
unfiltered_hits = 0
|
||||
wing_filtered_hits = 0
|
||||
room_filtered_hits = 0
|
||||
|
||||
for needle in needle_info[:n_queries]:
|
||||
# Unfiltered search
|
||||
result = search_memories(needle["query"], palace_path=palace_path, n_results=5)
|
||||
texts = [h["text"] for h in result.get("results", [])]
|
||||
if any("NEEDLE_" in t for t in texts[:5]):
|
||||
unfiltered_hits += 1
|
||||
|
||||
# Wing-filtered search
|
||||
result = search_memories(
|
||||
needle["query"], palace_path=palace_path, wing=needle["wing"], n_results=5
|
||||
)
|
||||
texts = [h["text"] for h in result.get("results", [])]
|
||||
if any("NEEDLE_" in t for t in texts[:5]):
|
||||
wing_filtered_hits += 1
|
||||
|
||||
# Wing+room filtered search
|
||||
result = search_memories(
|
||||
needle["query"],
|
||||
palace_path=palace_path,
|
||||
wing=needle["wing"],
|
||||
room=needle["room"],
|
||||
n_results=5,
|
||||
)
|
||||
texts = [h["text"] for h in result.get("results", [])]
|
||||
if any("NEEDLE_" in t for t in texts[:5]):
|
||||
room_filtered_hits += 1
|
||||
|
||||
recall_none = unfiltered_hits / max(n_queries, 1)
|
||||
recall_wing = wing_filtered_hits / max(n_queries, 1)
|
||||
recall_room = room_filtered_hits / max(n_queries, 1)
|
||||
|
||||
boost_wing = recall_wing - recall_none
|
||||
boost_room = recall_room - recall_none
|
||||
|
||||
record_metric("palace_boost", f"recall_unfiltered_at_{n_drawers}", round(recall_none, 3))
|
||||
record_metric("palace_boost", f"recall_wing_filtered_at_{n_drawers}", round(recall_wing, 3))
|
||||
record_metric("palace_boost", f"recall_room_filtered_at_{n_drawers}", round(recall_room, 3))
|
||||
record_metric("palace_boost", f"wing_boost_at_{n_drawers}", round(boost_wing, 3))
|
||||
record_metric("palace_boost", f"room_boost_at_{n_drawers}", round(boost_room, 3))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestFilterLatencyBenefit:
|
||||
"""Does filtering reduce query latency by narrowing the search space?"""
|
||||
|
||||
def test_filter_speedup(self, tmp_path, bench_scale):
|
||||
"""Compare latency: no filter vs wing vs wing+room."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=5_000, include_needles=False)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
wing = gen.wings[0]
|
||||
room = gen.rooms_by_wing[wing][0]
|
||||
query = "authentication middleware optimization"
|
||||
n_runs = 10
|
||||
|
||||
# No filter
|
||||
latencies_none = []
|
||||
for _ in range(n_runs):
|
||||
start = time.perf_counter()
|
||||
search_memories(query, palace_path=palace_path, n_results=5)
|
||||
latencies_none.append((time.perf_counter() - start) * 1000)
|
||||
|
||||
# Wing filter
|
||||
latencies_wing = []
|
||||
for _ in range(n_runs):
|
||||
start = time.perf_counter()
|
||||
search_memories(query, palace_path=palace_path, wing=wing, n_results=5)
|
||||
latencies_wing.append((time.perf_counter() - start) * 1000)
|
||||
|
||||
# Wing + room filter
|
||||
latencies_room = []
|
||||
for _ in range(n_runs):
|
||||
start = time.perf_counter()
|
||||
search_memories(query, palace_path=palace_path, wing=wing, room=room, n_results=5)
|
||||
latencies_room.append((time.perf_counter() - start) * 1000)
|
||||
|
||||
avg_none = sum(latencies_none) / len(latencies_none)
|
||||
avg_wing = sum(latencies_wing) / len(latencies_wing)
|
||||
avg_room = sum(latencies_room) / len(latencies_room)
|
||||
|
||||
record_metric("filter_latency", "avg_unfiltered_ms", round(avg_none, 1))
|
||||
record_metric("filter_latency", "avg_wing_filtered_ms", round(avg_wing, 1))
|
||||
record_metric("filter_latency", "avg_room_filtered_ms", round(avg_room, 1))
|
||||
if avg_none > 0:
|
||||
record_metric(
|
||||
"filter_latency", "wing_speedup_pct", round((1 - avg_wing / avg_none) * 100, 1)
|
||||
)
|
||||
record_metric(
|
||||
"filter_latency", "room_speedup_pct", round((1 - avg_room / avg_none) * 100, 1)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestBoostAtIncreasingScale:
|
||||
"""Does the palace boost increase as the palace grows?"""
|
||||
|
||||
def test_boost_scaling(self, tmp_path, bench_scale):
|
||||
"""Measure wing-filtered recall improvement at multiple sizes."""
|
||||
sizes = [500, 1_000, 2_500]
|
||||
boosts = []
|
||||
|
||||
for size in sizes:
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
palace_path = str(tmp_path / f"palace_{size}")
|
||||
_, _, needle_info = gen.populate_palace_directly(
|
||||
palace_path, n_drawers=size, include_needles=True
|
||||
)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
n_queries = min(8, len(needle_info))
|
||||
unfiltered_hits = 0
|
||||
filtered_hits = 0
|
||||
|
||||
for needle in needle_info[:n_queries]:
|
||||
result = search_memories(needle["query"], palace_path=palace_path, n_results=5)
|
||||
if any("NEEDLE_" in h["text"] for h in result.get("results", [])[:5]):
|
||||
unfiltered_hits += 1
|
||||
|
||||
result = search_memories(
|
||||
needle["query"], palace_path=palace_path, wing=needle["wing"], n_results=5
|
||||
)
|
||||
if any("NEEDLE_" in h["text"] for h in result.get("results", [])[:5]):
|
||||
filtered_hits += 1
|
||||
|
||||
recall_none = unfiltered_hits / max(n_queries, 1)
|
||||
recall_filtered = filtered_hits / max(n_queries, 1)
|
||||
boost = recall_filtered - recall_none
|
||||
boosts.append({"size": size, "boost": boost})
|
||||
|
||||
record_metric("boost_scaling", "boosts_by_size", boosts)
|
||||
# Check if boost increases with scale (the hypothesis)
|
||||
if len(boosts) >= 2:
|
||||
trend_positive = boosts[-1]["boost"] >= boosts[0]["boost"]
|
||||
record_metric("boost_scaling", "trend_positive", trend_positive)
|
||||
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Recall threshold test — find the per-bucket size where retrieval breaks.
|
||||
|
||||
The palace_boost tests showed room-filtered recall of 1.0, but only because
|
||||
each room had ~333 drawers. This test concentrates ALL drawers into a single
|
||||
wing+room to find the actual embedding model limit.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import chromadb
|
||||
import pytest
|
||||
|
||||
from tests.benchmarks.data_generator import PalaceDataGenerator
|
||||
from tests.benchmarks.report import record_metric
|
||||
|
||||
|
||||
NEEDLE_TOPICS = [
|
||||
"Fibonacci sequence optimization uses memoization with O(n) space complexity",
|
||||
"PostgreSQL vacuum autovacuum threshold set to 50 percent for table users",
|
||||
"Redis cluster failover timeout configured at 30 seconds with sentinel monitoring",
|
||||
"Kubernetes horizontal pod autoscaler targets 70 percent CPU utilization",
|
||||
"GraphQL subscription uses WebSocket transport with heartbeat interval 25 seconds",
|
||||
"JWT token rotation policy requires refresh every 15 minutes with sliding window",
|
||||
"Elasticsearch index sharding strategy uses 5 primary shards with 1 replica each",
|
||||
"Docker multi-stage build reduces image size from 1.2GB to 180MB for production",
|
||||
"Apache Kafka consumer group rebalance timeout set to 45 seconds",
|
||||
"MongoDB change streams resume token persisted every 100 operations",
|
||||
]
|
||||
|
||||
NEEDLE_QUERIES = [
|
||||
"Fibonacci sequence optimization memoization",
|
||||
"PostgreSQL vacuum autovacuum threshold",
|
||||
"Redis cluster failover timeout sentinel",
|
||||
"Kubernetes horizontal pod autoscaler CPU",
|
||||
"GraphQL subscription WebSocket heartbeat",
|
||||
"JWT token rotation policy refresh",
|
||||
"Elasticsearch index sharding primary replica",
|
||||
"Docker multi-stage build image size production",
|
||||
"Apache Kafka consumer group rebalance",
|
||||
"MongoDB change streams resume token",
|
||||
]
|
||||
|
||||
|
||||
def _populate_single_room(palace_path, n_drawers, n_needles=10):
|
||||
"""Pack all drawers into one wing+room, plant needles, return queries."""
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
os.makedirs(palace_path, exist_ok=True)
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
|
||||
batch_size = 500
|
||||
docs, ids, metas = [], [], []
|
||||
|
||||
# Plant needles
|
||||
for i in range(n_needles):
|
||||
needle_id = f"NEEDLE_{i:04d}"
|
||||
content = f"{needle_id}: {NEEDLE_TOPICS[i]}. Unique planted needle for threshold test."
|
||||
drawer_id = f"drawer_single_room_{hashlib.md5(needle_id.encode()).hexdigest()[:16]}"
|
||||
docs.append(content)
|
||||
ids.append(drawer_id)
|
||||
metas.append(
|
||||
{
|
||||
"wing": "concentrated",
|
||||
"room": "single_room",
|
||||
"source_file": f"needle_{i}.txt",
|
||||
"chunk_index": 0,
|
||||
"added_by": "threshold_bench",
|
||||
"filed_at": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
# Fill with noise — all in the SAME room
|
||||
remaining = n_drawers - len(docs)
|
||||
for i in range(remaining):
|
||||
content = gen._random_text(400, 800)
|
||||
drawer_id = f"drawer_single_room_{hashlib.md5(f'noise_{i}'.encode()).hexdigest()[:16]}"
|
||||
docs.append(content)
|
||||
ids.append(drawer_id)
|
||||
metas.append(
|
||||
{
|
||||
"wing": "concentrated",
|
||||
"room": "single_room",
|
||||
"source_file": f"noise_{i:06d}.txt",
|
||||
"chunk_index": i % 10,
|
||||
"added_by": "threshold_bench",
|
||||
"filed_at": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
if len(docs) >= batch_size:
|
||||
col.add(documents=docs, ids=ids, metadatas=metas)
|
||||
docs, ids, metas = [], [], []
|
||||
|
||||
if docs:
|
||||
col.add(documents=docs, ids=ids, metadatas=metas)
|
||||
|
||||
return client, col
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestRecallThresholdSingleRoom:
|
||||
"""
|
||||
All drawers in one room — isolates the embedding model's retrieval limit.
|
||||
|
||||
Room filtering can't help here. This is the true ceiling.
|
||||
"""
|
||||
|
||||
SIZES = [250, 500, 1_000, 2_000, 3_000, 5_000]
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_single_room_recall(self, n_drawers, tmp_path):
|
||||
"""Recall@5 and @10 with all drawers in one bucket."""
|
||||
palace_path = str(tmp_path / "palace")
|
||||
_populate_single_room(palace_path, n_drawers, n_needles=10)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
hits_at_5 = 0
|
||||
hits_at_10 = 0
|
||||
n_queries = len(NEEDLE_QUERIES)
|
||||
|
||||
for i, query in enumerate(NEEDLE_QUERIES):
|
||||
result = search_memories(
|
||||
query,
|
||||
palace_path=palace_path,
|
||||
wing="concentrated",
|
||||
room="single_room",
|
||||
n_results=10,
|
||||
)
|
||||
if "error" in result:
|
||||
continue
|
||||
|
||||
texts = [h["text"] for h in result.get("results", [])]
|
||||
needle_id = f"NEEDLE_{i:04d}"
|
||||
|
||||
found_at_5 = any(needle_id in t for t in texts[:5])
|
||||
found_at_10 = any(needle_id in t for t in texts[:10])
|
||||
|
||||
if found_at_5:
|
||||
hits_at_5 += 1
|
||||
if found_at_10:
|
||||
hits_at_10 += 1
|
||||
|
||||
recall_5 = hits_at_5 / n_queries
|
||||
recall_10 = hits_at_10 / n_queries
|
||||
|
||||
record_metric("single_room_recall", f"recall_at_5_at_{n_drawers}", round(recall_5, 3))
|
||||
record_metric("single_room_recall", f"recall_at_10_at_{n_drawers}", round(recall_10, 3))
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_single_room_no_filter_recall(self, n_drawers, tmp_path):
|
||||
"""Same test but WITHOUT wing/room filter — pure unfiltered search."""
|
||||
palace_path = str(tmp_path / "palace")
|
||||
_populate_single_room(palace_path, n_drawers, n_needles=10)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
hits_at_5 = 0
|
||||
hits_at_10 = 0
|
||||
n_queries = len(NEEDLE_QUERIES)
|
||||
|
||||
for i, query in enumerate(NEEDLE_QUERIES):
|
||||
result = search_memories(query, palace_path=palace_path, n_results=10)
|
||||
if "error" in result:
|
||||
continue
|
||||
|
||||
texts = [h["text"] for h in result.get("results", [])]
|
||||
needle_id = f"NEEDLE_{i:04d}"
|
||||
|
||||
if any(needle_id in t for t in texts[:5]):
|
||||
hits_at_5 += 1
|
||||
if any(needle_id in t for t in texts[:10]):
|
||||
hits_at_10 += 1
|
||||
|
||||
recall_5 = hits_at_5 / n_queries
|
||||
recall_10 = hits_at_10 / n_queries
|
||||
|
||||
record_metric("single_room_unfiltered", f"recall_at_5_at_{n_drawers}", round(recall_5, 3))
|
||||
record_metric("single_room_unfiltered", f"recall_at_10_at_{n_drawers}", round(recall_10, 3))
|
||||
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
Search performance benchmarks.
|
||||
|
||||
Measures query latency, recall@k, and concurrent search behavior
|
||||
as palace size grows. Uses planted needles for recall measurement.
|
||||
"""
|
||||
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.benchmarks.data_generator import PalaceDataGenerator
|
||||
from tests.benchmarks.report import record_metric
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestSearchLatencyVsSize:
|
||||
"""Query latency scaling as palace grows."""
|
||||
|
||||
SIZES = [500, 1_000, 2_500, 5_000]
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_search_latency_curve(self, n_drawers, tmp_path, bench_scale):
|
||||
"""Measure average search latency at different palace sizes."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=n_drawers, include_needles=False)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
queries = [
|
||||
"authentication middleware",
|
||||
"database optimization",
|
||||
"error handling patterns",
|
||||
"deployment configuration",
|
||||
"testing strategy",
|
||||
]
|
||||
|
||||
latencies = []
|
||||
for q in queries:
|
||||
start = time.perf_counter()
|
||||
result = search_memories(q, palace_path=palace_path, n_results=5)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
latencies.append(elapsed_ms)
|
||||
assert "error" not in result
|
||||
|
||||
avg_ms = sum(latencies) / len(latencies)
|
||||
sorted_lat = sorted(latencies)
|
||||
p50_ms = sorted_lat[len(sorted_lat) // 2]
|
||||
p95_ms = sorted_lat[int(len(sorted_lat) * 0.95)]
|
||||
|
||||
record_metric("search", f"avg_latency_ms_at_{n_drawers}", round(avg_ms, 1))
|
||||
record_metric("search", f"p50_ms_at_{n_drawers}", round(p50_ms, 1))
|
||||
record_metric("search", f"p95_ms_at_{n_drawers}", round(p95_ms, 1))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestSearchRecallAtScale:
|
||||
"""Planted needle recall — does accuracy degrade as palace grows?"""
|
||||
|
||||
SIZES = [500, 1_000, 2_500, 5_000]
|
||||
|
||||
@pytest.mark.parametrize("n_drawers", SIZES)
|
||||
def test_recall_at_k(self, n_drawers, tmp_path, bench_scale):
|
||||
"""Recall@5 and Recall@10 using planted needles."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
_, _, needle_info = gen.populate_palace_directly(
|
||||
palace_path, n_drawers=n_drawers, include_needles=True
|
||||
)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
hits_at_5 = 0
|
||||
hits_at_10 = 0
|
||||
total_needle_queries = min(10, len(needle_info))
|
||||
|
||||
for needle in needle_info[:total_needle_queries]:
|
||||
result = search_memories(needle["query"], palace_path=palace_path, n_results=10)
|
||||
if "error" in result:
|
||||
continue
|
||||
|
||||
texts = [h["text"] for h in result.get("results", [])]
|
||||
|
||||
# Check if needle content appears in top 5
|
||||
found_at_5 = any("NEEDLE_" in t for t in texts[:5])
|
||||
found_at_10 = any("NEEDLE_" in t for t in texts[:10])
|
||||
|
||||
if found_at_5:
|
||||
hits_at_5 += 1
|
||||
if found_at_10:
|
||||
hits_at_10 += 1
|
||||
|
||||
recall_at_5 = hits_at_5 / max(total_needle_queries, 1)
|
||||
recall_at_10 = hits_at_10 / max(total_needle_queries, 1)
|
||||
|
||||
record_metric("search_recall", f"recall_at_5_at_{n_drawers}", round(recall_at_5, 3))
|
||||
record_metric("search_recall", f"recall_at_10_at_{n_drawers}", round(recall_at_10, 3))
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestSearchFilteredVsUnfiltered:
|
||||
"""Compare search performance with and without wing/room filters."""
|
||||
|
||||
def test_filter_impact(self, tmp_path, bench_scale):
|
||||
"""Measure latency and recall difference with wing filtering."""
|
||||
gen = PalaceDataGenerator(seed=42, scale=bench_scale)
|
||||
palace_path = str(tmp_path / "palace")
|
||||
_, _, needle_info = gen.populate_palace_directly(
|
||||
palace_path, n_drawers=2_000, include_needles=True
|
||||
)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
filtered_latencies = []
|
||||
unfiltered_latencies = []
|
||||
filtered_hits = 0
|
||||
unfiltered_hits = 0
|
||||
n_queries = min(10, len(needle_info))
|
||||
|
||||
for needle in needle_info[:n_queries]:
|
||||
# Unfiltered
|
||||
start = time.perf_counter()
|
||||
result_unfiltered = search_memories(
|
||||
needle["query"], palace_path=palace_path, n_results=5
|
||||
)
|
||||
unfiltered_latencies.append((time.perf_counter() - start) * 1000)
|
||||
if any("NEEDLE_" in h["text"] for h in result_unfiltered.get("results", [])[:5]):
|
||||
unfiltered_hits += 1
|
||||
|
||||
# Filtered by wing
|
||||
start = time.perf_counter()
|
||||
result_filtered = search_memories(
|
||||
needle["query"],
|
||||
palace_path=palace_path,
|
||||
wing=needle["wing"],
|
||||
n_results=5,
|
||||
)
|
||||
filtered_latencies.append((time.perf_counter() - start) * 1000)
|
||||
if any("NEEDLE_" in h["text"] for h in result_filtered.get("results", [])[:5]):
|
||||
filtered_hits += 1
|
||||
|
||||
avg_unfiltered = sum(unfiltered_latencies) / max(len(unfiltered_latencies), 1)
|
||||
avg_filtered = sum(filtered_latencies) / max(len(filtered_latencies), 1)
|
||||
latency_improvement = ((avg_unfiltered - avg_filtered) / max(avg_unfiltered, 0.01)) * 100
|
||||
|
||||
record_metric("search_filter", "avg_unfiltered_ms", round(avg_unfiltered, 1))
|
||||
record_metric("search_filter", "avg_filtered_ms", round(avg_filtered, 1))
|
||||
record_metric("search_filter", "latency_improvement_pct", round(latency_improvement, 1))
|
||||
record_metric(
|
||||
"search_filter", "unfiltered_recall_at_5", round(unfiltered_hits / max(n_queries, 1), 3)
|
||||
)
|
||||
record_metric(
|
||||
"search_filter", "filtered_recall_at_5", round(filtered_hits / max(n_queries, 1), 3)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestConcurrentSearch:
|
||||
"""Concurrent query performance — tests PersistentClient contention."""
|
||||
|
||||
def test_concurrent_queries(self, tmp_path):
|
||||
"""Issue N simultaneous queries and measure p50/p95/p99."""
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
queries = [
|
||||
"authentication",
|
||||
"database",
|
||||
"deployment",
|
||||
"error handling",
|
||||
"testing",
|
||||
"monitoring",
|
||||
"caching",
|
||||
"middleware",
|
||||
"serialization",
|
||||
"validation",
|
||||
] * 3 # 30 total queries
|
||||
|
||||
def run_search(query):
|
||||
start = time.perf_counter()
|
||||
result = search_memories(query, palace_path=palace_path, n_results=5)
|
||||
elapsed = (time.perf_counter() - start) * 1000
|
||||
return elapsed, "error" not in result
|
||||
|
||||
# Concurrent execution
|
||||
latencies = []
|
||||
errors = 0
|
||||
with ThreadPoolExecutor(max_workers=4) as executor:
|
||||
futures = {executor.submit(run_search, q): q for q in queries}
|
||||
for future in as_completed(futures):
|
||||
elapsed, success = future.result()
|
||||
latencies.append(elapsed)
|
||||
if not success:
|
||||
errors += 1
|
||||
|
||||
sorted_lat = sorted(latencies)
|
||||
n = len(sorted_lat)
|
||||
|
||||
record_metric("concurrent_search", "p50_ms", round(sorted_lat[n // 2], 1))
|
||||
record_metric("concurrent_search", "p95_ms", round(sorted_lat[int(n * 0.95)], 1))
|
||||
record_metric("concurrent_search", "p99_ms", round(sorted_lat[int(n * 0.99)], 1))
|
||||
record_metric("concurrent_search", "avg_ms", round(sum(sorted_lat) / n, 1))
|
||||
record_metric("concurrent_search", "error_count", errors)
|
||||
record_metric("concurrent_search", "total_queries", len(queries))
|
||||
record_metric("concurrent_search", "workers", 4)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestSearchNResultsScaling:
|
||||
"""How does n_results affect query latency?"""
|
||||
|
||||
@pytest.mark.parametrize("n_results", [1, 5, 10, 25, 50])
|
||||
def test_n_results_latency(self, n_results, tmp_path):
|
||||
gen = PalaceDataGenerator(seed=42, scale="small")
|
||||
palace_path = str(tmp_path / "palace")
|
||||
gen.populate_palace_directly(palace_path, n_drawers=2_000, include_needles=False)
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
|
||||
latencies = []
|
||||
for _ in range(5):
|
||||
start = time.perf_counter()
|
||||
search_memories(
|
||||
"authentication middleware", palace_path=palace_path, n_results=n_results
|
||||
)
|
||||
latencies.append((time.perf_counter() - start) * 1000)
|
||||
|
||||
avg_ms = sum(latencies) / len(latencies)
|
||||
record_metric("search_n_results", f"avg_ms_at_n_{n_results}", round(avg_ms, 1))
|
||||
+21
-1
@@ -34,6 +34,24 @@ from mempalace.config import MempalaceConfig # noqa: E402
|
||||
from mempalace.knowledge_graph import KnowledgeGraph # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_mcp_cache():
|
||||
"""Reset the MCP server's cached ChromaDB client/collection between tests."""
|
||||
|
||||
def _clear_cache():
|
||||
try:
|
||||
from mempalace import mcp_server
|
||||
|
||||
mcp_server._client_cache = None
|
||||
mcp_server._collection_cache = None
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
_clear_cache()
|
||||
yield
|
||||
_clear_cache()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _isolate_home():
|
||||
"""Ensure HOME points to a temp dir for the entire test session.
|
||||
@@ -84,7 +102,9 @@ def collection(palace_path):
|
||||
"""A ChromaDB collection pre-seeded in the temp palace."""
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
return col
|
||||
yield col
|
||||
client.delete_collection("mempalace_drawers")
|
||||
del client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -0,0 +1,609 @@
|
||||
"""Tests for mempalace.cli — the main CLI dispatcher."""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mempalace.cli import (
|
||||
cmd_compress,
|
||||
cmd_hook,
|
||||
cmd_init,
|
||||
cmd_instructions,
|
||||
cmd_mine,
|
||||
cmd_repair,
|
||||
cmd_search,
|
||||
cmd_split,
|
||||
cmd_status,
|
||||
cmd_wakeup,
|
||||
main,
|
||||
)
|
||||
|
||||
|
||||
# ── cmd_status ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_status_default_palace(mock_config_cls):
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
args = argparse.Namespace(palace=None)
|
||||
mock_miner = MagicMock()
|
||||
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
|
||||
cmd_status(args)
|
||||
mock_miner.status.assert_called_once_with(palace_path="/fake/palace")
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_status_custom_palace(mock_config_cls):
|
||||
args = argparse.Namespace(palace="~/my_palace")
|
||||
mock_miner = MagicMock()
|
||||
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
|
||||
cmd_status(args)
|
||||
import os
|
||||
|
||||
expected = os.path.expanduser("~/my_palace")
|
||||
mock_miner.status.assert_called_once_with(palace_path=expected)
|
||||
|
||||
|
||||
# ── cmd_search ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_search_calls_search(mock_config_cls):
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
args = argparse.Namespace(
|
||||
palace=None, query="test query", wing="mywing", room="myroom", results=3
|
||||
)
|
||||
with patch("mempalace.searcher.search") as mock_search:
|
||||
cmd_search(args)
|
||||
mock_search.assert_called_once_with(
|
||||
query="test query",
|
||||
palace_path="/fake/palace",
|
||||
wing="mywing",
|
||||
room="myroom",
|
||||
n_results=3,
|
||||
)
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_search_error_exits(mock_config_cls):
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
args = argparse.Namespace(palace=None, query="q", wing=None, room=None, results=5)
|
||||
from mempalace.searcher import SearchError
|
||||
|
||||
with patch("mempalace.searcher.search", side_effect=SearchError("fail")):
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
cmd_search(args)
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
|
||||
# ── cmd_instructions ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_cmd_instructions_calls_run_instructions():
|
||||
args = argparse.Namespace(name="help")
|
||||
with patch("mempalace.instructions_cli.run_instructions") as mock_run:
|
||||
cmd_instructions(args)
|
||||
mock_run.assert_called_once_with(name="help")
|
||||
|
||||
|
||||
# ── cmd_hook ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_cmd_hook_calls_run_hook():
|
||||
args = argparse.Namespace(hook="session-start", harness="claude-code")
|
||||
with patch("mempalace.hooks_cli.run_hook") as mock_run:
|
||||
cmd_hook(args)
|
||||
mock_run.assert_called_once_with(hook_name="session-start", harness="claude-code")
|
||||
|
||||
|
||||
# ── cmd_init ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_init_no_entities(mock_config_cls, tmp_path):
|
||||
args = argparse.Namespace(dir=str(tmp_path), yes=True)
|
||||
with (
|
||||
patch("mempalace.entity_detector.scan_for_detection", return_value=[]),
|
||||
patch("mempalace.room_detector_local.detect_rooms_local") as mock_rooms,
|
||||
):
|
||||
cmd_init(args)
|
||||
mock_rooms.assert_called_once_with(project_dir=str(tmp_path), yes=True)
|
||||
mock_config_cls.return_value.init.assert_called_once()
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_init_with_entities(mock_config_cls, tmp_path):
|
||||
fake_files = [tmp_path / "a.txt"]
|
||||
detected = {"people": [{"name": "Alice"}], "projects": [], "uncertain": []}
|
||||
confirmed = {"people": ["Alice"], "projects": []}
|
||||
args = argparse.Namespace(dir=str(tmp_path), yes=True)
|
||||
with (
|
||||
patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files),
|
||||
patch("mempalace.entity_detector.detect_entities", return_value=detected),
|
||||
patch("mempalace.entity_detector.confirm_entities", return_value=confirmed),
|
||||
patch("mempalace.room_detector_local.detect_rooms_local"),
|
||||
patch("builtins.open", MagicMock()),
|
||||
):
|
||||
cmd_init(args)
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_init_with_entities_zero_total(mock_config_cls, tmp_path, capsys):
|
||||
"""When entities detected but total is 0, prints 'No entities' message."""
|
||||
fake_files = [tmp_path / "a.txt"]
|
||||
detected = {"people": [], "projects": [], "uncertain": []}
|
||||
args = argparse.Namespace(dir=str(tmp_path), yes=False)
|
||||
with (
|
||||
patch("mempalace.entity_detector.scan_for_detection", return_value=fake_files),
|
||||
patch("mempalace.entity_detector.detect_entities", return_value=detected),
|
||||
patch("mempalace.room_detector_local.detect_rooms_local"),
|
||||
):
|
||||
cmd_init(args)
|
||||
out = capsys.readouterr().out
|
||||
assert "No entities detected" in out
|
||||
|
||||
|
||||
# ── cmd_mine ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_mine_projects_mode(mock_config_cls):
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
args = argparse.Namespace(
|
||||
dir="/src",
|
||||
palace=None,
|
||||
mode="projects",
|
||||
wing=None,
|
||||
agent="mempalace",
|
||||
limit=0,
|
||||
dry_run=False,
|
||||
no_gitignore=False,
|
||||
include_ignored=[],
|
||||
extract="exchange",
|
||||
)
|
||||
with patch("mempalace.miner.mine") as mock_mine:
|
||||
cmd_mine(args)
|
||||
mock_mine.assert_called_once_with(
|
||||
project_dir="/src",
|
||||
palace_path="/fake/palace",
|
||||
wing_override=None,
|
||||
agent="mempalace",
|
||||
limit=0,
|
||||
dry_run=False,
|
||||
respect_gitignore=True,
|
||||
include_ignored=[],
|
||||
)
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_mine_convos_mode(mock_config_cls):
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
args = argparse.Namespace(
|
||||
dir="/chats",
|
||||
palace=None,
|
||||
mode="convos",
|
||||
wing="mywing",
|
||||
agent="me",
|
||||
limit=10,
|
||||
dry_run=True,
|
||||
no_gitignore=False,
|
||||
include_ignored=[],
|
||||
extract="general",
|
||||
)
|
||||
with patch("mempalace.convo_miner.mine_convos") as mock_mine:
|
||||
cmd_mine(args)
|
||||
mock_mine.assert_called_once_with(
|
||||
convo_dir="/chats",
|
||||
palace_path="/fake/palace",
|
||||
wing="mywing",
|
||||
agent="me",
|
||||
limit=10,
|
||||
dry_run=True,
|
||||
extract_mode="general",
|
||||
)
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_mine_include_ignored_comma_split(mock_config_cls):
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
args = argparse.Namespace(
|
||||
dir="/src",
|
||||
palace=None,
|
||||
mode="projects",
|
||||
wing=None,
|
||||
agent="mempalace",
|
||||
limit=0,
|
||||
dry_run=False,
|
||||
no_gitignore=False,
|
||||
include_ignored=["a.txt,b.txt", "c.txt"],
|
||||
extract="exchange",
|
||||
)
|
||||
with patch("mempalace.miner.mine") as mock_mine:
|
||||
cmd_mine(args)
|
||||
mock_mine.assert_called_once()
|
||||
call_kwargs = mock_mine.call_args[1]
|
||||
assert call_kwargs["include_ignored"] == ["a.txt", "b.txt", "c.txt"]
|
||||
|
||||
|
||||
# ── cmd_wakeup ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_wakeup(mock_config_cls, capsys):
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
args = argparse.Namespace(palace=None, wing=None)
|
||||
mock_stack = MagicMock()
|
||||
mock_stack.wake_up.return_value = "Hello world context"
|
||||
with patch("mempalace.layers.MemoryStack", return_value=mock_stack):
|
||||
cmd_wakeup(args)
|
||||
out = capsys.readouterr().out
|
||||
assert "Hello world context" in out
|
||||
assert "tokens" in out
|
||||
|
||||
|
||||
# ── cmd_split ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_cmd_split_basic():
|
||||
args = argparse.Namespace(dir="/chats", output_dir=None, dry_run=False, min_sessions=2)
|
||||
with patch("mempalace.split_mega_files.main") as mock_main:
|
||||
cmd_split(args)
|
||||
mock_main.assert_called_once()
|
||||
|
||||
|
||||
def test_cmd_split_all_options():
|
||||
args = argparse.Namespace(dir="/chats", output_dir="/out", dry_run=True, min_sessions=5)
|
||||
with patch("mempalace.split_mega_files.main") as mock_main:
|
||||
cmd_split(args)
|
||||
mock_main.assert_called_once()
|
||||
# sys.argv should be restored
|
||||
assert sys.argv[0] != "mempalace split"
|
||||
|
||||
|
||||
# ── main() argparse dispatch ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_main_no_args_prints_help(capsys):
|
||||
with patch("sys.argv", ["mempalace"]):
|
||||
main()
|
||||
out = capsys.readouterr().out
|
||||
assert "MemPalace" in out
|
||||
|
||||
|
||||
def test_main_status_dispatches():
|
||||
with (
|
||||
patch("sys.argv", ["mempalace", "status"]),
|
||||
patch("mempalace.cli.cmd_status") as mock_cmd,
|
||||
):
|
||||
main()
|
||||
mock_cmd.assert_called_once()
|
||||
|
||||
|
||||
def test_main_search_dispatches():
|
||||
with (
|
||||
patch("sys.argv", ["mempalace", "search", "my query"]),
|
||||
patch("mempalace.cli.cmd_search") as mock_cmd,
|
||||
):
|
||||
main()
|
||||
mock_cmd.assert_called_once()
|
||||
|
||||
|
||||
def test_main_init_dispatches():
|
||||
with (
|
||||
patch("sys.argv", ["mempalace", "init", "/some/dir"]),
|
||||
patch("mempalace.cli.cmd_init") as mock_cmd,
|
||||
):
|
||||
main()
|
||||
mock_cmd.assert_called_once()
|
||||
|
||||
|
||||
def test_main_mine_dispatches():
|
||||
with (
|
||||
patch("sys.argv", ["mempalace", "mine", "/some/dir"]),
|
||||
patch("mempalace.cli.cmd_mine") as mock_cmd,
|
||||
):
|
||||
main()
|
||||
mock_cmd.assert_called_once()
|
||||
|
||||
|
||||
def test_main_wakeup_dispatches():
|
||||
with (
|
||||
patch("sys.argv", ["mempalace", "wake-up"]),
|
||||
patch("mempalace.cli.cmd_wakeup") as mock_cmd,
|
||||
):
|
||||
main()
|
||||
mock_cmd.assert_called_once()
|
||||
|
||||
|
||||
def test_main_split_dispatches():
|
||||
with (
|
||||
patch("sys.argv", ["mempalace", "split", "/chats"]),
|
||||
patch("mempalace.cli.cmd_split") as mock_cmd,
|
||||
):
|
||||
main()
|
||||
mock_cmd.assert_called_once()
|
||||
|
||||
|
||||
def test_main_hook_no_subcommand_prints_help(capsys):
|
||||
with patch("sys.argv", ["mempalace", "hook"]):
|
||||
main()
|
||||
out = capsys.readouterr().out
|
||||
assert "hook" in out.lower() or "run" in out.lower()
|
||||
|
||||
|
||||
def test_main_hook_run_dispatches():
|
||||
with (
|
||||
patch(
|
||||
"sys.argv",
|
||||
["mempalace", "hook", "run", "--hook", "session-start", "--harness", "claude-code"],
|
||||
),
|
||||
patch("mempalace.cli.cmd_hook") as mock_cmd,
|
||||
):
|
||||
main()
|
||||
mock_cmd.assert_called_once()
|
||||
|
||||
|
||||
def test_main_instructions_no_subcommand_prints_help(capsys):
|
||||
with patch("sys.argv", ["mempalace", "instructions"]):
|
||||
main()
|
||||
out = capsys.readouterr().out
|
||||
assert "instructions" in out.lower() or "init" in out.lower()
|
||||
|
||||
|
||||
def test_main_instructions_dispatches():
|
||||
with (
|
||||
patch("sys.argv", ["mempalace", "instructions", "help"]),
|
||||
patch("mempalace.cli.cmd_instructions") as mock_cmd,
|
||||
):
|
||||
main()
|
||||
mock_cmd.assert_called_once()
|
||||
|
||||
|
||||
def test_main_repair_dispatches():
|
||||
with (
|
||||
patch("sys.argv", ["mempalace", "repair"]),
|
||||
patch("mempalace.cli.cmd_repair") as mock_cmd,
|
||||
):
|
||||
main()
|
||||
mock_cmd.assert_called_once()
|
||||
|
||||
|
||||
def test_main_compress_dispatches():
|
||||
with (
|
||||
patch("sys.argv", ["mempalace", "compress"]),
|
||||
patch("mempalace.cli.cmd_compress") as mock_cmd,
|
||||
):
|
||||
main()
|
||||
mock_cmd.assert_called_once()
|
||||
|
||||
|
||||
# ── cmd_repair ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_repair_no_palace(mock_config_cls, tmp_path, capsys):
|
||||
mock_config_cls.return_value.palace_path = str(tmp_path / "nonexistent")
|
||||
args = argparse.Namespace(palace=None)
|
||||
mock_chromadb = MagicMock()
|
||||
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
|
||||
cmd_repair(args)
|
||||
out = capsys.readouterr().out
|
||||
assert "No palace found" in out
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys):
|
||||
palace_dir = tmp_path / "palace"
|
||||
palace_dir.mkdir()
|
||||
mock_config_cls.return_value.palace_path = str(palace_dir)
|
||||
args = argparse.Namespace(palace=None)
|
||||
mock_chromadb = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.side_effect = Exception("corrupt db")
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
|
||||
cmd_repair(args)
|
||||
out = capsys.readouterr().out
|
||||
assert "Error reading palace" in out
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys):
|
||||
palace_dir = tmp_path / "palace"
|
||||
palace_dir.mkdir()
|
||||
mock_config_cls.return_value.palace_path = str(palace_dir)
|
||||
args = argparse.Namespace(palace=None)
|
||||
mock_chromadb = MagicMock()
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 0
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
|
||||
cmd_repair(args)
|
||||
out = capsys.readouterr().out
|
||||
assert "Nothing to repair" in out
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_repair_success(mock_config_cls, tmp_path, capsys):
|
||||
palace_dir = tmp_path / "palace"
|
||||
palace_dir.mkdir()
|
||||
mock_config_cls.return_value.palace_path = str(palace_dir)
|
||||
args = argparse.Namespace(palace=None)
|
||||
mock_chromadb = MagicMock()
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 2
|
||||
mock_col.get.return_value = {
|
||||
"ids": ["id1", "id2"],
|
||||
"documents": ["doc1", "doc2"],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "b"}],
|
||||
}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_new_col = MagicMock()
|
||||
mock_client.create_collection.return_value = mock_new_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
|
||||
cmd_repair(args)
|
||||
out = capsys.readouterr().out
|
||||
assert "Repair complete" in out
|
||||
assert "2 drawers rebuilt" in out
|
||||
|
||||
|
||||
# ── cmd_compress ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_compress_no_palace(mock_config_cls, capsys):
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None)
|
||||
mock_chromadb = MagicMock()
|
||||
mock_chromadb.PersistentClient.side_effect = Exception("no palace")
|
||||
with (
|
||||
patch.dict("sys.modules", {"chromadb": mock_chromadb}),
|
||||
pytest.raises(SystemExit),
|
||||
):
|
||||
cmd_compress(args)
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_compress_no_drawers(mock_config_cls, capsys):
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
args = argparse.Namespace(palace=None, wing="mywing", dry_run=False, config=None)
|
||||
mock_chromadb = MagicMock()
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
|
||||
cmd_compress(args)
|
||||
out = capsys.readouterr().out
|
||||
assert "No drawers found" in out
|
||||
|
||||
|
||||
def _make_mock_dialect_module(dialect_instance):
|
||||
"""Create a mock dialect module with a Dialect class that returns the given instance."""
|
||||
mock_mod = MagicMock()
|
||||
mock_mod.Dialect.return_value = dialect_instance
|
||||
mock_mod.Dialect.from_config.return_value = dialect_instance
|
||||
mock_mod.Dialect.count_tokens = MagicMock(side_effect=lambda x: len(x) // 4)
|
||||
return mock_mod
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_compress_dry_run(mock_config_cls, capsys):
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=None)
|
||||
mock_chromadb = MagicMock()
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.side_effect = [
|
||||
{
|
||||
"documents": ["some long text here for testing"],
|
||||
"metadatas": [{"wing": "test", "room": "general", "source_file": "test.txt"}],
|
||||
"ids": ["id1"],
|
||||
},
|
||||
{"documents": [], "metadatas": [], "ids": []},
|
||||
]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
mock_dialect = MagicMock()
|
||||
mock_dialect.compress.return_value = "compressed"
|
||||
mock_dialect.compression_stats.return_value = {
|
||||
"original_chars": 100,
|
||||
"compressed_chars": 30,
|
||||
"original_tokens": 25,
|
||||
"compressed_tokens": 8,
|
||||
"ratio": 3.3,
|
||||
}
|
||||
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"chromadb": mock_chromadb,
|
||||
"mempalace.dialect": mock_dialect_mod,
|
||||
},
|
||||
):
|
||||
cmd_compress(args)
|
||||
out = capsys.readouterr().out
|
||||
assert "dry run" in out.lower()
|
||||
assert "Compressing" in out
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_compress_with_config(mock_config_cls, tmp_path, capsys):
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
config_file = tmp_path / "entities.json"
|
||||
config_file.write_text('{"people": [], "projects": []}')
|
||||
args = argparse.Namespace(palace=None, wing=None, dry_run=True, config=str(config_file))
|
||||
mock_chromadb = MagicMock()
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.return_value = {"documents": [], "metadatas": [], "ids": []}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
mock_dialect = MagicMock()
|
||||
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"chromadb": mock_chromadb,
|
||||
"mempalace.dialect": mock_dialect_mod,
|
||||
},
|
||||
):
|
||||
cmd_compress(args)
|
||||
out = capsys.readouterr().out
|
||||
assert "Loaded entity config" in out
|
||||
|
||||
|
||||
@patch("mempalace.cli.MempalaceConfig")
|
||||
def test_cmd_compress_stores_results(mock_config_cls, capsys):
|
||||
"""Non-dry-run compress stores to mempalace_compressed collection."""
|
||||
mock_config_cls.return_value.palace_path = "/fake/palace"
|
||||
args = argparse.Namespace(palace=None, wing=None, dry_run=False, config=None)
|
||||
mock_chromadb = MagicMock()
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.side_effect = [
|
||||
{
|
||||
"documents": ["text"],
|
||||
"metadatas": [{"wing": "w", "room": "r", "source_file": "f.txt"}],
|
||||
"ids": ["id1"],
|
||||
},
|
||||
{"documents": [], "metadatas": [], "ids": []},
|
||||
]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_comp_col = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_comp_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
mock_dialect = MagicMock()
|
||||
mock_dialect.compress.return_value = "compressed"
|
||||
mock_dialect.compression_stats.return_value = {
|
||||
"original_chars": 100,
|
||||
"compressed_chars": 30,
|
||||
"original_tokens": 25,
|
||||
"compressed_tokens": 8,
|
||||
"ratio": 3.3,
|
||||
}
|
||||
mock_dialect_mod = _make_mock_dialect_module(mock_dialect)
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"chromadb": mock_chromadb,
|
||||
"mempalace.dialect": mock_dialect_mod,
|
||||
},
|
||||
):
|
||||
cmd_compress(args)
|
||||
out = capsys.readouterr().out
|
||||
assert "Stored" in out
|
||||
mock_comp_col.upsert.assert_called_once()
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Extra tests for mempalace.config to cover remaining gaps."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from mempalace.config import MempalaceConfig
|
||||
|
||||
|
||||
def test_config_bad_json(tmp_path):
|
||||
"""Bad JSON in config file falls back to empty."""
|
||||
(tmp_path / "config.json").write_text("not json", encoding="utf-8")
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||
assert cfg.palace_path # still returns default
|
||||
|
||||
|
||||
def test_people_map_from_file(tmp_path):
|
||||
(tmp_path / "people_map.json").write_text(json.dumps({"bob": "Robert"}), encoding="utf-8")
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||
assert cfg.people_map == {"bob": "Robert"}
|
||||
|
||||
|
||||
def test_people_map_bad_json(tmp_path):
|
||||
(tmp_path / "people_map.json").write_text("bad", encoding="utf-8")
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||
assert cfg.people_map == {}
|
||||
|
||||
|
||||
def test_people_map_missing(tmp_path):
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||
assert cfg.people_map == {}
|
||||
|
||||
|
||||
def test_topic_wings_default(tmp_path):
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||
assert isinstance(cfg.topic_wings, list)
|
||||
assert "emotions" in cfg.topic_wings
|
||||
|
||||
|
||||
def test_hall_keywords_default(tmp_path):
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||
assert isinstance(cfg.hall_keywords, dict)
|
||||
assert "technical" in cfg.hall_keywords
|
||||
|
||||
|
||||
def test_init_idempotent(tmp_path):
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||
cfg.init()
|
||||
cfg.init() # second call should not overwrite
|
||||
with open(tmp_path / "config.json") as f:
|
||||
data = json.load(f)
|
||||
assert "palace_path" in data
|
||||
|
||||
|
||||
def test_save_people_map(tmp_path):
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||
result = cfg.save_people_map({"alice": "Alice Smith"})
|
||||
assert result.exists()
|
||||
with open(result) as f:
|
||||
data = json.load(f)
|
||||
assert data["alice"] == "Alice Smith"
|
||||
|
||||
|
||||
def test_env_mempal_palace_path(tmp_path):
|
||||
"""MEMPAL_PALACE_PATH (legacy) should also work."""
|
||||
os.environ.pop("MEMPALACE_PALACE_PATH", None)
|
||||
os.environ["MEMPAL_PALACE_PATH"] = "/legacy/path"
|
||||
try:
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||
assert cfg.palace_path == "/legacy/path"
|
||||
finally:
|
||||
del os.environ["MEMPAL_PALACE_PATH"]
|
||||
|
||||
|
||||
def test_collection_name_from_config(tmp_path):
|
||||
(tmp_path / "config.json").write_text(
|
||||
json.dumps({"collection_name": "custom_col"}), encoding="utf-8"
|
||||
)
|
||||
cfg = MempalaceConfig(config_dir=str(tmp_path))
|
||||
assert cfg.collection_name == "custom_col"
|
||||
@@ -23,4 +23,4 @@ def test_convo_mining():
|
||||
results = col.query(query_texts=["memory persistence"], n_results=1)
|
||||
assert len(results["documents"][0]) > 0
|
||||
|
||||
shutil.rmtree(tmpdir)
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
"""Unit tests for convo_miner pure functions (no chromadb needed)."""
|
||||
|
||||
from mempalace.convo_miner import (
|
||||
chunk_exchanges,
|
||||
detect_convo_room,
|
||||
scan_convos,
|
||||
)
|
||||
|
||||
|
||||
class TestChunkExchanges:
|
||||
def test_exchange_chunking(self):
|
||||
content = (
|
||||
"> What is memory?\n"
|
||||
"Memory is persistence of information over time.\n\n"
|
||||
"> Why does it matter?\n"
|
||||
"It enables continuity across sessions and conversations.\n\n"
|
||||
"> How do we build it?\n"
|
||||
"With structured storage and retrieval mechanisms.\n"
|
||||
)
|
||||
chunks = chunk_exchanges(content)
|
||||
assert len(chunks) >= 2
|
||||
assert all("content" in c and "chunk_index" in c for c in chunks)
|
||||
|
||||
def test_paragraph_fallback(self):
|
||||
"""Content without '>' lines falls back to paragraph chunking."""
|
||||
content = (
|
||||
"This is a long paragraph about memory systems. " * 10 + "\n\n"
|
||||
"This is another paragraph about storage. " * 10 + "\n\n"
|
||||
"And a third paragraph about retrieval. " * 10
|
||||
)
|
||||
chunks = chunk_exchanges(content)
|
||||
assert len(chunks) >= 2
|
||||
|
||||
def test_paragraph_line_group_fallback(self):
|
||||
"""Long content with no paragraph breaks chunks by line groups."""
|
||||
lines = [f"Line {i}: some content that is meaningful" for i in range(60)]
|
||||
content = "\n".join(lines)
|
||||
chunks = chunk_exchanges(content)
|
||||
assert len(chunks) >= 1
|
||||
|
||||
def test_empty_content(self):
|
||||
chunks = chunk_exchanges("")
|
||||
assert chunks == []
|
||||
|
||||
def test_short_content_skipped(self):
|
||||
chunks = chunk_exchanges("> hi\nbye")
|
||||
# Too short to produce chunks (below MIN_CHUNK_SIZE)
|
||||
assert isinstance(chunks, list)
|
||||
|
||||
|
||||
class TestDetectConvoRoom:
|
||||
def test_technical_room(self):
|
||||
content = "Let me debug this python function and fix the code error in the api"
|
||||
assert detect_convo_room(content) == "technical"
|
||||
|
||||
def test_planning_room(self):
|
||||
content = "We need to plan the roadmap for the next sprint and set milestone deadlines"
|
||||
assert detect_convo_room(content) == "planning"
|
||||
|
||||
def test_architecture_room(self):
|
||||
content = "The architecture uses a service layer with component interface and module design"
|
||||
assert detect_convo_room(content) == "architecture"
|
||||
|
||||
def test_decisions_room(self):
|
||||
content = "We decided to switch and migrated to the new framework after we chose it"
|
||||
assert detect_convo_room(content) == "decisions"
|
||||
|
||||
def test_general_fallback(self):
|
||||
content = "Hello, how are you doing today? The weather is nice."
|
||||
assert detect_convo_room(content) == "general"
|
||||
|
||||
|
||||
class TestScanConvos:
|
||||
def test_scan_finds_txt_and_md(self, tmp_path):
|
||||
(tmp_path / "chat.txt").write_text("hello", encoding="utf-8")
|
||||
(tmp_path / "notes.md").write_text("world", encoding="utf-8")
|
||||
(tmp_path / "image.png").write_bytes(b"fake")
|
||||
files = scan_convos(str(tmp_path))
|
||||
extensions = {f.suffix for f in files}
|
||||
assert ".txt" in extensions
|
||||
assert ".md" in extensions
|
||||
assert ".png" not in extensions
|
||||
|
||||
def test_scan_skips_git_dir(self, tmp_path):
|
||||
git_dir = tmp_path / ".git"
|
||||
git_dir.mkdir()
|
||||
(git_dir / "config.txt").write_text("git stuff", encoding="utf-8")
|
||||
(tmp_path / "chat.txt").write_text("hello", encoding="utf-8")
|
||||
files = scan_convos(str(tmp_path))
|
||||
assert len(files) == 1
|
||||
|
||||
def test_scan_skips_meta_json(self, tmp_path):
|
||||
(tmp_path / "chat.meta.json").write_text("{}", encoding="utf-8")
|
||||
(tmp_path / "chat.json").write_text("{}", encoding="utf-8")
|
||||
files = scan_convos(str(tmp_path))
|
||||
names = [f.name for f in files]
|
||||
assert "chat.json" in names
|
||||
assert "chat.meta.json" not in names
|
||||
|
||||
def test_scan_empty_dir(self, tmp_path):
|
||||
files = scan_convos(str(tmp_path))
|
||||
assert files == []
|
||||
@@ -0,0 +1,380 @@
|
||||
"""Tests for mempalace.entity_detector."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from mempalace.entity_detector import (
|
||||
PROSE_EXTENSIONS,
|
||||
STOPWORDS,
|
||||
_print_entity_list,
|
||||
classify_entity,
|
||||
confirm_entities,
|
||||
detect_entities,
|
||||
extract_candidates,
|
||||
scan_for_detection,
|
||||
score_entity,
|
||||
)
|
||||
|
||||
|
||||
# ── extract_candidates ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_candidates_finds_frequent_names():
|
||||
text = "Riley said hello. Riley laughed. Riley smiled. Riley waved."
|
||||
result = extract_candidates(text)
|
||||
assert "Riley" in result
|
||||
assert result["Riley"] >= 3
|
||||
|
||||
|
||||
def test_extract_candidates_ignores_stopwords():
|
||||
# "The" appears many times but is a stopword
|
||||
text = "The The The The The The"
|
||||
result = extract_candidates(text)
|
||||
assert "The" not in result
|
||||
|
||||
|
||||
def test_extract_candidates_requires_min_frequency():
|
||||
text = "Riley said hi. Devon waved."
|
||||
result = extract_candidates(text)
|
||||
# Each name appears only once, below the threshold of 3
|
||||
assert "Riley" not in result
|
||||
assert "Devon" not in result
|
||||
|
||||
|
||||
def test_extract_candidates_finds_multi_word_names():
|
||||
# Multi-word names need 3+ occurrences and no stopwords
|
||||
text = "Claude Code is great. Claude Code rocks. Claude Code works. Claude Code rules."
|
||||
result = extract_candidates(text)
|
||||
assert "Claude Code" in result
|
||||
|
||||
|
||||
def test_extract_candidates_empty_text():
|
||||
result = extract_candidates("")
|
||||
assert result == {}
|
||||
|
||||
|
||||
# ── score_entity ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_score_entity_person_verbs():
|
||||
text = "Riley said hello. Riley asked why. Riley told me."
|
||||
lines = text.splitlines()
|
||||
result = score_entity("Riley", text, lines)
|
||||
assert result["person_score"] > 0
|
||||
assert len(result["person_signals"]) > 0
|
||||
|
||||
|
||||
def test_score_entity_project_verbs():
|
||||
text = "We are building ChromaDB. We deployed ChromaDB. Install ChromaDB."
|
||||
lines = text.splitlines()
|
||||
result = score_entity("ChromaDB", text, lines)
|
||||
assert result["project_score"] > 0
|
||||
assert len(result["project_signals"]) > 0
|
||||
|
||||
|
||||
def test_score_entity_dialogue_markers():
|
||||
text = "Riley: Hey, how are you?\nRiley: I'm fine."
|
||||
lines = text.splitlines()
|
||||
result = score_entity("Riley", text, lines)
|
||||
assert result["person_score"] > 0
|
||||
|
||||
|
||||
def test_score_entity_code_ref():
|
||||
text = "Check out ChromaDB.py for details. Also ChromaDB.js is good."
|
||||
lines = text.splitlines()
|
||||
result = score_entity("ChromaDB", text, lines)
|
||||
assert result["project_score"] > 0
|
||||
|
||||
|
||||
def test_score_entity_no_signals():
|
||||
text = "Nothing interesting here at all."
|
||||
lines = text.splitlines()
|
||||
result = score_entity("Riley", text, lines)
|
||||
assert result["person_score"] == 0
|
||||
assert result["project_score"] == 0
|
||||
|
||||
|
||||
# ── classify_entity ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_classify_entity_no_signals_gives_uncertain():
|
||||
scores = {
|
||||
"person_score": 0,
|
||||
"project_score": 0,
|
||||
"person_signals": [],
|
||||
"project_signals": [],
|
||||
}
|
||||
result = classify_entity("Foo", 10, scores)
|
||||
assert result["type"] == "uncertain"
|
||||
assert result["name"] == "Foo"
|
||||
|
||||
|
||||
def test_classify_entity_strong_project():
|
||||
scores = {
|
||||
"person_score": 0,
|
||||
"project_score": 10,
|
||||
"person_signals": [],
|
||||
"project_signals": ["project verb (5x)", "code file reference (2x)"],
|
||||
}
|
||||
result = classify_entity("ChromaDB", 5, scores)
|
||||
assert result["type"] == "project"
|
||||
|
||||
|
||||
def test_classify_entity_strong_person_needs_two_signal_types():
|
||||
scores = {
|
||||
"person_score": 10,
|
||||
"project_score": 0,
|
||||
"person_signals": [
|
||||
"dialogue marker (3x)",
|
||||
"'Riley ...' action (4x)",
|
||||
],
|
||||
"project_signals": [],
|
||||
}
|
||||
result = classify_entity("Riley", 8, scores)
|
||||
assert result["type"] == "person"
|
||||
|
||||
|
||||
def test_classify_entity_pronoun_only_is_uncertain():
|
||||
scores = {
|
||||
"person_score": 8,
|
||||
"project_score": 0,
|
||||
"person_signals": ["pronoun nearby (4x)"],
|
||||
"project_signals": [],
|
||||
}
|
||||
result = classify_entity("Riley", 5, scores)
|
||||
assert result["type"] == "uncertain"
|
||||
|
||||
|
||||
def test_classify_entity_mixed_signals():
|
||||
scores = {
|
||||
"person_score": 5,
|
||||
"project_score": 5,
|
||||
"person_signals": ["pronoun nearby (2x)"],
|
||||
"project_signals": ["project verb (2x)"],
|
||||
}
|
||||
result = classify_entity("Lantern", 5, scores)
|
||||
assert result["type"] == "uncertain"
|
||||
assert "mixed signals" in result["signals"][-1]
|
||||
|
||||
|
||||
# ── detect_entities (integration) ───────────────────────────────────────
|
||||
|
||||
|
||||
def test_detect_entities_with_person_file(tmp_path):
|
||||
f = tmp_path / "notes.txt"
|
||||
content = "\n".join(
|
||||
[
|
||||
"Riley said hello today.",
|
||||
"Riley asked about the project.",
|
||||
"Riley told me she was happy.",
|
||||
"Riley: I think we should go.",
|
||||
"Hey Riley, thanks for the help.",
|
||||
"Riley laughed and smiled.",
|
||||
"Riley decided to join.",
|
||||
"Riley pushed the change.",
|
||||
]
|
||||
)
|
||||
f.write_text(content)
|
||||
result = detect_entities([f])
|
||||
all_names = [e["name"] for cat in result.values() for e in cat]
|
||||
assert "Riley" in all_names
|
||||
|
||||
|
||||
def test_detect_entities_with_project_file(tmp_path):
|
||||
f = tmp_path / "readme.txt"
|
||||
# "ChromaDB" has uppercase+lowercase mix but extract_candidates looks
|
||||
# for /[A-Z][a-z]{1,19}/ — so we need a name that matches that regex.
|
||||
# Use "Lantern" which matches the capitalized-word pattern.
|
||||
content = "\n".join(
|
||||
[
|
||||
"The Lantern project is great.",
|
||||
"Building Lantern was fun.",
|
||||
"We deployed Lantern today.",
|
||||
"Install Lantern with pip install Lantern.",
|
||||
"Check Lantern.py for the source.",
|
||||
"Lantern v2 is faster.",
|
||||
]
|
||||
)
|
||||
f.write_text(content)
|
||||
result = detect_entities([f])
|
||||
all_names = [e["name"] for cat in result.values() for e in cat]
|
||||
assert "Lantern" in all_names
|
||||
|
||||
|
||||
def test_detect_entities_empty_files(tmp_path):
|
||||
f = tmp_path / "empty.txt"
|
||||
f.write_text("")
|
||||
result = detect_entities([f])
|
||||
assert result == {"people": [], "projects": [], "uncertain": []}
|
||||
|
||||
|
||||
def test_detect_entities_handles_missing_file(tmp_path):
|
||||
missing = tmp_path / "nonexistent.txt"
|
||||
result = detect_entities([missing])
|
||||
assert result == {"people": [], "projects": [], "uncertain": []}
|
||||
|
||||
|
||||
def test_detect_entities_respects_max_files(tmp_path):
|
||||
files = []
|
||||
for i in range(5):
|
||||
f = tmp_path / f"file{i}.txt"
|
||||
f.write_text("Riley said hello. " * 10)
|
||||
files.append(f)
|
||||
# max_files=2 should only read 2 files
|
||||
result = detect_entities(files, max_files=2)
|
||||
# Should still work without error
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
# ── scan_for_detection ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_scan_for_detection_finds_prose(tmp_path):
|
||||
(tmp_path / "notes.md").write_text("hello")
|
||||
(tmp_path / "data.txt").write_text("world")
|
||||
(tmp_path / "code.py").write_text("import os")
|
||||
files = scan_for_detection(str(tmp_path))
|
||||
extensions = {os.path.splitext(str(f))[1] for f in files}
|
||||
# Prose files should be found
|
||||
assert ".md" in extensions or ".txt" in extensions
|
||||
|
||||
|
||||
def test_scan_for_detection_skips_git_dir(tmp_path):
|
||||
git_dir = tmp_path / ".git"
|
||||
git_dir.mkdir()
|
||||
(git_dir / "config.txt").write_text("git config")
|
||||
(tmp_path / "readme.md").write_text("hello")
|
||||
files = scan_for_detection(str(tmp_path))
|
||||
file_strs = [str(f) for f in files]
|
||||
assert not any(".git" in f for f in file_strs)
|
||||
|
||||
|
||||
# ── module-level constants ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_stopwords_contains_common_words():
|
||||
assert "the" in STOPWORDS
|
||||
assert "import" in STOPWORDS
|
||||
assert "class" in STOPWORDS
|
||||
|
||||
|
||||
def test_prose_extensions():
|
||||
assert ".txt" in PROSE_EXTENSIONS
|
||||
assert ".md" in PROSE_EXTENSIONS
|
||||
|
||||
|
||||
# ── _print_entity_list ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_print_entity_list_with_entities(capsys):
|
||||
entities = [
|
||||
{"name": "Alice", "confidence": 0.9, "signals": ["dialogue marker (3x)"]},
|
||||
{"name": "Bob", "confidence": 0.5, "signals": []},
|
||||
]
|
||||
_print_entity_list(entities, "PEOPLE")
|
||||
out = capsys.readouterr().out
|
||||
assert "PEOPLE" in out
|
||||
assert "Alice" in out
|
||||
assert "Bob" in out
|
||||
|
||||
|
||||
def test_print_entity_list_empty(capsys):
|
||||
_print_entity_list([], "PEOPLE")
|
||||
out = capsys.readouterr().out
|
||||
assert "none detected" in out
|
||||
|
||||
|
||||
# ── confirm_entities ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_confirm_entities_yes_mode():
|
||||
detected = {
|
||||
"people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}],
|
||||
"projects": [{"name": "Acme", "confidence": 0.8, "signals": ["test"]}],
|
||||
"uncertain": [{"name": "Foo", "confidence": 0.4, "signals": ["test"]}],
|
||||
}
|
||||
result = confirm_entities(detected, yes=True)
|
||||
assert result["people"] == ["Alice"]
|
||||
assert result["projects"] == ["Acme"]
|
||||
|
||||
|
||||
def test_confirm_entities_accept_all():
|
||||
detected = {
|
||||
"people": [{"name": "Alice", "confidence": 0.9, "signals": ["test"]}],
|
||||
"projects": [],
|
||||
"uncertain": [],
|
||||
}
|
||||
with patch("builtins.input", side_effect=["", "n"]):
|
||||
result = confirm_entities(detected, yes=False)
|
||||
assert "Alice" in result["people"]
|
||||
|
||||
|
||||
def test_confirm_entities_edit_reclassify_uncertain():
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [],
|
||||
"uncertain": [
|
||||
{"name": "Foo", "confidence": 0.4, "signals": ["test"]},
|
||||
{"name": "Bar", "confidence": 0.4, "signals": ["test"]},
|
||||
],
|
||||
}
|
||||
with patch(
|
||||
"builtins.input",
|
||||
side_effect=[
|
||||
"edit", # choice
|
||||
"p", # Foo -> person
|
||||
"s", # Bar -> skip
|
||||
"", # no removals from people
|
||||
"", # no removals from projects
|
||||
"n", # don't add missing
|
||||
],
|
||||
):
|
||||
result = confirm_entities(detected, yes=False)
|
||||
assert "Foo" in result["people"]
|
||||
assert "Bar" not in result["people"]
|
||||
assert "Bar" not in result["projects"]
|
||||
|
||||
|
||||
def test_confirm_entities_add_mode():
|
||||
detected = {
|
||||
"people": [],
|
||||
"projects": [],
|
||||
"uncertain": [],
|
||||
}
|
||||
with patch(
|
||||
"builtins.input",
|
||||
side_effect=[
|
||||
"add", # choice = add
|
||||
"NewPerson", # name
|
||||
"p", # person
|
||||
"NewProj", # name
|
||||
"r", # project
|
||||
"", # stop adding
|
||||
],
|
||||
):
|
||||
result = confirm_entities(detected, yes=False)
|
||||
assert "NewPerson" in result["people"]
|
||||
assert "NewProj" in result["projects"]
|
||||
|
||||
|
||||
# ── scan_for_detection fallback ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_scan_for_detection_fallback_to_all_readable(tmp_path):
|
||||
"""When fewer than 3 prose files, falls back to include all readable files."""
|
||||
(tmp_path / "one.md").write_text("hello")
|
||||
(tmp_path / "two.txt").write_text("world")
|
||||
# Only 2 prose files, so it should also include code files
|
||||
(tmp_path / "code.py").write_text("import os")
|
||||
(tmp_path / "app.js").write_text("console.log()")
|
||||
files = scan_for_detection(str(tmp_path))
|
||||
extensions = {os.path.splitext(str(f))[1] for f in files}
|
||||
assert ".py" in extensions or ".js" in extensions
|
||||
|
||||
|
||||
def test_scan_for_detection_max_files(tmp_path):
|
||||
"""Caps to max_files."""
|
||||
for i in range(20):
|
||||
(tmp_path / f"note{i}.md").write_text(f"content {i}")
|
||||
files = scan_for_detection(str(tmp_path), max_files=5)
|
||||
assert len(files) <= 5
|
||||
@@ -0,0 +1,313 @@
|
||||
"""Tests for mempalace.entity_registry."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from mempalace.entity_registry import (
|
||||
COMMON_ENGLISH_WORDS,
|
||||
PERSON_CONTEXT_PATTERNS,
|
||||
EntityRegistry,
|
||||
)
|
||||
|
||||
|
||||
# ── COMMON_ENGLISH_WORDS ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_common_english_words_has_expected_entries():
|
||||
assert "ever" in COMMON_ENGLISH_WORDS
|
||||
assert "grace" in COMMON_ENGLISH_WORDS
|
||||
assert "will" in COMMON_ENGLISH_WORDS
|
||||
assert "may" in COMMON_ENGLISH_WORDS
|
||||
assert "monday" in COMMON_ENGLISH_WORDS
|
||||
|
||||
|
||||
def test_common_english_words_is_lowercase():
|
||||
for word in COMMON_ENGLISH_WORDS:
|
||||
assert word == word.lower(), f"{word} should be lowercase"
|
||||
|
||||
|
||||
# ── PERSON_CONTEXT_PATTERNS ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_person_context_patterns_is_nonempty():
|
||||
assert len(PERSON_CONTEXT_PATTERNS) > 0
|
||||
|
||||
|
||||
# ── EntityRegistry creation and empty state ─────────────────────────────
|
||||
|
||||
|
||||
def test_load_from_nonexistent_dir(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
assert registry.people == {}
|
||||
assert registry.projects == []
|
||||
assert registry.mode == "personal"
|
||||
assert registry.ambiguous_flags == []
|
||||
|
||||
|
||||
def test_save_and_load_roundtrip(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="work",
|
||||
people=[{"name": "Alice", "relationship": "colleague", "context": "work"}],
|
||||
projects=["MemPalace"],
|
||||
)
|
||||
# Load again from same dir
|
||||
loaded = EntityRegistry.load(config_dir=tmp_path)
|
||||
assert loaded.mode == "work"
|
||||
assert "Alice" in loaded.people
|
||||
assert "MemPalace" in loaded.projects
|
||||
|
||||
|
||||
def test_save_creates_file(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.save()
|
||||
assert (tmp_path / "entity_registry.json").exists()
|
||||
|
||||
|
||||
# ── seed ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_seed_registers_people(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[
|
||||
{"name": "Riley", "relationship": "daughter", "context": "personal"},
|
||||
{"name": "Devon", "relationship": "friend", "context": "personal"},
|
||||
],
|
||||
projects=["MemPalace"],
|
||||
)
|
||||
assert "Riley" in registry.people
|
||||
assert "Devon" in registry.people
|
||||
assert registry.people["Riley"]["relationship"] == "daughter"
|
||||
assert registry.people["Riley"]["source"] == "onboarding"
|
||||
assert registry.people["Riley"]["confidence"] == 1.0
|
||||
|
||||
|
||||
def test_seed_registers_projects(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(mode="work", people=[], projects=["Acme", "Widget"])
|
||||
assert registry.projects == ["Acme", "Widget"]
|
||||
|
||||
|
||||
def test_seed_sets_mode(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(mode="combo", people=[], projects=[])
|
||||
assert registry.mode == "combo"
|
||||
|
||||
|
||||
def test_seed_flags_ambiguous_names(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[
|
||||
{"name": "Grace", "relationship": "friend", "context": "personal"},
|
||||
{"name": "Riley", "relationship": "daughter", "context": "personal"},
|
||||
],
|
||||
projects=[],
|
||||
)
|
||||
assert "grace" in registry.ambiguous_flags
|
||||
# Riley is not a common English word
|
||||
assert "riley" not in registry.ambiguous_flags
|
||||
|
||||
|
||||
def test_seed_with_aliases(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}],
|
||||
projects=[],
|
||||
aliases={"Max": "Maxwell"},
|
||||
)
|
||||
assert "Maxwell" in registry.people
|
||||
assert "Max" in registry.people
|
||||
assert registry.people["Max"].get("canonical") == "Maxwell"
|
||||
|
||||
|
||||
def test_seed_skips_empty_names(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[{"name": "", "relationship": "", "context": "personal"}],
|
||||
projects=[],
|
||||
)
|
||||
assert len(registry.people) == 0
|
||||
|
||||
|
||||
# ── lookup ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_lookup_known_person(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||
projects=[],
|
||||
)
|
||||
result = registry.lookup("Riley")
|
||||
assert result["type"] == "person"
|
||||
assert result["confidence"] == 1.0
|
||||
assert result["name"] == "Riley"
|
||||
|
||||
|
||||
def test_lookup_known_project(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(mode="work", people=[], projects=["MemPalace"])
|
||||
result = registry.lookup("MemPalace")
|
||||
assert result["type"] == "project"
|
||||
assert result["confidence"] == 1.0
|
||||
|
||||
|
||||
def test_lookup_unknown_word(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(mode="personal", people=[], projects=[])
|
||||
result = registry.lookup("Xyzzy")
|
||||
assert result["type"] == "unknown"
|
||||
assert result["confidence"] == 0.0
|
||||
|
||||
|
||||
def test_lookup_case_insensitive(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||
projects=[],
|
||||
)
|
||||
result = registry.lookup("riley")
|
||||
assert result["type"] == "person"
|
||||
|
||||
|
||||
def test_lookup_alias(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[{"name": "Maxwell", "relationship": "friend", "context": "personal"}],
|
||||
projects=[],
|
||||
aliases={"Max": "Maxwell"},
|
||||
)
|
||||
result = registry.lookup("Max")
|
||||
assert result["type"] == "person"
|
||||
|
||||
|
||||
# ── disambiguation ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_lookup_ambiguous_word_as_person(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[{"name": "Grace", "relationship": "friend", "context": "personal"}],
|
||||
projects=[],
|
||||
)
|
||||
result = registry.lookup("Grace", context="I went with Grace today")
|
||||
assert result["type"] == "person"
|
||||
|
||||
|
||||
def test_lookup_ambiguous_word_as_concept(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[{"name": "Ever", "relationship": "friend", "context": "personal"}],
|
||||
projects=[],
|
||||
)
|
||||
result = registry.lookup("Ever", context="have you ever tried this")
|
||||
assert result["type"] == "concept"
|
||||
|
||||
|
||||
# ── research (Wikipedia) — mocked ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_research_caches_result(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(mode="personal", people=[], projects=[])
|
||||
|
||||
mock_result = {
|
||||
"inferred_type": "person",
|
||||
"confidence": 0.80,
|
||||
"wiki_summary": "Saoirse is an Irish given name.",
|
||||
"wiki_title": "Saoirse",
|
||||
}
|
||||
|
||||
with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result):
|
||||
result = registry.research("Saoirse", auto_confirm=True)
|
||||
assert result["inferred_type"] == "person"
|
||||
|
||||
# Second call should use cache, not call Wikipedia again
|
||||
with patch(
|
||||
"mempalace.entity_registry._wikipedia_lookup",
|
||||
side_effect=AssertionError("should not be called"),
|
||||
):
|
||||
cached = registry.research("Saoirse")
|
||||
assert cached["inferred_type"] == "person"
|
||||
|
||||
|
||||
def test_confirm_research_adds_to_people(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(mode="personal", people=[], projects=[])
|
||||
|
||||
mock_result = {
|
||||
"inferred_type": "person",
|
||||
"confidence": 0.80,
|
||||
"wiki_summary": "Saoirse is a name",
|
||||
"wiki_title": "Saoirse",
|
||||
}
|
||||
with patch("mempalace.entity_registry._wikipedia_lookup", return_value=mock_result):
|
||||
registry.research("Saoirse", auto_confirm=False)
|
||||
|
||||
registry.confirm_research("Saoirse", entity_type="person", relationship="friend")
|
||||
assert "Saoirse" in registry.people
|
||||
assert registry.people["Saoirse"]["source"] == "wiki"
|
||||
|
||||
|
||||
# ── extract_people_from_query ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_people_from_query(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[
|
||||
{"name": "Riley", "relationship": "daughter", "context": "personal"},
|
||||
{"name": "Devon", "relationship": "friend", "context": "personal"},
|
||||
],
|
||||
projects=[],
|
||||
)
|
||||
found = registry.extract_people_from_query("What did Riley say about the weather?")
|
||||
assert "Riley" in found
|
||||
assert "Devon" not in found
|
||||
|
||||
|
||||
# ── extract_unknown_candidates ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_unknown_candidates(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(mode="personal", people=[], projects=[])
|
||||
unknowns = registry.extract_unknown_candidates("Saoirse went to the store")
|
||||
assert "Saoirse" in unknowns
|
||||
|
||||
|
||||
def test_extract_unknown_candidates_skips_known(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||
projects=[],
|
||||
)
|
||||
unknowns = registry.extract_unknown_candidates("Riley went to the store")
|
||||
assert "Riley" not in unknowns
|
||||
|
||||
|
||||
# ── summary ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_summary(tmp_path):
|
||||
registry = EntityRegistry.load(config_dir=tmp_path)
|
||||
registry.seed(
|
||||
mode="personal",
|
||||
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||
projects=["MemPalace"],
|
||||
)
|
||||
s = registry.summary()
|
||||
assert "personal" in s
|
||||
assert "Riley" in s
|
||||
assert "MemPalace" in s
|
||||
@@ -0,0 +1,248 @@
|
||||
"""Tests for mempalace.general_extractor."""
|
||||
|
||||
from mempalace.general_extractor import (
|
||||
ALL_MARKERS,
|
||||
NEGATIVE_WORDS,
|
||||
POSITIVE_WORDS,
|
||||
_extract_prose,
|
||||
_get_sentiment,
|
||||
_has_resolution,
|
||||
_is_code_line,
|
||||
_score_markers,
|
||||
_split_into_segments,
|
||||
extract_memories,
|
||||
)
|
||||
|
||||
|
||||
# ── extract_memories — empty / no markers ───────────────────────────────
|
||||
|
||||
|
||||
def test_extract_memories_empty_text():
|
||||
result = extract_memories("")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_memories_no_markers():
|
||||
result = extract_memories("The quick brown fox jumped over the lazy dog.")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_memories_short_text_skipped():
|
||||
# Paragraphs shorter than 20 chars are skipped
|
||||
result = extract_memories("ok sure")
|
||||
assert result == []
|
||||
|
||||
|
||||
# ── extract_memories — decision markers ─────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_memories_decision():
|
||||
text = (
|
||||
"We decided to go with PostgreSQL instead of MySQL "
|
||||
"because the performance was better for our use case. "
|
||||
"The trade-off was more complexity in setup."
|
||||
)
|
||||
result = extract_memories(text)
|
||||
assert len(result) >= 1
|
||||
assert any(m["memory_type"] == "decision" for m in result)
|
||||
|
||||
|
||||
# ── extract_memories — preference markers ───────────────────────────────
|
||||
|
||||
|
||||
def test_extract_memories_preference():
|
||||
text = (
|
||||
"I prefer using snake_case in Python code. "
|
||||
"Please always use type hints. "
|
||||
"Never use wildcard imports."
|
||||
)
|
||||
result = extract_memories(text)
|
||||
assert len(result) >= 1
|
||||
assert any(m["memory_type"] == "preference" for m in result)
|
||||
|
||||
|
||||
# ── extract_memories — milestone markers ────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_memories_milestone():
|
||||
text = (
|
||||
"It finally works! After three days of debugging, "
|
||||
"I figured out the issue. The breakthrough was realizing "
|
||||
"the config file was cached. Got it working at 2am."
|
||||
)
|
||||
result = extract_memories(text)
|
||||
assert len(result) >= 1
|
||||
assert any(m["memory_type"] == "milestone" for m in result)
|
||||
|
||||
|
||||
# ── extract_memories — problem markers ──────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_memories_problem():
|
||||
text = (
|
||||
"There's a critical bug in the auth module. "
|
||||
"The error keeps crashing the server. "
|
||||
"The root cause was a missing null check. "
|
||||
"The problem is that tokens expire silently."
|
||||
)
|
||||
result = extract_memories(text)
|
||||
assert len(result) >= 1
|
||||
types = {m["memory_type"] for m in result}
|
||||
assert "problem" in types or "milestone" in types # resolved problems become milestones
|
||||
|
||||
|
||||
# ── extract_memories — emotional markers ────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_memories_emotional():
|
||||
text = (
|
||||
"I feel so proud of what we built together. "
|
||||
"I love working on this project, it makes me happy. "
|
||||
"I'm grateful for the team and the beautiful code we wrote."
|
||||
)
|
||||
result = extract_memories(text)
|
||||
assert len(result) >= 1
|
||||
assert any(m["memory_type"] == "emotional" for m in result)
|
||||
|
||||
|
||||
# ── extract_memories — chunk_index ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_memories_chunk_index_increments():
|
||||
text = (
|
||||
"We decided to use React because it fits our team.\n\n"
|
||||
"I prefer functional components always.\n\n"
|
||||
"It works! We finally shipped the v1.0 release."
|
||||
)
|
||||
result = extract_memories(text)
|
||||
if len(result) >= 2:
|
||||
indices = [m["chunk_index"] for m in result]
|
||||
assert indices == list(range(len(result)))
|
||||
|
||||
|
||||
# ── _score_markers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_score_markers_with_matches():
|
||||
score, keywords = _score_markers(
|
||||
"we decided to go with postgres because it is faster",
|
||||
ALL_MARKERS["decision"],
|
||||
)
|
||||
assert score > 0
|
||||
assert len(keywords) > 0
|
||||
|
||||
|
||||
def test_score_markers_no_matches():
|
||||
score, keywords = _score_markers("nothing relevant here", ALL_MARKERS["decision"])
|
||||
assert score == 0.0
|
||||
|
||||
|
||||
# ── _get_sentiment ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_sentiment_positive():
|
||||
assert _get_sentiment("I am so happy and proud of this breakthrough") == "positive"
|
||||
|
||||
|
||||
def test_get_sentiment_negative():
|
||||
assert _get_sentiment("This bug caused a crash and total failure") == "negative"
|
||||
|
||||
|
||||
def test_get_sentiment_neutral():
|
||||
assert _get_sentiment("The meeting is at three") == "neutral"
|
||||
|
||||
|
||||
# ── _has_resolution ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_has_resolution_true():
|
||||
assert _has_resolution("I fixed the auth bug and it works now") is True
|
||||
|
||||
|
||||
def test_has_resolution_false():
|
||||
assert _has_resolution("The server keeps crashing") is False
|
||||
|
||||
|
||||
# ── _is_code_line ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_is_code_line_detects_code():
|
||||
assert _is_code_line(" import os") is True
|
||||
assert _is_code_line(" $ pip install flask") is True
|
||||
assert _is_code_line(" ```python") is True
|
||||
|
||||
|
||||
def test_is_code_line_allows_prose():
|
||||
assert _is_code_line("This is a regular sentence about coding.") is False
|
||||
assert _is_code_line("") is False
|
||||
|
||||
|
||||
# ── _extract_prose ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_prose_strips_code_blocks():
|
||||
text = "Hello world\n```\nimport os\nprint('hi')\n```\nGoodbye"
|
||||
result = _extract_prose(text)
|
||||
assert "import os" not in result
|
||||
assert "Hello world" in result
|
||||
assert "Goodbye" in result
|
||||
|
||||
|
||||
def test_extract_prose_returns_original_if_all_code():
|
||||
text = "import os\nfrom sys import argv"
|
||||
result = _extract_prose(text)
|
||||
# Falls back to original text if nothing left
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
# ── _split_into_segments ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_split_into_segments_by_paragraph():
|
||||
text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph."
|
||||
result = _split_into_segments(text)
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
def test_split_into_segments_by_turns():
|
||||
lines = []
|
||||
for i in range(5):
|
||||
lines.append(f"Human: Question {i}")
|
||||
lines.append(f"Assistant: Answer {i}")
|
||||
text = "\n".join(lines)
|
||||
result = _split_into_segments(text)
|
||||
assert len(result) >= 3 # turn-based splitting should fire
|
||||
|
||||
|
||||
def test_split_into_segments_single_block():
|
||||
# Many lines without double-newline produces chunked segments
|
||||
lines = [f"Line {i} of the document" for i in range(30)]
|
||||
text = "\n".join(lines)
|
||||
result = _split_into_segments(text)
|
||||
assert len(result) >= 1
|
||||
|
||||
|
||||
# ── ALL_MARKERS constant ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_all_markers_has_five_types():
|
||||
assert set(ALL_MARKERS.keys()) == {
|
||||
"decision",
|
||||
"preference",
|
||||
"milestone",
|
||||
"problem",
|
||||
"emotional",
|
||||
}
|
||||
|
||||
|
||||
# ── POSITIVE_WORDS / NEGATIVE_WORDS ────────────────────────────────────
|
||||
|
||||
|
||||
def test_positive_words():
|
||||
assert "happy" in POSITIVE_WORDS
|
||||
assert "proud" in POSITIVE_WORDS
|
||||
|
||||
|
||||
def test_negative_words():
|
||||
assert "bug" in NEGATIVE_WORDS
|
||||
assert "crash" in NEGATIVE_WORDS
|
||||
@@ -0,0 +1,420 @@
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mempalace.hooks_cli import (
|
||||
SAVE_INTERVAL,
|
||||
STOP_BLOCK_REASON,
|
||||
PRECOMPACT_BLOCK_REASON,
|
||||
_count_human_messages,
|
||||
_log,
|
||||
_maybe_auto_ingest,
|
||||
_parse_harness_input,
|
||||
_sanitize_session_id,
|
||||
hook_stop,
|
||||
hook_session_start,
|
||||
hook_precompact,
|
||||
run_hook,
|
||||
)
|
||||
|
||||
|
||||
# --- _sanitize_session_id ---
|
||||
|
||||
|
||||
def test_sanitize_normal_id():
|
||||
assert _sanitize_session_id("abc-123_XYZ") == "abc-123_XYZ"
|
||||
|
||||
|
||||
def test_sanitize_strips_dangerous_chars():
|
||||
assert _sanitize_session_id("../../etc/passwd") == "etcpasswd"
|
||||
|
||||
|
||||
def test_sanitize_empty_returns_unknown():
|
||||
assert _sanitize_session_id("") == "unknown"
|
||||
assert _sanitize_session_id("!!!") == "unknown"
|
||||
|
||||
|
||||
# --- _count_human_messages ---
|
||||
|
||||
|
||||
def _write_transcript(path: Path, entries: list[dict]):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
for entry in entries:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
|
||||
def test_count_human_messages_basic(tmp_path):
|
||||
transcript = tmp_path / "t.jsonl"
|
||||
_write_transcript(
|
||||
transcript,
|
||||
[
|
||||
{"message": {"role": "user", "content": "hello"}},
|
||||
{"message": {"role": "assistant", "content": "hi"}},
|
||||
{"message": {"role": "user", "content": "bye"}},
|
||||
],
|
||||
)
|
||||
assert _count_human_messages(str(transcript)) == 2
|
||||
|
||||
|
||||
def test_count_skips_command_messages(tmp_path):
|
||||
transcript = tmp_path / "t.jsonl"
|
||||
_write_transcript(
|
||||
transcript,
|
||||
[
|
||||
{"message": {"role": "user", "content": "<command-message>status</command-message>"}},
|
||||
{"message": {"role": "user", "content": "real question"}},
|
||||
],
|
||||
)
|
||||
assert _count_human_messages(str(transcript)) == 1
|
||||
|
||||
|
||||
def test_count_handles_list_content(tmp_path):
|
||||
transcript = tmp_path / "t.jsonl"
|
||||
_write_transcript(
|
||||
transcript,
|
||||
[
|
||||
{"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}},
|
||||
{
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "<command-message>x</command-message>"}],
|
||||
}
|
||||
},
|
||||
],
|
||||
)
|
||||
assert _count_human_messages(str(transcript)) == 1
|
||||
|
||||
|
||||
def test_count_missing_file():
|
||||
assert _count_human_messages("/nonexistent/path.jsonl") == 0
|
||||
|
||||
|
||||
def test_count_empty_file(tmp_path):
|
||||
transcript = tmp_path / "t.jsonl"
|
||||
transcript.write_text("")
|
||||
assert _count_human_messages(str(transcript)) == 0
|
||||
|
||||
|
||||
def test_count_malformed_json_lines(tmp_path):
|
||||
transcript = tmp_path / "t.jsonl"
|
||||
transcript.write_text('not json\n{"message": {"role": "user", "content": "ok"}}\n')
|
||||
assert _count_human_messages(str(transcript)) == 1
|
||||
|
||||
|
||||
# --- hook_stop ---
|
||||
|
||||
|
||||
def _capture_hook_output(hook_fn, data, harness="claude-code", state_dir=None):
|
||||
"""Run a hook and capture its JSON stdout output."""
|
||||
import io
|
||||
|
||||
buf = io.StringIO()
|
||||
patches = [patch("mempalace.hooks_cli._output", side_effect=lambda d: buf.write(json.dumps(d)))]
|
||||
if state_dir:
|
||||
patches.append(patch("mempalace.hooks_cli.STATE_DIR", state_dir))
|
||||
with contextlib.ExitStack() as stack:
|
||||
for p in patches:
|
||||
stack.enter_context(p)
|
||||
hook_fn(data, harness)
|
||||
return json.loads(buf.getvalue())
|
||||
|
||||
|
||||
def test_stop_hook_passthrough_when_active(tmp_path):
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||
result = _capture_hook_output(
|
||||
hook_stop,
|
||||
{"session_id": "test", "stop_hook_active": True, "transcript_path": ""},
|
||||
state_dir=tmp_path,
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_stop_hook_passthrough_when_active_string(tmp_path):
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||
result = _capture_hook_output(
|
||||
hook_stop,
|
||||
{"session_id": "test", "stop_hook_active": "true", "transcript_path": ""},
|
||||
state_dir=tmp_path,
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_stop_hook_passthrough_below_interval(tmp_path):
|
||||
transcript = tmp_path / "t.jsonl"
|
||||
_write_transcript(
|
||||
transcript,
|
||||
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL - 1)],
|
||||
)
|
||||
result = _capture_hook_output(
|
||||
hook_stop,
|
||||
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
|
||||
state_dir=tmp_path,
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_stop_hook_blocks_at_interval(tmp_path):
|
||||
transcript = tmp_path / "t.jsonl"
|
||||
_write_transcript(
|
||||
transcript,
|
||||
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
|
||||
)
|
||||
result = _capture_hook_output(
|
||||
hook_stop,
|
||||
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
|
||||
state_dir=tmp_path,
|
||||
)
|
||||
assert result["decision"] == "block"
|
||||
assert result["reason"] == STOP_BLOCK_REASON
|
||||
|
||||
|
||||
def test_stop_hook_tracks_save_point(tmp_path):
|
||||
transcript = tmp_path / "t.jsonl"
|
||||
_write_transcript(
|
||||
transcript,
|
||||
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
|
||||
)
|
||||
data = {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}
|
||||
|
||||
# First call blocks
|
||||
result = _capture_hook_output(hook_stop, data, state_dir=tmp_path)
|
||||
assert result["decision"] == "block"
|
||||
|
||||
# Second call with same count passes through (already saved)
|
||||
result = _capture_hook_output(hook_stop, data, state_dir=tmp_path)
|
||||
assert result == {}
|
||||
|
||||
|
||||
# --- hook_session_start ---
|
||||
|
||||
|
||||
def test_session_start_passes_through(tmp_path):
|
||||
result = _capture_hook_output(
|
||||
hook_session_start,
|
||||
{"session_id": "test"},
|
||||
state_dir=tmp_path,
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
# --- hook_precompact ---
|
||||
|
||||
|
||||
def test_precompact_always_blocks(tmp_path):
|
||||
result = _capture_hook_output(
|
||||
hook_precompact,
|
||||
{"session_id": "test"},
|
||||
state_dir=tmp_path,
|
||||
)
|
||||
assert result["decision"] == "block"
|
||||
assert result["reason"] == PRECOMPACT_BLOCK_REASON
|
||||
|
||||
|
||||
# --- _log ---
|
||||
|
||||
|
||||
def test_log_writes_to_hook_log(tmp_path):
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||
_log("test message")
|
||||
log_path = tmp_path / "hook.log"
|
||||
assert log_path.is_file()
|
||||
content = log_path.read_text()
|
||||
assert "test message" in content
|
||||
|
||||
|
||||
def test_log_oserror_is_silenced(tmp_path):
|
||||
"""_log should not raise if the directory cannot be created."""
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", Path("/nonexistent/deeply/nested/dir")):
|
||||
# Should not raise
|
||||
_log("this will fail silently")
|
||||
|
||||
|
||||
# --- _maybe_auto_ingest ---
|
||||
|
||||
|
||||
def test_maybe_auto_ingest_no_env(tmp_path):
|
||||
"""Without MEMPAL_DIR set, does nothing."""
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||
_maybe_auto_ingest() # should not raise
|
||||
|
||||
|
||||
def test_maybe_auto_ingest_with_env(tmp_path):
|
||||
"""With MEMPAL_DIR set to a valid directory, spawns subprocess."""
|
||||
mempal_dir = tmp_path / "project"
|
||||
mempal_dir.mkdir()
|
||||
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||
with patch("mempalace.hooks_cli.subprocess.Popen") as mock_popen:
|
||||
_maybe_auto_ingest()
|
||||
mock_popen.assert_called_once()
|
||||
|
||||
|
||||
def test_maybe_auto_ingest_oserror(tmp_path):
|
||||
"""OSError during subprocess spawn is silenced."""
|
||||
mempal_dir = tmp_path / "project"
|
||||
mempal_dir.mkdir()
|
||||
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||
with patch("mempalace.hooks_cli.subprocess.Popen", side_effect=OSError("fail")):
|
||||
_maybe_auto_ingest() # should not raise
|
||||
|
||||
|
||||
# --- _parse_harness_input ---
|
||||
|
||||
|
||||
def test_parse_harness_input_unknown():
|
||||
"""Unknown harness should sys.exit(1)."""
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
_parse_harness_input({"session_id": "test"}, "unknown-harness")
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
|
||||
def test_parse_harness_input_valid():
|
||||
result = _parse_harness_input(
|
||||
{"session_id": "abc-123", "stop_hook_active": True, "transcript_path": "/tmp/t.jsonl"},
|
||||
"claude-code",
|
||||
)
|
||||
assert result["session_id"] == "abc-123"
|
||||
assert result["stop_hook_active"] is True
|
||||
|
||||
|
||||
# --- hook_stop with OSError on write ---
|
||||
|
||||
|
||||
def test_stop_hook_oserror_on_last_save_read(tmp_path):
|
||||
"""When last_save_file has invalid content, falls back to 0."""
|
||||
transcript = tmp_path / "t.jsonl"
|
||||
_write_transcript(
|
||||
transcript,
|
||||
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
|
||||
)
|
||||
# Write invalid content to last save file
|
||||
(tmp_path / "test_last_save").write_text("not_a_number")
|
||||
result = _capture_hook_output(
|
||||
hook_stop,
|
||||
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
|
||||
state_dir=tmp_path,
|
||||
)
|
||||
assert result["decision"] == "block"
|
||||
|
||||
|
||||
def test_stop_hook_oserror_on_write(tmp_path):
|
||||
"""When write to last_save_file fails, hook still outputs correctly."""
|
||||
transcript = tmp_path / "t.jsonl"
|
||||
_write_transcript(
|
||||
transcript,
|
||||
[{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
|
||||
)
|
||||
|
||||
def bad_write_text(*args, **kwargs):
|
||||
raise OSError("disk full")
|
||||
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||
with patch.object(Path, "write_text", bad_write_text):
|
||||
result = _capture_hook_output(
|
||||
hook_stop,
|
||||
{
|
||||
"session_id": "test",
|
||||
"stop_hook_active": False,
|
||||
"transcript_path": str(transcript),
|
||||
},
|
||||
state_dir=tmp_path,
|
||||
)
|
||||
assert result["decision"] == "block"
|
||||
|
||||
|
||||
# --- hook_precompact with MEMPAL_DIR ---
|
||||
|
||||
|
||||
def test_precompact_with_mempal_dir(tmp_path):
|
||||
"""Precompact runs subprocess.run when MEMPAL_DIR is set."""
|
||||
mempal_dir = tmp_path / "project"
|
||||
mempal_dir.mkdir()
|
||||
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
|
||||
with patch("mempalace.hooks_cli.subprocess.run") as mock_run:
|
||||
result = _capture_hook_output(
|
||||
hook_precompact,
|
||||
{"session_id": "test"},
|
||||
state_dir=tmp_path,
|
||||
)
|
||||
assert result["decision"] == "block"
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
def test_precompact_with_mempal_dir_oserror(tmp_path):
|
||||
"""Precompact handles OSError from subprocess gracefully."""
|
||||
mempal_dir = tmp_path / "project"
|
||||
mempal_dir.mkdir()
|
||||
with patch.dict("os.environ", {"MEMPAL_DIR": str(mempal_dir)}):
|
||||
with patch("mempalace.hooks_cli.subprocess.run", side_effect=OSError("fail")):
|
||||
result = _capture_hook_output(
|
||||
hook_precompact,
|
||||
{"session_id": "test"},
|
||||
state_dir=tmp_path,
|
||||
)
|
||||
assert result["decision"] == "block"
|
||||
|
||||
|
||||
# --- run_hook ---
|
||||
|
||||
|
||||
def test_run_hook_dispatches_session_start(tmp_path):
|
||||
"""run_hook reads stdin JSON and dispatches to correct handler."""
|
||||
stdin_data = json.dumps({"session_id": "run-test"})
|
||||
with patch("sys.stdin", io.StringIO(stdin_data)):
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||
with patch("mempalace.hooks_cli._output") as mock_output:
|
||||
run_hook("session-start", "claude-code")
|
||||
mock_output.assert_called_once_with({})
|
||||
|
||||
|
||||
def test_run_hook_dispatches_stop(tmp_path):
|
||||
transcript = tmp_path / "t.jsonl"
|
||||
_write_transcript(
|
||||
transcript, [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(3)]
|
||||
)
|
||||
stdin_data = json.dumps(
|
||||
{
|
||||
"session_id": "run-test",
|
||||
"stop_hook_active": False,
|
||||
"transcript_path": str(transcript),
|
||||
}
|
||||
)
|
||||
with patch("sys.stdin", io.StringIO(stdin_data)):
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||
with patch("mempalace.hooks_cli._output") as mock_output:
|
||||
run_hook("stop", "claude-code")
|
||||
mock_output.assert_called_once_with({})
|
||||
|
||||
|
||||
def test_run_hook_dispatches_precompact(tmp_path):
|
||||
stdin_data = json.dumps({"session_id": "run-test"})
|
||||
with patch("sys.stdin", io.StringIO(stdin_data)):
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||
with patch("mempalace.hooks_cli._output") as mock_output:
|
||||
run_hook("precompact", "claude-code")
|
||||
mock_output.assert_called_once()
|
||||
call_args = mock_output.call_args[0][0]
|
||||
assert call_args["decision"] == "block"
|
||||
|
||||
|
||||
def test_run_hook_unknown_hook():
|
||||
stdin_data = json.dumps({"session_id": "test"})
|
||||
with patch("sys.stdin", io.StringIO(stdin_data)):
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
run_hook("nonexistent", "claude-code")
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
|
||||
def test_run_hook_invalid_json(tmp_path):
|
||||
"""Invalid stdin JSON should not crash — falls back to empty dict."""
|
||||
with patch("sys.stdin", io.StringIO("not valid json")):
|
||||
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||||
with patch("mempalace.hooks_cli._output") as mock_output:
|
||||
run_hook("session-start", "claude-code")
|
||||
mock_output.assert_called_once_with({})
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Tests for mempalace.instructions_cli — instruction text output."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mempalace.instructions_cli import AVAILABLE, INSTRUCTIONS_DIR, run_instructions
|
||||
|
||||
|
||||
def test_run_instructions_valid_name(capsys):
|
||||
"""Valid name prints the .md file content."""
|
||||
name = "init"
|
||||
expected = (INSTRUCTIONS_DIR / f"{name}.md").read_text()
|
||||
run_instructions(name)
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_run_instructions_all_available(capsys):
|
||||
"""Every name in AVAILABLE should succeed without error."""
|
||||
for name in AVAILABLE:
|
||||
run_instructions(name)
|
||||
out = capsys.readouterr().out
|
||||
assert len(out) > 0
|
||||
|
||||
|
||||
def test_run_instructions_invalid_name(capsys):
|
||||
"""Invalid name should sys.exit(1) and print error to stderr."""
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
run_instructions("nonexistent")
|
||||
assert exc_info.value.code == 1
|
||||
captured = capsys.readouterr()
|
||||
assert "Unknown instructions: nonexistent" in captured.err
|
||||
assert "Available:" in captured.err
|
||||
|
||||
|
||||
def test_run_instructions_missing_md_file(capsys, tmp_path):
|
||||
"""If the .md file is missing on disk, should sys.exit(1)."""
|
||||
with patch("mempalace.instructions_cli.INSTRUCTIONS_DIR", tmp_path):
|
||||
with patch("mempalace.instructions_cli.AVAILABLE", ["fakecmd"]):
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
run_instructions("fakecmd")
|
||||
assert exc_info.value.code == 1
|
||||
captured = capsys.readouterr()
|
||||
assert "Instructions file not found" in captured.err
|
||||
@@ -0,0 +1,105 @@
|
||||
"""Extra knowledge graph tests for seed_from_entity_facts and query_relationship."""
|
||||
|
||||
import pytest
|
||||
|
||||
from mempalace.knowledge_graph import KnowledgeGraph
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kg(tmp_path):
|
||||
return KnowledgeGraph(db_path=str(tmp_path / "kg.db"))
|
||||
|
||||
|
||||
class TestSeedFromEntityFacts:
|
||||
def test_seed_person_with_partner(self, kg):
|
||||
facts = {
|
||||
"alice": {
|
||||
"full_name": "Alice Smith",
|
||||
"type": "person",
|
||||
"gender": "female",
|
||||
"partner": "bob",
|
||||
"relationship": "husband",
|
||||
}
|
||||
}
|
||||
kg.seed_from_entity_facts(facts)
|
||||
stats = kg.stats()
|
||||
assert stats["entities"] >= 1
|
||||
results = kg.query_entity("Alice Smith", direction="outgoing")
|
||||
predicates = {r["predicate"] for r in results}
|
||||
assert "married_to" in predicates
|
||||
assert "is_partner_of" in predicates
|
||||
|
||||
def test_seed_child(self, kg):
|
||||
facts = {
|
||||
"max": {
|
||||
"full_name": "Max",
|
||||
"type": "person",
|
||||
"birthday": "2015-04-01",
|
||||
"parent": "alice",
|
||||
"relationship": "daughter",
|
||||
}
|
||||
}
|
||||
kg.seed_from_entity_facts(facts)
|
||||
results = kg.query_entity("Max", direction="outgoing")
|
||||
predicates = {r["predicate"] for r in results}
|
||||
assert "child_of" in predicates
|
||||
assert "is_child_of" in predicates
|
||||
|
||||
def test_seed_sibling(self, kg):
|
||||
facts = {
|
||||
"emma": {
|
||||
"full_name": "Emma",
|
||||
"type": "person",
|
||||
"relationship": "brother",
|
||||
"sibling": "max",
|
||||
}
|
||||
}
|
||||
kg.seed_from_entity_facts(facts)
|
||||
results = kg.query_entity("Emma", direction="outgoing")
|
||||
predicates = {r["predicate"] for r in results}
|
||||
assert "is_sibling_of" in predicates
|
||||
|
||||
def test_seed_dog(self, kg):
|
||||
facts = {
|
||||
"rex": {
|
||||
"full_name": "Rex",
|
||||
"type": "animal",
|
||||
"relationship": "dog",
|
||||
"owner": "alice",
|
||||
}
|
||||
}
|
||||
kg.seed_from_entity_facts(facts)
|
||||
results = kg.query_entity("Rex", direction="outgoing")
|
||||
predicates = {r["predicate"] for r in results}
|
||||
assert "is_pet_of" in predicates
|
||||
|
||||
def test_seed_with_interests(self, kg):
|
||||
facts = {
|
||||
"max": {
|
||||
"full_name": "Max",
|
||||
"type": "person",
|
||||
"interests": ["swimming", "chess"],
|
||||
}
|
||||
}
|
||||
kg.seed_from_entity_facts(facts)
|
||||
results = kg.query_entity("Max", direction="outgoing")
|
||||
objects = {r["object"] for r in results if r["predicate"] == "loves"}
|
||||
assert "Swimming" in objects
|
||||
assert "Chess" in objects
|
||||
|
||||
def test_seed_minimal_facts(self, kg):
|
||||
"""Facts with no relationships just create entities."""
|
||||
facts = {"bob": {"full_name": "Bob"}}
|
||||
kg.seed_from_entity_facts(facts)
|
||||
stats = kg.stats()
|
||||
assert stats["entities"] >= 1
|
||||
|
||||
|
||||
class TestQueryRelationshipWithTime:
|
||||
def test_query_relationship_with_as_of(self, kg):
|
||||
kg.add_triple("Alice", "works_at", "Acme", valid_from="2020-01-01", valid_to="2024-12-31")
|
||||
kg.add_triple("Alice", "works_at", "NewCo", valid_from="2025-01-01")
|
||||
results = kg.query_relationship("works_at", as_of="2023-06-01")
|
||||
objects = [r["object"] for r in results]
|
||||
assert "Acme" in objects
|
||||
assert "NewCo" not in objects
|
||||
@@ -0,0 +1,719 @@
|
||||
"""Tests for mempalace.layers — Layer0, Layer1, Layer2, Layer3, MemoryStack."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from mempalace.layers import Layer0, Layer1, Layer2, Layer3, MemoryStack
|
||||
|
||||
|
||||
# ── Layer0 — with identity file ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_layer0_reads_identity_file(tmp_path):
|
||||
identity_file = tmp_path / "identity.txt"
|
||||
identity_file.write_text("I am Atlas, a personal AI assistant for Alice.")
|
||||
layer = Layer0(identity_path=str(identity_file))
|
||||
text = layer.render()
|
||||
assert "Atlas" in text
|
||||
assert "Alice" in text
|
||||
|
||||
|
||||
def test_layer0_caches_text(tmp_path):
|
||||
identity_file = tmp_path / "identity.txt"
|
||||
identity_file.write_text("Hello world")
|
||||
layer = Layer0(identity_path=str(identity_file))
|
||||
first = layer.render()
|
||||
identity_file.write_text("Changed content")
|
||||
second = layer.render()
|
||||
assert first == second
|
||||
assert second == "Hello world"
|
||||
|
||||
|
||||
def test_layer0_missing_file_returns_default(tmp_path):
|
||||
missing = str(tmp_path / "nonexistent.txt")
|
||||
layer = Layer0(identity_path=missing)
|
||||
text = layer.render()
|
||||
assert "No identity configured" in text
|
||||
assert "identity.txt" in text
|
||||
|
||||
|
||||
def test_layer0_token_estimate(tmp_path):
|
||||
identity_file = tmp_path / "identity.txt"
|
||||
content = "A" * 400
|
||||
identity_file.write_text(content)
|
||||
layer = Layer0(identity_path=str(identity_file))
|
||||
estimate = layer.token_estimate()
|
||||
assert estimate == 100
|
||||
|
||||
|
||||
def test_layer0_token_estimate_empty(tmp_path):
|
||||
identity_file = tmp_path / "identity.txt"
|
||||
identity_file.write_text("")
|
||||
layer = Layer0(identity_path=str(identity_file))
|
||||
assert layer.token_estimate() == 0
|
||||
|
||||
|
||||
def test_layer0_strips_whitespace(tmp_path):
|
||||
identity_file = tmp_path / "identity.txt"
|
||||
identity_file.write_text(" Hello world \n\n")
|
||||
layer = Layer0(identity_path=str(identity_file))
|
||||
text = layer.render()
|
||||
assert text == "Hello world"
|
||||
|
||||
|
||||
def test_layer0_default_path():
|
||||
layer = Layer0()
|
||||
expected = os.path.expanduser("~/.mempalace/identity.txt")
|
||||
assert layer.path == expected
|
||||
|
||||
|
||||
# ── Layer1 — mocked chromadb ────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_chromadb_for_layer(docs, metas, monkeypatch=None):
|
||||
"""Return a mock PersistentClient whose collection.get returns docs/metas."""
|
||||
mock_col = MagicMock()
|
||||
# First batch returns data, second batch returns empty (end of pagination)
|
||||
mock_col.get.side_effect = [
|
||||
{"documents": docs, "metadatas": metas},
|
||||
{"documents": [], "metadatas": []},
|
||||
]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
return mock_client
|
||||
|
||||
|
||||
def test_layer1_no_palace():
|
||||
"""Layer1 returns helpful message when no palace exists."""
|
||||
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||
mock_cfg.return_value.palace_path = "/nonexistent/palace"
|
||||
layer = Layer1(palace_path="/nonexistent/palace")
|
||||
result = layer.generate()
|
||||
assert "No palace found" in result or "No memories" in result
|
||||
|
||||
|
||||
def test_layer1_generates_essential_story():
|
||||
docs = [
|
||||
"Important memory about project decisions",
|
||||
"Key architectural choice for the backend",
|
||||
]
|
||||
metas = [
|
||||
{"room": "decisions", "source_file": "meeting.txt", "importance": 5},
|
||||
{"room": "architecture", "source_file": "design.txt", "importance": 4},
|
||||
]
|
||||
mock_client = _mock_chromadb_for_layer(docs, metas)
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer1(palace_path="/fake")
|
||||
result = layer.generate()
|
||||
|
||||
assert "ESSENTIAL STORY" in result
|
||||
assert "project decisions" in result
|
||||
|
||||
|
||||
def test_layer1_empty_palace():
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.return_value = {"documents": [], "metadatas": []}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer1(palace_path="/fake")
|
||||
result = layer.generate()
|
||||
|
||||
assert "No memories" in result
|
||||
|
||||
|
||||
def test_layer1_with_wing_filter():
|
||||
docs = ["Memory about project X"]
|
||||
metas = [{"room": "general", "source_file": "x.txt", "importance": 3}]
|
||||
mock_client = _mock_chromadb_for_layer(docs, metas)
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer1(palace_path="/fake", wing="project_x")
|
||||
result = layer.generate()
|
||||
|
||||
assert "ESSENTIAL STORY" in result
|
||||
# Verify wing filter was passed
|
||||
call_kwargs = mock_client.get_collection.return_value.get.call_args_list[0][1]
|
||||
assert call_kwargs.get("where") == {"wing": "project_x"}
|
||||
|
||||
|
||||
def test_layer1_truncates_long_snippets():
|
||||
docs = ["A" * 300]
|
||||
metas = [{"room": "general", "source_file": "long.txt"}]
|
||||
mock_client = _mock_chromadb_for_layer(docs, metas)
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer1(palace_path="/fake")
|
||||
result = layer.generate()
|
||||
|
||||
assert "..." in result
|
||||
|
||||
|
||||
def test_layer1_respects_max_chars():
|
||||
"""L1 stops adding entries once MAX_CHARS is reached."""
|
||||
docs = [f"Memory number {i} with substantial content padding here" for i in range(30)]
|
||||
metas = [{"room": "general", "source_file": f"f{i}.txt", "importance": 5} for i in range(30)]
|
||||
mock_client = _mock_chromadb_for_layer(docs, metas)
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer1(palace_path="/fake")
|
||||
layer.MAX_CHARS = 200 # Very low cap to trigger truncation
|
||||
result = layer.generate()
|
||||
|
||||
assert "more in L3 search" in result
|
||||
|
||||
|
||||
def test_layer1_importance_from_various_keys():
|
||||
"""Layer1 tries importance, emotional_weight, weight keys."""
|
||||
docs = ["mem1", "mem2", "mem3"]
|
||||
metas = [
|
||||
{"room": "r", "emotional_weight": 5},
|
||||
{"room": "r", "weight": 1},
|
||||
{"room": "r"}, # no weight key, defaults to 3
|
||||
]
|
||||
mock_client = _mock_chromadb_for_layer(docs, metas)
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer1(palace_path="/fake")
|
||||
result = layer.generate()
|
||||
|
||||
assert "ESSENTIAL STORY" in result
|
||||
|
||||
|
||||
def test_layer1_batch_exception_breaks():
|
||||
"""If col.get raises on a batch, loop breaks gracefully."""
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.side_effect = [
|
||||
{"documents": ["doc1"], "metadatas": [{"room": "r"}]},
|
||||
RuntimeError("batch error"),
|
||||
]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer1(palace_path="/fake")
|
||||
result = layer.generate()
|
||||
|
||||
assert "ESSENTIAL STORY" in result
|
||||
|
||||
|
||||
# ── Layer2 — mocked chromadb ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_layer2_no_palace():
|
||||
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||
mock_cfg.return_value.palace_path = "/nonexistent/palace"
|
||||
layer = Layer2(palace_path="/nonexistent/palace")
|
||||
result = layer.retrieve(wing="test")
|
||||
assert "No palace found" in result
|
||||
|
||||
|
||||
def test_layer2_retrieve_with_wing():
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.return_value = {
|
||||
"documents": ["Some memory about the project"],
|
||||
"metadatas": [{"room": "backend", "source_file": "notes.txt"}],
|
||||
}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer2(palace_path="/fake")
|
||||
result = layer.retrieve(wing="project")
|
||||
|
||||
assert "ON-DEMAND" in result
|
||||
assert "memory about the project" in result
|
||||
|
||||
|
||||
def test_layer2_retrieve_with_room():
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.return_value = {
|
||||
"documents": ["Backend architecture notes"],
|
||||
"metadatas": [{"room": "architecture", "source_file": "arch.txt"}],
|
||||
}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer2(palace_path="/fake")
|
||||
result = layer.retrieve(room="architecture")
|
||||
|
||||
assert "ON-DEMAND" in result
|
||||
|
||||
|
||||
def test_layer2_retrieve_wing_and_room():
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.return_value = {
|
||||
"documents": ["Filtered result"],
|
||||
"metadatas": [{"room": "backend", "source_file": "x.txt"}],
|
||||
}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer2(palace_path="/fake")
|
||||
result = layer.retrieve(wing="proj", room="backend")
|
||||
|
||||
assert "ON-DEMAND" in result
|
||||
call_kwargs = mock_col.get.call_args[1]
|
||||
assert "$and" in call_kwargs.get("where", {})
|
||||
|
||||
|
||||
def test_layer2_retrieve_empty():
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.return_value = {"documents": [], "metadatas": []}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer2(palace_path="/fake")
|
||||
result = layer.retrieve(wing="missing")
|
||||
|
||||
assert "No drawers found" in result
|
||||
|
||||
|
||||
def test_layer2_retrieve_no_filter():
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.return_value = {"documents": [], "metadatas": []}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer2(palace_path="/fake")
|
||||
layer.retrieve()
|
||||
|
||||
# No where filter should be passed
|
||||
call_kwargs = mock_col.get.call_args[1]
|
||||
assert "where" not in call_kwargs
|
||||
|
||||
|
||||
def test_layer2_retrieve_error():
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.side_effect = RuntimeError("db error")
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer2(palace_path="/fake")
|
||||
result = layer.retrieve(wing="test")
|
||||
|
||||
assert "Retrieval error" in result
|
||||
|
||||
|
||||
def test_layer2_truncates_long_snippets():
|
||||
mock_col = MagicMock()
|
||||
mock_col.get.return_value = {
|
||||
"documents": ["B" * 400],
|
||||
"metadatas": [{"room": "r", "source_file": "s.txt"}],
|
||||
}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer2(palace_path="/fake")
|
||||
result = layer.retrieve(wing="test")
|
||||
|
||||
assert "..." in result
|
||||
|
||||
|
||||
# ── Layer3 — mocked chromadb ────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_query_results(docs, metas, dists):
|
||||
return {
|
||||
"documents": [docs],
|
||||
"metadatas": [metas],
|
||||
"distances": [dists],
|
||||
}
|
||||
|
||||
|
||||
def test_layer3_no_palace():
|
||||
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||
mock_cfg.return_value.palace_path = "/nonexistent/palace"
|
||||
layer = Layer3(palace_path="/nonexistent/palace")
|
||||
result = layer.search("test query")
|
||||
assert "No palace found" in result
|
||||
|
||||
|
||||
def test_layer3_search_raw_no_palace():
|
||||
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||
mock_cfg.return_value.palace_path = "/nonexistent/palace"
|
||||
layer = Layer3(palace_path="/nonexistent/palace")
|
||||
result = layer.search_raw("test query")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_layer3_search_with_results():
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.return_value = _mock_query_results(
|
||||
["Found this important memory"],
|
||||
[{"wing": "project", "room": "backend", "source_file": "notes.txt"}],
|
||||
[0.2],
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer3(palace_path="/fake")
|
||||
result = layer.search("important")
|
||||
|
||||
assert "SEARCH RESULTS" in result
|
||||
assert "important memory" in result
|
||||
assert "sim=0.8" in result
|
||||
|
||||
|
||||
def test_layer3_search_no_results():
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.return_value = _mock_query_results([], [], [])
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer3(palace_path="/fake")
|
||||
result = layer.search("nothing")
|
||||
|
||||
assert "No results found" in result
|
||||
|
||||
|
||||
def test_layer3_search_with_wing_filter():
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.return_value = _mock_query_results(
|
||||
["result"],
|
||||
[{"wing": "proj", "room": "r"}],
|
||||
[0.1],
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer3(palace_path="/fake")
|
||||
layer.search("q", wing="proj")
|
||||
|
||||
call_kwargs = mock_col.query.call_args[1]
|
||||
assert call_kwargs["where"] == {"wing": "proj"}
|
||||
|
||||
|
||||
def test_layer3_search_with_room_filter():
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.return_value = _mock_query_results(
|
||||
["result"],
|
||||
[{"wing": "w", "room": "backend"}],
|
||||
[0.1],
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer3(palace_path="/fake")
|
||||
layer.search("q", room="backend")
|
||||
|
||||
call_kwargs = mock_col.query.call_args[1]
|
||||
assert call_kwargs["where"] == {"room": "backend"}
|
||||
|
||||
|
||||
def test_layer3_search_with_wing_and_room():
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.return_value = _mock_query_results(
|
||||
["result"],
|
||||
[{"wing": "proj", "room": "backend"}],
|
||||
[0.1],
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer3(palace_path="/fake")
|
||||
layer.search("q", wing="proj", room="backend")
|
||||
|
||||
call_kwargs = mock_col.query.call_args[1]
|
||||
assert "$and" in call_kwargs["where"]
|
||||
|
||||
|
||||
def test_layer3_search_error():
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.side_effect = RuntimeError("search failed")
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer3(palace_path="/fake")
|
||||
result = layer.search("q")
|
||||
|
||||
assert "Search error" in result
|
||||
|
||||
|
||||
def test_layer3_search_truncates_long_docs():
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.return_value = _mock_query_results(
|
||||
["C" * 400],
|
||||
[{"wing": "w", "room": "r", "source_file": "s.txt"}],
|
||||
[0.1],
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer3(palace_path="/fake")
|
||||
result = layer.search("q")
|
||||
|
||||
assert "..." in result
|
||||
|
||||
|
||||
def test_layer3_search_raw_returns_dicts():
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.return_value = _mock_query_results(
|
||||
["doc text"],
|
||||
[{"wing": "proj", "room": "backend", "source_file": "f.txt"}],
|
||||
[0.3],
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer3(palace_path="/fake")
|
||||
hits = layer.search_raw("q")
|
||||
|
||||
assert len(hits) == 1
|
||||
assert hits[0]["text"] == "doc text"
|
||||
assert hits[0]["wing"] == "proj"
|
||||
assert hits[0]["similarity"] == 0.7
|
||||
assert "metadata" in hits[0]
|
||||
|
||||
|
||||
def test_layer3_search_raw_with_filters():
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.return_value = _mock_query_results(
|
||||
["doc"],
|
||||
[{"wing": "w", "room": "r"}],
|
||||
[0.1],
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer3(palace_path="/fake")
|
||||
layer.search_raw("q", wing="w", room="r")
|
||||
|
||||
call_kwargs = mock_col.query.call_args[1]
|
||||
assert "$and" in call_kwargs["where"]
|
||||
|
||||
|
||||
def test_layer3_search_raw_error():
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.side_effect = RuntimeError("fail")
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
layer = Layer3(palace_path="/fake")
|
||||
result = layer.search_raw("q")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
# ── MemoryStack ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_memory_stack_wake_up(tmp_path):
|
||||
identity_file = tmp_path / "identity.txt"
|
||||
identity_file.write_text("I am Atlas.")
|
||||
|
||||
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||
mock_cfg.return_value.palace_path = "/nonexistent"
|
||||
stack = MemoryStack(
|
||||
palace_path="/nonexistent",
|
||||
identity_path=str(identity_file),
|
||||
)
|
||||
result = stack.wake_up()
|
||||
|
||||
assert "Atlas" in result
|
||||
# L1 will say no palace found
|
||||
assert "No palace" in result or "No memories" in result
|
||||
|
||||
|
||||
def test_memory_stack_wake_up_with_wing(tmp_path):
|
||||
identity_file = tmp_path / "identity.txt"
|
||||
identity_file.write_text("I am Atlas.")
|
||||
|
||||
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||
mock_cfg.return_value.palace_path = "/nonexistent"
|
||||
stack = MemoryStack(
|
||||
palace_path="/nonexistent",
|
||||
identity_path=str(identity_file),
|
||||
)
|
||||
result = stack.wake_up(wing="my_project")
|
||||
|
||||
assert stack.l1.wing == "my_project"
|
||||
assert "Atlas" in result
|
||||
|
||||
|
||||
def test_memory_stack_recall(tmp_path):
|
||||
identity_file = tmp_path / "identity.txt"
|
||||
identity_file.write_text("I am Atlas.")
|
||||
|
||||
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||
mock_cfg.return_value.palace_path = "/nonexistent"
|
||||
stack = MemoryStack(
|
||||
palace_path="/nonexistent",
|
||||
identity_path=str(identity_file),
|
||||
)
|
||||
result = stack.recall(wing="test")
|
||||
|
||||
assert "No palace found" in result
|
||||
|
||||
|
||||
def test_memory_stack_search(tmp_path):
|
||||
identity_file = tmp_path / "identity.txt"
|
||||
identity_file.write_text("I am Atlas.")
|
||||
|
||||
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||
mock_cfg.return_value.palace_path = "/nonexistent"
|
||||
stack = MemoryStack(
|
||||
palace_path="/nonexistent",
|
||||
identity_path=str(identity_file),
|
||||
)
|
||||
result = stack.search("test query")
|
||||
|
||||
assert "No palace found" in result
|
||||
|
||||
|
||||
def test_memory_stack_status(tmp_path):
|
||||
identity_file = tmp_path / "identity.txt"
|
||||
identity_file.write_text("I am Atlas.")
|
||||
|
||||
with patch("mempalace.layers.MempalaceConfig") as mock_cfg:
|
||||
mock_cfg.return_value.palace_path = "/nonexistent"
|
||||
stack = MemoryStack(
|
||||
palace_path="/nonexistent",
|
||||
identity_path=str(identity_file),
|
||||
)
|
||||
result = stack.status()
|
||||
|
||||
assert result["palace_path"] == "/nonexistent"
|
||||
assert result["total_drawers"] == 0
|
||||
assert "L0_identity" in result
|
||||
assert "L1_essential" in result
|
||||
assert "L2_on_demand" in result
|
||||
assert "L3_deep_search" in result
|
||||
|
||||
|
||||
def test_memory_stack_status_with_palace(tmp_path):
|
||||
identity_file = tmp_path / "identity.txt"
|
||||
identity_file.write_text("I am Atlas.")
|
||||
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 42
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with (
|
||||
patch("mempalace.layers.MempalaceConfig") as mock_cfg,
|
||||
patch("mempalace.layers.chromadb.PersistentClient", return_value=mock_client),
|
||||
):
|
||||
mock_cfg.return_value.palace_path = "/fake"
|
||||
stack = MemoryStack(
|
||||
palace_path="/fake",
|
||||
identity_path=str(identity_file),
|
||||
)
|
||||
result = stack.status()
|
||||
|
||||
assert result["total_drawers"] == 42
|
||||
assert result["L0_identity"]["exists"] is True
|
||||
+45
-39
@@ -9,25 +9,26 @@ via monkeypatch to avoid touching real data.
|
||||
import json
|
||||
|
||||
|
||||
def _patch_mcp_server(monkeypatch, config, palace_path, kg):
|
||||
def _patch_mcp_server(monkeypatch, config, kg):
|
||||
"""Patch the mcp_server module globals to use test fixtures."""
|
||||
from mempalace import mcp_server
|
||||
|
||||
assert getattr(config, "palace_path", None) == palace_path, (
|
||||
f"config.palace_path ({getattr(config, 'palace_path', None)!r}) does not match palace_path fixture ({palace_path!r})"
|
||||
)
|
||||
monkeypatch.setattr(mcp_server, "_config", config)
|
||||
monkeypatch.setattr(mcp_server, "_kg", kg)
|
||||
|
||||
|
||||
def _get_collection(palace_path, create=False):
|
||||
"""Helper to get collection from test palace."""
|
||||
"""Helper to get collection from test palace.
|
||||
|
||||
Returns (client, collection) so callers can clean up the client
|
||||
when they are done.
|
||||
"""
|
||||
import chromadb
|
||||
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
if create:
|
||||
return client.get_or_create_collection("mempalace_drawers")
|
||||
return client.get_collection("mempalace_drawers")
|
||||
return client, client.get_or_create_collection("mempalace_drawers")
|
||||
return client, client.get_collection("mempalace_drawers")
|
||||
|
||||
|
||||
# ── Protocol Layer ──────────────────────────────────────────────────────
|
||||
@@ -77,11 +78,12 @@ class TestHandleRequest:
|
||||
assert resp["error"]["code"] == -32601
|
||||
|
||||
def test_tools_call_dispatches(self, monkeypatch, config, palace_path, seeded_kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, seeded_kg)
|
||||
_patch_mcp_server(monkeypatch, config, seeded_kg)
|
||||
from mempalace.mcp_server import handle_request
|
||||
|
||||
# Create a collection so status works
|
||||
_get_collection(palace_path, create=True)
|
||||
_client, _col = _get_collection(palace_path, create=True)
|
||||
del _client
|
||||
|
||||
resp = handle_request(
|
||||
{
|
||||
@@ -100,8 +102,9 @@ class TestHandleRequest:
|
||||
|
||||
class TestReadTools:
|
||||
def test_status_empty_palace(self, monkeypatch, config, palace_path, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_get_collection(palace_path, create=True)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
_client, _col = _get_collection(palace_path, create=True)
|
||||
del _client
|
||||
from mempalace.mcp_server import tool_status
|
||||
|
||||
result = tool_status()
|
||||
@@ -109,7 +112,7 @@ class TestReadTools:
|
||||
assert result["wings"] == {}
|
||||
|
||||
def test_status_with_data(self, monkeypatch, config, palace_path, seeded_collection, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_status
|
||||
|
||||
result = tool_status()
|
||||
@@ -118,7 +121,7 @@ class TestReadTools:
|
||||
assert "notes" in result["wings"]
|
||||
|
||||
def test_list_wings(self, monkeypatch, config, palace_path, seeded_collection, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_list_wings
|
||||
|
||||
result = tool_list_wings()
|
||||
@@ -126,7 +129,7 @@ class TestReadTools:
|
||||
assert result["wings"]["notes"] == 1
|
||||
|
||||
def test_list_rooms_all(self, monkeypatch, config, palace_path, seeded_collection, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_list_rooms
|
||||
|
||||
result = tool_list_rooms()
|
||||
@@ -135,7 +138,7 @@ class TestReadTools:
|
||||
assert "planning" in result["rooms"]
|
||||
|
||||
def test_list_rooms_filtered(self, monkeypatch, config, palace_path, seeded_collection, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_list_rooms
|
||||
|
||||
result = tool_list_rooms(wing="project")
|
||||
@@ -143,7 +146,7 @@ class TestReadTools:
|
||||
assert "planning" not in result["rooms"]
|
||||
|
||||
def test_get_taxonomy(self, monkeypatch, config, palace_path, seeded_collection, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_get_taxonomy
|
||||
|
||||
result = tool_get_taxonomy()
|
||||
@@ -152,8 +155,7 @@ class TestReadTools:
|
||||
assert result["taxonomy"]["notes"]["planning"] == 1
|
||||
|
||||
def test_no_palace_returns_error(self, monkeypatch, config, kg):
|
||||
config._file_config["palace_path"] = "/nonexistent/path"
|
||||
_patch_mcp_server(monkeypatch, config, "/nonexistent/path", kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_status
|
||||
|
||||
result = tool_status()
|
||||
@@ -165,7 +167,7 @@ class TestReadTools:
|
||||
|
||||
class TestSearchTool:
|
||||
def test_search_basic(self, monkeypatch, config, palace_path, seeded_collection, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_search
|
||||
|
||||
result = tool_search(query="JWT authentication tokens")
|
||||
@@ -176,14 +178,14 @@ class TestSearchTool:
|
||||
assert "JWT" in top["text"] or "authentication" in top["text"].lower()
|
||||
|
||||
def test_search_with_wing_filter(self, monkeypatch, config, palace_path, seeded_collection, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_search
|
||||
|
||||
result = tool_search(query="planning", wing="notes")
|
||||
assert all(r["wing"] == "notes" for r in result["results"])
|
||||
|
||||
def test_search_with_room_filter(self, monkeypatch, config, palace_path, seeded_collection, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_search
|
||||
|
||||
result = tool_search(query="database", room="backend")
|
||||
@@ -195,8 +197,9 @@ class TestSearchTool:
|
||||
|
||||
class TestWriteTools:
|
||||
def test_add_drawer(self, monkeypatch, config, palace_path, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_get_collection(palace_path, create=True)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
_client, _col = _get_collection(palace_path, create=True)
|
||||
del _client
|
||||
from mempalace.mcp_server import tool_add_drawer
|
||||
|
||||
result = tool_add_drawer(
|
||||
@@ -210,8 +213,9 @@ class TestWriteTools:
|
||||
assert result["drawer_id"].startswith("drawer_test_wing_test_room_")
|
||||
|
||||
def test_add_drawer_duplicate_detection(self, monkeypatch, config, palace_path, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_get_collection(palace_path, create=True)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
_client, _col = _get_collection(palace_path, create=True)
|
||||
del _client
|
||||
from mempalace.mcp_server import tool_add_drawer
|
||||
|
||||
content = "This is a unique test memory about Rust ownership and borrowing."
|
||||
@@ -219,11 +223,11 @@ class TestWriteTools:
|
||||
assert result1["success"] is True
|
||||
|
||||
result2 = tool_add_drawer(wing="w", room="r", content=content)
|
||||
assert result2["success"] is False
|
||||
assert result2["reason"] == "duplicate"
|
||||
assert result2["success"] is True
|
||||
assert result2["reason"] == "already_exists"
|
||||
|
||||
def test_delete_drawer(self, monkeypatch, config, palace_path, seeded_collection, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_delete_drawer
|
||||
|
||||
result = tool_delete_drawer("drawer_proj_backend_aaa")
|
||||
@@ -231,14 +235,14 @@ class TestWriteTools:
|
||||
assert seeded_collection.count() == 3
|
||||
|
||||
def test_delete_drawer_not_found(self, monkeypatch, config, palace_path, seeded_collection, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_delete_drawer
|
||||
|
||||
result = tool_delete_drawer("nonexistent_drawer")
|
||||
assert result["success"] is False
|
||||
|
||||
def test_check_duplicate(self, monkeypatch, config, palace_path, seeded_collection, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_check_duplicate
|
||||
|
||||
# Exact match text from seeded_collection should be flagged
|
||||
@@ -262,7 +266,7 @@ class TestWriteTools:
|
||||
|
||||
class TestKGTools:
|
||||
def test_kg_add(self, monkeypatch, config, palace_path, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
from mempalace.mcp_server import tool_kg_add
|
||||
|
||||
result = tool_kg_add(
|
||||
@@ -274,14 +278,14 @@ class TestKGTools:
|
||||
assert result["success"] is True
|
||||
|
||||
def test_kg_query(self, monkeypatch, config, palace_path, seeded_kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, seeded_kg)
|
||||
_patch_mcp_server(monkeypatch, config, seeded_kg)
|
||||
from mempalace.mcp_server import tool_kg_query
|
||||
|
||||
result = tool_kg_query(entity="Max")
|
||||
assert result["count"] > 0
|
||||
|
||||
def test_kg_invalidate(self, monkeypatch, config, palace_path, seeded_kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, seeded_kg)
|
||||
_patch_mcp_server(monkeypatch, config, seeded_kg)
|
||||
from mempalace.mcp_server import tool_kg_invalidate
|
||||
|
||||
result = tool_kg_invalidate(
|
||||
@@ -293,14 +297,14 @@ class TestKGTools:
|
||||
assert result["success"] is True
|
||||
|
||||
def test_kg_timeline(self, monkeypatch, config, palace_path, seeded_kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, seeded_kg)
|
||||
_patch_mcp_server(monkeypatch, config, seeded_kg)
|
||||
from mempalace.mcp_server import tool_kg_timeline
|
||||
|
||||
result = tool_kg_timeline(entity="Alice")
|
||||
assert result["count"] > 0
|
||||
|
||||
def test_kg_stats(self, monkeypatch, config, palace_path, seeded_kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, seeded_kg)
|
||||
_patch_mcp_server(monkeypatch, config, seeded_kg)
|
||||
from mempalace.mcp_server import tool_kg_stats
|
||||
|
||||
result = tool_kg_stats()
|
||||
@@ -312,8 +316,9 @@ class TestKGTools:
|
||||
|
||||
class TestDiaryTools:
|
||||
def test_diary_write_and_read(self, monkeypatch, config, palace_path, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_get_collection(palace_path, create=True)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
_client, _col = _get_collection(palace_path, create=True)
|
||||
del _client
|
||||
from mempalace.mcp_server import tool_diary_write, tool_diary_read
|
||||
|
||||
w = tool_diary_write(
|
||||
@@ -330,8 +335,9 @@ class TestDiaryTools:
|
||||
assert "authentication" in r["entries"][0]["content"]
|
||||
|
||||
def test_diary_read_empty(self, monkeypatch, config, palace_path, kg):
|
||||
_patch_mcp_server(monkeypatch, config, palace_path, kg)
|
||||
_get_collection(palace_path, create=True)
|
||||
_patch_mcp_server(monkeypatch, config, kg)
|
||||
_client, _col = _get_collection(palace_path, create=True)
|
||||
del _client
|
||||
from mempalace.mcp_server import tool_diary_read
|
||||
|
||||
r = tool_diary_read(agent_name="Nobody")
|
||||
|
||||
+1
-1
@@ -47,7 +47,7 @@ def test_project_mining():
|
||||
col = client.get_collection("mempalace_drawers")
|
||||
assert col.count() > 0
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_scan_project_respects_gitignore():
|
||||
|
||||
+490
-20
@@ -1,31 +1,501 @@
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from mempalace.normalize import normalize
|
||||
from unittest.mock import patch
|
||||
|
||||
from mempalace.normalize import (
|
||||
_extract_content,
|
||||
_messages_to_transcript,
|
||||
_try_chatgpt_json,
|
||||
_try_claude_ai_json,
|
||||
_try_claude_code_jsonl,
|
||||
_try_codex_jsonl,
|
||||
_try_normalize_json,
|
||||
_try_slack_json,
|
||||
normalize,
|
||||
)
|
||||
|
||||
|
||||
def test_plain_text():
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False)
|
||||
f.write("Hello world\nSecond line\n")
|
||||
f.close()
|
||||
result = normalize(f.name)
|
||||
# ── normalize() top-level ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_plain_text(tmp_path):
|
||||
f = tmp_path / "plain.txt"
|
||||
f.write_text("Hello world\nSecond line\n")
|
||||
result = normalize(str(f))
|
||||
assert "Hello world" in result
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
def test_claude_json():
|
||||
def test_claude_json(tmp_path):
|
||||
data = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}]
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
|
||||
json.dump(data, f)
|
||||
f.close()
|
||||
result = normalize(f.name)
|
||||
f = tmp_path / "claude.json"
|
||||
f.write_text(json.dumps(data))
|
||||
result = normalize(str(f))
|
||||
assert "Hi" in result
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
def test_empty():
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False)
|
||||
f.close()
|
||||
result = normalize(f.name)
|
||||
def test_empty(tmp_path):
|
||||
f = tmp_path / "empty.txt"
|
||||
f.write_text("")
|
||||
result = normalize(str(f))
|
||||
assert result.strip() == ""
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
def test_normalize_io_error():
|
||||
"""normalize raises IOError for unreadable file."""
|
||||
try:
|
||||
normalize("/nonexistent/path/file.txt")
|
||||
assert False, "Should have raised"
|
||||
except IOError as e:
|
||||
assert "Could not read" in str(e)
|
||||
|
||||
|
||||
def test_normalize_already_has_markers(tmp_path):
|
||||
"""Files with >= 3 '>' lines pass through unchanged."""
|
||||
content = "> question 1\nanswer 1\n> question 2\nanswer 2\n> question 3\nanswer 3\n"
|
||||
f = tmp_path / "markers.txt"
|
||||
f.write_text(content)
|
||||
result = normalize(str(f))
|
||||
assert result == content
|
||||
|
||||
|
||||
def test_normalize_json_content_detected_by_brace(tmp_path):
|
||||
"""A .txt file starting with [ triggers JSON parsing."""
|
||||
data = [{"role": "user", "content": "Hey"}, {"role": "assistant", "content": "Hi there"}]
|
||||
f = tmp_path / "chat.txt"
|
||||
f.write_text(json.dumps(data))
|
||||
result = normalize(str(f))
|
||||
assert "Hey" in result
|
||||
|
||||
|
||||
def test_normalize_whitespace_only(tmp_path):
|
||||
f = tmp_path / "ws.txt"
|
||||
f.write_text(" \n \n ")
|
||||
result = normalize(str(f))
|
||||
assert result.strip() == ""
|
||||
|
||||
|
||||
# ── _extract_content ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_content_string():
|
||||
assert _extract_content("hello") == "hello"
|
||||
|
||||
|
||||
def test_extract_content_list_of_strings():
|
||||
assert _extract_content(["hello", "world"]) == "hello world"
|
||||
|
||||
|
||||
def test_extract_content_list_of_blocks():
|
||||
blocks = [{"type": "text", "text": "hello"}, {"type": "image", "url": "x"}]
|
||||
assert _extract_content(blocks) == "hello"
|
||||
|
||||
|
||||
def test_extract_content_dict():
|
||||
assert _extract_content({"text": "hello"}) == "hello"
|
||||
|
||||
|
||||
def test_extract_content_none():
|
||||
assert _extract_content(None) == ""
|
||||
|
||||
|
||||
def test_extract_content_mixed_list():
|
||||
blocks = ["plain", {"type": "text", "text": "block"}]
|
||||
assert _extract_content(blocks) == "plain block"
|
||||
|
||||
|
||||
# ── _try_claude_code_jsonl ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_claude_code_jsonl_valid():
|
||||
lines = [
|
||||
json.dumps({"type": "human", "message": {"content": "What is X?"}}),
|
||||
json.dumps({"type": "assistant", "message": {"content": "X is Y."}}),
|
||||
]
|
||||
result = _try_claude_code_jsonl("\n".join(lines))
|
||||
assert result is not None
|
||||
assert "> What is X?" in result
|
||||
assert "X is Y." in result
|
||||
|
||||
|
||||
def test_claude_code_jsonl_user_type():
|
||||
lines = [
|
||||
json.dumps({"type": "user", "message": {"content": "Q"}}),
|
||||
json.dumps({"type": "assistant", "message": {"content": "A"}}),
|
||||
]
|
||||
result = _try_claude_code_jsonl("\n".join(lines))
|
||||
assert result is not None
|
||||
assert "> Q" in result
|
||||
|
||||
|
||||
def test_claude_code_jsonl_too_few_messages():
|
||||
lines = [json.dumps({"type": "human", "message": {"content": "only one"}})]
|
||||
result = _try_claude_code_jsonl("\n".join(lines))
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_claude_code_jsonl_invalid_json_lines():
|
||||
lines = [
|
||||
"not json",
|
||||
json.dumps({"type": "human", "message": {"content": "Q"}}),
|
||||
json.dumps({"type": "assistant", "message": {"content": "A"}}),
|
||||
]
|
||||
result = _try_claude_code_jsonl("\n".join(lines))
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_claude_code_jsonl_non_dict_entries():
|
||||
lines = [
|
||||
json.dumps([1, 2, 3]),
|
||||
json.dumps({"type": "human", "message": {"content": "Q"}}),
|
||||
json.dumps({"type": "assistant", "message": {"content": "A"}}),
|
||||
]
|
||||
result = _try_claude_code_jsonl("\n".join(lines))
|
||||
assert result is not None
|
||||
|
||||
|
||||
# ── _try_codex_jsonl ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_codex_jsonl_valid():
|
||||
lines = [
|
||||
json.dumps({"type": "session_meta", "payload": {}}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||
]
|
||||
result = _try_codex_jsonl("\n".join(lines))
|
||||
assert result is not None
|
||||
assert "> Q" in result
|
||||
|
||||
|
||||
def test_codex_jsonl_no_session_meta():
|
||||
"""Without session_meta, codex parser returns None."""
|
||||
lines = [
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||
]
|
||||
result = _try_codex_jsonl("\n".join(lines))
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_codex_jsonl_skips_non_event_msg():
|
||||
lines = [
|
||||
json.dumps({"type": "session_meta"}),
|
||||
json.dumps({"type": "response_item", "payload": {"type": "user_message", "message": "X"}}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||
]
|
||||
result = _try_codex_jsonl("\n".join(lines))
|
||||
assert result is not None
|
||||
assert "X" not in result.split("> Q")[0]
|
||||
|
||||
|
||||
def test_codex_jsonl_non_string_message():
|
||||
lines = [
|
||||
json.dumps({"type": "session_meta"}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": 123}}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||
]
|
||||
result = _try_codex_jsonl("\n".join(lines))
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_codex_jsonl_empty_text_skipped():
|
||||
lines = [
|
||||
json.dumps({"type": "session_meta"}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": " "}}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||
]
|
||||
result = _try_codex_jsonl("\n".join(lines))
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_codex_jsonl_payload_not_dict():
|
||||
lines = [
|
||||
json.dumps({"type": "session_meta"}),
|
||||
json.dumps({"type": "event_msg", "payload": "not a dict"}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}),
|
||||
json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}),
|
||||
]
|
||||
result = _try_codex_jsonl("\n".join(lines))
|
||||
assert result is not None
|
||||
|
||||
|
||||
# ── _try_claude_ai_json ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_claude_ai_flat_messages():
|
||||
data = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
result = _try_claude_ai_json(data)
|
||||
assert result is not None
|
||||
assert "> Hello" in result
|
||||
|
||||
|
||||
def test_claude_ai_dict_with_messages_key():
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
}
|
||||
result = _try_claude_ai_json(data)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_claude_ai_privacy_export():
|
||||
data = [
|
||||
{
|
||||
"chat_messages": [
|
||||
{"role": "human", "content": "Q1"},
|
||||
{"role": "ai", "content": "A1"},
|
||||
]
|
||||
}
|
||||
]
|
||||
result = _try_claude_ai_json(data)
|
||||
assert result is not None
|
||||
assert "> Q1" in result
|
||||
|
||||
|
||||
def test_claude_ai_not_a_list():
|
||||
result = _try_claude_ai_json("not a list")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_claude_ai_too_few_messages():
|
||||
data = [{"role": "user", "content": "Hello"}]
|
||||
result = _try_claude_ai_json(data)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_claude_ai_dict_with_chat_messages_key():
|
||||
data = {
|
||||
"chat_messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "World"},
|
||||
]
|
||||
}
|
||||
result = _try_claude_ai_json(data)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_claude_ai_privacy_export_non_dict_items():
|
||||
"""Non-dict items in privacy export are skipped."""
|
||||
data = [
|
||||
{
|
||||
"chat_messages": [
|
||||
"not a dict",
|
||||
{"role": "user", "content": "Q"},
|
||||
{"role": "assistant", "content": "A"},
|
||||
]
|
||||
},
|
||||
"not a convo",
|
||||
]
|
||||
result = _try_claude_ai_json(data)
|
||||
assert result is not None
|
||||
|
||||
|
||||
# ── _try_chatgpt_json ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_chatgpt_json_valid():
|
||||
data = {
|
||||
"mapping": {
|
||||
"root": {
|
||||
"parent": None,
|
||||
"message": None,
|
||||
"children": ["msg1"],
|
||||
},
|
||||
"msg1": {
|
||||
"parent": "root",
|
||||
"message": {
|
||||
"author": {"role": "user"},
|
||||
"content": {"parts": ["Hello ChatGPT"]},
|
||||
},
|
||||
"children": ["msg2"],
|
||||
},
|
||||
"msg2": {
|
||||
"parent": "msg1",
|
||||
"message": {
|
||||
"author": {"role": "assistant"},
|
||||
"content": {"parts": ["Hello! How can I help?"]},
|
||||
},
|
||||
"children": [],
|
||||
},
|
||||
}
|
||||
}
|
||||
result = _try_chatgpt_json(data)
|
||||
assert result is not None
|
||||
assert "> Hello ChatGPT" in result
|
||||
|
||||
|
||||
def test_chatgpt_json_no_mapping():
|
||||
result = _try_chatgpt_json({"data": []})
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_chatgpt_json_not_dict():
|
||||
result = _try_chatgpt_json([1, 2, 3])
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_chatgpt_json_fallback_root():
|
||||
"""Root node has a message (no synthetic root), uses fallback."""
|
||||
data = {
|
||||
"mapping": {
|
||||
"root": {
|
||||
"parent": None,
|
||||
"message": {
|
||||
"author": {"role": "system"},
|
||||
"content": {"parts": ["system prompt"]},
|
||||
},
|
||||
"children": ["msg1"],
|
||||
},
|
||||
"msg1": {
|
||||
"parent": "root",
|
||||
"message": {
|
||||
"author": {"role": "user"},
|
||||
"content": {"parts": ["Hello"]},
|
||||
},
|
||||
"children": ["msg2"],
|
||||
},
|
||||
"msg2": {
|
||||
"parent": "msg1",
|
||||
"message": {
|
||||
"author": {"role": "assistant"},
|
||||
"content": {"parts": ["Hi there"]},
|
||||
},
|
||||
"children": [],
|
||||
},
|
||||
}
|
||||
}
|
||||
result = _try_chatgpt_json(data)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_chatgpt_json_too_few_messages():
|
||||
data = {
|
||||
"mapping": {
|
||||
"root": {
|
||||
"parent": None,
|
||||
"message": None,
|
||||
"children": ["msg1"],
|
||||
},
|
||||
"msg1": {
|
||||
"parent": "root",
|
||||
"message": {
|
||||
"author": {"role": "user"},
|
||||
"content": {"parts": ["Only one"]},
|
||||
},
|
||||
"children": [],
|
||||
},
|
||||
}
|
||||
}
|
||||
result = _try_chatgpt_json(data)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── _try_slack_json ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_slack_json_valid():
|
||||
data = [
|
||||
{"type": "message", "user": "U1", "text": "Hello"},
|
||||
{"type": "message", "user": "U2", "text": "Hi there"},
|
||||
]
|
||||
result = _try_slack_json(data)
|
||||
assert result is not None
|
||||
assert "Hello" in result
|
||||
|
||||
|
||||
def test_slack_json_not_a_list():
|
||||
result = _try_slack_json({"type": "message"})
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_slack_json_too_few_messages():
|
||||
data = [{"type": "message", "user": "U1", "text": "Hello"}]
|
||||
result = _try_slack_json(data)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_slack_json_skips_non_message_types():
|
||||
data = [
|
||||
{"type": "channel_join", "user": "U1", "text": "joined"},
|
||||
{"type": "message", "user": "U1", "text": "Hello"},
|
||||
{"type": "message", "user": "U2", "text": "Hi"},
|
||||
]
|
||||
result = _try_slack_json(data)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_slack_json_three_users():
|
||||
"""Three speakers get alternating roles."""
|
||||
data = [
|
||||
{"type": "message", "user": "U1", "text": "Hello"},
|
||||
{"type": "message", "user": "U2", "text": "Hi"},
|
||||
{"type": "message", "user": "U3", "text": "Hey"},
|
||||
]
|
||||
result = _try_slack_json(data)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_slack_json_empty_text_skipped():
|
||||
data = [
|
||||
{"type": "message", "user": "U1", "text": ""},
|
||||
{"type": "message", "user": "U1", "text": "Hello"},
|
||||
{"type": "message", "user": "U2", "text": "Hi"},
|
||||
]
|
||||
result = _try_slack_json(data)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_slack_json_username_fallback():
|
||||
data = [
|
||||
{"type": "message", "username": "bot1", "text": "Hello"},
|
||||
{"type": "message", "username": "bot2", "text": "Hi"},
|
||||
]
|
||||
result = _try_slack_json(data)
|
||||
assert result is not None
|
||||
|
||||
|
||||
# ── _try_normalize_json ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_try_normalize_json_invalid_json():
|
||||
result = _try_normalize_json("not json at all {{{")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_try_normalize_json_valid_but_unknown_schema():
|
||||
result = _try_normalize_json(json.dumps({"random": "data"}))
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── _messages_to_transcript ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_messages_to_transcript_basic():
|
||||
msgs = [("user", "Q"), ("assistant", "A")]
|
||||
with patch("mempalace.normalize.spellcheck_user_text", side_effect=lambda x: x, create=True):
|
||||
result = _messages_to_transcript(msgs, spellcheck=False)
|
||||
assert "> Q" in result
|
||||
assert "A" in result
|
||||
|
||||
|
||||
def test_messages_to_transcript_consecutive_users():
|
||||
"""Two user messages in a row (no assistant between)."""
|
||||
msgs = [("user", "Q1"), ("user", "Q2"), ("assistant", "A")]
|
||||
result = _messages_to_transcript(msgs, spellcheck=False)
|
||||
assert "> Q1" in result
|
||||
assert "> Q2" in result
|
||||
|
||||
|
||||
def test_messages_to_transcript_assistant_first():
|
||||
"""Leading assistant message (no user before it)."""
|
||||
msgs = [("assistant", "preamble"), ("user", "Q"), ("assistant", "A")]
|
||||
result = _messages_to_transcript(msgs, spellcheck=False)
|
||||
assert "preamble" in result
|
||||
assert "> Q" in result
|
||||
|
||||
@@ -0,0 +1,452 @@
|
||||
"""Tests for mempalace.onboarding."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from mempalace.onboarding import (
|
||||
DEFAULT_WINGS,
|
||||
_ask,
|
||||
_ask_mode,
|
||||
_ask_people,
|
||||
_ask_projects,
|
||||
_ask_wings,
|
||||
_auto_detect,
|
||||
_generate_aaak_bootstrap,
|
||||
_header,
|
||||
_hr,
|
||||
_warn_ambiguous,
|
||||
_yn,
|
||||
quick_setup,
|
||||
run_onboarding,
|
||||
)
|
||||
|
||||
# Force UTF-8 for Windows (source file contains Unicode symbols like hearts/stars)
|
||||
os.environ["PYTHONUTF8"] = "1"
|
||||
|
||||
|
||||
# ── DEFAULT_WINGS ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_default_wings_has_expected_keys():
|
||||
assert "work" in DEFAULT_WINGS
|
||||
assert "personal" in DEFAULT_WINGS
|
||||
assert "combo" in DEFAULT_WINGS
|
||||
|
||||
|
||||
def test_default_wings_work_has_projects():
|
||||
assert "projects" in DEFAULT_WINGS["work"]
|
||||
|
||||
|
||||
def test_default_wings_personal_has_family():
|
||||
assert "family" in DEFAULT_WINGS["personal"]
|
||||
|
||||
|
||||
def test_default_wings_combo_has_both():
|
||||
wings = DEFAULT_WINGS["combo"]
|
||||
assert "family" in wings
|
||||
assert "work" in wings
|
||||
|
||||
|
||||
def test_default_wings_values_are_lists():
|
||||
for mode, wings in DEFAULT_WINGS.items():
|
||||
assert isinstance(wings, list), f"{mode} wings should be a list"
|
||||
assert len(wings) >= 3, f"{mode} should have at least 3 wings"
|
||||
|
||||
|
||||
# ── _warn_ambiguous ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_warn_ambiguous_flags_common_words():
|
||||
people = [
|
||||
{"name": "Grace", "relationship": "friend"},
|
||||
{"name": "Riley", "relationship": "daughter"},
|
||||
]
|
||||
result = _warn_ambiguous(people)
|
||||
assert "Grace" in result
|
||||
# Riley is not a common English word
|
||||
assert "Riley" not in result
|
||||
|
||||
|
||||
def test_warn_ambiguous_empty_list():
|
||||
result = _warn_ambiguous([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_warn_ambiguous_no_ambiguous_names():
|
||||
people = [
|
||||
{"name": "Riley", "relationship": "daughter"},
|
||||
{"name": "Devon", "relationship": "friend"},
|
||||
]
|
||||
result = _warn_ambiguous(people)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_warn_ambiguous_multiple_hits():
|
||||
people = [
|
||||
{"name": "Grace", "relationship": "friend"},
|
||||
{"name": "May", "relationship": "aunt"},
|
||||
{"name": "Joy", "relationship": "sister"},
|
||||
]
|
||||
result = _warn_ambiguous(people)
|
||||
assert "Grace" in result
|
||||
assert "May" in result
|
||||
assert "Joy" in result
|
||||
|
||||
|
||||
# ── quick_setup ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_quick_setup_creates_registry(tmp_path):
|
||||
registry = quick_setup(
|
||||
mode="personal",
|
||||
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||
projects=["MemPalace"],
|
||||
config_dir=tmp_path,
|
||||
)
|
||||
assert "Riley" in registry.people
|
||||
assert "MemPalace" in registry.projects
|
||||
assert registry.mode == "personal"
|
||||
|
||||
|
||||
def test_quick_setup_work_mode(tmp_path):
|
||||
registry = quick_setup(
|
||||
mode="work",
|
||||
people=[{"name": "Alice", "relationship": "colleague", "context": "work"}],
|
||||
projects=["Acme"],
|
||||
config_dir=tmp_path,
|
||||
)
|
||||
assert registry.mode == "work"
|
||||
assert "Alice" in registry.people
|
||||
assert "Acme" in registry.projects
|
||||
|
||||
|
||||
def test_quick_setup_empty(tmp_path):
|
||||
registry = quick_setup(mode="personal", people=[], config_dir=tmp_path)
|
||||
assert len(registry.people) == 0
|
||||
assert len(registry.projects) == 0
|
||||
|
||||
|
||||
def test_quick_setup_saves_to_disk(tmp_path):
|
||||
quick_setup(
|
||||
mode="personal",
|
||||
people=[{"name": "Riley", "relationship": "daughter", "context": "personal"}],
|
||||
config_dir=tmp_path,
|
||||
)
|
||||
assert (tmp_path / "entity_registry.json").exists()
|
||||
|
||||
|
||||
# ── _generate_aaak_bootstrap ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_generate_aaak_bootstrap_creates_files(tmp_path):
|
||||
people = [
|
||||
{"name": "Riley", "relationship": "daughter", "context": "personal"},
|
||||
{"name": "Devon", "relationship": "friend", "context": "personal"},
|
||||
]
|
||||
projects = ["MemPalace"]
|
||||
wings = ["family", "creative"]
|
||||
_generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path)
|
||||
|
||||
assert (tmp_path / "aaak_entities.md").exists()
|
||||
assert (tmp_path / "critical_facts.md").exists()
|
||||
|
||||
|
||||
def test_generate_aaak_bootstrap_entities_content(tmp_path):
|
||||
people = [{"name": "Riley", "relationship": "daughter", "context": "personal"}]
|
||||
projects = ["MemPalace"]
|
||||
wings = ["family"]
|
||||
_generate_aaak_bootstrap(people, projects, wings, "personal", config_dir=tmp_path)
|
||||
|
||||
content = (tmp_path / "aaak_entities.md").read_text()
|
||||
assert "Riley" in content
|
||||
assert "RIL" in content # entity code
|
||||
assert "MemPalace" in content
|
||||
|
||||
|
||||
def test_generate_aaak_bootstrap_facts_content(tmp_path):
|
||||
people = [
|
||||
{"name": "Alice", "relationship": "colleague", "context": "work"},
|
||||
]
|
||||
projects = ["Acme"]
|
||||
wings = ["projects"]
|
||||
_generate_aaak_bootstrap(people, projects, wings, "work", config_dir=tmp_path)
|
||||
|
||||
content = (tmp_path / "critical_facts.md").read_text()
|
||||
assert "Alice" in content
|
||||
assert "Acme" in content
|
||||
assert "work" in content.lower()
|
||||
|
||||
|
||||
def test_generate_aaak_bootstrap_empty_people(tmp_path):
|
||||
_generate_aaak_bootstrap([], [], ["general"], "personal", config_dir=tmp_path)
|
||||
assert (tmp_path / "aaak_entities.md").exists()
|
||||
assert (tmp_path / "critical_facts.md").exists()
|
||||
|
||||
|
||||
def test_generate_aaak_bootstrap_collision(tmp_path):
|
||||
"""Two people with same 3-letter code get different codes."""
|
||||
people = [
|
||||
{"name": "Alice", "relationship": "friend", "context": "work"},
|
||||
{"name": "Alison", "relationship": "coworker", "context": "work"},
|
||||
]
|
||||
_generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path)
|
||||
content = (tmp_path / "aaak_entities.md").read_text()
|
||||
assert "ALI" in content
|
||||
assert "ALIS" in content
|
||||
|
||||
|
||||
def test_generate_aaak_bootstrap_no_relationship(tmp_path):
|
||||
"""Person without relationship string still generates entry."""
|
||||
people = [{"name": "Bob", "context": "work"}]
|
||||
_generate_aaak_bootstrap(people, [], ["work"], "work", config_dir=tmp_path)
|
||||
content = (tmp_path / "aaak_entities.md").read_text()
|
||||
assert "BOB=Bob" in content
|
||||
|
||||
|
||||
# ── _hr, _header ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_hr_prints_line(capsys):
|
||||
_hr()
|
||||
out = capsys.readouterr().out
|
||||
assert "─" in out
|
||||
|
||||
|
||||
def test_header_prints_banner(capsys):
|
||||
_header("Test Title")
|
||||
out = capsys.readouterr().out
|
||||
assert "Test Title" in out
|
||||
assert "=" in out
|
||||
|
||||
|
||||
# ── _ask ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_ask_with_default_uses_default():
|
||||
with patch("builtins.input", return_value=""):
|
||||
result = _ask("prompt", default="fallback")
|
||||
assert result == "fallback"
|
||||
|
||||
|
||||
def test_ask_with_default_uses_input():
|
||||
with patch("builtins.input", return_value="custom"):
|
||||
result = _ask("prompt", default="fallback")
|
||||
assert result == "custom"
|
||||
|
||||
|
||||
def test_ask_no_default():
|
||||
with patch("builtins.input", return_value="answer"):
|
||||
result = _ask("prompt")
|
||||
assert result == "answer"
|
||||
|
||||
|
||||
# ── _yn ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_yn_default_yes_empty_input():
|
||||
with patch("builtins.input", return_value=""):
|
||||
assert _yn("continue?") is True
|
||||
|
||||
|
||||
def test_yn_default_no_empty_input():
|
||||
with patch("builtins.input", return_value=""):
|
||||
assert _yn("continue?", default="n") is False
|
||||
|
||||
|
||||
def test_yn_explicit_yes():
|
||||
with patch("builtins.input", return_value="yes"):
|
||||
assert _yn("continue?", default="n") is True
|
||||
|
||||
|
||||
def test_yn_explicit_no():
|
||||
with patch("builtins.input", return_value="no"):
|
||||
assert _yn("continue?") is False
|
||||
|
||||
|
||||
# ── _ask_mode ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_ask_mode_work():
|
||||
with patch("builtins.input", return_value="1"):
|
||||
assert _ask_mode() == "work"
|
||||
|
||||
|
||||
def test_ask_mode_personal():
|
||||
with patch("builtins.input", return_value="2"):
|
||||
assert _ask_mode() == "personal"
|
||||
|
||||
|
||||
def test_ask_mode_combo():
|
||||
with patch("builtins.input", return_value="3"):
|
||||
assert _ask_mode() == "combo"
|
||||
|
||||
|
||||
def test_ask_mode_retries_on_bad_input():
|
||||
with patch("builtins.input", side_effect=["x", "bad", "1"]):
|
||||
assert _ask_mode() == "work"
|
||||
|
||||
|
||||
# ── _ask_people ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_ask_people_personal_mode():
|
||||
with patch("builtins.input", side_effect=["Alice, daughter", "", "done"]):
|
||||
people, aliases = _ask_people("personal")
|
||||
assert len(people) == 1
|
||||
assert people[0]["name"] == "Alice"
|
||||
assert people[0]["relationship"] == "daughter"
|
||||
|
||||
|
||||
def test_ask_people_work_mode():
|
||||
with patch("builtins.input", side_effect=["Bob, manager", "", "done"]):
|
||||
people, aliases = _ask_people("work")
|
||||
assert len(people) == 1
|
||||
assert people[0]["name"] == "Bob"
|
||||
assert people[0]["context"] == "work"
|
||||
|
||||
|
||||
def test_ask_people_combo_mode():
|
||||
with patch(
|
||||
"builtins.input",
|
||||
side_effect=[
|
||||
"Alice, daughter",
|
||||
"",
|
||||
"done", # personal
|
||||
"Bob, boss",
|
||||
"done", # work
|
||||
],
|
||||
):
|
||||
people, aliases = _ask_people("combo")
|
||||
assert len(people) == 2
|
||||
|
||||
|
||||
def test_ask_people_with_nickname():
|
||||
with patch("builtins.input", side_effect=["Alice, daughter", "Ali", "done"]):
|
||||
people, aliases = _ask_people("personal")
|
||||
assert aliases == {"Ali": "Alice"}
|
||||
|
||||
|
||||
def test_ask_people_empty_name_skipped():
|
||||
with patch("builtins.input", side_effect=["", "done"]):
|
||||
people, aliases = _ask_people("personal")
|
||||
assert len(people) == 0
|
||||
|
||||
|
||||
# ── _ask_projects ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_ask_projects_personal_returns_empty():
|
||||
result = _ask_projects("personal")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_ask_projects_work_mode():
|
||||
with patch("builtins.input", side_effect=["Acme", "BigCo", "done"]):
|
||||
result = _ask_projects("work")
|
||||
assert result == ["Acme", "BigCo"]
|
||||
|
||||
|
||||
def test_ask_projects_empty_entry_stops():
|
||||
with patch("builtins.input", side_effect=["Acme", ""]):
|
||||
result = _ask_projects("work")
|
||||
assert result == ["Acme"]
|
||||
|
||||
|
||||
# ── _ask_wings ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_ask_wings_accept_defaults():
|
||||
with patch("builtins.input", return_value=""):
|
||||
result = _ask_wings("work")
|
||||
assert result == DEFAULT_WINGS["work"]
|
||||
|
||||
|
||||
def test_ask_wings_custom():
|
||||
with patch("builtins.input", return_value="alpha, beta, gamma"):
|
||||
result = _ask_wings("personal")
|
||||
assert result == ["alpha", "beta", "gamma"]
|
||||
|
||||
|
||||
# ── _auto_detect ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auto_detect_no_files(tmp_path):
|
||||
result = _auto_detect(str(tmp_path), [])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_auto_detect_filters_known(tmp_path):
|
||||
known = [{"name": "Alice"}]
|
||||
fake_detected = {
|
||||
"people": [
|
||||
{"name": "Alice", "confidence": 0.9, "signals": ["test"]},
|
||||
{"name": "Bob", "confidence": 0.8, "signals": ["test"]},
|
||||
],
|
||||
"projects": [],
|
||||
"uncertain": [],
|
||||
}
|
||||
with (
|
||||
patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]),
|
||||
patch("mempalace.onboarding.detect_entities", return_value=fake_detected),
|
||||
):
|
||||
result = _auto_detect(str(tmp_path), known)
|
||||
names = [p["name"] for p in result]
|
||||
assert "Alice" not in names
|
||||
assert "Bob" in names
|
||||
|
||||
|
||||
def test_auto_detect_filters_low_confidence(tmp_path):
|
||||
fake_detected = {
|
||||
"people": [{"name": "Bob", "confidence": 0.5, "signals": ["test"]}],
|
||||
"projects": [],
|
||||
"uncertain": [],
|
||||
}
|
||||
with (
|
||||
patch("mempalace.onboarding.scan_for_detection", return_value=["file.txt"]),
|
||||
patch("mempalace.onboarding.detect_entities", return_value=fake_detected),
|
||||
):
|
||||
result = _auto_detect(str(tmp_path), [])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_auto_detect_handles_exception(tmp_path):
|
||||
with patch("mempalace.onboarding.scan_for_detection", side_effect=Exception("boom")):
|
||||
result = _auto_detect(str(tmp_path), [])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ── run_onboarding ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_run_onboarding_basic_flow(tmp_path):
|
||||
"""Test the full onboarding flow with minimal mocking."""
|
||||
with (
|
||||
patch("mempalace.onboarding._ask_mode", return_value="work"),
|
||||
patch(
|
||||
"mempalace.onboarding._ask_people",
|
||||
return_value=([{"name": "Bob", "relationship": "boss", "context": "work"}], {}),
|
||||
),
|
||||
patch("mempalace.onboarding._ask_projects", return_value=["Acme"]),
|
||||
patch("mempalace.onboarding._ask_wings", return_value=["projects", "team"]),
|
||||
patch("mempalace.onboarding._yn", return_value=False),
|
||||
patch("mempalace.onboarding._warn_ambiguous", return_value=[]),
|
||||
):
|
||||
registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False)
|
||||
assert "Bob" in registry.people
|
||||
assert "Acme" in registry.projects
|
||||
|
||||
|
||||
def test_run_onboarding_with_ambiguous_names(tmp_path):
|
||||
"""Onboarding prints a warning for ambiguous names."""
|
||||
with (
|
||||
patch("mempalace.onboarding._ask_mode", return_value="personal"),
|
||||
patch(
|
||||
"mempalace.onboarding._ask_people",
|
||||
return_value=([{"name": "Grace", "relationship": "friend", "context": "personal"}], {}),
|
||||
),
|
||||
patch("mempalace.onboarding._ask_projects", return_value=[]),
|
||||
patch("mempalace.onboarding._ask_wings", return_value=["family"]),
|
||||
patch("mempalace.onboarding._yn", return_value=False),
|
||||
):
|
||||
registry = run_onboarding(directory=".", config_dir=tmp_path, auto_detect=False)
|
||||
assert "Grace" in registry.people
|
||||
@@ -0,0 +1,244 @@
|
||||
"""Tests for mempalace.palace_graph — graph traversal layer.
|
||||
|
||||
All ChromaDB access is mocked — no real database needed.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _make_fake_collection(metadatas, ids=None):
|
||||
"""Create a mock collection that returns the given metadata in batches."""
|
||||
if ids is None:
|
||||
ids = [f"id_{i}" for i in range(len(metadatas))]
|
||||
|
||||
col = MagicMock()
|
||||
col.count.return_value = len(metadatas)
|
||||
|
||||
def fake_get(limit=1000, offset=0, include=None):
|
||||
batch_meta = metadatas[offset : offset + limit]
|
||||
batch_ids = ids[offset : offset + limit]
|
||||
return {"ids": batch_ids, "metadatas": batch_meta}
|
||||
|
||||
col.get.side_effect = fake_get
|
||||
return col
|
||||
|
||||
|
||||
# Patch chromadb at import time so palace_graph can be imported
|
||||
with patch.dict("sys.modules", {"chromadb": MagicMock()}):
|
||||
from mempalace.palace_graph import (
|
||||
_fuzzy_match,
|
||||
build_graph,
|
||||
find_tunnels,
|
||||
graph_stats,
|
||||
traverse,
|
||||
)
|
||||
|
||||
|
||||
# --- build_graph ---
|
||||
|
||||
|
||||
class TestBuildGraph:
|
||||
def test_empty_collection(self):
|
||||
col = _make_fake_collection([])
|
||||
nodes, edges = build_graph(col=col)
|
||||
assert nodes == {}
|
||||
assert edges == []
|
||||
|
||||
def test_falsy_collection(self):
|
||||
"""When col is explicitly falsy, build_graph returns empty."""
|
||||
nodes, edges = build_graph(col=0)
|
||||
assert nodes == {}
|
||||
assert edges == []
|
||||
|
||||
def test_single_wing_no_edges(self):
|
||||
col = _make_fake_collection(
|
||||
[
|
||||
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
|
||||
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-02"},
|
||||
]
|
||||
)
|
||||
nodes, edges = build_graph(col=col)
|
||||
assert "auth" in nodes
|
||||
assert nodes["auth"]["count"] == 2
|
||||
assert edges == []
|
||||
|
||||
def test_multi_wing_creates_edges(self):
|
||||
col = _make_fake_collection(
|
||||
[
|
||||
{
|
||||
"room": "chromadb",
|
||||
"wing": "wing_code",
|
||||
"hall": "databases",
|
||||
"date": "2026-01-01",
|
||||
},
|
||||
{
|
||||
"room": "chromadb",
|
||||
"wing": "wing_project",
|
||||
"hall": "databases",
|
||||
"date": "2026-01-02",
|
||||
},
|
||||
]
|
||||
)
|
||||
nodes, edges = build_graph(col=col)
|
||||
assert "chromadb" in nodes
|
||||
assert len(edges) == 1
|
||||
assert edges[0]["wing_a"] == "wing_code"
|
||||
assert edges[0]["wing_b"] == "wing_project"
|
||||
assert edges[0]["hall"] == "databases"
|
||||
|
||||
def test_general_room_excluded(self):
|
||||
col = _make_fake_collection(
|
||||
[
|
||||
{"room": "general", "wing": "wing_code", "hall": "misc", "date": ""},
|
||||
]
|
||||
)
|
||||
nodes, edges = build_graph(col=col)
|
||||
assert "general" not in nodes
|
||||
|
||||
def test_missing_wing_excluded(self):
|
||||
col = _make_fake_collection(
|
||||
[
|
||||
{"room": "orphan", "wing": "", "hall": "misc", "date": ""},
|
||||
]
|
||||
)
|
||||
nodes, edges = build_graph(col=col)
|
||||
assert "orphan" not in nodes
|
||||
|
||||
def test_dates_capped_at_five(self):
|
||||
col = _make_fake_collection(
|
||||
[
|
||||
{"room": "busy", "wing": "w", "hall": "h", "date": f"2026-01-{i:02d}"}
|
||||
for i in range(1, 10)
|
||||
]
|
||||
)
|
||||
nodes, _ = build_graph(col=col)
|
||||
assert len(nodes["busy"]["dates"]) <= 5
|
||||
|
||||
|
||||
# --- traverse ---
|
||||
|
||||
|
||||
class TestTraverse:
|
||||
def _build_col(self):
|
||||
return _make_fake_collection(
|
||||
[
|
||||
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
|
||||
{"room": "login", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
|
||||
{"room": "deploy", "wing": "wing_ops", "hall": "infra", "date": "2026-01-01"},
|
||||
]
|
||||
)
|
||||
|
||||
def test_traverse_known_room(self):
|
||||
col = self._build_col()
|
||||
result = traverse("auth", col=col)
|
||||
assert isinstance(result, list)
|
||||
rooms = [r["room"] for r in result]
|
||||
assert "auth" in rooms
|
||||
# login shares wing_code with auth
|
||||
assert "login" in rooms
|
||||
|
||||
def test_traverse_unknown_room(self):
|
||||
col = self._build_col()
|
||||
result = traverse("nonexistent", col=col)
|
||||
assert isinstance(result, dict)
|
||||
assert "error" in result
|
||||
assert "suggestions" in result
|
||||
|
||||
def test_traverse_max_hops(self):
|
||||
col = self._build_col()
|
||||
result = traverse("auth", col=col, max_hops=0)
|
||||
# Only the start room itself at hop 0
|
||||
assert len(result) == 1
|
||||
assert result[0]["room"] == "auth"
|
||||
|
||||
|
||||
# --- find_tunnels ---
|
||||
|
||||
|
||||
class TestFindTunnels:
|
||||
def _build_tunnel_col(self):
|
||||
return _make_fake_collection(
|
||||
[
|
||||
{"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"},
|
||||
{"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"},
|
||||
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
|
||||
]
|
||||
)
|
||||
|
||||
def test_find_all_tunnels(self):
|
||||
col = self._build_tunnel_col()
|
||||
tunnels = find_tunnels(col=col)
|
||||
assert len(tunnels) == 1
|
||||
assert tunnels[0]["room"] == "chromadb"
|
||||
|
||||
def test_find_tunnels_with_wing_filter(self):
|
||||
col = self._build_tunnel_col()
|
||||
tunnels = find_tunnels(wing_a="wing_code", col=col)
|
||||
assert len(tunnels) == 1
|
||||
|
||||
def test_find_tunnels_no_match(self):
|
||||
col = self._build_tunnel_col()
|
||||
tunnels = find_tunnels(wing_a="wing_nonexistent", col=col)
|
||||
assert tunnels == []
|
||||
|
||||
def test_find_tunnels_both_wings(self):
|
||||
col = self._build_tunnel_col()
|
||||
tunnels = find_tunnels(wing_a="wing_code", wing_b="wing_project", col=col)
|
||||
assert len(tunnels) == 1
|
||||
assert tunnels[0]["room"] == "chromadb"
|
||||
|
||||
|
||||
# --- graph_stats ---
|
||||
|
||||
|
||||
class TestGraphStats:
|
||||
def test_empty_graph(self):
|
||||
col = _make_fake_collection([])
|
||||
stats = graph_stats(col=col)
|
||||
assert stats["total_rooms"] == 0
|
||||
assert stats["tunnel_rooms"] == 0
|
||||
assert stats["total_edges"] == 0
|
||||
|
||||
def test_stats_with_data(self):
|
||||
col = _make_fake_collection(
|
||||
[
|
||||
{"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"},
|
||||
{"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"},
|
||||
{"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"},
|
||||
]
|
||||
)
|
||||
stats = graph_stats(col=col)
|
||||
assert stats["total_rooms"] == 2
|
||||
assert stats["tunnel_rooms"] == 1
|
||||
assert stats["total_edges"] == 1
|
||||
assert "wing_code" in stats["rooms_per_wing"]
|
||||
|
||||
|
||||
# --- _fuzzy_match ---
|
||||
|
||||
|
||||
class TestFuzzyMatch:
|
||||
def test_exact_substring(self):
|
||||
nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}}
|
||||
result = _fuzzy_match("chromadb", nodes)
|
||||
assert "chromadb-setup" in result
|
||||
|
||||
def test_partial_word_match(self):
|
||||
nodes = {"chromadb-setup": {}, "auth-module": {}, "deploy-config": {}}
|
||||
result = _fuzzy_match("auth", nodes)
|
||||
assert "auth-module" in result
|
||||
|
||||
def test_no_match(self):
|
||||
nodes = {"chromadb-setup": {}, "auth-module": {}}
|
||||
result = _fuzzy_match("zzzzz", nodes)
|
||||
assert result == []
|
||||
|
||||
def test_hyphenated_query(self):
|
||||
nodes = {"riley-college-apps": {}, "college-prep": {}}
|
||||
result = _fuzzy_match("riley-college", nodes)
|
||||
assert "riley-college-apps" in result
|
||||
|
||||
def test_max_results(self):
|
||||
nodes = {f"room-{i}": {} for i in range(20)}
|
||||
result = _fuzzy_match("room", nodes, n=3)
|
||||
assert len(result) <= 3
|
||||
@@ -0,0 +1,264 @@
|
||||
"""Tests for mempalace.room_detector_local."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from mempalace.room_detector_local import (
|
||||
FOLDER_ROOM_MAP,
|
||||
detect_rooms_from_files,
|
||||
detect_rooms_from_folders,
|
||||
detect_rooms_local,
|
||||
get_user_approval,
|
||||
print_proposed_structure,
|
||||
save_config,
|
||||
)
|
||||
|
||||
|
||||
# ── FOLDER_ROOM_MAP ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_folder_room_map_has_expected_mappings():
|
||||
assert FOLDER_ROOM_MAP["frontend"] == "frontend"
|
||||
assert FOLDER_ROOM_MAP["backend"] == "backend"
|
||||
assert FOLDER_ROOM_MAP["docs"] == "documentation"
|
||||
assert FOLDER_ROOM_MAP["tests"] == "testing"
|
||||
assert FOLDER_ROOM_MAP["config"] == "configuration"
|
||||
|
||||
|
||||
def test_folder_room_map_alternative_names():
|
||||
assert FOLDER_ROOM_MAP["front-end"] == "frontend"
|
||||
assert FOLDER_ROOM_MAP["back-end"] == "backend"
|
||||
assert FOLDER_ROOM_MAP["server"] == "backend"
|
||||
assert FOLDER_ROOM_MAP["client"] == "frontend"
|
||||
assert FOLDER_ROOM_MAP["api"] == "backend"
|
||||
|
||||
|
||||
# ── detect_rooms_from_folders ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_detect_rooms_from_folders_standard_layout(tmp_path):
|
||||
(tmp_path / "frontend").mkdir()
|
||||
(tmp_path / "backend").mkdir()
|
||||
(tmp_path / "docs").mkdir()
|
||||
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||
room_names = {r["name"] for r in rooms}
|
||||
assert "frontend" in room_names
|
||||
assert "backend" in room_names
|
||||
assert "documentation" in room_names
|
||||
|
||||
|
||||
def test_detect_rooms_from_folders_always_has_general(tmp_path):
|
||||
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||
room_names = {r["name"] for r in rooms}
|
||||
assert "general" in room_names
|
||||
|
||||
|
||||
def test_detect_rooms_from_folders_empty_dir(tmp_path):
|
||||
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||
# Should at least have "general"
|
||||
assert len(rooms) >= 1
|
||||
assert any(r["name"] == "general" for r in rooms)
|
||||
|
||||
|
||||
def test_detect_rooms_from_folders_skips_git(tmp_path):
|
||||
(tmp_path / ".git").mkdir()
|
||||
(tmp_path / "node_modules").mkdir()
|
||||
(tmp_path / "frontend").mkdir()
|
||||
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||
room_names = {r["name"] for r in rooms}
|
||||
assert ".git" not in room_names
|
||||
assert "node_modules" not in room_names
|
||||
|
||||
|
||||
def test_detect_rooms_from_folders_nested_dirs(tmp_path):
|
||||
src = tmp_path / "src"
|
||||
src.mkdir()
|
||||
(src / "components").mkdir()
|
||||
(src / "routes").mkdir()
|
||||
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||
room_names = {r["name"] for r in rooms}
|
||||
# Nested dirs should be detected at one level deep
|
||||
assert "frontend" in room_names or "backend" in room_names
|
||||
|
||||
|
||||
def test_detect_rooms_from_folders_room_has_description(tmp_path):
|
||||
(tmp_path / "docs").mkdir()
|
||||
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||
doc_room = next((r for r in rooms if r["name"] == "documentation"), None)
|
||||
assert doc_room is not None
|
||||
assert "description" in doc_room
|
||||
assert "docs" in doc_room["description"]
|
||||
|
||||
|
||||
def test_detect_rooms_from_folders_room_has_keywords(tmp_path):
|
||||
(tmp_path / "frontend").mkdir()
|
||||
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||
fe_room = next((r for r in rooms if r["name"] == "frontend"), None)
|
||||
assert fe_room is not None
|
||||
assert "keywords" in fe_room
|
||||
assert len(fe_room["keywords"]) > 0
|
||||
|
||||
|
||||
def test_detect_rooms_from_folders_custom_named_dirs(tmp_path):
|
||||
(tmp_path / "mylib").mkdir()
|
||||
rooms = detect_rooms_from_folders(str(tmp_path))
|
||||
room_names = {r["name"] for r in rooms}
|
||||
# Custom dir names that don't match FOLDER_ROOM_MAP get added as-is
|
||||
assert "mylib" in room_names or "general" in room_names
|
||||
|
||||
|
||||
# ── detect_rooms_from_files ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_detect_rooms_from_files_with_matching_filenames(tmp_path):
|
||||
# Create files whose names contain room keywords
|
||||
for name in ["test_auth.py", "test_login.py", "test_api.py"]:
|
||||
(tmp_path / name).write_text("content")
|
||||
rooms = detect_rooms_from_files(str(tmp_path))
|
||||
room_names = {r["name"] for r in rooms}
|
||||
assert "testing" in room_names or "general" in room_names
|
||||
|
||||
|
||||
def test_detect_rooms_from_files_empty_dir(tmp_path):
|
||||
rooms = detect_rooms_from_files(str(tmp_path))
|
||||
assert len(rooms) >= 1
|
||||
assert any(r["name"] == "general" for r in rooms)
|
||||
|
||||
|
||||
def test_detect_rooms_from_files_caps_at_six(tmp_path):
|
||||
# Create many files with different keywords to hit the cap
|
||||
for keyword in ["test", "doc", "api", "config", "frontend", "backend", "design", "meeting"]:
|
||||
for i in range(3):
|
||||
(tmp_path / f"{keyword}_file_{i}.txt").write_text("content")
|
||||
rooms = detect_rooms_from_files(str(tmp_path))
|
||||
assert len(rooms) <= 6
|
||||
|
||||
|
||||
# ── save_config ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_save_config_creates_yaml(tmp_path):
|
||||
rooms = [
|
||||
{"name": "frontend", "description": "UI files", "keywords": ["frontend"]},
|
||||
{"name": "backend", "description": "Server files", "keywords": ["backend"]},
|
||||
]
|
||||
save_config(str(tmp_path), "myproject", rooms)
|
||||
config_file = tmp_path / "mempalace.yaml"
|
||||
assert config_file.exists()
|
||||
content = config_file.read_text()
|
||||
assert "myproject" in content
|
||||
assert "frontend" in content
|
||||
assert "backend" in content
|
||||
|
||||
|
||||
def test_save_config_valid_yaml(tmp_path):
|
||||
import yaml
|
||||
|
||||
rooms = [{"name": "general", "description": "All files", "keywords": []}]
|
||||
save_config(str(tmp_path), "test_proj", rooms)
|
||||
config_file = tmp_path / "mempalace.yaml"
|
||||
data = yaml.safe_load(config_file.read_text())
|
||||
assert data["wing"] == "test_proj"
|
||||
assert len(data["rooms"]) == 1
|
||||
assert data["rooms"][0]["name"] == "general"
|
||||
|
||||
|
||||
# ── print_proposed_structure ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_print_proposed_structure(capsys):
|
||||
rooms = [
|
||||
{"name": "frontend", "description": "UI files"},
|
||||
{"name": "general", "description": "Everything else"},
|
||||
]
|
||||
print_proposed_structure("myapp", rooms, 42, "folder structure")
|
||||
out = capsys.readouterr().out
|
||||
assert "myapp" in out
|
||||
assert "frontend" in out
|
||||
assert "42 files" in out
|
||||
assert "folder structure" in out
|
||||
|
||||
|
||||
# ── get_user_approval ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_user_approval_accept_all():
|
||||
rooms = [{"name": "frontend", "description": "UI"}]
|
||||
with patch("builtins.input", return_value=""):
|
||||
result = get_user_approval(rooms)
|
||||
assert result == rooms
|
||||
|
||||
|
||||
def test_get_user_approval_edit_remove():
|
||||
rooms = [
|
||||
{"name": "frontend", "description": "UI"},
|
||||
{"name": "backend", "description": "Server"},
|
||||
]
|
||||
with patch("builtins.input", side_effect=["edit", "1", "n"]):
|
||||
result = get_user_approval(rooms)
|
||||
# Room 1 (frontend) removed
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "backend"
|
||||
|
||||
|
||||
def test_get_user_approval_add_room():
|
||||
rooms = [{"name": "general", "description": "All files"}]
|
||||
with patch(
|
||||
"builtins.input",
|
||||
side_effect=[
|
||||
"add",
|
||||
"custom_room",
|
||||
"My custom room",
|
||||
"",
|
||||
],
|
||||
):
|
||||
result = get_user_approval(rooms)
|
||||
names = [r["name"] for r in result]
|
||||
assert "custom_room" in names
|
||||
|
||||
|
||||
# ── detect_rooms_local ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_detect_rooms_local_yes_mode(tmp_path):
|
||||
(tmp_path / "docs").mkdir()
|
||||
(tmp_path / "docs" / "readme.md").write_text("hello")
|
||||
mock_miner = MagicMock()
|
||||
mock_miner.scan_project.return_value = ["file1.py"]
|
||||
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
|
||||
detect_rooms_local(str(tmp_path), yes=True)
|
||||
assert (tmp_path / "mempalace.yaml").exists()
|
||||
|
||||
|
||||
def test_detect_rooms_local_fallback_to_files(tmp_path):
|
||||
"""When folder detection gives only 'general', falls back to file patterns."""
|
||||
for i in range(3):
|
||||
(tmp_path / f"test_file_{i}.py").write_text("content")
|
||||
mock_miner = MagicMock()
|
||||
mock_miner.scan_project.return_value = ["f1", "f2"]
|
||||
with patch.dict("sys.modules", {"mempalace.miner": mock_miner}):
|
||||
detect_rooms_local(str(tmp_path), yes=True)
|
||||
assert (tmp_path / "mempalace.yaml").exists()
|
||||
|
||||
|
||||
def test_detect_rooms_local_missing_dir():
|
||||
"""Non-existent directory causes sys.exit."""
|
||||
import pytest
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
detect_rooms_local("/nonexistent/path/that/does/not/exist", yes=True)
|
||||
|
||||
|
||||
def test_detect_rooms_local_interactive(tmp_path):
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "main.py").write_text("code")
|
||||
mock_miner = MagicMock()
|
||||
mock_miner.scan_project.return_value = ["f1"]
|
||||
with (
|
||||
patch.dict("sys.modules", {"mempalace.miner": mock_miner}),
|
||||
patch(
|
||||
"mempalace.room_detector_local.get_user_approval",
|
||||
return_value=[{"name": "general", "description": "All files", "keywords": []}],
|
||||
),
|
||||
):
|
||||
detect_rooms_local(str(tmp_path), yes=False)
|
||||
assert (tmp_path / "mempalace.yaml").exists()
|
||||
+85
-5
@@ -1,10 +1,18 @@
|
||||
"""
|
||||
test_searcher.py — Tests for the programmatic search_memories API.
|
||||
test_searcher.py -- Tests for both search() (CLI) and search_memories() (API).
|
||||
|
||||
Tests the library-facing search interface (not the CLI print variant).
|
||||
Uses the real ChromaDB fixtures from conftest.py for integration tests,
|
||||
plus mock-based tests for error paths.
|
||||
"""
|
||||
|
||||
from mempalace.searcher import search_memories
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mempalace.searcher import SearchError, search, search_memories
|
||||
|
||||
|
||||
# ── search_memories (API) ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSearchMemories:
|
||||
@@ -30,8 +38,8 @@ class TestSearchMemories:
|
||||
result = search_memories("code", palace_path, n_results=2)
|
||||
assert len(result["results"]) <= 2
|
||||
|
||||
def test_no_palace_returns_error(self):
|
||||
result = search_memories("anything", "/nonexistent/path")
|
||||
def test_no_palace_returns_error(self, tmp_path):
|
||||
result = search_memories("anything", str(tmp_path / "missing"))
|
||||
assert "error" in result
|
||||
|
||||
def test_result_fields(self, palace_path, seeded_collection):
|
||||
@@ -43,3 +51,75 @@ class TestSearchMemories:
|
||||
assert "source_file" in hit
|
||||
assert "similarity" in hit
|
||||
assert isinstance(hit["similarity"], float)
|
||||
|
||||
def test_search_memories_query_error(self):
|
||||
"""search_memories returns error dict when query raises."""
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.side_effect = RuntimeError("query failed")
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client):
|
||||
result = search_memories("test", "/fake/path")
|
||||
assert "error" in result
|
||||
assert "query failed" in result["error"]
|
||||
|
||||
def test_search_memories_filters_in_result(self, palace_path, seeded_collection):
|
||||
result = search_memories("test", palace_path, wing="project", room="backend")
|
||||
assert result["filters"]["wing"] == "project"
|
||||
assert result["filters"]["room"] == "backend"
|
||||
|
||||
|
||||
# ── search() (CLI print function) ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestSearchCLI:
|
||||
def test_search_prints_results(self, palace_path, seeded_collection, capsys):
|
||||
search("JWT authentication", palace_path)
|
||||
captured = capsys.readouterr()
|
||||
assert "JWT" in captured.out or "authentication" in captured.out
|
||||
|
||||
def test_search_with_wing_filter(self, palace_path, seeded_collection, capsys):
|
||||
search("planning", palace_path, wing="notes")
|
||||
captured = capsys.readouterr()
|
||||
assert "Results for" in captured.out
|
||||
|
||||
def test_search_with_room_filter(self, palace_path, seeded_collection, capsys):
|
||||
search("database", palace_path, room="backend")
|
||||
captured = capsys.readouterr()
|
||||
assert "Room:" in captured.out
|
||||
|
||||
def test_search_with_wing_and_room(self, palace_path, seeded_collection, capsys):
|
||||
search("code", palace_path, wing="project", room="frontend")
|
||||
captured = capsys.readouterr()
|
||||
assert "Wing:" in captured.out
|
||||
assert "Room:" in captured.out
|
||||
|
||||
def test_search_no_palace_raises(self, tmp_path):
|
||||
with pytest.raises(SearchError, match="No palace found"):
|
||||
search("anything", str(tmp_path / "missing"))
|
||||
|
||||
def test_search_no_results(self, palace_path, collection, capsys):
|
||||
"""Empty collection returns no results message."""
|
||||
# collection is empty (no seeded data)
|
||||
result = search("xyzzy_nonexistent_query", palace_path, n_results=1)
|
||||
captured = capsys.readouterr()
|
||||
# Either prints "No results" or returns None
|
||||
assert result is None or "No results" in captured.out
|
||||
|
||||
def test_search_query_error_raises(self):
|
||||
"""search raises SearchError when query fails."""
|
||||
mock_col = MagicMock()
|
||||
mock_col.query.side_effect = RuntimeError("boom")
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
|
||||
with patch("mempalace.searcher.chromadb.PersistentClient", return_value=mock_client):
|
||||
with pytest.raises(SearchError, match="Search error"):
|
||||
search("test", "/fake/path")
|
||||
|
||||
def test_search_n_results(self, palace_path, seeded_collection, capsys):
|
||||
search("code", palace_path, n_results=1)
|
||||
captured = capsys.readouterr()
|
||||
# Should have output with at least one result block
|
||||
assert "[1]" in captured.out
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Tests for mempalace.spellcheck — spell-correction utilities."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from mempalace.spellcheck import (
|
||||
_edit_distance,
|
||||
_get_system_words,
|
||||
_should_skip,
|
||||
spellcheck_transcript,
|
||||
spellcheck_transcript_line,
|
||||
spellcheck_user_text,
|
||||
)
|
||||
|
||||
|
||||
# --- _should_skip ---
|
||||
|
||||
|
||||
class TestShouldSkip:
|
||||
"""Token-level skip logic."""
|
||||
|
||||
def test_short_tokens_skipped(self):
|
||||
assert _should_skip("hi", set()) is True
|
||||
assert _should_skip("ok", set()) is True
|
||||
assert _should_skip("I", set()) is True
|
||||
|
||||
def test_digits_skipped(self):
|
||||
assert _should_skip("3am", set()) is True
|
||||
assert _should_skip("top10", set()) is True
|
||||
assert _should_skip("bge-large-v1.5", set()) is True
|
||||
|
||||
def test_camelcase_skipped(self):
|
||||
assert _should_skip("ChromaDB", set()) is True
|
||||
assert _should_skip("MemPalace", set()) is True
|
||||
|
||||
def test_allcaps_skipped(self):
|
||||
assert _should_skip("NDCG", set()) is True
|
||||
assert _should_skip("MAX_RESULTS", set()) is True
|
||||
|
||||
def test_technical_skipped(self):
|
||||
assert _should_skip("bge-large", set()) is True
|
||||
assert _should_skip("train_test", set()) is True
|
||||
|
||||
def test_url_skipped(self):
|
||||
assert _should_skip("https://example.com", set()) is True
|
||||
assert _should_skip("www.google.com", set()) is True
|
||||
|
||||
def test_code_or_emoji_skipped(self):
|
||||
assert _should_skip("`code`", set()) is True
|
||||
assert _should_skip("**bold**", set()) is True
|
||||
|
||||
def test_known_name_skipped(self):
|
||||
assert _should_skip("mempalace", {"mempalace"}) is True
|
||||
|
||||
def test_normal_word_not_skipped(self):
|
||||
assert _should_skip("hello", set()) is False
|
||||
assert _should_skip("question", set()) is False
|
||||
|
||||
|
||||
# --- _edit_distance ---
|
||||
|
||||
|
||||
class TestEditDistance:
|
||||
def test_identical(self):
|
||||
assert _edit_distance("hello", "hello") == 0
|
||||
|
||||
def test_empty_strings(self):
|
||||
assert _edit_distance("", "abc") == 3
|
||||
assert _edit_distance("abc", "") == 3
|
||||
assert _edit_distance("", "") == 0
|
||||
|
||||
def test_single_edit(self):
|
||||
assert _edit_distance("cat", "bat") == 1 # substitution
|
||||
assert _edit_distance("cat", "cats") == 1 # insertion
|
||||
assert _edit_distance("cats", "cat") == 1 # deletion
|
||||
|
||||
def test_known_distance(self):
|
||||
assert _edit_distance("kitten", "sitting") == 3
|
||||
|
||||
|
||||
# --- _get_system_words ---
|
||||
|
||||
|
||||
def test_get_system_words_returns_set():
|
||||
result = _get_system_words()
|
||||
assert isinstance(result, set)
|
||||
|
||||
|
||||
# --- spellcheck_user_text ---
|
||||
|
||||
|
||||
def test_spellcheck_user_text_passthrough_no_autocorrect():
|
||||
"""When autocorrect is not installed, text passes through unchanged."""
|
||||
with patch("mempalace.spellcheck._get_speller", return_value=None):
|
||||
text = "somee misspeledd textt"
|
||||
assert spellcheck_user_text(text) == text
|
||||
|
||||
|
||||
def test_spellcheck_user_text_with_speller():
|
||||
"""When a speller is available, it corrects words."""
|
||||
|
||||
def fake_speller(word):
|
||||
corrections = {"knoe": "know", "befor": "before"}
|
||||
return corrections.get(word, word)
|
||||
|
||||
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
|
||||
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
|
||||
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
|
||||
result = spellcheck_user_text("knoe the question befor")
|
||||
assert "know" in result
|
||||
assert "before" in result
|
||||
|
||||
|
||||
def test_spellcheck_preserves_technical_terms():
|
||||
"""Technical terms should never be touched even with a speller."""
|
||||
|
||||
def fake_speller(word):
|
||||
return "WRONG"
|
||||
|
||||
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
|
||||
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
|
||||
result = spellcheck_user_text("ChromaDB bge-large", known_names=set())
|
||||
assert "ChromaDB" in result
|
||||
assert "bge-large" in result
|
||||
assert "WRONG" not in result
|
||||
|
||||
|
||||
# --- spellcheck_transcript_line ---
|
||||
|
||||
|
||||
def test_transcript_line_user_turn():
|
||||
"""Lines starting with '>' should be processed."""
|
||||
with patch("mempalace.spellcheck.spellcheck_user_text", return_value="corrected"):
|
||||
result = spellcheck_transcript_line("> hello world")
|
||||
assert "corrected" in result
|
||||
|
||||
|
||||
def test_transcript_line_assistant_turn():
|
||||
"""Lines not starting with '>' should pass through unchanged."""
|
||||
line = "This is an assistant response"
|
||||
assert spellcheck_transcript_line(line) == line
|
||||
|
||||
|
||||
def test_transcript_line_empty_user_turn():
|
||||
"""A '> ' line with no message content should pass through."""
|
||||
line = "> "
|
||||
assert spellcheck_transcript_line(line) == line
|
||||
|
||||
|
||||
# --- spellcheck_transcript ---
|
||||
|
||||
|
||||
def test_spellcheck_transcript_processes_content():
|
||||
"""Full transcript: only '>' lines are touched."""
|
||||
content = "Assistant line\n> user line\nAnother assistant line"
|
||||
with patch("mempalace.spellcheck.spellcheck_user_text", return_value="fixed"):
|
||||
result = spellcheck_transcript(content)
|
||||
lines = result.split("\n")
|
||||
assert lines[0] == "Assistant line"
|
||||
assert "fixed" in lines[1]
|
||||
assert lines[2] == "Another assistant line"
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Extra spellcheck tests covering _load_known_names and speller edge cases."""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from mempalace.spellcheck import (
|
||||
_load_known_names,
|
||||
spellcheck_user_text,
|
||||
)
|
||||
|
||||
|
||||
class TestLoadKnownNames:
|
||||
def test_returns_names_from_registry(self):
|
||||
mock_reg = MagicMock()
|
||||
mock_reg._data = {
|
||||
"entities": {
|
||||
"e1": {"canonical": "Alice", "aliases": ["ali"]},
|
||||
"e2": {"canonical": "Bob", "aliases": []},
|
||||
}
|
||||
}
|
||||
with patch("mempalace.entity_registry.EntityRegistry") as MockER:
|
||||
MockER.load.return_value = mock_reg
|
||||
names = _load_known_names()
|
||||
assert "alice" in names
|
||||
assert "ali" in names
|
||||
assert "bob" in names
|
||||
|
||||
def test_returns_empty_on_exception(self):
|
||||
with patch(
|
||||
"mempalace.entity_registry.EntityRegistry.load",
|
||||
side_effect=Exception("no registry"),
|
||||
):
|
||||
names = _load_known_names()
|
||||
assert names == set()
|
||||
|
||||
|
||||
class TestSpellerEdgeCases:
|
||||
def test_capitalized_word_skipped(self):
|
||||
"""Capitalized words (likely proper nouns) are not corrected."""
|
||||
|
||||
def fake_speller(word):
|
||||
return "WRONG"
|
||||
|
||||
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
|
||||
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
|
||||
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
|
||||
result = spellcheck_user_text("Alice went home")
|
||||
assert "Alice" in result
|
||||
assert "WRONG" not in result
|
||||
|
||||
def test_system_word_not_corrected(self):
|
||||
"""Words in system dict should not be corrected."""
|
||||
|
||||
def fake_speller(word):
|
||||
return "WRONG"
|
||||
|
||||
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
|
||||
with patch("mempalace.spellcheck._get_system_words", return_value={"coherently"}):
|
||||
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
|
||||
result = spellcheck_user_text("coherently")
|
||||
assert "coherently" in result
|
||||
|
||||
def test_high_edit_distance_rejected(self):
|
||||
"""Corrections with too many edits are rejected."""
|
||||
|
||||
def fake_speller(word):
|
||||
return "completely_different_word"
|
||||
|
||||
with patch("mempalace.spellcheck._get_speller", return_value=fake_speller):
|
||||
with patch("mempalace.spellcheck._get_system_words", return_value=set()):
|
||||
with patch("mempalace.spellcheck._load_known_names", return_value=set()):
|
||||
result = spellcheck_user_text("hello")
|
||||
assert "hello" in result
|
||||
@@ -3,6 +3,9 @@ import json
|
||||
from mempalace import split_mega_files as smf
|
||||
|
||||
|
||||
# ── Config loading ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_load_known_people_falls_back_when_config_missing(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", tmp_path / "missing.json")
|
||||
smf._KNOWN_NAMES_CACHE = None
|
||||
@@ -46,3 +49,244 @@ def test_extract_people_detects_names_from_content(monkeypatch):
|
||||
monkeypatch.setattr(smf, "KNOWN_PEOPLE", ["Alice", "Ben"])
|
||||
people = smf.extract_people(["> Alice reviewed the change with Ben\n"])
|
||||
assert people == ["Alice", "Ben"]
|
||||
|
||||
|
||||
# ── Config: force_reload and invalid JSON ──────────────────────────────
|
||||
|
||||
|
||||
def test_load_known_names_force_reload(monkeypatch, tmp_path):
|
||||
config_path = tmp_path / "known_names.json"
|
||||
config_path.write_text(json.dumps(["Alice"]))
|
||||
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
|
||||
smf._KNOWN_NAMES_CACHE = None
|
||||
|
||||
smf._load_known_names_config()
|
||||
assert smf._KNOWN_NAMES_CACHE == ["Alice"]
|
||||
|
||||
config_path.write_text(json.dumps(["Bob"]))
|
||||
smf._load_known_names_config(force_reload=True)
|
||||
assert smf._KNOWN_NAMES_CACHE == ["Bob"]
|
||||
|
||||
|
||||
def test_load_known_names_invalid_json(monkeypatch, tmp_path):
|
||||
config_path = tmp_path / "known_names.json"
|
||||
config_path.write_text("not json {{{")
|
||||
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
|
||||
smf._KNOWN_NAMES_CACHE = None
|
||||
|
||||
result = smf._load_known_names_config()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_known_names_caching(monkeypatch, tmp_path):
|
||||
config_path = tmp_path / "known_names.json"
|
||||
config_path.write_text(json.dumps(["Alice"]))
|
||||
monkeypatch.setattr(smf, "_KNOWN_NAMES_PATH", config_path)
|
||||
smf._KNOWN_NAMES_CACHE = None
|
||||
|
||||
smf._load_known_names_config()
|
||||
# Second call returns cached value without re-reading
|
||||
config_path.write_text(json.dumps(["Changed"]))
|
||||
result = smf._load_known_names_config()
|
||||
assert result == ["Alice"]
|
||||
|
||||
|
||||
# ── is_true_session_start ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_is_true_session_start_yes():
|
||||
lines = ["Claude Code v1.0", "Some content", "More content", "", "", ""]
|
||||
assert smf.is_true_session_start(lines, 0) is True
|
||||
|
||||
|
||||
def test_is_true_session_start_no_ctrl_e():
|
||||
lines = [
|
||||
"Claude Code v1.0",
|
||||
"Ctrl+E to show 5 previous messages",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
assert smf.is_true_session_start(lines, 0) is False
|
||||
|
||||
|
||||
def test_is_true_session_start_no_previous_messages():
|
||||
lines = [
|
||||
"Claude Code v1.0",
|
||||
"Some text",
|
||||
"previous messages here",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
assert smf.is_true_session_start(lines, 0) is False
|
||||
|
||||
|
||||
# ── find_session_boundaries ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_find_session_boundaries_two_sessions():
|
||||
lines = [
|
||||
"Claude Code v1.0",
|
||||
"content 1",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"Claude Code v1.0",
|
||||
"content 2",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
boundaries = smf.find_session_boundaries(lines)
|
||||
assert boundaries == [0, 7]
|
||||
|
||||
|
||||
def test_find_session_boundaries_none():
|
||||
lines = ["Just some text", "No sessions here"]
|
||||
assert smf.find_session_boundaries(lines) == []
|
||||
|
||||
|
||||
def test_find_session_boundaries_context_restore_skipped():
|
||||
lines = [
|
||||
"Claude Code v1.0",
|
||||
"content",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"Claude Code v1.0",
|
||||
"Ctrl+E to show 5 previous messages",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
boundaries = smf.find_session_boundaries(lines)
|
||||
assert len(boundaries) == 1
|
||||
|
||||
|
||||
# ── extract_timestamp ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_timestamp_found():
|
||||
lines = ["⏺ 2:30 PM Wednesday, March 25, 2026"]
|
||||
human, iso = smf.extract_timestamp(lines)
|
||||
assert human == "2026-03-25_230PM"
|
||||
assert iso == "2026-03-25"
|
||||
|
||||
|
||||
def test_extract_timestamp_not_found():
|
||||
lines = ["No timestamp here"]
|
||||
human, iso = smf.extract_timestamp(lines)
|
||||
assert human is None
|
||||
assert iso is None
|
||||
|
||||
|
||||
def test_extract_timestamp_only_checks_first_50():
|
||||
lines = ["filler\n"] * 51 + ["⏺ 1:00 AM Monday, January 01, 2026"]
|
||||
human, iso = smf.extract_timestamp(lines)
|
||||
assert human is None
|
||||
|
||||
|
||||
# ── extract_subject ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_subject_found():
|
||||
lines = ["> How do we handle authentication?"]
|
||||
subject = smf.extract_subject(lines)
|
||||
assert "authentication" in subject.lower()
|
||||
|
||||
|
||||
def test_extract_subject_skips_commands():
|
||||
lines = ["> cd /some/dir", "> git status", "> What is the plan?"]
|
||||
subject = smf.extract_subject(lines)
|
||||
assert "plan" in subject.lower()
|
||||
|
||||
|
||||
def test_extract_subject_fallback():
|
||||
lines = ["No prompts at all", "Just text"]
|
||||
subject = smf.extract_subject(lines)
|
||||
assert subject == "session"
|
||||
|
||||
|
||||
def test_extract_subject_short_prompt_skipped():
|
||||
lines = ["> ok", "> yes", "> What about the deployment strategy?"]
|
||||
subject = smf.extract_subject(lines)
|
||||
assert "deployment" in subject.lower()
|
||||
|
||||
|
||||
def test_extract_subject_truncated():
|
||||
lines = ["> " + "a" * 100]
|
||||
subject = smf.extract_subject(lines)
|
||||
assert len(subject) <= 60
|
||||
|
||||
|
||||
# ── split_file ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_mega_file(tmp_path, n_sessions=3, lines_per_session=15):
|
||||
"""Create a mega-file with N sessions."""
|
||||
content = ""
|
||||
for i in range(n_sessions):
|
||||
content += f"Claude Code v1.{i}\n"
|
||||
content += f"> What about topic {i} and how it works?\n"
|
||||
for j in range(lines_per_session - 2):
|
||||
content += f"Line {j} of session {i}\n"
|
||||
path = tmp_path / "mega.txt"
|
||||
path.write_text(content)
|
||||
return path
|
||||
|
||||
|
||||
def test_split_file_creates_output(tmp_path):
|
||||
mega = _make_mega_file(tmp_path)
|
||||
out_dir = tmp_path / "output"
|
||||
out_dir.mkdir()
|
||||
written = smf.split_file(str(mega), str(out_dir))
|
||||
assert len(written) >= 2
|
||||
for p in written:
|
||||
assert p.exists()
|
||||
|
||||
|
||||
def test_split_file_dry_run(tmp_path):
|
||||
mega = _make_mega_file(tmp_path)
|
||||
out_dir = tmp_path / "output"
|
||||
out_dir.mkdir()
|
||||
written = smf.split_file(str(mega), str(out_dir), dry_run=True)
|
||||
assert len(written) >= 2
|
||||
for p in written:
|
||||
assert not p.exists()
|
||||
|
||||
|
||||
def test_split_file_not_mega(tmp_path):
|
||||
"""File with fewer than 2 sessions is not split."""
|
||||
path = tmp_path / "single.txt"
|
||||
path.write_text("Claude Code v1.0\nJust one session\n" + "line\n" * 20)
|
||||
written = smf.split_file(str(path), str(tmp_path))
|
||||
assert written == []
|
||||
|
||||
|
||||
def test_split_file_output_dir_none(tmp_path):
|
||||
"""When output_dir is None, writes to same dir as source."""
|
||||
mega = _make_mega_file(tmp_path)
|
||||
written = smf.split_file(str(mega), None)
|
||||
assert len(written) >= 2
|
||||
for p in written:
|
||||
assert str(p.parent) == str(tmp_path)
|
||||
|
||||
|
||||
def test_split_file_tiny_fragments_skipped(tmp_path):
|
||||
"""Tiny chunks (< 10 lines) are skipped."""
|
||||
content = "Claude Code v1.0\nline\n" * 2 + "Claude Code v1.0\n" + "line\n" * 20
|
||||
path = tmp_path / "tiny.txt"
|
||||
path.write_text(content)
|
||||
written = smf.split_file(str(path), str(tmp_path))
|
||||
# The first chunk is very small, should be skipped
|
||||
for p in written:
|
||||
assert p.stat().st_size > 0
|
||||
|
||||
Reference in New Issue
Block a user