Merge branch 'main' into fix/chromadb-version-constraint

This commit is contained in:
Ben Sigman
2026-04-12 14:29:43 -07:00
committed by GitHub
23 changed files with 2027 additions and 56 deletions
+13
View File
@@ -0,0 +1,13 @@
# Default owners for everything
* @milla-jovovich @bensig @igorls
# Core library
mempalace/ @milla-jovovich @bensig
# CI and workflows
.github/ @bensig
# Plugins and integrations
.claude-plugin/ @bensig
.codex-plugin/ @bensig
integrations/ @bensig
+12
View File
@@ -0,0 +1,12 @@
version: 2
updates:
- package-ecosystem: "pip"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 5
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 3
+3 -3
View File
@@ -18,7 +18,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- run: pip install -e ".[dev]"
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
test-windows:
runs-on: windows-latest
@@ -28,7 +28,7 @@ jobs:
with:
python-version: "3.9"
- run: pip install -e ".[dev]"
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
test-macos:
runs-on: macos-latest
@@ -38,7 +38,7 @@ jobs:
with:
python-version: "3.9"
- run: pip install -e ".[dev]"
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
lint:
runs-on: ubuntu-latest
steps:
+27
View File
@@ -6,3 +6,30 @@ __pycache__/
.pytest_cache/
mempal.yaml
.a5c/
# Environment
.env
.env.*
# OS
.DS_Store
Thumbs.db
# IDEs
.idea/
.vscode/
*.swp
*.swo
*~
# Coverage
htmlcov/
.coverage
coverage.xml
# Virtual environments
.venv/
venv/
# ChromaDB local data
*.sqlite3-journal
+78
View File
@@ -0,0 +1,78 @@
# AGENTS.md
> How to build, test, and contribute to MemPalace.
## Setup
```bash
pip install -e ".[dev]"
```
## Commands
```bash
# Run tests
python -m pytest tests/ -v --ignore=tests/benchmarks
# Run tests with coverage
python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing
# Lint
ruff check .
# Format
ruff format .
# Format check (CI mode)
ruff format --check .
```
## Project structure
```
mempalace/
├── mcp_server.py # MCP server — all read/write tools
├── miner.py # Project file miner
├── convo_miner.py # Conversation transcript miner
├── searcher.py # Semantic search
├── knowledge_graph.py # Temporal entity-relationship graph (SQLite)
├── palace.py # Shared palace operations (ChromaDB access)
├── config.py # Configuration + input validation
├── normalize.py # Transcript format detection + normalization
├── cli.py # CLI dispatcher
├── dialect.py # AAAK compression dialect
├── palace_graph.py # Room traversal + cross-wing tunnels
├── hooks_cli.py # Hook system for auto-save
└── version.py # Single source of truth for version
```
## Conventions
- **Python style**: snake_case for functions/variables, PascalCase for classes
- **Linter**: ruff with E/F/W rules
- **Formatter**: ruff format, double quotes
- **Commits**: conventional commits (`fix:`, `feat:`, `test:`, `docs:`, `ci:`)
- **Tests**: `tests/test_*.py`, fixtures in `tests/conftest.py`
- **Coverage**: 85% threshold (80% on Windows due to ChromaDB file lock cleanup)
## Architecture
```
User → CLI / MCP Server → ChromaDB (vector store) + SQLite (knowledge graph)
Palace structure:
WING (person/project)
└── ROOM (topic)
└── DRAWER (verbatim text chunk)
Knowledge Graph:
ENTITY → PREDICATE → ENTITY (with valid_from / valid_to dates)
```
## Key files for common tasks
- **Adding an MCP tool**: `mempalace/mcp_server.py` — add handler function + TOOLS dict entry
- **Changing search**: `mempalace/searcher.py`
- **Modifying mining**: `mempalace/miner.py` (project files) or `mempalace/convo_miner.py` (transcripts)
- **Input validation**: `mempalace/config.py``sanitize_name()` / `sanitize_content()`
- **Tests**: mirror source structure in `tests/test_<module>.py`
+4 -1
View File
@@ -5,8 +5,11 @@ Thanks for wanting to help. MemPalace is open source and we welcome contribution
## Getting Started
```bash
git clone https://github.com/milla-jovovich/mempalace.git
# Fork the repo on GitHub first, then clone your fork
git clone https://github.com/<your-username>/mempalace.git
cd mempalace
git remote add upstream https://github.com/milla-jovovich/mempalace.git
pip install -e ".[dev]" # installs with dev dependencies (pytest, build, twine)
```
+12
View File
@@ -84,6 +84,18 @@ Other memory systems try to fix this by letting AI decide what's worth rememberi
---
## An important follow up note regarding fake MemPalace websites - April 11, 2026
Several Community Members (#267, #326, #506) have pointed out there are fake MemPalace websites popping up, including ones with Malware.
To be super clear, MemPalace *has no website* (at least for now), so anything claiming to be one is false.
Thanks to our Community Members for letting us know about the problem.
Stay safe out there.
---
## Quick Start
```bash
+36
View File
@@ -0,0 +1,36 @@
-- MemPalace Knowledge Graph Schema
-- SQLite database at ~/.mempalace/knowledge_graph.db
CREATE TABLE IF NOT EXISTS entities (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
type TEXT DEFAULT 'unknown',
properties TEXT DEFAULT '{}'
);
CREATE TABLE IF NOT EXISTS triples (
id TEXT PRIMARY KEY,
subject TEXT NOT NULL,
predicate TEXT NOT NULL,
object TEXT NOT NULL,
valid_from TEXT,
valid_to TEXT,
confidence REAL DEFAULT 1.0,
source_closet TEXT,
source_file TEXT
);
CREATE TABLE IF NOT EXISTS attributes (
entity_id TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT,
valid_from TEXT,
valid_to TEXT,
PRIMARY KEY (entity_id, key, valid_from)
);
-- Indexes
CREATE INDEX IF NOT EXISTS idx_triples_subject ON triples(subject);
CREATE INDEX IF NOT EXISTS idx_triples_object ON triples(object);
CREATE INDEX IF NOT EXISTS idx_triples_predicate ON triples(predicate);
CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to);
+1 -1
View File
@@ -72,6 +72,6 @@ fi
cat << 'HOOKJSON'
{
"decision": "block",
"reason": "COMPACTION IMMINENT. Save ALL topics, decisions, quotes, code, and important context from this session to your memory system. Be thorough — after compaction, detailed context will be lost. Organize into appropriate categories. Use verbatim quotes where possible. Save everything, then allow compaction to proceed."
"reason": "COMPACTION IMMINENT (MemPalace). Save ALL session content before context is lost:\n1. mempalace_diary_write — thorough AAAK-compressed session summary\n2. mempalace_add_drawer — ALL verbatim quotes, decisions, code, context\n3. mempalace_kg_add — entity relationships (optional)\nBe thorough — after compaction, detailed context will be lost. Do NOT write to Claude Code's native auto-memory (.md files). Save everything to MemPalace, then allow compaction to proceed."
}
HOOKJSON
+1 -1
View File
@@ -145,7 +145,7 @@ if [ "$SINCE_LAST" -ge "$SAVE_INTERVAL" ] && [ "$EXCHANGE_COUNT" -gt 0 ]; then
cat << 'HOOKJSON'
{
"decision": "block",
"reason": "AUTO-SAVE checkpoint. Save key topics, decisions, quotes, and code from this session to your memory system. Organize into appropriate categories. Use verbatim quotes where possible. Continue conversation after saving."
"reason": "AUTO-SAVE checkpoint (MemPalace). Save this session's key content:\n1. mempalace_diary_write — AAAK-compressed session summary\n2. mempalace_add_drawer — verbatim quotes, decisions, code snippets\n3. mempalace_kg_add — entity relationships (optional)\nDo NOT write to Claude Code's native auto-memory (.md files). Continue conversation after saving."
}
HOOKJSON
else
+20
View File
@@ -150,6 +150,14 @@ def cmd_split(args):
sys.argv = old_argv
def cmd_migrate(args):
"""Migrate palace from a different ChromaDB version."""
from .migrate import migrate
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
migrate(palace_path=palace_path, dry_run=args.dry_run)
def cmd_status(args):
from .miner import status
@@ -531,6 +539,17 @@ def main():
)
# status
# migrate
p_migrate = sub.add_parser(
"migrate",
help="Migrate palace from a different ChromaDB version (fixes 3.0.0 → 3.1.0 upgrade)",
)
p_migrate.add_argument(
"--dry-run",
action="store_true",
help="Show what would be migrated without changing anything",
)
sub.add_parser("status", help="Show what's been filed")
args = parser.parse_args()
@@ -565,6 +584,7 @@ def main():
"compress": cmd_compress,
"wake-up": cmd_wakeup,
"repair": cmd_repair,
"migrate": cmd_migrate,
"status": cmd_status,
}
dispatch[args.command](args)
+1 -1
View File
@@ -334,7 +334,7 @@ def mine_convos(
room_counts[chunk_room] += 1
drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}"
try:
collection.add(
collection.upsert(
documents=[chunk["content"]],
ids=[drawer_id],
metadatas=[
+239
View File
@@ -0,0 +1,239 @@
"""
dedup.py — Detect and remove near-duplicate drawers
====================================================
When the same files are mined multiple times, near-identical drawers
accumulate. This module finds drawers from the same source_file that
are too similar (cosine distance < threshold), keeps the longest/richest
version, and deletes the rest.
No API calls — uses ChromaDB's built-in embedding similarity.
Usage (standalone):
python -m mempalace.dedup # dedup all
python -m mempalace.dedup --dry-run # preview only
python -m mempalace.dedup --threshold 0.10 # stricter (near-identical only)
python -m mempalace.dedup --threshold 0.35 # looser (catches paraphrased content)
python -m mempalace.dedup --wing my_project # scope to one wing
python -m mempalace.dedup --stats # stats only
python -m mempalace.dedup --source "my_project" # filter by source
Usage (from CLI):
mempalace dedup [--dry-run] [--threshold 0.15] [--stats]
"""
import argparse
import os
import time
from collections import defaultdict
import chromadb
COLLECTION_NAME = "mempalace_drawers"
# Cosine DISTANCE threshold (not similarity). Lower = stricter.
# 0.15 = ~85% cosine similarity — catches near-identical chunks.
# For looser dedup of paraphrased content, try 0.30.4.
DEFAULT_THRESHOLD = 0.15
MIN_DRAWERS_TO_CHECK = 5
def _get_palace_path():
"""Resolve palace path from config."""
try:
from .config import MempalaceConfig
return MempalaceConfig().palace_path
except Exception:
return os.path.join(os.path.expanduser("~"), ".mempalace", "palace")
def get_source_groups(col, min_count=MIN_DRAWERS_TO_CHECK, source_pattern=None, wing=None):
"""Group drawers by source_file, return groups with min_count+ entries.
If wing is specified, only considers drawers in that wing. This catches
cross-wing duplicates when the same source was mined into multiple wings.
"""
total = col.count()
groups = defaultdict(list)
offset = 0
batch_size = 1000
while offset < total:
kwargs = {"limit": batch_size, "offset": offset, "include": ["metadatas"]}
if wing:
kwargs["where"] = {"wing": wing}
batch = col.get(**kwargs)
if not batch["ids"]:
break
for did, meta in zip(batch["ids"], batch["metadatas"]):
src = meta.get("source_file", "unknown")
if source_pattern and source_pattern.lower() not in src.lower():
continue
groups[src].append(did)
offset += len(batch["ids"])
return {src: ids for src, ids in groups.items() if len(ids) >= min_count}
def dedup_source_group(col, drawer_ids, threshold=DEFAULT_THRESHOLD, dry_run=True):
"""Dedup drawers within one source_file group.
Greedy: sort by doc length (longest first), keep if not too similar
to any already-kept drawer. Returns (kept_ids, deleted_ids).
"""
data = col.get(ids=drawer_ids, include=["documents", "metadatas"])
items = list(zip(data["ids"], data["documents"], data["metadatas"]))
items.sort(key=lambda x: len(x[1] or ""), reverse=True)
kept = []
to_delete = []
for did, doc, meta in items:
if not doc or len(doc) < 20:
to_delete.append(did)
continue
if not kept:
kept.append((did, doc))
continue
try:
results = col.query(
query_texts=[doc],
n_results=min(len(kept), 5),
include=["distances"],
)
dists = results["distances"][0] if results["distances"] else []
kept_ids_set = {k[0] for k in kept}
is_dup = False
for rid, dist in zip(results["ids"][0], dists):
if rid in kept_ids_set and dist < threshold:
is_dup = True
break
if is_dup:
to_delete.append(did)
else:
kept.append((did, doc))
except Exception:
kept.append((did, doc))
if to_delete and not dry_run:
for i in range(0, len(to_delete), 500):
col.delete(ids=to_delete[i : i + 500])
return [k[0] for k in kept], to_delete
def show_stats(palace_path=None):
"""Show duplication statistics without making changes."""
palace_path = palace_path or _get_palace_path()
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection(COLLECTION_NAME)
groups = get_source_groups(col)
total_drawers = sum(len(ids) for ids in groups.values())
print(f"\n Sources with {MIN_DRAWERS_TO_CHECK}+ drawers: {len(groups)}")
print(f" Total drawers in those sources: {total_drawers:,}")
print("\n Top 15 by drawer count:")
sorted_groups = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True)
for src, ids in sorted_groups[:15]:
print(f" {len(ids):4d} {src[:65]}")
estimated_dups = sum(int(len(ids) * 0.4) for ids in groups.values() if len(ids) > 20)
print(f"\n Estimated duplicates (groups > 20): ~{estimated_dups:,}")
def dedup_palace(
palace_path=None,
threshold=DEFAULT_THRESHOLD,
dry_run=True,
source_pattern=None,
min_count=MIN_DRAWERS_TO_CHECK,
wing=None,
):
"""Main entry point: deduplicate near-identical drawers across the palace."""
palace_path = palace_path or _get_palace_path()
print(f"\n{'=' * 55}")
print(" MemPalace Deduplicator")
print(f"{'=' * 55}")
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection(COLLECTION_NAME)
print(f" Palace: {palace_path}")
print(f" Drawers: {col.count():,}")
print(f" Threshold: {threshold}")
print(f" Mode: {'DRY RUN' if dry_run else 'LIVE'}")
print(f"{'' * 55}")
if wing:
print(f" Wing: {wing}")
groups = get_source_groups(col, min_count, source_pattern, wing=wing)
print(f"\n Sources to check: {len(groups)}")
t0 = time.time()
total_kept = 0
total_deleted = 0
sorted_groups = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True)
for i, (src, drawer_ids) in enumerate(sorted_groups):
kept, deleted = dedup_source_group(col, drawer_ids, threshold, dry_run)
total_kept += len(kept)
total_deleted += len(deleted)
if deleted:
print(
f" [{i + 1:3d}/{len(groups)}] "
f"{src[:50]:50s} {len(drawer_ids):4d}{len(kept):4d} "
f"(-{len(deleted)})"
)
elapsed = time.time() - t0
print(f"\n{'' * 55}")
print(f" Done in {elapsed:.1f}s")
print(
f" Drawers: {total_kept + total_deleted:,}{total_kept:,} (-{total_deleted:,} removed)"
)
print(f" Palace after: {col.count():,} drawers")
if dry_run:
print("\n [DRY RUN] No changes written. Re-run without --dry-run to apply.")
print(f"{'=' * 55}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Deduplicate near-identical drawers")
parser.add_argument("--palace", default=None, help="Palace directory path")
parser.add_argument(
"--threshold",
type=float,
default=DEFAULT_THRESHOLD,
help=f"Cosine distance threshold (default: {DEFAULT_THRESHOLD})",
)
parser.add_argument("--dry-run", action="store_true", help="Preview without deleting")
parser.add_argument("--stats", action="store_true", help="Show stats only")
parser.add_argument("--wing", default=None, help="Scope dedup to a single wing")
parser.add_argument("--source", default=None, help="Filter by source file pattern")
args = parser.parse_args()
path = os.path.expanduser(args.palace) if args.palace else None
if args.stats:
show_stats(palace_path=path)
else:
dedup_palace(
palace_path=path,
threshold=args.threshold,
dry_run=args.dry_run,
source_pattern=args.source,
wing=args.wing,
)
+21 -9
View File
@@ -18,18 +18,22 @@ SAVE_INTERVAL = 15
STATE_DIR = Path.home() / ".mempalace" / "hook_state"
STOP_BLOCK_REASON = (
"AUTO-SAVE checkpoint. Save key topics, decisions, quotes, and code "
"from this session to your memory system. Organize into appropriate "
"categories. Use verbatim quotes where possible. Continue conversation "
"after saving."
"AUTO-SAVE checkpoint (MemPalace). Save this session's key content:\n"
"1. mempalace_diary_write — AAAK-compressed session summary\n"
"2. mempalace_add_drawer — verbatim quotes, decisions, code snippets\n"
"3. mempalace_kg_add — entity relationships (optional)\n"
"Do NOT write to Claude Code's native auto-memory (.md files). "
"Continue conversation after saving."
)
PRECOMPACT_BLOCK_REASON = (
"COMPACTION IMMINENT. Save ALL topics, decisions, quotes, code, and "
"important context from this session to your memory system. Be thorough "
"\u2014 after compaction, detailed context will be lost. Organize into "
"appropriate categories. Use verbatim quotes where possible. Save "
"everything, then allow compaction to proceed."
"COMPACTION IMMINENT (MemPalace). Save ALL session content before context is lost:\n"
"1. mempalace_diary_write — thorough AAAK-compressed session summary\n"
"2. mempalace_add_drawer — ALL verbatim quotes, decisions, code, context\n"
"3. mempalace_kg_add — entity relationships (optional)\n"
"Be thorough \u2014 after compaction, detailed context will be lost. "
"Do NOT write to Claude Code's native auto-memory (.md files). "
"Save everything to MemPalace, then allow compaction to proceed."
)
@@ -63,6 +67,14 @@ def _count_human_messages(transcript_path: str) -> int:
if "<command-message>" in text:
continue
count += 1
# Also handle Codex CLI transcript format
# {"type": "event_msg", "payload": {"type": "user_message", "message": "..."}}
elif entry.get("type") == "event_msg":
payload = entry.get("payload", {})
if isinstance(payload, dict) and payload.get("type") == "user_message":
msg_text = payload.get("message", "")
if isinstance(msg_text, str) and "<command-message>" not in msg_text:
count += 1
except (json.JSONDecodeError, AttributeError):
pass
except OSError:
+124 -39
View File
@@ -28,6 +28,7 @@ from pathlib import Path
from .config import MempalaceConfig, sanitize_name, sanitize_content
from .version import __version__
from .query_sanitizer import sanitize_query
from .searcher import search_memories
from .palace_graph import traverse, find_tunnels, graph_stats
import chromadb
@@ -143,16 +144,25 @@ def tool_status():
count = col.count()
wings = {}
rooms = {}
try:
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
for m in all_meta:
w = m.get("wing", "unknown")
r = m.get("room", "unknown")
wings[w] = wings.get(w, 0) + 1
rooms[r] = rooms.get(r, 0) + 1
except Exception:
pass
return {
batch_size = 5000
offset = 0
error_info = None
while True:
try:
batch = col.get(include=["metadatas"], limit=batch_size, offset=offset)
rows = batch["metadatas"]
for m in rows:
w = m.get("wing", "unknown")
r = m.get("room", "unknown")
wings[w] = wings.get(w, 0) + 1
rooms[r] = rooms.get(r, 0) + 1
offset += len(rows)
if len(rows) < batch_size:
break
except Exception as e:
error_info = f"Partial result, failed at offset {offset}: {str(e)}"
break
result = {
"total_drawers": count,
"wings": wings,
"rooms": rooms,
@@ -160,6 +170,10 @@ def tool_status():
"protocol": PALACE_PROTOCOL,
"aaak_dialect": AAAK_SPEC,
}
if error_info:
result["error"] = error_info
result["partial"] = True
return result
# ── AAAK Dialect Spec ─────────────────────────────────────────────────────────
@@ -200,13 +214,28 @@ def tool_list_wings():
if not col:
return _no_palace()
wings = {}
batch_size = 5000
offset = 0
try:
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
for m in all_meta:
w = m.get("wing", "unknown")
wings[w] = wings.get(w, 0) + 1
except Exception:
pass
col.count() # verify collection is accessible
except Exception as e:
return {"wings": {}, "error": str(e)}
while True:
try:
batch = col.get(include=["metadatas"], limit=batch_size, offset=offset)
rows = batch["metadatas"]
for m in rows:
w = m.get("wing", "unknown")
wings[w] = wings.get(w, 0) + 1
offset += len(rows)
if len(rows) < batch_size:
break
except Exception as e:
return {
"wings": wings,
"error": f"Partial result, failed at offset {offset}: {str(e)}",
"partial": True,
}
return {"wings": wings}
@@ -215,16 +244,33 @@ def tool_list_rooms(wing: str = None):
if not col:
return _no_palace()
rooms = {}
batch_size = 5000
offset = 0
where = {"wing": wing} if wing else None
try:
kwargs = {"include": ["metadatas"], "limit": 10000}
if wing:
kwargs["where"] = {"wing": wing}
all_meta = col.get(**kwargs)["metadatas"]
for m in all_meta:
r = m.get("room", "unknown")
rooms[r] = rooms.get(r, 0) + 1
except Exception:
pass
col.count() # verify collection is accessible
except Exception as e:
return {"wing": wing or "all", "rooms": {}, "error": str(e)}
while True:
try:
kwargs = {"include": ["metadatas"], "limit": batch_size, "offset": offset}
if where:
kwargs["where"] = where
batch = col.get(**kwargs)
rows = batch["metadatas"]
for m in rows:
r = m.get("room", "unknown")
rooms[r] = rooms.get(r, 0) + 1
offset += len(rows)
if len(rows) < batch_size:
break
except Exception as e:
return {
"wing": wing or "all",
"rooms": rooms,
"error": f"Partial result, failed at offset {offset}: {str(e)}",
"partial": True,
}
return {"wing": wing or "all", "rooms": rooms}
@@ -233,27 +279,58 @@ def tool_get_taxonomy():
if not col:
return _no_palace()
taxonomy = {}
batch_size = 5000
offset = 0
try:
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
for m in all_meta:
w = m.get("wing", "unknown")
r = m.get("room", "unknown")
if w not in taxonomy:
taxonomy[w] = {}
taxonomy[w][r] = taxonomy[w].get(r, 0) + 1
except Exception:
pass
col.count() # verify collection is accessible
except Exception as e:
return {"taxonomy": {}, "error": str(e)}
while True:
try:
batch = col.get(include=["metadatas"], limit=batch_size, offset=offset)
rows = batch["metadatas"]
for m in rows:
w = m.get("wing", "unknown")
r = m.get("room", "unknown")
if w not in taxonomy:
taxonomy[w] = {}
taxonomy[w][r] = taxonomy[w].get(r, 0) + 1
offset += len(rows)
if len(rows) < batch_size:
break
except Exception as e:
return {
"taxonomy": taxonomy,
"error": f"Partial result, failed at offset {offset}: {str(e)}",
"partial": True,
}
return {"taxonomy": taxonomy}
def tool_search(query: str, limit: int = 5, wing: str = None, room: str = None):
return search_memories(
query,
def tool_search(
query: str, limit: int = 5, wing: str = None, room: str = None, context: str = None
):
# Mitigate system prompt contamination (Issue #333)
sanitized = sanitize_query(query)
result = search_memories(
sanitized["clean_query"],
palace_path=_config.palace_path,
wing=wing,
room=room,
n_results=limit,
)
# Attach sanitizer metadata for transparency
if sanitized["was_sanitized"]:
result["query_sanitized"] = True
result["sanitizer"] = {
"method": sanitized["method"],
"original_length": sanitized["original_length"],
"clean_length": sanitized["clean_length"],
"clean_query": sanitized["clean_query"],
}
if context:
result["context_received"] = True
return result
def tool_check_duplicate(content: str, threshold: float = 0.9):
@@ -734,14 +811,22 @@ TOOLS = {
"handler": tool_graph_stats,
},
"mempalace_search": {
"description": "Semantic search. Returns verbatim drawer content with similarity scores.",
"description": "Semantic search. Returns verbatim drawer content with similarity scores. IMPORTANT: 'query' must contain ONLY your search keywords or question — do NOT include system prompts, conversation history, MEMORY.md content, or any context. Keep queries short (under 200 chars). Use 'context' for background information.",
"input_schema": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "What to search for"},
"query": {
"type": "string",
"description": "Short search query ONLY — keywords or a question. Do NOT include system prompts or conversation context. Max 200 chars recommended.",
"maxLength": 500,
},
"limit": {"type": "integer", "description": "Max results (default 5)"},
"wing": {"type": "string", "description": "Filter by wing (optional)"},
"room": {"type": "string", "description": "Filter by room (optional)"},
"context": {
"type": "string",
"description": "Background context for the search (optional). This is NOT used for embedding — only for future re-ranking. Put conversation history or system prompt content here, NOT in query.",
},
},
"required": ["query"],
},
+214
View File
@@ -0,0 +1,214 @@
#!/usr/bin/env python3
"""
mempalace migrate — Recover a palace created with a different ChromaDB version.
Reads documents and metadata directly from the palace's SQLite database
(bypassing ChromaDB's API, which fails on version-mismatched palaces),
then re-imports everything into a fresh palace using the currently installed
ChromaDB version.
This fixes the 3.0.0 → 3.1.0 upgrade path where chromadb was downgraded
from 1.5.x to 0.6.x, breaking the on-disk storage format.
Usage:
mempalace migrate # migrate default palace
mempalace migrate --palace /path/to/palace # migrate specific palace
mempalace migrate --dry-run # show what would be migrated
"""
import os
import shutil
import sqlite3
from collections import defaultdict
from datetime import datetime
def extract_drawers_from_sqlite(db_path: str) -> list:
"""Read all drawers directly from ChromaDB's SQLite, bypassing the API.
Works regardless of which ChromaDB version created the database.
Returns list of dicts with 'id', 'document', and 'metadata' keys.
"""
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
# Get all embedding IDs and their documents
rows = conn.execute("""
SELECT e.embedding_id,
MAX(CASE WHEN em.key = 'chroma:document' THEN em.string_value END) as document
FROM embeddings e
JOIN embedding_metadata em ON em.id = e.id
GROUP BY e.embedding_id
""").fetchall()
drawers = []
for row in rows:
embedding_id = row["embedding_id"]
document = row["document"]
if not document:
continue
# Get metadata for this embedding
meta_rows = conn.execute(
"""
SELECT em.key, em.string_value, em.int_value, em.float_value, em.bool_value
FROM embedding_metadata em
JOIN embeddings e ON e.id = em.id
WHERE e.embedding_id = ?
AND em.key NOT LIKE 'chroma:%'
""",
(embedding_id,),
).fetchall()
metadata = {}
for mr in meta_rows:
key = mr["key"]
if mr["string_value"] is not None:
metadata[key] = mr["string_value"]
elif mr["int_value"] is not None:
metadata[key] = mr["int_value"]
elif mr["float_value"] is not None:
metadata[key] = mr["float_value"]
elif mr["bool_value"] is not None:
metadata[key] = bool(mr["bool_value"])
drawers.append(
{
"id": embedding_id,
"document": document,
"metadata": metadata,
}
)
conn.close()
return drawers
def detect_chromadb_version(db_path: str) -> str:
"""Detect which ChromaDB version created the database by checking schema."""
conn = sqlite3.connect(db_path)
try:
# 1.x has schema_str column in collections table
cols = [r[1] for r in conn.execute("PRAGMA table_info(collections)").fetchall()]
if "schema_str" in cols:
return "1.x"
# 0.6.x has embeddings_queue but no schema_str
tables = [
r[0]
for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
]
if "embeddings_queue" in tables:
return "0.6.x"
return "unknown"
finally:
conn.close()
def migrate(palace_path: str, dry_run: bool = False):
"""Migrate a palace to the currently installed ChromaDB version."""
import chromadb
palace_path = os.path.expanduser(palace_path)
db_path = os.path.join(palace_path, "chroma.sqlite3")
if not os.path.isfile(db_path):
print(f"\n No palace database found at {db_path}")
return False
print(f"\n{'=' * 60}")
print(" MemPalace Migrate")
print(f"{'=' * 60}\n")
print(f" Palace: {palace_path}")
print(f" Database: {db_path}")
print(f" DB size: {os.path.getsize(db_path) / 1024 / 1024:.1f} MB")
# Detect version
source_version = detect_chromadb_version(db_path)
print(f" Source: ChromaDB {source_version}")
print(f" Target: ChromaDB {chromadb.__version__}")
# Try reading with current chromadb first
try:
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection("mempalace_drawers")
count = col.count()
print(f"\n Palace is already readable by chromadb {chromadb.__version__}.")
print(f" {count} drawers found. No migration needed.")
return True
except Exception:
print(f"\n Palace is NOT readable by chromadb {chromadb.__version__}.")
print(" Extracting from SQLite directly...")
# Extract all drawers via raw SQL
drawers = extract_drawers_from_sqlite(db_path)
print(f" Extracted {len(drawers)} drawers from SQLite")
if not drawers:
print(" Nothing to migrate.")
return True
# Show summary
wings = defaultdict(lambda: defaultdict(int))
for d in drawers:
w = d["metadata"].get("wing", "?")
r = d["metadata"].get("room", "?")
wings[w][r] += 1
print("\n Summary:")
for wing, rooms in sorted(wings.items()):
total = sum(rooms.values())
print(f" WING: {wing} ({total} drawers)")
for room, count in sorted(rooms.items(), key=lambda x: -x[1]):
print(f" ROOM: {room:30} {count:5}")
if dry_run:
print("\n DRY RUN — no changes made.")
print(f" Would migrate {len(drawers)} drawers.")
return True
# Backup the old palace
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"{palace_path}.pre-migrate.{timestamp}"
print(f"\n Backing up to {backup_path}...")
shutil.copytree(palace_path, backup_path)
# Build fresh palace in a temp directory (avoids chromadb reading old state)
import tempfile
temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_")
print(f" Creating fresh palace in {temp_palace}...")
client = chromadb.PersistentClient(path=temp_palace)
col = client.get_or_create_collection("mempalace_drawers")
# Re-import in batches
batch_size = 500
imported = 0
for i in range(0, len(drawers), batch_size):
batch = drawers[i : i + batch_size]
col.add(
ids=[d["id"] for d in batch],
documents=[d["document"] for d in batch],
metadatas=[d["metadata"] for d in batch],
)
imported += len(batch)
print(f" Imported {imported}/{len(drawers)} drawers...")
# Verify before swapping
final_count = col.count()
del col
del client
# Swap: remove old palace, move new one into place
print(" Swapping old palace for migrated version...")
shutil.rmtree(palace_path)
shutil.move(temp_palace, palace_path)
print("\n Migration complete.")
print(f" Drawers migrated: {final_count}")
print(f" Backup at: {backup_path}")
if final_count != len(drawers):
print(f" WARNING: Expected {len(drawers)}, got {final_count}")
print(f"\n{'=' * 60}\n")
return True
+10
View File
@@ -436,6 +436,16 @@ def process_file(
print(f" [DRY RUN] {filepath.name} → room:{room} ({len(chunks)} drawers)")
return len(chunks), room
# Purge stale drawers for this file before re-inserting the fresh chunks.
# Converts modified-file re-mines from upsert-over-existing-IDs (which hits
# hnswlib's thread-unsafe updatePoint path and can segfault on macOS ARM
# with chromadb 0.6.3) into a clean delete+insert, bypassing the update
# path entirely.
try:
collection.delete(where={"source_file": source_file})
except Exception:
pass
drawers_added = 0
for chunk in chunks:
added = add_drawer(
+157
View File
@@ -0,0 +1,157 @@
"""
query_sanitizer.py — Mitigate system prompt contamination in search queries.
Problem: AI agents sometimes prepend system prompts (2000+ chars) to search queries.
Embedding models represent the concatenated string as a single vector where the
system prompt overwhelms the actual question (typically 10-50 chars), causing
near-total retrieval failure (89.8% → 1.0% R@10). See Issue #333.
Approach: "Mitigation" (減災) — not perfect prevention, but prevents the cliff.
Expected recovery:
Step 1 passthrough (≤200 chars) → no degradation, ~89.8%
Step 2 question extraction (found) → near-full recovery, ~85-89%
Step 3 tail sentence extraction → moderate recovery, ~80-89%
Step 4 tail truncation (fallback) → minimum viable, ~70-80%
Without sanitizer: 1.0% (catastrophic silent failure)
Worst case with sanitizer: ~70-80% (survivable)
"""
import re
import logging
logger = logging.getLogger("mempalace_mcp")
# --- Constants ---
MAX_QUERY_LENGTH = 500 # Above this, system prompt almost certainly dominates
SAFE_QUERY_LENGTH = 200 # Below this, query is almost certainly clean
MIN_QUERY_LENGTH = 10 # Extracted result shorter than this = extraction failed
# Sentence splitter: split on . ! ? (including fullwidth) and newlines
_SENTENCE_SPLIT = re.compile(r"[.!?。!?\n]+")
# Question detector: ends with ? or (possibly with trailing whitespace/quotes)
_QUESTION_MARK = re.compile(r'[?]\s*["\']?\s*$')
def sanitize_query(raw_query: str) -> dict:
"""
Extract the actual search intent from a potentially contaminated query.
Args:
raw_query: The raw query string from the AI agent, possibly containing
system prompt content prepended to the actual question.
Returns:
dict with keys:
clean_query (str): The sanitized query to use for embedding search
was_sanitized (bool): Whether any sanitization was applied
original_length (int): Length of the raw input
clean_length (int): Length of the sanitized output
method (str): Which extraction method was used
- "passthrough": query was short enough, no action taken
- "question_extraction": found and extracted a question sentence
- "tail_sentence": extracted the last meaningful sentence
- "tail_truncation": fallback — took the last MAX_QUERY_LENGTH chars
"""
if not raw_query or not raw_query.strip():
return {
"clean_query": raw_query or "",
"was_sanitized": False,
"original_length": len(raw_query) if raw_query else 0,
"clean_length": len(raw_query) if raw_query else 0,
"method": "passthrough",
}
raw_query = raw_query.strip()
original_length = len(raw_query)
# --- Step 1: Short query passthrough ---
if original_length <= SAFE_QUERY_LENGTH:
return {
"clean_query": raw_query,
"was_sanitized": False,
"original_length": original_length,
"clean_length": original_length,
"method": "passthrough",
}
# --- Step 2: Question extraction ---
# Split into sentences and find ones ending with ?
sentences = [s.strip() for s in _SENTENCE_SPLIT.split(raw_query) if s.strip()]
# Also split on newlines to catch questions on their own line
all_segments = []
for s in raw_query.split("\n"):
s = s.strip()
if s:
all_segments.append(s)
# Look for question marks in segments (prefer later ones = more likely the actual query)
question_sentences = []
for seg in reversed(all_segments):
if _QUESTION_MARK.search(seg):
question_sentences.append(seg)
if not question_sentences:
# Also check the sentence-split results
for sent in reversed(sentences):
if "?" in sent or "" in sent:
question_sentences.append(sent)
if question_sentences:
# Take the last (most recent) question found
candidate = question_sentences[0].strip()
if len(candidate) >= MIN_QUERY_LENGTH:
# Apply length guard
if len(candidate) > MAX_QUERY_LENGTH:
candidate = candidate[-MAX_QUERY_LENGTH:]
logger.warning(
"Query sanitized: %d%d chars (method=question_extraction)",
original_length,
len(candidate),
)
return {
"clean_query": candidate,
"was_sanitized": True,
"original_length": original_length,
"clean_length": len(candidate),
"method": "question_extraction",
}
# --- Step 3: Tail sentence extraction ---
# System prompts are prepended, so the actual query is near the end.
# Walk backwards through segments to find the last meaningful sentence.
for seg in reversed(all_segments):
seg = seg.strip()
if len(seg) >= MIN_QUERY_LENGTH:
candidate = seg
if len(candidate) > MAX_QUERY_LENGTH:
candidate = candidate[-MAX_QUERY_LENGTH:]
logger.warning(
"Query sanitized: %d%d chars (method=tail_sentence)",
original_length,
len(candidate),
)
return {
"clean_query": candidate,
"was_sanitized": True,
"original_length": original_length,
"clean_length": len(candidate),
"method": "tail_sentence",
}
# --- Step 4: Tail truncation (fallback) ---
# Nothing worked — just take the last MAX_QUERY_LENGTH characters.
candidate = raw_query[-MAX_QUERY_LENGTH:].strip()
logger.warning(
"Query sanitized: %d%d chars (method=tail_truncation)", original_length, len(candidate)
)
return {
"clean_query": candidate,
"was_sanitized": True,
"original_length": original_length,
"clean_length": len(candidate),
"method": "tail_truncation",
}
+299
View File
@@ -0,0 +1,299 @@
"""
repair.py — Scan, prune corrupt entries, and rebuild HNSW index
================================================================
When ChromaDB's HNSW index accumulates duplicate entries (from repeated
add() calls with the same ID), link_lists.bin can grow unbounded —
terabytes on large palaces — eventually causing segfaults.
This module provides three operations:
scan — find every corrupt/unfetchable ID in the palace
prune — delete only the corrupt IDs (surgical)
rebuild — extract all drawers, delete the collection, recreate with
correct HNSW settings, and upsert everything back
The rebuild backs up ONLY chroma.sqlite3 (the source of truth), not the
full palace directory — so it works even when link_lists.bin is bloated.
Usage (standalone):
python -m mempalace.repair scan [--wing X]
python -m mempalace.repair prune --confirm
python -m mempalace.repair rebuild
Usage (from CLI):
mempalace repair
mempalace repair-scan [--wing X]
mempalace repair-prune --confirm
"""
import argparse
import os
import shutil
import time
import chromadb
COLLECTION_NAME = "mempalace_drawers"
def _get_palace_path():
"""Resolve palace path from config."""
try:
from .config import MempalaceConfig
return MempalaceConfig().palace_path
except Exception:
default = os.path.join(os.path.expanduser("~"), ".mempalace", "palace")
return default
def _paginate_ids(col, where=None):
"""Pull all IDs in a collection using pagination."""
ids = []
page = 1000
offset = 0
while True:
try:
r = col.get(where=where, include=[], limit=page, offset=offset)
except Exception:
try:
r = col.get(where=where, include=[], limit=page)
new_ids = [i for i in r["ids"] if i not in set(ids)]
if not new_ids:
break
ids.extend(new_ids)
offset += len(new_ids)
continue
except Exception:
break
n = len(r["ids"]) if r["ids"] else 0
if n == 0:
break
ids.extend(r["ids"])
offset += n
if n < page:
break
return ids
def scan_palace(palace_path=None, only_wing=None):
"""Scan the palace for corrupt/unfetchable IDs.
Probes in batches of 100, falls back to per-ID on failure.
Writes corrupt_ids.txt to the palace directory for the prune step.
Returns (good_set, bad_set).
"""
palace_path = palace_path or _get_palace_path()
print(f"\n Palace: {palace_path}")
print(" Loading...")
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection(COLLECTION_NAME)
where = {"wing": only_wing} if only_wing else None
total = col.count()
print(f" Collection: {COLLECTION_NAME}, total: {total:,}")
if only_wing:
print(f" Scanning wing: {only_wing}")
print("\n Step 1: listing all IDs...")
t0 = time.time()
all_ids = _paginate_ids(col, where=where)
print(f" Found {len(all_ids):,} IDs in {time.time() - t0:.1f}s\n")
if not all_ids:
print(" Nothing to scan.")
return set(), set()
print(" Step 2: probing each ID (batches of 100)...")
t0 = time.time()
good_set = set()
bad_set = set()
batch = 100
for i in range(0, len(all_ids), batch):
chunk = all_ids[i : i + batch]
try:
r = col.get(ids=chunk, include=["documents"])
for got in r["ids"]:
good_set.add(got)
for mid in chunk:
if mid not in good_set:
bad_set.add(mid)
except Exception:
for sid in chunk:
try:
r = col.get(ids=[sid], include=["documents"])
if r["ids"]:
good_set.add(sid)
else:
bad_set.add(sid)
except Exception:
bad_set.add(sid)
if (i // batch) % 50 == 0:
elapsed = time.time() - t0
rate = (i + batch) / max(elapsed, 0.01)
eta = (len(all_ids) - i - batch) / max(rate, 0.01)
print(
f" {i + batch:>6}/{len(all_ids):>6} "
f"good={len(good_set):>6} bad={len(bad_set):>6} "
f"eta={eta:.0f}s"
)
print(f"\n Scan complete in {time.time() - t0:.1f}s")
print(f" GOOD: {len(good_set):,}")
print(f" BAD: {len(bad_set):,} ({len(bad_set) / max(len(all_ids), 1) * 100:.1f}%)")
bad_file = os.path.join(palace_path, "corrupt_ids.txt")
with open(bad_file, "w") as f:
for bid in sorted(bad_set):
f.write(bid + "\n")
print(f"\n Bad IDs written to: {bad_file}")
return good_set, bad_set
def prune_corrupt(palace_path=None, confirm=False):
"""Delete corrupt IDs listed in corrupt_ids.txt."""
palace_path = palace_path or _get_palace_path()
bad_file = os.path.join(palace_path, "corrupt_ids.txt")
if not os.path.exists(bad_file):
print(" No corrupt_ids.txt found — run scan first.")
return
with open(bad_file) as f:
bad_ids = [line.strip() for line in f if line.strip()]
print(f" {len(bad_ids):,} corrupt IDs queued for deletion")
if not confirm:
print("\n DRY RUN — no deletions performed.")
print(" Re-run with --confirm to actually delete.")
return
client = chromadb.PersistentClient(path=palace_path)
col = client.get_collection(COLLECTION_NAME)
before = col.count()
print(f" Collection size before: {before:,}")
batch = 100
deleted = 0
failed = 0
for i in range(0, len(bad_ids), batch):
chunk = bad_ids[i : i + batch]
try:
col.delete(ids=chunk)
deleted += len(chunk)
except Exception:
for sid in chunk:
try:
col.delete(ids=[sid])
deleted += 1
except Exception:
failed += 1
if (i // batch) % 20 == 0:
print(f" deleted {deleted}/{len(bad_ids)} (failed: {failed})")
after = col.count()
print(f"\n Deleted: {deleted:,}")
print(f" Failed: {failed:,}")
print(f" Collection size: {before:,}{after:,}")
def rebuild_index(palace_path=None):
"""Rebuild the HNSW index from scratch.
1. Extract all drawers via ChromaDB get()
2. Back up ONLY chroma.sqlite3 (not the bloated HNSW files)
3. Delete and recreate the collection with hnsw:space=cosine
4. Upsert all drawers back
"""
palace_path = palace_path or _get_palace_path()
if not os.path.isdir(palace_path):
print(f"\n No palace found at {palace_path}")
return
print(f"\n{'=' * 55}")
print(" MemPalace Repair — Index Rebuild")
print(f"{'=' * 55}\n")
print(f" Palace: {palace_path}")
client = chromadb.PersistentClient(path=palace_path)
try:
col = client.get_collection(COLLECTION_NAME)
total = col.count()
except Exception as e:
print(f" Error reading palace: {e}")
print(" Palace may need to be re-mined from source files.")
return
print(f" Drawers found: {total}")
if total == 0:
print(" Nothing to repair.")
return
# Extract all drawers in batches
print("\n Extracting drawers...")
batch_size = 5000
all_ids = []
all_docs = []
all_metas = []
offset = 0
while offset < total:
batch = col.get(limit=batch_size, offset=offset, include=["documents", "metadatas"])
if not batch["ids"]:
break
all_ids.extend(batch["ids"])
all_docs.extend(batch["documents"])
all_metas.extend(batch["metadatas"])
offset += len(batch["ids"])
print(f" Extracted {len(all_ids)} drawers")
# Back up ONLY the SQLite database, not the bloated HNSW files
sqlite_path = os.path.join(palace_path, "chroma.sqlite3")
if os.path.exists(sqlite_path):
backup_path = sqlite_path + ".backup"
print(f" Backing up chroma.sqlite3 ({os.path.getsize(sqlite_path) / 1e6:.0f} MB)...")
shutil.copy2(sqlite_path, backup_path)
print(f" Backup: {backup_path}")
# Rebuild with correct HNSW settings
print(" Rebuilding collection with hnsw:space=cosine...")
client.delete_collection(COLLECTION_NAME)
new_col = client.create_collection(COLLECTION_NAME, metadata={"hnsw:space": "cosine"})
filed = 0
for i in range(0, len(all_ids), batch_size):
batch_ids = all_ids[i : i + batch_size]
batch_docs = all_docs[i : i + batch_size]
batch_metas = all_metas[i : i + batch_size]
new_col.upsert(documents=batch_docs, ids=batch_ids, metadatas=batch_metas)
filed += len(batch_ids)
print(f" Re-filed {filed}/{len(all_ids)} drawers...")
print(f"\n Repair complete. {filed} drawers rebuilt.")
print(" HNSW index is now clean with cosine distance metric.")
print(f"\n{'=' * 55}\n")
if __name__ == "__main__":
p = argparse.ArgumentParser(description="MemPalace repair tools")
p.add_argument("command", choices=["scan", "prune", "rebuild"])
p.add_argument("--palace", default=None, help="Palace directory path")
p.add_argument("--wing", default=None, help="Scan only this wing")
p.add_argument("--confirm", action="store_true", help="Actually delete corrupt IDs")
args = p.parse_args()
path = os.path.expanduser(args.palace) if args.palace else None
if args.command == "scan":
scan_palace(palace_path=path, only_wing=args.wing)
elif args.command == "prune":
prune_corrupt(palace_path=path, confirm=args.confirm)
elif args.command == "rebuild":
rebuild_index(palace_path=path)
+5 -1
View File
@@ -54,11 +54,15 @@ packages = ["mempalace"]
[tool.ruff]
line-length = 100
target-version = "py39"
extend-exclude = ["benchmarks"]
[tool.ruff.lint]
select = ["E", "F", "W"]
select = ["E", "F", "W", "C901"]
ignore = ["E501"]
[tool.ruff.lint.mccabe]
max-complexity = 25
[tool.ruff.format]
quote-style = "double"
+272
View File
@@ -0,0 +1,272 @@
"""Tests for mempalace.dedup — near-duplicate drawer detection and removal."""
from unittest.mock import MagicMock, patch
from mempalace import dedup
# ── get_source_groups ─────────────────────────────────────────────────
def test_get_source_groups_basic():
col = MagicMock()
col.count.return_value = 5
col.get.side_effect = [
{
"ids": ["d1", "d2", "d3", "d4", "d5"],
"metadatas": [
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
],
},
{"ids": []},
]
groups = dedup.get_source_groups(col, min_count=5)
assert "a.txt" in groups
assert len(groups["a.txt"]) == 5
def test_get_source_groups_below_min():
col = MagicMock()
col.count.return_value = 2
col.get.side_effect = [
{
"ids": ["d1", "d2"],
"metadatas": [
{"source_file": "a.txt"},
{"source_file": "a.txt"},
],
},
{"ids": []},
]
groups = dedup.get_source_groups(col, min_count=5)
assert len(groups) == 0
def test_get_source_groups_source_filter():
col = MagicMock()
col.count.return_value = 6
col.get.side_effect = [
{
"ids": ["d1", "d2", "d3", "d4", "d5", "d6"],
"metadatas": [
{"source_file": "project_a.txt"},
{"source_file": "project_a.txt"},
{"source_file": "project_a.txt"},
{"source_file": "project_a.txt"},
{"source_file": "project_a.txt"},
{"source_file": "other.txt"},
],
},
{"ids": []},
]
groups = dedup.get_source_groups(col, min_count=5, source_pattern="project_a")
assert "project_a.txt" in groups
assert "other.txt" not in groups
def test_get_source_groups_wing_filter():
col = MagicMock()
col.count.return_value = 5
col.get.side_effect = [
{
"ids": ["d1", "d2", "d3", "d4", "d5"],
"metadatas": [
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
],
},
{"ids": []},
]
dedup.get_source_groups(col, min_count=5, wing="my_wing")
# Verify where filter was passed
first_call = col.get.call_args_list[0]
assert first_call.kwargs.get("where") == {"wing": "my_wing"}
def test_get_source_groups_missing_source_file():
col = MagicMock()
col.count.return_value = 5
col.get.side_effect = [
{
"ids": ["d1", "d2", "d3", "d4", "d5"],
"metadatas": [{}, {}, {}, {}, {}],
},
{"ids": []},
]
groups = dedup.get_source_groups(col, min_count=5)
assert "unknown" in groups
# ── dedup_source_group ────────────────────────────────────────────────
def test_dedup_source_group_all_unique():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": ["long document one content here", "different document two here"],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
col.query.return_value = {
"ids": [["d1"]],
"distances": [[0.8]], # far apart = unique
}
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
assert len(kept) == 2
assert len(deleted) == 0
def test_dedup_source_group_with_duplicate():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": [
"long document content that is fairly long",
"long document content that is fairly long",
],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
col.query.return_value = {
"ids": [["d1"]],
"distances": [[0.05]], # very close = duplicate
}
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
assert len(kept) == 1
assert len(deleted) == 1
def test_dedup_source_group_short_docs_deleted():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": ["long enough document to keep in the palace", "tiny"],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
assert "d2" in deleted # too short
def test_dedup_source_group_empty_doc_deleted():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": ["real document content here that is long enough", None],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
assert "d2" in deleted
def test_dedup_source_group_live_deletes():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": ["long document content here enough", "long document content here enough"],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
col.query.return_value = {
"ids": [["d1"]],
"distances": [[0.05]],
}
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=False)
col.delete.assert_called_once()
def test_dedup_source_group_query_failure_keeps():
col = MagicMock()
col.get.return_value = {
"ids": ["d1", "d2"],
"documents": [
"long document one content here enough",
"long document two content here enough",
],
"metadatas": [{"wing": "a"}, {"wing": "a"}],
}
col.query.side_effect = Exception("query failed")
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
assert len(kept) == 2 # both kept on error
# ── show_stats ────────────────────────────────────────────────────────
@patch("mempalace.dedup.chromadb")
def test_show_stats(mock_chromadb, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 5
mock_col.get.side_effect = [
{
"ids": ["d1", "d2", "d3", "d4", "d5"],
"metadatas": [
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
{"source_file": "a.txt"},
],
},
{"ids": []},
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
dedup.show_stats(palace_path=str(tmp_path)) # should not raise
# ── dedup_palace ──────────────────────────────────────────────────────
@patch("mempalace.dedup.dedup_source_group")
@patch("mempalace.dedup.get_source_groups")
@patch("mempalace.dedup.chromadb")
def test_dedup_palace_dry_run(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 10
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_groups.return_value = {"a.txt": ["d1", "d2", "d3", "d4", "d5"]}
mock_dedup_group.return_value = (["d1", "d2", "d3"], ["d4", "d5"])
dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True)
mock_dedup_group.assert_called_once()
@patch("mempalace.dedup.dedup_source_group")
@patch("mempalace.dedup.get_source_groups")
@patch("mempalace.dedup.chromadb")
def test_dedup_palace_with_wing(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 10
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_groups.return_value = {}
dedup.dedup_palace(palace_path=str(tmp_path), wing="test_wing", dry_run=True)
mock_groups.assert_called_once_with(mock_col, 5, None, wing="test_wing")
@patch("mempalace.dedup.dedup_source_group")
@patch("mempalace.dedup.get_source_groups")
@patch("mempalace.dedup.chromadb")
def test_dedup_palace_no_groups(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 3
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
mock_groups.return_value = {}
dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True)
mock_dedup_group.assert_not_called()
+212
View File
@@ -0,0 +1,212 @@
"""
Tests for query_sanitizer.py — system prompt contamination mitigation (#333).
Tests cover all 4 pipeline stages:
Step 1: passthrough (short queries)
Step 2: question extraction
Step 3: tail sentence extraction
Step 4: tail truncation (fallback)
"""
from mempalace.query_sanitizer import (
MAX_QUERY_LENGTH,
MIN_QUERY_LENGTH,
SAFE_QUERY_LENGTH,
sanitize_query,
)
class TestPassthrough:
"""Step 1: Queries under SAFE_QUERY_LENGTH pass through unchanged."""
def test_short_query_unchanged(self):
result = sanitize_query("What is Rust error handling?")
assert result["clean_query"] == "What is Rust error handling?"
assert result["was_sanitized"] is False
assert result["method"] == "passthrough"
def test_empty_query(self):
result = sanitize_query("")
assert result["clean_query"] == ""
assert result["was_sanitized"] is False
assert result["method"] == "passthrough"
def test_none_query(self):
result = sanitize_query(None)
assert result["was_sanitized"] is False
assert result["method"] == "passthrough"
def test_exactly_safe_length(self):
query = "a" * SAFE_QUERY_LENGTH
result = sanitize_query(query)
assert result["was_sanitized"] is False
assert result["method"] == "passthrough"
def test_one_over_safe_length_triggers_sanitization(self):
query = "a" * (SAFE_QUERY_LENGTH + 1)
result = sanitize_query(query)
# Will go through sanitization pipeline (may or may not change the query)
assert result["original_length"] == SAFE_QUERY_LENGTH + 1
class TestQuestionExtraction:
"""Step 2: Extract question sentences (ending with ?)."""
def test_question_at_end_of_long_text(self):
system_prompt = "You are a helpful assistant. " * 50 # ~1400 chars
query = system_prompt + "What is the best way to handle errors in Rust?"
result = sanitize_query(query)
assert result["was_sanitized"] is True
assert "error" in result["clean_query"].lower() or "Rust" in result["clean_query"]
assert result["method"] == "question_extraction"
def test_japanese_question_mark(self):
system_prompt = "You are a helpful assistant. " * 50
query = system_prompt + "Rustのエラーハンドリング方法は?"
result = sanitize_query(query)
assert result["was_sanitized"] is True
assert "Rust" in result["clean_query"] or "エラー" in result["clean_query"]
assert result["method"] == "question_extraction"
def test_multiple_questions_takes_last(self):
system_prompt = "You are a helpful assistant. " * 50
query = system_prompt + "What is Python?\nHow does Rust handle errors?"
result = sanitize_query(query)
assert "Rust" in result["clean_query"] or "error" in result["clean_query"].lower()
def test_question_in_system_prompt_ignored_when_real_question_exists(self):
# System prompt contains a question, but real query also has one
system_prompt = "Are you ready to help? " * 30 + "\n"
real_query = "What databases does MemPalace support?"
query = system_prompt + real_query
result = sanitize_query(query)
assert result["was_sanitized"] is True
assert "MemPalace" in result["clean_query"] or "database" in result["clean_query"].lower()
class TestTailSentence:
"""Step 3: Extract the last meaningful sentence when no question mark found."""
def test_command_style_query(self):
system_prompt = "You are a helpful assistant. " * 50
query = system_prompt + "Show me all Rust error handling patterns"
result = sanitize_query(query)
assert result["was_sanitized"] is True
assert "Rust" in result["clean_query"] or "error" in result["clean_query"].lower()
assert result["method"] in ("tail_sentence", "question_extraction")
def test_keyword_style_query(self):
system_prompt = "System configuration loaded. " * 60
query = system_prompt + "\nMemPalace ChromaDB integration setup"
result = sanitize_query(query)
assert result["was_sanitized"] is True
assert "MemPalace" in result["clean_query"] or "ChromaDB" in result["clean_query"]
class TestTailTruncation:
"""Step 4: Fallback — take the last MAX_QUERY_LENGTH characters."""
def test_single_long_line_no_sentences(self):
# Short lines only — no segment reaches MIN_QUERY_LENGTH; fallback truncates tail
filler = "\n".join(["ab"] * 200)
result = sanitize_query(filler)
assert result["was_sanitized"] is True
assert len(result["clean_query"]) <= MAX_QUERY_LENGTH
assert result["method"] == "tail_truncation"
def test_truncation_preserves_tail(self):
filler = "x" * 1000 + "IMPORTANT_QUERY_CONTENT"
result = sanitize_query(filler)
assert "IMPORTANT_QUERY_CONTENT" in result["clean_query"]
class TestLengthGuards:
"""Verify output length constraints."""
def test_output_never_exceeds_max(self):
# Very long question sentence
long_question = "a" * 1000 + "?"
system_prompt = "Context. " * 100
query = system_prompt + long_question
result = sanitize_query(query)
assert len(result["clean_query"]) <= MAX_QUERY_LENGTH
def test_extraction_too_short_falls_through(self):
# Question mark found but the sentence is too short
system_prompt = "You are helpful. " * 50
query = system_prompt + "\nOK?"
result = sanitize_query(query)
# "OK?" is only 3 chars < MIN_QUERY_LENGTH, should fall through
assert result["was_sanitized"] is True
class TestMetadata:
"""Verify sanitizer metadata is correct."""
def test_original_length_preserved(self):
system_prompt = "You are a helpful assistant. " * 50
query = system_prompt + "What is Rust?"
result = sanitize_query(query)
assert result["original_length"] == len(query.strip())
def test_clean_length_matches_clean_query(self):
system_prompt = "You are a helpful assistant. " * 50
query = system_prompt + "What is Rust?"
result = sanitize_query(query)
assert result["clean_length"] == len(result["clean_query"])
def test_sanitized_flag_true_when_changed(self):
system_prompt = "You are a helpful assistant. " * 50
query = system_prompt + "What is Rust?"
result = sanitize_query(query)
assert result["was_sanitized"] is True
def test_sanitized_flag_false_when_unchanged(self):
result = sanitize_query("Short query")
assert result["was_sanitized"] is False
class TestRealWorldScenarios:
"""Simulate realistic system prompt contamination patterns."""
def test_mempalace_wakeup_prepended(self):
"""Simulates mempalace wake-up output prepended to a query."""
wakeup = (
"MemPalace loaded. Wings: technical, emotions, identity. "
"Rooms: chromadb-setup, error-handling, project-planning. "
"Total drawers: 234. Knowledge graph: 89 entities, 156 triples. "
"AAAK dialect active. Protocol: verify before responding. "
) * 5 # ~1000 chars
real_query = "How did we decide on the database architecture?"
query = wakeup + real_query
result = sanitize_query(query)
assert result["was_sanitized"] is True
assert len(result["clean_query"]) <= MAX_QUERY_LENGTH
# Should recover something meaningful
assert len(result["clean_query"]) >= MIN_QUERY_LENGTH
def test_memory_md_prepended(self):
"""Simulates MEMORY.md content prepended to a query."""
memory_md = (
"# Project Memory\n"
"## Architecture Decisions\n"
"- Use ChromaDB for vector storage\n"
"- MCP protocol for tool integration\n"
"- AAAK compression for efficient storage\n"
) * 10 # ~750 chars
real_query = "What were the performance benchmarks for the search system?"
query = memory_md + "\n" + real_query
result = sanitize_query(query)
assert result["was_sanitized"] is True
assert result["method"] in ("question_extraction", "tail_sentence")
def test_2000_char_system_prompt_with_question(self):
"""The exact scenario from Issue #333 — 2000 chars prepended."""
system_prompt = "You are an AI assistant with access to tools. " * 45 # ~2000 chars
real_query = "What is the status of the MemPalace project?"
query = system_prompt + real_query
result = sanitize_query(query)
assert result["was_sanitized"] is True
assert result["original_length"] > 2000
assert result["clean_length"] <= MAX_QUERY_LENGTH
assert result["method"] == "question_extraction"
+266
View File
@@ -0,0 +1,266 @@
"""Tests for mempalace.repair — scan, prune, and rebuild HNSW index."""
import os
from unittest.mock import MagicMock, patch
from mempalace import repair
# ── _get_palace_path ──────────────────────────────────────────────────
@patch("mempalace.repair.MempalaceConfig", create=True)
def test_get_palace_path_from_config(mock_config_cls):
mock_config_cls.return_value.palace_path = "/configured/palace"
with patch.dict("sys.modules", {}):
# Force reimport to pick up the mock
result = repair._get_palace_path()
assert isinstance(result, str)
def test_get_palace_path_fallback():
with patch("mempalace.repair._get_palace_path") as mock_get:
mock_get.return_value = os.path.join(os.path.expanduser("~"), ".mempalace", "palace")
result = mock_get()
assert ".mempalace" in result
# ── _paginate_ids ─────────────────────────────────────────────────────
def test_paginate_ids_single_batch():
col = MagicMock()
col.get.return_value = {"ids": ["id1", "id2", "id3"]}
ids = repair._paginate_ids(col)
assert ids == ["id1", "id2", "id3"]
def test_paginate_ids_empty():
col = MagicMock()
col.get.return_value = {"ids": []}
ids = repair._paginate_ids(col)
assert ids == []
def test_paginate_ids_with_where():
col = MagicMock()
col.get.return_value = {"ids": ["id1"]}
repair._paginate_ids(col, where={"wing": "test"})
col.get.assert_called_with(where={"wing": "test"}, include=[], limit=1000, offset=0)
def test_paginate_ids_offset_exception_fallback():
col = MagicMock()
# First call raises, fallback returns ids, second fallback returns empty
col.get.side_effect = [
Exception("offset bug"),
{"ids": ["id1", "id2"]},
Exception("offset bug"),
{"ids": ["id1", "id2"]}, # same ids = no new = break
]
ids = repair._paginate_ids(col)
assert "id1" in ids
# ── scan_palace ───────────────────────────────────────────────────────
@patch("mempalace.repair.chromadb")
def test_scan_palace_no_ids(mock_chromadb, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 0
mock_col.get.return_value = {"ids": []}
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
good, bad = repair.scan_palace(palace_path=str(tmp_path))
assert good == set()
assert bad == set()
@patch("mempalace.repair.chromadb")
def test_scan_palace_all_good(mock_chromadb, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 2
# _paginate_ids call
mock_col.get.side_effect = [
{"ids": ["id1", "id2"]}, # paginate
{"ids": ["id1", "id2"]}, # probe batch — both returned
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
good, bad = repair.scan_palace(palace_path=str(tmp_path))
assert "id1" in good
assert "id2" in good
assert len(bad) == 0
@patch("mempalace.repair.chromadb")
def test_scan_palace_with_bad_ids(mock_chromadb, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 2
def get_side_effect(**kwargs):
ids = kwargs.get("ids", None)
if ids is None:
# paginate call
return {"ids": ["good1", "bad1"]}
if "bad1" in ids and len(ids) == 1:
raise Exception("corrupt")
if "good1" in ids and len(ids) == 1:
return {"ids": ["good1"]}
# batch probe — raise to force per-id
raise Exception("batch fail")
mock_col.get.side_effect = get_side_effect
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
good, bad = repair.scan_palace(palace_path=str(tmp_path))
assert "good1" in good
assert "bad1" in bad
@patch("mempalace.repair.chromadb")
def test_scan_palace_with_wing_filter(mock_chromadb, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 1
mock_col.get.side_effect = [
{"ids": ["id1"]}, # paginate
{"ids": ["id1"]}, # probe
]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.scan_palace(palace_path=str(tmp_path), only_wing="test_wing")
# Verify where filter was passed
first_call = mock_col.get.call_args_list[0]
assert first_call.kwargs.get("where") == {"wing": "test_wing"}
# ── prune_corrupt ─────────────────────────────────────────────────────
@patch("mempalace.repair.chromadb")
def test_prune_corrupt_no_file(mock_chromadb, tmp_path):
# Should print message and return without error
repair.prune_corrupt(palace_path=str(tmp_path))
@patch("mempalace.repair.chromadb")
def test_prune_corrupt_dry_run(mock_chromadb, tmp_path):
bad_file = tmp_path / "corrupt_ids.txt"
bad_file.write_text("bad1\nbad2\n")
repair.prune_corrupt(palace_path=str(tmp_path), confirm=False)
# No chromadb calls in dry run
mock_chromadb.PersistentClient.assert_not_called()
@patch("mempalace.repair.chromadb")
def test_prune_corrupt_confirmed(mock_chromadb, tmp_path):
bad_file = tmp_path / "corrupt_ids.txt"
bad_file.write_text("bad1\nbad2\n")
mock_col = MagicMock()
mock_col.count.side_effect = [10, 8]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.prune_corrupt(palace_path=str(tmp_path), confirm=True)
mock_col.delete.assert_called_once()
@patch("mempalace.repair.chromadb")
def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path):
bad_file = tmp_path / "corrupt_ids.txt"
bad_file.write_text("bad1\nbad2\n")
mock_col = MagicMock()
mock_col.count.side_effect = [10, 8]
# Batch delete fails, per-id succeeds
mock_col.delete.side_effect = [Exception("batch fail"), None, None]
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.prune_corrupt(palace_path=str(tmp_path), confirm=True)
assert mock_col.delete.call_count == 3 # 1 batch + 2 individual
# ── rebuild_index ─────────────────────────────────────────────────────
@patch("mempalace.repair.chromadb")
def test_rebuild_index_no_palace(mock_chromadb, tmp_path):
nonexistent = str(tmp_path / "nope")
repair.rebuild_index(palace_path=nonexistent)
mock_chromadb.PersistentClient.assert_not_called()
@patch("mempalace.repair.shutil")
@patch("mempalace.repair.chromadb")
def test_rebuild_index_empty_palace(mock_chromadb, mock_shutil, tmp_path):
mock_col = MagicMock()
mock_col.count.return_value = 0
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.rebuild_index(palace_path=str(tmp_path))
mock_client.delete_collection.assert_not_called()
@patch("mempalace.repair.shutil")
@patch("mempalace.repair.chromadb")
def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path):
# Create a fake sqlite file
sqlite_path = tmp_path / "chroma.sqlite3"
sqlite_path.write_text("fake")
mock_col = MagicMock()
mock_col.count.return_value = 2
mock_col.get.return_value = {
"ids": ["id1", "id2"],
"documents": ["doc1", "doc2"],
"metadatas": [{"wing": "a"}, {"wing": "b"}],
}
mock_new_col = MagicMock()
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_client.create_collection.return_value = mock_new_col
mock_chromadb.PersistentClient.return_value = mock_client
repair.rebuild_index(palace_path=str(tmp_path))
# Verify: backed up sqlite only (not copytree)
mock_shutil.copy2.assert_called_once()
assert "chroma.sqlite3" in str(mock_shutil.copy2.call_args)
# Verify: deleted and recreated with cosine
mock_client.delete_collection.assert_called_once_with("mempalace_drawers")
mock_client.create_collection.assert_called_once_with(
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
)
# Verify: used upsert not add
mock_new_col.upsert.assert_called_once()
mock_new_col.add.assert_not_called()
@patch("mempalace.repair.shutil")
@patch("mempalace.repair.chromadb")
def test_rebuild_index_error_reading(mock_chromadb, mock_shutil, tmp_path):
mock_client = MagicMock()
mock_client.get_collection.side_effect = Exception("corrupt")
mock_chromadb.PersistentClient.return_value = mock_client
repair.rebuild_index(palace_path=str(tmp_path))
mock_client.delete_collection.assert_not_called()