Merge branch 'main' into fix/chromadb-version-constraint
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Default owners for everything
|
||||
* @milla-jovovich @bensig @igorls
|
||||
|
||||
# Core library
|
||||
mempalace/ @milla-jovovich @bensig
|
||||
|
||||
# CI and workflows
|
||||
.github/ @bensig
|
||||
|
||||
# Plugins and integrations
|
||||
.claude-plugin/ @bensig
|
||||
.codex-plugin/ @bensig
|
||||
integrations/ @bensig
|
||||
@@ -0,0 +1,12 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 3
|
||||
@@ -18,7 +18,7 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- run: pip install -e ".[dev]"
|
||||
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80
|
||||
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
|
||||
|
||||
test-windows:
|
||||
runs-on: windows-latest
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.9"
|
||||
- run: pip install -e ".[dev]"
|
||||
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80
|
||||
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
|
||||
|
||||
test-macos:
|
||||
runs-on: macos-latest
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.9"
|
||||
- run: pip install -e ".[dev]"
|
||||
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80
|
||||
- run: python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing --cov-fail-under=80 --durations=10
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
+27
@@ -6,3 +6,30 @@ __pycache__/
|
||||
.pytest_cache/
|
||||
mempal.yaml
|
||||
.a5c/
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.*
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# IDEs
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Coverage
|
||||
htmlcov/
|
||||
.coverage
|
||||
coverage.xml
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
|
||||
# ChromaDB local data
|
||||
*.sqlite3-journal
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# AGENTS.md
|
||||
|
||||
> How to build, test, and contribute to MemPalace.
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
# Run tests
|
||||
python -m pytest tests/ -v --ignore=tests/benchmarks
|
||||
|
||||
# Run tests with coverage
|
||||
python -m pytest tests/ -v --ignore=tests/benchmarks --cov=mempalace --cov-report=term-missing
|
||||
|
||||
# Lint
|
||||
ruff check .
|
||||
|
||||
# Format
|
||||
ruff format .
|
||||
|
||||
# Format check (CI mode)
|
||||
ruff format --check .
|
||||
```
|
||||
|
||||
## Project structure
|
||||
|
||||
```
|
||||
mempalace/
|
||||
├── mcp_server.py # MCP server — all read/write tools
|
||||
├── miner.py # Project file miner
|
||||
├── convo_miner.py # Conversation transcript miner
|
||||
├── searcher.py # Semantic search
|
||||
├── knowledge_graph.py # Temporal entity-relationship graph (SQLite)
|
||||
├── palace.py # Shared palace operations (ChromaDB access)
|
||||
├── config.py # Configuration + input validation
|
||||
├── normalize.py # Transcript format detection + normalization
|
||||
├── cli.py # CLI dispatcher
|
||||
├── dialect.py # AAAK compression dialect
|
||||
├── palace_graph.py # Room traversal + cross-wing tunnels
|
||||
├── hooks_cli.py # Hook system for auto-save
|
||||
└── version.py # Single source of truth for version
|
||||
```
|
||||
|
||||
## Conventions
|
||||
|
||||
- **Python style**: snake_case for functions/variables, PascalCase for classes
|
||||
- **Linter**: ruff with E/F/W rules
|
||||
- **Formatter**: ruff format, double quotes
|
||||
- **Commits**: conventional commits (`fix:`, `feat:`, `test:`, `docs:`, `ci:`)
|
||||
- **Tests**: `tests/test_*.py`, fixtures in `tests/conftest.py`
|
||||
- **Coverage**: 85% threshold (80% on Windows due to ChromaDB file lock cleanup)
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
User → CLI / MCP Server → ChromaDB (vector store) + SQLite (knowledge graph)
|
||||
|
||||
Palace structure:
|
||||
WING (person/project)
|
||||
└── ROOM (topic)
|
||||
└── DRAWER (verbatim text chunk)
|
||||
|
||||
Knowledge Graph:
|
||||
ENTITY → PREDICATE → ENTITY (with valid_from / valid_to dates)
|
||||
```
|
||||
|
||||
## Key files for common tasks
|
||||
|
||||
- **Adding an MCP tool**: `mempalace/mcp_server.py` — add handler function + TOOLS dict entry
|
||||
- **Changing search**: `mempalace/searcher.py`
|
||||
- **Modifying mining**: `mempalace/miner.py` (project files) or `mempalace/convo_miner.py` (transcripts)
|
||||
- **Input validation**: `mempalace/config.py` — `sanitize_name()` / `sanitize_content()`
|
||||
- **Tests**: mirror source structure in `tests/test_<module>.py`
|
||||
+4
-1
@@ -5,8 +5,11 @@ Thanks for wanting to help. MemPalace is open source and we welcome contribution
|
||||
## Getting Started
|
||||
|
||||
```bash
|
||||
git clone https://github.com/milla-jovovich/mempalace.git
|
||||
# Fork the repo on GitHub first, then clone your fork
|
||||
git clone https://github.com/<your-username>/mempalace.git
|
||||
cd mempalace
|
||||
git remote add upstream https://github.com/milla-jovovich/mempalace.git
|
||||
|
||||
pip install -e ".[dev]" # installs with dev dependencies (pytest, build, twine)
|
||||
```
|
||||
|
||||
|
||||
@@ -84,6 +84,18 @@ Other memory systems try to fix this by letting AI decide what's worth rememberi
|
||||
|
||||
---
|
||||
|
||||
## An important follow up note regarding fake MemPalace websites - April 11, 2026
|
||||
|
||||
Several Community Members (#267, #326, #506) have pointed out there are fake MemPalace websites popping up, including ones with Malware.
|
||||
|
||||
To be super clear, MemPalace *has no website* (at least for now), so anything claiming to be one is false.
|
||||
|
||||
Thanks to our Community Members for letting us know about the problem.
|
||||
|
||||
Stay safe out there.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
-- MemPalace Knowledge Graph Schema
|
||||
-- SQLite database at ~/.mempalace/knowledge_graph.db
|
||||
|
||||
CREATE TABLE IF NOT EXISTS entities (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT DEFAULT 'unknown',
|
||||
properties TEXT DEFAULT '{}'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS triples (
|
||||
id TEXT PRIMARY KEY,
|
||||
subject TEXT NOT NULL,
|
||||
predicate TEXT NOT NULL,
|
||||
object TEXT NOT NULL,
|
||||
valid_from TEXT,
|
||||
valid_to TEXT,
|
||||
confidence REAL DEFAULT 1.0,
|
||||
source_closet TEXT,
|
||||
source_file TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS attributes (
|
||||
entity_id TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT,
|
||||
valid_from TEXT,
|
||||
valid_to TEXT,
|
||||
PRIMARY KEY (entity_id, key, valid_from)
|
||||
);
|
||||
|
||||
-- Indexes
|
||||
CREATE INDEX IF NOT EXISTS idx_triples_subject ON triples(subject);
|
||||
CREATE INDEX IF NOT EXISTS idx_triples_object ON triples(object);
|
||||
CREATE INDEX IF NOT EXISTS idx_triples_predicate ON triples(predicate);
|
||||
CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to);
|
||||
@@ -72,6 +72,6 @@ fi
|
||||
cat << 'HOOKJSON'
|
||||
{
|
||||
"decision": "block",
|
||||
"reason": "COMPACTION IMMINENT. Save ALL topics, decisions, quotes, code, and important context from this session to your memory system. Be thorough — after compaction, detailed context will be lost. Organize into appropriate categories. Use verbatim quotes where possible. Save everything, then allow compaction to proceed."
|
||||
"reason": "COMPACTION IMMINENT (MemPalace). Save ALL session content before context is lost:\n1. mempalace_diary_write — thorough AAAK-compressed session summary\n2. mempalace_add_drawer — ALL verbatim quotes, decisions, code, context\n3. mempalace_kg_add — entity relationships (optional)\nBe thorough — after compaction, detailed context will be lost. Do NOT write to Claude Code's native auto-memory (.md files). Save everything to MemPalace, then allow compaction to proceed."
|
||||
}
|
||||
HOOKJSON
|
||||
|
||||
@@ -145,7 +145,7 @@ if [ "$SINCE_LAST" -ge "$SAVE_INTERVAL" ] && [ "$EXCHANGE_COUNT" -gt 0 ]; then
|
||||
cat << 'HOOKJSON'
|
||||
{
|
||||
"decision": "block",
|
||||
"reason": "AUTO-SAVE checkpoint. Save key topics, decisions, quotes, and code from this session to your memory system. Organize into appropriate categories. Use verbatim quotes where possible. Continue conversation after saving."
|
||||
"reason": "AUTO-SAVE checkpoint (MemPalace). Save this session's key content:\n1. mempalace_diary_write — AAAK-compressed session summary\n2. mempalace_add_drawer — verbatim quotes, decisions, code snippets\n3. mempalace_kg_add — entity relationships (optional)\nDo NOT write to Claude Code's native auto-memory (.md files). Continue conversation after saving."
|
||||
}
|
||||
HOOKJSON
|
||||
else
|
||||
|
||||
@@ -150,6 +150,14 @@ def cmd_split(args):
|
||||
sys.argv = old_argv
|
||||
|
||||
|
||||
def cmd_migrate(args):
|
||||
"""Migrate palace from a different ChromaDB version."""
|
||||
from .migrate import migrate
|
||||
|
||||
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
|
||||
migrate(palace_path=palace_path, dry_run=args.dry_run)
|
||||
|
||||
|
||||
def cmd_status(args):
|
||||
from .miner import status
|
||||
|
||||
@@ -531,6 +539,17 @@ def main():
|
||||
)
|
||||
|
||||
# status
|
||||
# migrate
|
||||
p_migrate = sub.add_parser(
|
||||
"migrate",
|
||||
help="Migrate palace from a different ChromaDB version (fixes 3.0.0 → 3.1.0 upgrade)",
|
||||
)
|
||||
p_migrate.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be migrated without changing anything",
|
||||
)
|
||||
|
||||
sub.add_parser("status", help="Show what's been filed")
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -565,6 +584,7 @@ def main():
|
||||
"compress": cmd_compress,
|
||||
"wake-up": cmd_wakeup,
|
||||
"repair": cmd_repair,
|
||||
"migrate": cmd_migrate,
|
||||
"status": cmd_status,
|
||||
}
|
||||
dispatch[args.command](args)
|
||||
|
||||
@@ -334,7 +334,7 @@ def mine_convos(
|
||||
room_counts[chunk_room] += 1
|
||||
drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}"
|
||||
try:
|
||||
collection.add(
|
||||
collection.upsert(
|
||||
documents=[chunk["content"]],
|
||||
ids=[drawer_id],
|
||||
metadatas=[
|
||||
|
||||
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
dedup.py — Detect and remove near-duplicate drawers
|
||||
====================================================
|
||||
|
||||
When the same files are mined multiple times, near-identical drawers
|
||||
accumulate. This module finds drawers from the same source_file that
|
||||
are too similar (cosine distance < threshold), keeps the longest/richest
|
||||
version, and deletes the rest.
|
||||
|
||||
No API calls — uses ChromaDB's built-in embedding similarity.
|
||||
|
||||
Usage (standalone):
|
||||
python -m mempalace.dedup # dedup all
|
||||
python -m mempalace.dedup --dry-run # preview only
|
||||
python -m mempalace.dedup --threshold 0.10 # stricter (near-identical only)
|
||||
python -m mempalace.dedup --threshold 0.35 # looser (catches paraphrased content)
|
||||
python -m mempalace.dedup --wing my_project # scope to one wing
|
||||
python -m mempalace.dedup --stats # stats only
|
||||
python -m mempalace.dedup --source "my_project" # filter by source
|
||||
|
||||
Usage (from CLI):
|
||||
mempalace dedup [--dry-run] [--threshold 0.15] [--stats]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
import chromadb
|
||||
|
||||
|
||||
COLLECTION_NAME = "mempalace_drawers"
|
||||
# Cosine DISTANCE threshold (not similarity). Lower = stricter.
|
||||
# 0.15 = ~85% cosine similarity — catches near-identical chunks.
|
||||
# For looser dedup of paraphrased content, try 0.3–0.4.
|
||||
DEFAULT_THRESHOLD = 0.15
|
||||
MIN_DRAWERS_TO_CHECK = 5
|
||||
|
||||
|
||||
def _get_palace_path():
|
||||
"""Resolve palace path from config."""
|
||||
try:
|
||||
from .config import MempalaceConfig
|
||||
|
||||
return MempalaceConfig().palace_path
|
||||
except Exception:
|
||||
return os.path.join(os.path.expanduser("~"), ".mempalace", "palace")
|
||||
|
||||
|
||||
def get_source_groups(col, min_count=MIN_DRAWERS_TO_CHECK, source_pattern=None, wing=None):
|
||||
"""Group drawers by source_file, return groups with min_count+ entries.
|
||||
|
||||
If wing is specified, only considers drawers in that wing. This catches
|
||||
cross-wing duplicates when the same source was mined into multiple wings.
|
||||
"""
|
||||
total = col.count()
|
||||
groups = defaultdict(list)
|
||||
|
||||
offset = 0
|
||||
batch_size = 1000
|
||||
while offset < total:
|
||||
kwargs = {"limit": batch_size, "offset": offset, "include": ["metadatas"]}
|
||||
if wing:
|
||||
kwargs["where"] = {"wing": wing}
|
||||
batch = col.get(**kwargs)
|
||||
if not batch["ids"]:
|
||||
break
|
||||
for did, meta in zip(batch["ids"], batch["metadatas"]):
|
||||
src = meta.get("source_file", "unknown")
|
||||
if source_pattern and source_pattern.lower() not in src.lower():
|
||||
continue
|
||||
groups[src].append(did)
|
||||
offset += len(batch["ids"])
|
||||
|
||||
return {src: ids for src, ids in groups.items() if len(ids) >= min_count}
|
||||
|
||||
|
||||
def dedup_source_group(col, drawer_ids, threshold=DEFAULT_THRESHOLD, dry_run=True):
|
||||
"""Dedup drawers within one source_file group.
|
||||
|
||||
Greedy: sort by doc length (longest first), keep if not too similar
|
||||
to any already-kept drawer. Returns (kept_ids, deleted_ids).
|
||||
"""
|
||||
data = col.get(ids=drawer_ids, include=["documents", "metadatas"])
|
||||
items = list(zip(data["ids"], data["documents"], data["metadatas"]))
|
||||
items.sort(key=lambda x: len(x[1] or ""), reverse=True)
|
||||
|
||||
kept = []
|
||||
to_delete = []
|
||||
|
||||
for did, doc, meta in items:
|
||||
if not doc or len(doc) < 20:
|
||||
to_delete.append(did)
|
||||
continue
|
||||
|
||||
if not kept:
|
||||
kept.append((did, doc))
|
||||
continue
|
||||
|
||||
try:
|
||||
results = col.query(
|
||||
query_texts=[doc],
|
||||
n_results=min(len(kept), 5),
|
||||
include=["distances"],
|
||||
)
|
||||
dists = results["distances"][0] if results["distances"] else []
|
||||
kept_ids_set = {k[0] for k in kept}
|
||||
|
||||
is_dup = False
|
||||
for rid, dist in zip(results["ids"][0], dists):
|
||||
if rid in kept_ids_set and dist < threshold:
|
||||
is_dup = True
|
||||
break
|
||||
|
||||
if is_dup:
|
||||
to_delete.append(did)
|
||||
else:
|
||||
kept.append((did, doc))
|
||||
except Exception:
|
||||
kept.append((did, doc))
|
||||
|
||||
if to_delete and not dry_run:
|
||||
for i in range(0, len(to_delete), 500):
|
||||
col.delete(ids=to_delete[i : i + 500])
|
||||
|
||||
return [k[0] for k in kept], to_delete
|
||||
|
||||
|
||||
def show_stats(palace_path=None):
|
||||
"""Show duplication statistics without making changes."""
|
||||
palace_path = palace_path or _get_palace_path()
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_collection(COLLECTION_NAME)
|
||||
|
||||
groups = get_source_groups(col)
|
||||
|
||||
total_drawers = sum(len(ids) for ids in groups.values())
|
||||
print(f"\n Sources with {MIN_DRAWERS_TO_CHECK}+ drawers: {len(groups)}")
|
||||
print(f" Total drawers in those sources: {total_drawers:,}")
|
||||
|
||||
print("\n Top 15 by drawer count:")
|
||||
sorted_groups = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
for src, ids in sorted_groups[:15]:
|
||||
print(f" {len(ids):4d} {src[:65]}")
|
||||
|
||||
estimated_dups = sum(int(len(ids) * 0.4) for ids in groups.values() if len(ids) > 20)
|
||||
print(f"\n Estimated duplicates (groups > 20): ~{estimated_dups:,}")
|
||||
|
||||
|
||||
def dedup_palace(
|
||||
palace_path=None,
|
||||
threshold=DEFAULT_THRESHOLD,
|
||||
dry_run=True,
|
||||
source_pattern=None,
|
||||
min_count=MIN_DRAWERS_TO_CHECK,
|
||||
wing=None,
|
||||
):
|
||||
"""Main entry point: deduplicate near-identical drawers across the palace."""
|
||||
palace_path = palace_path or _get_palace_path()
|
||||
|
||||
print(f"\n{'=' * 55}")
|
||||
print(" MemPalace Deduplicator")
|
||||
print(f"{'=' * 55}")
|
||||
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_collection(COLLECTION_NAME)
|
||||
|
||||
print(f" Palace: {palace_path}")
|
||||
print(f" Drawers: {col.count():,}")
|
||||
print(f" Threshold: {threshold}")
|
||||
print(f" Mode: {'DRY RUN' if dry_run else 'LIVE'}")
|
||||
print(f"{'─' * 55}")
|
||||
|
||||
if wing:
|
||||
print(f" Wing: {wing}")
|
||||
groups = get_source_groups(col, min_count, source_pattern, wing=wing)
|
||||
print(f"\n Sources to check: {len(groups)}")
|
||||
|
||||
t0 = time.time()
|
||||
total_kept = 0
|
||||
total_deleted = 0
|
||||
|
||||
sorted_groups = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
|
||||
for i, (src, drawer_ids) in enumerate(sorted_groups):
|
||||
kept, deleted = dedup_source_group(col, drawer_ids, threshold, dry_run)
|
||||
total_kept += len(kept)
|
||||
total_deleted += len(deleted)
|
||||
|
||||
if deleted:
|
||||
print(
|
||||
f" [{i + 1:3d}/{len(groups)}] "
|
||||
f"{src[:50]:50s} {len(drawer_ids):4d} → {len(kept):4d} "
|
||||
f"(-{len(deleted)})"
|
||||
)
|
||||
|
||||
elapsed = time.time() - t0
|
||||
|
||||
print(f"\n{'─' * 55}")
|
||||
print(f" Done in {elapsed:.1f}s")
|
||||
print(
|
||||
f" Drawers: {total_kept + total_deleted:,} → {total_kept:,} (-{total_deleted:,} removed)"
|
||||
)
|
||||
print(f" Palace after: {col.count():,} drawers")
|
||||
|
||||
if dry_run:
|
||||
print("\n [DRY RUN] No changes written. Re-run without --dry-run to apply.")
|
||||
|
||||
print(f"{'=' * 55}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Deduplicate near-identical drawers")
|
||||
parser.add_argument("--palace", default=None, help="Palace directory path")
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=DEFAULT_THRESHOLD,
|
||||
help=f"Cosine distance threshold (default: {DEFAULT_THRESHOLD})",
|
||||
)
|
||||
parser.add_argument("--dry-run", action="store_true", help="Preview without deleting")
|
||||
parser.add_argument("--stats", action="store_true", help="Show stats only")
|
||||
parser.add_argument("--wing", default=None, help="Scope dedup to a single wing")
|
||||
parser.add_argument("--source", default=None, help="Filter by source file pattern")
|
||||
args = parser.parse_args()
|
||||
|
||||
path = os.path.expanduser(args.palace) if args.palace else None
|
||||
|
||||
if args.stats:
|
||||
show_stats(palace_path=path)
|
||||
else:
|
||||
dedup_palace(
|
||||
palace_path=path,
|
||||
threshold=args.threshold,
|
||||
dry_run=args.dry_run,
|
||||
source_pattern=args.source,
|
||||
wing=args.wing,
|
||||
)
|
||||
+21
-9
@@ -18,18 +18,22 @@ SAVE_INTERVAL = 15
|
||||
STATE_DIR = Path.home() / ".mempalace" / "hook_state"
|
||||
|
||||
STOP_BLOCK_REASON = (
|
||||
"AUTO-SAVE checkpoint. Save key topics, decisions, quotes, and code "
|
||||
"from this session to your memory system. Organize into appropriate "
|
||||
"categories. Use verbatim quotes where possible. Continue conversation "
|
||||
"after saving."
|
||||
"AUTO-SAVE checkpoint (MemPalace). Save this session's key content:\n"
|
||||
"1. mempalace_diary_write — AAAK-compressed session summary\n"
|
||||
"2. mempalace_add_drawer — verbatim quotes, decisions, code snippets\n"
|
||||
"3. mempalace_kg_add — entity relationships (optional)\n"
|
||||
"Do NOT write to Claude Code's native auto-memory (.md files). "
|
||||
"Continue conversation after saving."
|
||||
)
|
||||
|
||||
PRECOMPACT_BLOCK_REASON = (
|
||||
"COMPACTION IMMINENT. Save ALL topics, decisions, quotes, code, and "
|
||||
"important context from this session to your memory system. Be thorough "
|
||||
"\u2014 after compaction, detailed context will be lost. Organize into "
|
||||
"appropriate categories. Use verbatim quotes where possible. Save "
|
||||
"everything, then allow compaction to proceed."
|
||||
"COMPACTION IMMINENT (MemPalace). Save ALL session content before context is lost:\n"
|
||||
"1. mempalace_diary_write — thorough AAAK-compressed session summary\n"
|
||||
"2. mempalace_add_drawer — ALL verbatim quotes, decisions, code, context\n"
|
||||
"3. mempalace_kg_add — entity relationships (optional)\n"
|
||||
"Be thorough \u2014 after compaction, detailed context will be lost. "
|
||||
"Do NOT write to Claude Code's native auto-memory (.md files). "
|
||||
"Save everything to MemPalace, then allow compaction to proceed."
|
||||
)
|
||||
|
||||
|
||||
@@ -63,6 +67,14 @@ def _count_human_messages(transcript_path: str) -> int:
|
||||
if "<command-message>" in text:
|
||||
continue
|
||||
count += 1
|
||||
# Also handle Codex CLI transcript format
|
||||
# {"type": "event_msg", "payload": {"type": "user_message", "message": "..."}}
|
||||
elif entry.get("type") == "event_msg":
|
||||
payload = entry.get("payload", {})
|
||||
if isinstance(payload, dict) and payload.get("type") == "user_message":
|
||||
msg_text = payload.get("message", "")
|
||||
if isinstance(msg_text, str) and "<command-message>" not in msg_text:
|
||||
count += 1
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
except OSError:
|
||||
|
||||
+124
-39
@@ -28,6 +28,7 @@ from pathlib import Path
|
||||
|
||||
from .config import MempalaceConfig, sanitize_name, sanitize_content
|
||||
from .version import __version__
|
||||
from .query_sanitizer import sanitize_query
|
||||
from .searcher import search_memories
|
||||
from .palace_graph import traverse, find_tunnels, graph_stats
|
||||
import chromadb
|
||||
@@ -143,16 +144,25 @@ def tool_status():
|
||||
count = col.count()
|
||||
wings = {}
|
||||
rooms = {}
|
||||
try:
|
||||
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
|
||||
for m in all_meta:
|
||||
w = m.get("wing", "unknown")
|
||||
r = m.get("room", "unknown")
|
||||
wings[w] = wings.get(w, 0) + 1
|
||||
rooms[r] = rooms.get(r, 0) + 1
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
batch_size = 5000
|
||||
offset = 0
|
||||
error_info = None
|
||||
while True:
|
||||
try:
|
||||
batch = col.get(include=["metadatas"], limit=batch_size, offset=offset)
|
||||
rows = batch["metadatas"]
|
||||
for m in rows:
|
||||
w = m.get("wing", "unknown")
|
||||
r = m.get("room", "unknown")
|
||||
wings[w] = wings.get(w, 0) + 1
|
||||
rooms[r] = rooms.get(r, 0) + 1
|
||||
offset += len(rows)
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
except Exception as e:
|
||||
error_info = f"Partial result, failed at offset {offset}: {str(e)}"
|
||||
break
|
||||
result = {
|
||||
"total_drawers": count,
|
||||
"wings": wings,
|
||||
"rooms": rooms,
|
||||
@@ -160,6 +170,10 @@ def tool_status():
|
||||
"protocol": PALACE_PROTOCOL,
|
||||
"aaak_dialect": AAAK_SPEC,
|
||||
}
|
||||
if error_info:
|
||||
result["error"] = error_info
|
||||
result["partial"] = True
|
||||
return result
|
||||
|
||||
|
||||
# ── AAAK Dialect Spec ─────────────────────────────────────────────────────────
|
||||
@@ -200,13 +214,28 @@ def tool_list_wings():
|
||||
if not col:
|
||||
return _no_palace()
|
||||
wings = {}
|
||||
batch_size = 5000
|
||||
offset = 0
|
||||
try:
|
||||
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
|
||||
for m in all_meta:
|
||||
w = m.get("wing", "unknown")
|
||||
wings[w] = wings.get(w, 0) + 1
|
||||
except Exception:
|
||||
pass
|
||||
col.count() # verify collection is accessible
|
||||
except Exception as e:
|
||||
return {"wings": {}, "error": str(e)}
|
||||
while True:
|
||||
try:
|
||||
batch = col.get(include=["metadatas"], limit=batch_size, offset=offset)
|
||||
rows = batch["metadatas"]
|
||||
for m in rows:
|
||||
w = m.get("wing", "unknown")
|
||||
wings[w] = wings.get(w, 0) + 1
|
||||
offset += len(rows)
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
except Exception as e:
|
||||
return {
|
||||
"wings": wings,
|
||||
"error": f"Partial result, failed at offset {offset}: {str(e)}",
|
||||
"partial": True,
|
||||
}
|
||||
return {"wings": wings}
|
||||
|
||||
|
||||
@@ -215,16 +244,33 @@ def tool_list_rooms(wing: str = None):
|
||||
if not col:
|
||||
return _no_palace()
|
||||
rooms = {}
|
||||
batch_size = 5000
|
||||
offset = 0
|
||||
where = {"wing": wing} if wing else None
|
||||
try:
|
||||
kwargs = {"include": ["metadatas"], "limit": 10000}
|
||||
if wing:
|
||||
kwargs["where"] = {"wing": wing}
|
||||
all_meta = col.get(**kwargs)["metadatas"]
|
||||
for m in all_meta:
|
||||
r = m.get("room", "unknown")
|
||||
rooms[r] = rooms.get(r, 0) + 1
|
||||
except Exception:
|
||||
pass
|
||||
col.count() # verify collection is accessible
|
||||
except Exception as e:
|
||||
return {"wing": wing or "all", "rooms": {}, "error": str(e)}
|
||||
while True:
|
||||
try:
|
||||
kwargs = {"include": ["metadatas"], "limit": batch_size, "offset": offset}
|
||||
if where:
|
||||
kwargs["where"] = where
|
||||
batch = col.get(**kwargs)
|
||||
rows = batch["metadatas"]
|
||||
for m in rows:
|
||||
r = m.get("room", "unknown")
|
||||
rooms[r] = rooms.get(r, 0) + 1
|
||||
offset += len(rows)
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
except Exception as e:
|
||||
return {
|
||||
"wing": wing or "all",
|
||||
"rooms": rooms,
|
||||
"error": f"Partial result, failed at offset {offset}: {str(e)}",
|
||||
"partial": True,
|
||||
}
|
||||
return {"wing": wing or "all", "rooms": rooms}
|
||||
|
||||
|
||||
@@ -233,27 +279,58 @@ def tool_get_taxonomy():
|
||||
if not col:
|
||||
return _no_palace()
|
||||
taxonomy = {}
|
||||
batch_size = 5000
|
||||
offset = 0
|
||||
try:
|
||||
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
|
||||
for m in all_meta:
|
||||
w = m.get("wing", "unknown")
|
||||
r = m.get("room", "unknown")
|
||||
if w not in taxonomy:
|
||||
taxonomy[w] = {}
|
||||
taxonomy[w][r] = taxonomy[w].get(r, 0) + 1
|
||||
except Exception:
|
||||
pass
|
||||
col.count() # verify collection is accessible
|
||||
except Exception as e:
|
||||
return {"taxonomy": {}, "error": str(e)}
|
||||
while True:
|
||||
try:
|
||||
batch = col.get(include=["metadatas"], limit=batch_size, offset=offset)
|
||||
rows = batch["metadatas"]
|
||||
for m in rows:
|
||||
w = m.get("wing", "unknown")
|
||||
r = m.get("room", "unknown")
|
||||
if w not in taxonomy:
|
||||
taxonomy[w] = {}
|
||||
taxonomy[w][r] = taxonomy[w].get(r, 0) + 1
|
||||
offset += len(rows)
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
except Exception as e:
|
||||
return {
|
||||
"taxonomy": taxonomy,
|
||||
"error": f"Partial result, failed at offset {offset}: {str(e)}",
|
||||
"partial": True,
|
||||
}
|
||||
return {"taxonomy": taxonomy}
|
||||
|
||||
|
||||
def tool_search(query: str, limit: int = 5, wing: str = None, room: str = None):
|
||||
return search_memories(
|
||||
query,
|
||||
def tool_search(
|
||||
query: str, limit: int = 5, wing: str = None, room: str = None, context: str = None
|
||||
):
|
||||
# Mitigate system prompt contamination (Issue #333)
|
||||
sanitized = sanitize_query(query)
|
||||
result = search_memories(
|
||||
sanitized["clean_query"],
|
||||
palace_path=_config.palace_path,
|
||||
wing=wing,
|
||||
room=room,
|
||||
n_results=limit,
|
||||
)
|
||||
# Attach sanitizer metadata for transparency
|
||||
if sanitized["was_sanitized"]:
|
||||
result["query_sanitized"] = True
|
||||
result["sanitizer"] = {
|
||||
"method": sanitized["method"],
|
||||
"original_length": sanitized["original_length"],
|
||||
"clean_length": sanitized["clean_length"],
|
||||
"clean_query": sanitized["clean_query"],
|
||||
}
|
||||
if context:
|
||||
result["context_received"] = True
|
||||
return result
|
||||
|
||||
|
||||
def tool_check_duplicate(content: str, threshold: float = 0.9):
|
||||
@@ -734,14 +811,22 @@ TOOLS = {
|
||||
"handler": tool_graph_stats,
|
||||
},
|
||||
"mempalace_search": {
|
||||
"description": "Semantic search. Returns verbatim drawer content with similarity scores.",
|
||||
"description": "Semantic search. Returns verbatim drawer content with similarity scores. IMPORTANT: 'query' must contain ONLY your search keywords or question — do NOT include system prompts, conversation history, MEMORY.md content, or any context. Keep queries short (under 200 chars). Use 'context' for background information.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "What to search for"},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Short search query ONLY — keywords or a question. Do NOT include system prompts or conversation context. Max 200 chars recommended.",
|
||||
"maxLength": 500,
|
||||
},
|
||||
"limit": {"type": "integer", "description": "Max results (default 5)"},
|
||||
"wing": {"type": "string", "description": "Filter by wing (optional)"},
|
||||
"room": {"type": "string", "description": "Filter by room (optional)"},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": "Background context for the search (optional). This is NOT used for embedding — only for future re-ranking. Put conversation history or system prompt content here, NOT in query.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
mempalace migrate — Recover a palace created with a different ChromaDB version.
|
||||
|
||||
Reads documents and metadata directly from the palace's SQLite database
|
||||
(bypassing ChromaDB's API, which fails on version-mismatched palaces),
|
||||
then re-imports everything into a fresh palace using the currently installed
|
||||
ChromaDB version.
|
||||
|
||||
This fixes the 3.0.0 → 3.1.0 upgrade path where chromadb was downgraded
|
||||
from 1.5.x to 0.6.x, breaking the on-disk storage format.
|
||||
|
||||
Usage:
|
||||
mempalace migrate # migrate default palace
|
||||
mempalace migrate --palace /path/to/palace # migrate specific palace
|
||||
mempalace migrate --dry-run # show what would be migrated
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def extract_drawers_from_sqlite(db_path: str) -> list:
|
||||
"""Read all drawers directly from ChromaDB's SQLite, bypassing the API.
|
||||
|
||||
Works regardless of which ChromaDB version created the database.
|
||||
Returns list of dicts with 'id', 'document', and 'metadata' keys.
|
||||
"""
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
# Get all embedding IDs and their documents
|
||||
rows = conn.execute("""
|
||||
SELECT e.embedding_id,
|
||||
MAX(CASE WHEN em.key = 'chroma:document' THEN em.string_value END) as document
|
||||
FROM embeddings e
|
||||
JOIN embedding_metadata em ON em.id = e.id
|
||||
GROUP BY e.embedding_id
|
||||
""").fetchall()
|
||||
|
||||
drawers = []
|
||||
for row in rows:
|
||||
embedding_id = row["embedding_id"]
|
||||
document = row["document"]
|
||||
if not document:
|
||||
continue
|
||||
|
||||
# Get metadata for this embedding
|
||||
meta_rows = conn.execute(
|
||||
"""
|
||||
SELECT em.key, em.string_value, em.int_value, em.float_value, em.bool_value
|
||||
FROM embedding_metadata em
|
||||
JOIN embeddings e ON e.id = em.id
|
||||
WHERE e.embedding_id = ?
|
||||
AND em.key NOT LIKE 'chroma:%'
|
||||
""",
|
||||
(embedding_id,),
|
||||
).fetchall()
|
||||
|
||||
metadata = {}
|
||||
for mr in meta_rows:
|
||||
key = mr["key"]
|
||||
if mr["string_value"] is not None:
|
||||
metadata[key] = mr["string_value"]
|
||||
elif mr["int_value"] is not None:
|
||||
metadata[key] = mr["int_value"]
|
||||
elif mr["float_value"] is not None:
|
||||
metadata[key] = mr["float_value"]
|
||||
elif mr["bool_value"] is not None:
|
||||
metadata[key] = bool(mr["bool_value"])
|
||||
|
||||
drawers.append(
|
||||
{
|
||||
"id": embedding_id,
|
||||
"document": document,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
conn.close()
|
||||
return drawers
|
||||
|
||||
|
||||
def detect_chromadb_version(db_path: str) -> str:
|
||||
"""Detect which ChromaDB version created the database by checking schema."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
try:
|
||||
# 1.x has schema_str column in collections table
|
||||
cols = [r[1] for r in conn.execute("PRAGMA table_info(collections)").fetchall()]
|
||||
if "schema_str" in cols:
|
||||
return "1.x"
|
||||
# 0.6.x has embeddings_queue but no schema_str
|
||||
tables = [
|
||||
r[0]
|
||||
for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
|
||||
]
|
||||
if "embeddings_queue" in tables:
|
||||
return "0.6.x"
|
||||
return "unknown"
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def migrate(palace_path: str, dry_run: bool = False):
|
||||
"""Migrate a palace to the currently installed ChromaDB version."""
|
||||
import chromadb
|
||||
|
||||
palace_path = os.path.expanduser(palace_path)
|
||||
db_path = os.path.join(palace_path, "chroma.sqlite3")
|
||||
|
||||
if not os.path.isfile(db_path):
|
||||
print(f"\n No palace database found at {db_path}")
|
||||
return False
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(" MemPalace Migrate")
|
||||
print(f"{'=' * 60}\n")
|
||||
print(f" Palace: {palace_path}")
|
||||
print(f" Database: {db_path}")
|
||||
print(f" DB size: {os.path.getsize(db_path) / 1024 / 1024:.1f} MB")
|
||||
|
||||
# Detect version
|
||||
source_version = detect_chromadb_version(db_path)
|
||||
print(f" Source: ChromaDB {source_version}")
|
||||
print(f" Target: ChromaDB {chromadb.__version__}")
|
||||
|
||||
# Try reading with current chromadb first
|
||||
try:
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_collection("mempalace_drawers")
|
||||
count = col.count()
|
||||
print(f"\n Palace is already readable by chromadb {chromadb.__version__}.")
|
||||
print(f" {count} drawers found. No migration needed.")
|
||||
return True
|
||||
except Exception:
|
||||
print(f"\n Palace is NOT readable by chromadb {chromadb.__version__}.")
|
||||
print(" Extracting from SQLite directly...")
|
||||
|
||||
# Extract all drawers via raw SQL
|
||||
drawers = extract_drawers_from_sqlite(db_path)
|
||||
print(f" Extracted {len(drawers)} drawers from SQLite")
|
||||
|
||||
if not drawers:
|
||||
print(" Nothing to migrate.")
|
||||
return True
|
||||
|
||||
# Show summary
|
||||
wings = defaultdict(lambda: defaultdict(int))
|
||||
for d in drawers:
|
||||
w = d["metadata"].get("wing", "?")
|
||||
r = d["metadata"].get("room", "?")
|
||||
wings[w][r] += 1
|
||||
|
||||
print("\n Summary:")
|
||||
for wing, rooms in sorted(wings.items()):
|
||||
total = sum(rooms.values())
|
||||
print(f" WING: {wing} ({total} drawers)")
|
||||
for room, count in sorted(rooms.items(), key=lambda x: -x[1]):
|
||||
print(f" ROOM: {room:30} {count:5}")
|
||||
|
||||
if dry_run:
|
||||
print("\n DRY RUN — no changes made.")
|
||||
print(f" Would migrate {len(drawers)} drawers.")
|
||||
return True
|
||||
|
||||
# Backup the old palace
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = f"{palace_path}.pre-migrate.{timestamp}"
|
||||
print(f"\n Backing up to {backup_path}...")
|
||||
shutil.copytree(palace_path, backup_path)
|
||||
|
||||
# Build fresh palace in a temp directory (avoids chromadb reading old state)
|
||||
import tempfile
|
||||
|
||||
temp_palace = tempfile.mkdtemp(prefix="mempalace_migrate_")
|
||||
print(f" Creating fresh palace in {temp_palace}...")
|
||||
client = chromadb.PersistentClient(path=temp_palace)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
|
||||
# Re-import in batches
|
||||
batch_size = 500
|
||||
imported = 0
|
||||
for i in range(0, len(drawers), batch_size):
|
||||
batch = drawers[i : i + batch_size]
|
||||
col.add(
|
||||
ids=[d["id"] for d in batch],
|
||||
documents=[d["document"] for d in batch],
|
||||
metadatas=[d["metadata"] for d in batch],
|
||||
)
|
||||
imported += len(batch)
|
||||
print(f" Imported {imported}/{len(drawers)} drawers...")
|
||||
|
||||
# Verify before swapping
|
||||
final_count = col.count()
|
||||
del col
|
||||
del client
|
||||
|
||||
# Swap: remove old palace, move new one into place
|
||||
print(" Swapping old palace for migrated version...")
|
||||
shutil.rmtree(palace_path)
|
||||
shutil.move(temp_palace, palace_path)
|
||||
|
||||
print("\n Migration complete.")
|
||||
print(f" Drawers migrated: {final_count}")
|
||||
print(f" Backup at: {backup_path}")
|
||||
|
||||
if final_count != len(drawers):
|
||||
print(f" WARNING: Expected {len(drawers)}, got {final_count}")
|
||||
|
||||
print(f"\n{'=' * 60}\n")
|
||||
return True
|
||||
@@ -436,6 +436,16 @@ def process_file(
|
||||
print(f" [DRY RUN] {filepath.name} → room:{room} ({len(chunks)} drawers)")
|
||||
return len(chunks), room
|
||||
|
||||
# Purge stale drawers for this file before re-inserting the fresh chunks.
|
||||
# Converts modified-file re-mines from upsert-over-existing-IDs (which hits
|
||||
# hnswlib's thread-unsafe updatePoint path and can segfault on macOS ARM
|
||||
# with chromadb 0.6.3) into a clean delete+insert, bypassing the update
|
||||
# path entirely.
|
||||
try:
|
||||
collection.delete(where={"source_file": source_file})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
drawers_added = 0
|
||||
for chunk in chunks:
|
||||
added = add_drawer(
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
query_sanitizer.py — Mitigate system prompt contamination in search queries.
|
||||
|
||||
Problem: AI agents sometimes prepend system prompts (2000+ chars) to search queries.
|
||||
Embedding models represent the concatenated string as a single vector where the
|
||||
system prompt overwhelms the actual question (typically 10-50 chars), causing
|
||||
near-total retrieval failure (89.8% → 1.0% R@10). See Issue #333.
|
||||
|
||||
Approach: "Mitigation" (減災) — not perfect prevention, but prevents the cliff.
|
||||
|
||||
Expected recovery:
|
||||
Step 1 passthrough (≤200 chars) → no degradation, ~89.8%
|
||||
Step 2 question extraction (?found) → near-full recovery, ~85-89%
|
||||
Step 3 tail sentence extraction → moderate recovery, ~80-89%
|
||||
Step 4 tail truncation (fallback) → minimum viable, ~70-80%
|
||||
|
||||
Without sanitizer: 1.0% (catastrophic silent failure)
|
||||
Worst case with sanitizer: ~70-80% (survivable)
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("mempalace_mcp")
|
||||
|
||||
# --- Constants ---
|
||||
MAX_QUERY_LENGTH = 500 # Above this, system prompt almost certainly dominates
|
||||
SAFE_QUERY_LENGTH = 200 # Below this, query is almost certainly clean
|
||||
MIN_QUERY_LENGTH = 10 # Extracted result shorter than this = extraction failed
|
||||
|
||||
# Sentence splitter: split on . ! ? (including fullwidth) and newlines
|
||||
_SENTENCE_SPLIT = re.compile(r"[.!?。!?\n]+")
|
||||
|
||||
# Question detector: ends with ? or ? (possibly with trailing whitespace/quotes)
|
||||
_QUESTION_MARK = re.compile(r'[??]\s*["\']?\s*$')
|
||||
|
||||
|
||||
def sanitize_query(raw_query: str) -> dict:
|
||||
"""
|
||||
Extract the actual search intent from a potentially contaminated query.
|
||||
|
||||
Args:
|
||||
raw_query: The raw query string from the AI agent, possibly containing
|
||||
system prompt content prepended to the actual question.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
clean_query (str): The sanitized query to use for embedding search
|
||||
was_sanitized (bool): Whether any sanitization was applied
|
||||
original_length (int): Length of the raw input
|
||||
clean_length (int): Length of the sanitized output
|
||||
method (str): Which extraction method was used
|
||||
- "passthrough": query was short enough, no action taken
|
||||
- "question_extraction": found and extracted a question sentence
|
||||
- "tail_sentence": extracted the last meaningful sentence
|
||||
- "tail_truncation": fallback — took the last MAX_QUERY_LENGTH chars
|
||||
"""
|
||||
if not raw_query or not raw_query.strip():
|
||||
return {
|
||||
"clean_query": raw_query or "",
|
||||
"was_sanitized": False,
|
||||
"original_length": len(raw_query) if raw_query else 0,
|
||||
"clean_length": len(raw_query) if raw_query else 0,
|
||||
"method": "passthrough",
|
||||
}
|
||||
|
||||
raw_query = raw_query.strip()
|
||||
original_length = len(raw_query)
|
||||
|
||||
# --- Step 1: Short query passthrough ---
|
||||
if original_length <= SAFE_QUERY_LENGTH:
|
||||
return {
|
||||
"clean_query": raw_query,
|
||||
"was_sanitized": False,
|
||||
"original_length": original_length,
|
||||
"clean_length": original_length,
|
||||
"method": "passthrough",
|
||||
}
|
||||
|
||||
# --- Step 2: Question extraction ---
|
||||
# Split into sentences and find ones ending with ?
|
||||
sentences = [s.strip() for s in _SENTENCE_SPLIT.split(raw_query) if s.strip()]
|
||||
|
||||
# Also split on newlines to catch questions on their own line
|
||||
all_segments = []
|
||||
for s in raw_query.split("\n"):
|
||||
s = s.strip()
|
||||
if s:
|
||||
all_segments.append(s)
|
||||
|
||||
# Look for question marks in segments (prefer later ones = more likely the actual query)
|
||||
question_sentences = []
|
||||
for seg in reversed(all_segments):
|
||||
if _QUESTION_MARK.search(seg):
|
||||
question_sentences.append(seg)
|
||||
|
||||
if not question_sentences:
|
||||
# Also check the sentence-split results
|
||||
for sent in reversed(sentences):
|
||||
if "?" in sent or "?" in sent:
|
||||
question_sentences.append(sent)
|
||||
|
||||
if question_sentences:
|
||||
# Take the last (most recent) question found
|
||||
candidate = question_sentences[0].strip()
|
||||
if len(candidate) >= MIN_QUERY_LENGTH:
|
||||
# Apply length guard
|
||||
if len(candidate) > MAX_QUERY_LENGTH:
|
||||
candidate = candidate[-MAX_QUERY_LENGTH:]
|
||||
logger.warning(
|
||||
"Query sanitized: %d → %d chars (method=question_extraction)",
|
||||
original_length,
|
||||
len(candidate),
|
||||
)
|
||||
return {
|
||||
"clean_query": candidate,
|
||||
"was_sanitized": True,
|
||||
"original_length": original_length,
|
||||
"clean_length": len(candidate),
|
||||
"method": "question_extraction",
|
||||
}
|
||||
|
||||
# --- Step 3: Tail sentence extraction ---
|
||||
# System prompts are prepended, so the actual query is near the end.
|
||||
# Walk backwards through segments to find the last meaningful sentence.
|
||||
for seg in reversed(all_segments):
|
||||
seg = seg.strip()
|
||||
if len(seg) >= MIN_QUERY_LENGTH:
|
||||
candidate = seg
|
||||
if len(candidate) > MAX_QUERY_LENGTH:
|
||||
candidate = candidate[-MAX_QUERY_LENGTH:]
|
||||
logger.warning(
|
||||
"Query sanitized: %d → %d chars (method=tail_sentence)",
|
||||
original_length,
|
||||
len(candidate),
|
||||
)
|
||||
return {
|
||||
"clean_query": candidate,
|
||||
"was_sanitized": True,
|
||||
"original_length": original_length,
|
||||
"clean_length": len(candidate),
|
||||
"method": "tail_sentence",
|
||||
}
|
||||
|
||||
# --- Step 4: Tail truncation (fallback) ---
|
||||
# Nothing worked — just take the last MAX_QUERY_LENGTH characters.
|
||||
candidate = raw_query[-MAX_QUERY_LENGTH:].strip()
|
||||
logger.warning(
|
||||
"Query sanitized: %d → %d chars (method=tail_truncation)", original_length, len(candidate)
|
||||
)
|
||||
return {
|
||||
"clean_query": candidate,
|
||||
"was_sanitized": True,
|
||||
"original_length": original_length,
|
||||
"clean_length": len(candidate),
|
||||
"method": "tail_truncation",
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
repair.py — Scan, prune corrupt entries, and rebuild HNSW index
|
||||
================================================================
|
||||
|
||||
When ChromaDB's HNSW index accumulates duplicate entries (from repeated
|
||||
add() calls with the same ID), link_lists.bin can grow unbounded —
|
||||
terabytes on large palaces — eventually causing segfaults.
|
||||
|
||||
This module provides three operations:
|
||||
|
||||
scan — find every corrupt/unfetchable ID in the palace
|
||||
prune — delete only the corrupt IDs (surgical)
|
||||
rebuild — extract all drawers, delete the collection, recreate with
|
||||
correct HNSW settings, and upsert everything back
|
||||
|
||||
The rebuild backs up ONLY chroma.sqlite3 (the source of truth), not the
|
||||
full palace directory — so it works even when link_lists.bin is bloated.
|
||||
|
||||
Usage (standalone):
|
||||
python -m mempalace.repair scan [--wing X]
|
||||
python -m mempalace.repair prune --confirm
|
||||
python -m mempalace.repair rebuild
|
||||
|
||||
Usage (from CLI):
|
||||
mempalace repair
|
||||
mempalace repair-scan [--wing X]
|
||||
mempalace repair-prune --confirm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import chromadb
|
||||
|
||||
|
||||
COLLECTION_NAME = "mempalace_drawers"
|
||||
|
||||
|
||||
def _get_palace_path():
|
||||
"""Resolve palace path from config."""
|
||||
try:
|
||||
from .config import MempalaceConfig
|
||||
|
||||
return MempalaceConfig().palace_path
|
||||
except Exception:
|
||||
default = os.path.join(os.path.expanduser("~"), ".mempalace", "palace")
|
||||
return default
|
||||
|
||||
|
||||
def _paginate_ids(col, where=None):
|
||||
"""Pull all IDs in a collection using pagination."""
|
||||
ids = []
|
||||
page = 1000
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
r = col.get(where=where, include=[], limit=page, offset=offset)
|
||||
except Exception:
|
||||
try:
|
||||
r = col.get(where=where, include=[], limit=page)
|
||||
new_ids = [i for i in r["ids"] if i not in set(ids)]
|
||||
if not new_ids:
|
||||
break
|
||||
ids.extend(new_ids)
|
||||
offset += len(new_ids)
|
||||
continue
|
||||
except Exception:
|
||||
break
|
||||
n = len(r["ids"]) if r["ids"] else 0
|
||||
if n == 0:
|
||||
break
|
||||
ids.extend(r["ids"])
|
||||
offset += n
|
||||
if n < page:
|
||||
break
|
||||
return ids
|
||||
|
||||
|
||||
def scan_palace(palace_path=None, only_wing=None):
|
||||
"""Scan the palace for corrupt/unfetchable IDs.
|
||||
|
||||
Probes in batches of 100, falls back to per-ID on failure.
|
||||
Writes corrupt_ids.txt to the palace directory for the prune step.
|
||||
|
||||
Returns (good_set, bad_set).
|
||||
"""
|
||||
palace_path = palace_path or _get_palace_path()
|
||||
print(f"\n Palace: {palace_path}")
|
||||
print(" Loading...")
|
||||
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_collection(COLLECTION_NAME)
|
||||
|
||||
where = {"wing": only_wing} if only_wing else None
|
||||
total = col.count()
|
||||
print(f" Collection: {COLLECTION_NAME}, total: {total:,}")
|
||||
if only_wing:
|
||||
print(f" Scanning wing: {only_wing}")
|
||||
|
||||
print("\n Step 1: listing all IDs...")
|
||||
t0 = time.time()
|
||||
all_ids = _paginate_ids(col, where=where)
|
||||
print(f" Found {len(all_ids):,} IDs in {time.time() - t0:.1f}s\n")
|
||||
|
||||
if not all_ids:
|
||||
print(" Nothing to scan.")
|
||||
return set(), set()
|
||||
|
||||
print(" Step 2: probing each ID (batches of 100)...")
|
||||
t0 = time.time()
|
||||
good_set = set()
|
||||
bad_set = set()
|
||||
batch = 100
|
||||
|
||||
for i in range(0, len(all_ids), batch):
|
||||
chunk = all_ids[i : i + batch]
|
||||
try:
|
||||
r = col.get(ids=chunk, include=["documents"])
|
||||
for got in r["ids"]:
|
||||
good_set.add(got)
|
||||
for mid in chunk:
|
||||
if mid not in good_set:
|
||||
bad_set.add(mid)
|
||||
except Exception:
|
||||
for sid in chunk:
|
||||
try:
|
||||
r = col.get(ids=[sid], include=["documents"])
|
||||
if r["ids"]:
|
||||
good_set.add(sid)
|
||||
else:
|
||||
bad_set.add(sid)
|
||||
except Exception:
|
||||
bad_set.add(sid)
|
||||
|
||||
if (i // batch) % 50 == 0:
|
||||
elapsed = time.time() - t0
|
||||
rate = (i + batch) / max(elapsed, 0.01)
|
||||
eta = (len(all_ids) - i - batch) / max(rate, 0.01)
|
||||
print(
|
||||
f" {i + batch:>6}/{len(all_ids):>6} "
|
||||
f"good={len(good_set):>6} bad={len(bad_set):>6} "
|
||||
f"eta={eta:.0f}s"
|
||||
)
|
||||
|
||||
print(f"\n Scan complete in {time.time() - t0:.1f}s")
|
||||
print(f" GOOD: {len(good_set):,}")
|
||||
print(f" BAD: {len(bad_set):,} ({len(bad_set) / max(len(all_ids), 1) * 100:.1f}%)")
|
||||
|
||||
bad_file = os.path.join(palace_path, "corrupt_ids.txt")
|
||||
with open(bad_file, "w") as f:
|
||||
for bid in sorted(bad_set):
|
||||
f.write(bid + "\n")
|
||||
print(f"\n Bad IDs written to: {bad_file}")
|
||||
return good_set, bad_set
|
||||
|
||||
|
||||
def prune_corrupt(palace_path=None, confirm=False):
|
||||
"""Delete corrupt IDs listed in corrupt_ids.txt."""
|
||||
palace_path = palace_path or _get_palace_path()
|
||||
bad_file = os.path.join(palace_path, "corrupt_ids.txt")
|
||||
|
||||
if not os.path.exists(bad_file):
|
||||
print(" No corrupt_ids.txt found — run scan first.")
|
||||
return
|
||||
|
||||
with open(bad_file) as f:
|
||||
bad_ids = [line.strip() for line in f if line.strip()]
|
||||
print(f" {len(bad_ids):,} corrupt IDs queued for deletion")
|
||||
|
||||
if not confirm:
|
||||
print("\n DRY RUN — no deletions performed.")
|
||||
print(" Re-run with --confirm to actually delete.")
|
||||
return
|
||||
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_collection(COLLECTION_NAME)
|
||||
before = col.count()
|
||||
print(f" Collection size before: {before:,}")
|
||||
|
||||
batch = 100
|
||||
deleted = 0
|
||||
failed = 0
|
||||
for i in range(0, len(bad_ids), batch):
|
||||
chunk = bad_ids[i : i + batch]
|
||||
try:
|
||||
col.delete(ids=chunk)
|
||||
deleted += len(chunk)
|
||||
except Exception:
|
||||
for sid in chunk:
|
||||
try:
|
||||
col.delete(ids=[sid])
|
||||
deleted += 1
|
||||
except Exception:
|
||||
failed += 1
|
||||
if (i // batch) % 20 == 0:
|
||||
print(f" deleted {deleted}/{len(bad_ids)} (failed: {failed})")
|
||||
|
||||
after = col.count()
|
||||
print(f"\n Deleted: {deleted:,}")
|
||||
print(f" Failed: {failed:,}")
|
||||
print(f" Collection size: {before:,} → {after:,}")
|
||||
|
||||
|
||||
def rebuild_index(palace_path=None):
|
||||
"""Rebuild the HNSW index from scratch.
|
||||
|
||||
1. Extract all drawers via ChromaDB get()
|
||||
2. Back up ONLY chroma.sqlite3 (not the bloated HNSW files)
|
||||
3. Delete and recreate the collection with hnsw:space=cosine
|
||||
4. Upsert all drawers back
|
||||
"""
|
||||
palace_path = palace_path or _get_palace_path()
|
||||
|
||||
if not os.path.isdir(palace_path):
|
||||
print(f"\n No palace found at {palace_path}")
|
||||
return
|
||||
|
||||
print(f"\n{'=' * 55}")
|
||||
print(" MemPalace Repair — Index Rebuild")
|
||||
print(f"{'=' * 55}\n")
|
||||
print(f" Palace: {palace_path}")
|
||||
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
try:
|
||||
col = client.get_collection(COLLECTION_NAME)
|
||||
total = col.count()
|
||||
except Exception as e:
|
||||
print(f" Error reading palace: {e}")
|
||||
print(" Palace may need to be re-mined from source files.")
|
||||
return
|
||||
|
||||
print(f" Drawers found: {total}")
|
||||
|
||||
if total == 0:
|
||||
print(" Nothing to repair.")
|
||||
return
|
||||
|
||||
# Extract all drawers in batches
|
||||
print("\n Extracting drawers...")
|
||||
batch_size = 5000
|
||||
all_ids = []
|
||||
all_docs = []
|
||||
all_metas = []
|
||||
offset = 0
|
||||
while offset < total:
|
||||
batch = col.get(limit=batch_size, offset=offset, include=["documents", "metadatas"])
|
||||
if not batch["ids"]:
|
||||
break
|
||||
all_ids.extend(batch["ids"])
|
||||
all_docs.extend(batch["documents"])
|
||||
all_metas.extend(batch["metadatas"])
|
||||
offset += len(batch["ids"])
|
||||
print(f" Extracted {len(all_ids)} drawers")
|
||||
|
||||
# Back up ONLY the SQLite database, not the bloated HNSW files
|
||||
sqlite_path = os.path.join(palace_path, "chroma.sqlite3")
|
||||
if os.path.exists(sqlite_path):
|
||||
backup_path = sqlite_path + ".backup"
|
||||
print(f" Backing up chroma.sqlite3 ({os.path.getsize(sqlite_path) / 1e6:.0f} MB)...")
|
||||
shutil.copy2(sqlite_path, backup_path)
|
||||
print(f" Backup: {backup_path}")
|
||||
|
||||
# Rebuild with correct HNSW settings
|
||||
print(" Rebuilding collection with hnsw:space=cosine...")
|
||||
client.delete_collection(COLLECTION_NAME)
|
||||
new_col = client.create_collection(COLLECTION_NAME, metadata={"hnsw:space": "cosine"})
|
||||
|
||||
filed = 0
|
||||
for i in range(0, len(all_ids), batch_size):
|
||||
batch_ids = all_ids[i : i + batch_size]
|
||||
batch_docs = all_docs[i : i + batch_size]
|
||||
batch_metas = all_metas[i : i + batch_size]
|
||||
new_col.upsert(documents=batch_docs, ids=batch_ids, metadatas=batch_metas)
|
||||
filed += len(batch_ids)
|
||||
print(f" Re-filed {filed}/{len(all_ids)} drawers...")
|
||||
|
||||
print(f"\n Repair complete. {filed} drawers rebuilt.")
|
||||
print(" HNSW index is now clean with cosine distance metric.")
|
||||
print(f"\n{'=' * 55}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser(description="MemPalace repair tools")
|
||||
p.add_argument("command", choices=["scan", "prune", "rebuild"])
|
||||
p.add_argument("--palace", default=None, help="Palace directory path")
|
||||
p.add_argument("--wing", default=None, help="Scan only this wing")
|
||||
p.add_argument("--confirm", action="store_true", help="Actually delete corrupt IDs")
|
||||
args = p.parse_args()
|
||||
|
||||
path = os.path.expanduser(args.palace) if args.palace else None
|
||||
|
||||
if args.command == "scan":
|
||||
scan_palace(palace_path=path, only_wing=args.wing)
|
||||
elif args.command == "prune":
|
||||
prune_corrupt(palace_path=path, confirm=args.confirm)
|
||||
elif args.command == "rebuild":
|
||||
rebuild_index(palace_path=path)
|
||||
+5
-1
@@ -54,11 +54,15 @@ packages = ["mempalace"]
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py39"
|
||||
extend-exclude = ["benchmarks"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W"]
|
||||
select = ["E", "F", "W", "C901"]
|
||||
ignore = ["E501"]
|
||||
|
||||
[tool.ruff.lint.mccabe]
|
||||
max-complexity = 25
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Tests for mempalace.dedup — near-duplicate drawer detection and removal."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
from mempalace import dedup
|
||||
|
||||
|
||||
# ── get_source_groups ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_source_groups_basic():
|
||||
col = MagicMock()
|
||||
col.count.return_value = 5
|
||||
col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2", "d3", "d4", "d5"],
|
||||
"metadatas": [
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
groups = dedup.get_source_groups(col, min_count=5)
|
||||
assert "a.txt" in groups
|
||||
assert len(groups["a.txt"]) == 5
|
||||
|
||||
|
||||
def test_get_source_groups_below_min():
|
||||
col = MagicMock()
|
||||
col.count.return_value = 2
|
||||
col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2"],
|
||||
"metadatas": [
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
groups = dedup.get_source_groups(col, min_count=5)
|
||||
assert len(groups) == 0
|
||||
|
||||
|
||||
def test_get_source_groups_source_filter():
|
||||
col = MagicMock()
|
||||
col.count.return_value = 6
|
||||
col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2", "d3", "d4", "d5", "d6"],
|
||||
"metadatas": [
|
||||
{"source_file": "project_a.txt"},
|
||||
{"source_file": "project_a.txt"},
|
||||
{"source_file": "project_a.txt"},
|
||||
{"source_file": "project_a.txt"},
|
||||
{"source_file": "project_a.txt"},
|
||||
{"source_file": "other.txt"},
|
||||
],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
groups = dedup.get_source_groups(col, min_count=5, source_pattern="project_a")
|
||||
assert "project_a.txt" in groups
|
||||
assert "other.txt" not in groups
|
||||
|
||||
|
||||
def test_get_source_groups_wing_filter():
|
||||
col = MagicMock()
|
||||
col.count.return_value = 5
|
||||
col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2", "d3", "d4", "d5"],
|
||||
"metadatas": [
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
dedup.get_source_groups(col, min_count=5, wing="my_wing")
|
||||
# Verify where filter was passed
|
||||
first_call = col.get.call_args_list[0]
|
||||
assert first_call.kwargs.get("where") == {"wing": "my_wing"}
|
||||
|
||||
|
||||
def test_get_source_groups_missing_source_file():
|
||||
col = MagicMock()
|
||||
col.count.return_value = 5
|
||||
col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2", "d3", "d4", "d5"],
|
||||
"metadatas": [{}, {}, {}, {}, {}],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
groups = dedup.get_source_groups(col, min_count=5)
|
||||
assert "unknown" in groups
|
||||
|
||||
|
||||
# ── dedup_source_group ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_dedup_source_group_all_unique():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": ["long document one content here", "different document two here"],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
col.query.return_value = {
|
||||
"ids": [["d1"]],
|
||||
"distances": [[0.8]], # far apart = unique
|
||||
}
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
|
||||
assert len(kept) == 2
|
||||
assert len(deleted) == 0
|
||||
|
||||
|
||||
def test_dedup_source_group_with_duplicate():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": [
|
||||
"long document content that is fairly long",
|
||||
"long document content that is fairly long",
|
||||
],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
col.query.return_value = {
|
||||
"ids": [["d1"]],
|
||||
"distances": [[0.05]], # very close = duplicate
|
||||
}
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
|
||||
assert len(kept) == 1
|
||||
assert len(deleted) == 1
|
||||
|
||||
|
||||
def test_dedup_source_group_short_docs_deleted():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": ["long enough document to keep in the palace", "tiny"],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
|
||||
assert "d2" in deleted # too short
|
||||
|
||||
|
||||
def test_dedup_source_group_empty_doc_deleted():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": ["real document content here that is long enough", None],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
|
||||
assert "d2" in deleted
|
||||
|
||||
|
||||
def test_dedup_source_group_live_deletes():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": ["long document content here enough", "long document content here enough"],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
col.query.return_value = {
|
||||
"ids": [["d1"]],
|
||||
"distances": [[0.05]],
|
||||
}
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=False)
|
||||
col.delete.assert_called_once()
|
||||
|
||||
|
||||
def test_dedup_source_group_query_failure_keeps():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {
|
||||
"ids": ["d1", "d2"],
|
||||
"documents": [
|
||||
"long document one content here enough",
|
||||
"long document two content here enough",
|
||||
],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "a"}],
|
||||
}
|
||||
col.query.side_effect = Exception("query failed")
|
||||
kept, deleted = dedup.dedup_source_group(col, ["d1", "d2"], threshold=0.15, dry_run=True)
|
||||
assert len(kept) == 2 # both kept on error
|
||||
|
||||
|
||||
# ── show_stats ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.dedup.chromadb")
|
||||
def test_show_stats(mock_chromadb, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 5
|
||||
mock_col.get.side_effect = [
|
||||
{
|
||||
"ids": ["d1", "d2", "d3", "d4", "d5"],
|
||||
"metadatas": [
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
{"source_file": "a.txt"},
|
||||
],
|
||||
},
|
||||
{"ids": []},
|
||||
]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
dedup.show_stats(palace_path=str(tmp_path)) # should not raise
|
||||
|
||||
|
||||
# ── dedup_palace ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.dedup.dedup_source_group")
|
||||
@patch("mempalace.dedup.get_source_groups")
|
||||
@patch("mempalace.dedup.chromadb")
|
||||
def test_dedup_palace_dry_run(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 10
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
mock_groups.return_value = {"a.txt": ["d1", "d2", "d3", "d4", "d5"]}
|
||||
mock_dedup_group.return_value = (["d1", "d2", "d3"], ["d4", "d5"])
|
||||
|
||||
dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True)
|
||||
mock_dedup_group.assert_called_once()
|
||||
|
||||
|
||||
@patch("mempalace.dedup.dedup_source_group")
|
||||
@patch("mempalace.dedup.get_source_groups")
|
||||
@patch("mempalace.dedup.chromadb")
|
||||
def test_dedup_palace_with_wing(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 10
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
mock_groups.return_value = {}
|
||||
dedup.dedup_palace(palace_path=str(tmp_path), wing="test_wing", dry_run=True)
|
||||
mock_groups.assert_called_once_with(mock_col, 5, None, wing="test_wing")
|
||||
|
||||
|
||||
@patch("mempalace.dedup.dedup_source_group")
|
||||
@patch("mempalace.dedup.get_source_groups")
|
||||
@patch("mempalace.dedup.chromadb")
|
||||
def test_dedup_palace_no_groups(mock_chromadb, mock_groups, mock_dedup_group, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 3
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
mock_groups.return_value = {}
|
||||
dedup.dedup_palace(palace_path=str(tmp_path), dry_run=True)
|
||||
mock_dedup_group.assert_not_called()
|
||||
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Tests for query_sanitizer.py — system prompt contamination mitigation (#333).
|
||||
|
||||
Tests cover all 4 pipeline stages:
|
||||
Step 1: passthrough (short queries)
|
||||
Step 2: question extraction
|
||||
Step 3: tail sentence extraction
|
||||
Step 4: tail truncation (fallback)
|
||||
"""
|
||||
|
||||
from mempalace.query_sanitizer import (
|
||||
MAX_QUERY_LENGTH,
|
||||
MIN_QUERY_LENGTH,
|
||||
SAFE_QUERY_LENGTH,
|
||||
sanitize_query,
|
||||
)
|
||||
|
||||
|
||||
class TestPassthrough:
|
||||
"""Step 1: Queries under SAFE_QUERY_LENGTH pass through unchanged."""
|
||||
|
||||
def test_short_query_unchanged(self):
|
||||
result = sanitize_query("What is Rust error handling?")
|
||||
assert result["clean_query"] == "What is Rust error handling?"
|
||||
assert result["was_sanitized"] is False
|
||||
assert result["method"] == "passthrough"
|
||||
|
||||
def test_empty_query(self):
|
||||
result = sanitize_query("")
|
||||
assert result["clean_query"] == ""
|
||||
assert result["was_sanitized"] is False
|
||||
assert result["method"] == "passthrough"
|
||||
|
||||
def test_none_query(self):
|
||||
result = sanitize_query(None)
|
||||
assert result["was_sanitized"] is False
|
||||
assert result["method"] == "passthrough"
|
||||
|
||||
def test_exactly_safe_length(self):
|
||||
query = "a" * SAFE_QUERY_LENGTH
|
||||
result = sanitize_query(query)
|
||||
assert result["was_sanitized"] is False
|
||||
assert result["method"] == "passthrough"
|
||||
|
||||
def test_one_over_safe_length_triggers_sanitization(self):
|
||||
query = "a" * (SAFE_QUERY_LENGTH + 1)
|
||||
result = sanitize_query(query)
|
||||
# Will go through sanitization pipeline (may or may not change the query)
|
||||
assert result["original_length"] == SAFE_QUERY_LENGTH + 1
|
||||
|
||||
|
||||
class TestQuestionExtraction:
|
||||
"""Step 2: Extract question sentences (ending with ?)."""
|
||||
|
||||
def test_question_at_end_of_long_text(self):
|
||||
system_prompt = "You are a helpful assistant. " * 50 # ~1400 chars
|
||||
query = system_prompt + "What is the best way to handle errors in Rust?"
|
||||
result = sanitize_query(query)
|
||||
assert result["was_sanitized"] is True
|
||||
assert "error" in result["clean_query"].lower() or "Rust" in result["clean_query"]
|
||||
assert result["method"] == "question_extraction"
|
||||
|
||||
def test_japanese_question_mark(self):
|
||||
system_prompt = "You are a helpful assistant. " * 50
|
||||
query = system_prompt + "Rustのエラーハンドリング方法は?"
|
||||
result = sanitize_query(query)
|
||||
assert result["was_sanitized"] is True
|
||||
assert "Rust" in result["clean_query"] or "エラー" in result["clean_query"]
|
||||
assert result["method"] == "question_extraction"
|
||||
|
||||
def test_multiple_questions_takes_last(self):
|
||||
system_prompt = "You are a helpful assistant. " * 50
|
||||
query = system_prompt + "What is Python?\nHow does Rust handle errors?"
|
||||
result = sanitize_query(query)
|
||||
assert "Rust" in result["clean_query"] or "error" in result["clean_query"].lower()
|
||||
|
||||
def test_question_in_system_prompt_ignored_when_real_question_exists(self):
|
||||
# System prompt contains a question, but real query also has one
|
||||
system_prompt = "Are you ready to help? " * 30 + "\n"
|
||||
real_query = "What databases does MemPalace support?"
|
||||
query = system_prompt + real_query
|
||||
result = sanitize_query(query)
|
||||
assert result["was_sanitized"] is True
|
||||
assert "MemPalace" in result["clean_query"] or "database" in result["clean_query"].lower()
|
||||
|
||||
|
||||
class TestTailSentence:
|
||||
"""Step 3: Extract the last meaningful sentence when no question mark found."""
|
||||
|
||||
def test_command_style_query(self):
|
||||
system_prompt = "You are a helpful assistant. " * 50
|
||||
query = system_prompt + "Show me all Rust error handling patterns"
|
||||
result = sanitize_query(query)
|
||||
assert result["was_sanitized"] is True
|
||||
assert "Rust" in result["clean_query"] or "error" in result["clean_query"].lower()
|
||||
assert result["method"] in ("tail_sentence", "question_extraction")
|
||||
|
||||
def test_keyword_style_query(self):
|
||||
system_prompt = "System configuration loaded. " * 60
|
||||
query = system_prompt + "\nMemPalace ChromaDB integration setup"
|
||||
result = sanitize_query(query)
|
||||
assert result["was_sanitized"] is True
|
||||
assert "MemPalace" in result["clean_query"] or "ChromaDB" in result["clean_query"]
|
||||
|
||||
|
||||
class TestTailTruncation:
|
||||
"""Step 4: Fallback — take the last MAX_QUERY_LENGTH characters."""
|
||||
|
||||
def test_single_long_line_no_sentences(self):
|
||||
# Short lines only — no segment reaches MIN_QUERY_LENGTH; fallback truncates tail
|
||||
filler = "\n".join(["ab"] * 200)
|
||||
result = sanitize_query(filler)
|
||||
assert result["was_sanitized"] is True
|
||||
assert len(result["clean_query"]) <= MAX_QUERY_LENGTH
|
||||
assert result["method"] == "tail_truncation"
|
||||
|
||||
def test_truncation_preserves_tail(self):
|
||||
filler = "x" * 1000 + "IMPORTANT_QUERY_CONTENT"
|
||||
result = sanitize_query(filler)
|
||||
assert "IMPORTANT_QUERY_CONTENT" in result["clean_query"]
|
||||
|
||||
|
||||
class TestLengthGuards:
|
||||
"""Verify output length constraints."""
|
||||
|
||||
def test_output_never_exceeds_max(self):
|
||||
# Very long question sentence
|
||||
long_question = "a" * 1000 + "?"
|
||||
system_prompt = "Context. " * 100
|
||||
query = system_prompt + long_question
|
||||
result = sanitize_query(query)
|
||||
assert len(result["clean_query"]) <= MAX_QUERY_LENGTH
|
||||
|
||||
def test_extraction_too_short_falls_through(self):
|
||||
# Question mark found but the sentence is too short
|
||||
system_prompt = "You are helpful. " * 50
|
||||
query = system_prompt + "\nOK?"
|
||||
result = sanitize_query(query)
|
||||
# "OK?" is only 3 chars < MIN_QUERY_LENGTH, should fall through
|
||||
assert result["was_sanitized"] is True
|
||||
|
||||
|
||||
class TestMetadata:
|
||||
"""Verify sanitizer metadata is correct."""
|
||||
|
||||
def test_original_length_preserved(self):
|
||||
system_prompt = "You are a helpful assistant. " * 50
|
||||
query = system_prompt + "What is Rust?"
|
||||
result = sanitize_query(query)
|
||||
assert result["original_length"] == len(query.strip())
|
||||
|
||||
def test_clean_length_matches_clean_query(self):
|
||||
system_prompt = "You are a helpful assistant. " * 50
|
||||
query = system_prompt + "What is Rust?"
|
||||
result = sanitize_query(query)
|
||||
assert result["clean_length"] == len(result["clean_query"])
|
||||
|
||||
def test_sanitized_flag_true_when_changed(self):
|
||||
system_prompt = "You are a helpful assistant. " * 50
|
||||
query = system_prompt + "What is Rust?"
|
||||
result = sanitize_query(query)
|
||||
assert result["was_sanitized"] is True
|
||||
|
||||
def test_sanitized_flag_false_when_unchanged(self):
|
||||
result = sanitize_query("Short query")
|
||||
assert result["was_sanitized"] is False
|
||||
|
||||
|
||||
class TestRealWorldScenarios:
|
||||
"""Simulate realistic system prompt contamination patterns."""
|
||||
|
||||
def test_mempalace_wakeup_prepended(self):
|
||||
"""Simulates mempalace wake-up output prepended to a query."""
|
||||
wakeup = (
|
||||
"MemPalace loaded. Wings: technical, emotions, identity. "
|
||||
"Rooms: chromadb-setup, error-handling, project-planning. "
|
||||
"Total drawers: 234. Knowledge graph: 89 entities, 156 triples. "
|
||||
"AAAK dialect active. Protocol: verify before responding. "
|
||||
) * 5 # ~1000 chars
|
||||
real_query = "How did we decide on the database architecture?"
|
||||
query = wakeup + real_query
|
||||
result = sanitize_query(query)
|
||||
assert result["was_sanitized"] is True
|
||||
assert len(result["clean_query"]) <= MAX_QUERY_LENGTH
|
||||
# Should recover something meaningful
|
||||
assert len(result["clean_query"]) >= MIN_QUERY_LENGTH
|
||||
|
||||
def test_memory_md_prepended(self):
|
||||
"""Simulates MEMORY.md content prepended to a query."""
|
||||
memory_md = (
|
||||
"# Project Memory\n"
|
||||
"## Architecture Decisions\n"
|
||||
"- Use ChromaDB for vector storage\n"
|
||||
"- MCP protocol for tool integration\n"
|
||||
"- AAAK compression for efficient storage\n"
|
||||
) * 10 # ~750 chars
|
||||
real_query = "What were the performance benchmarks for the search system?"
|
||||
query = memory_md + "\n" + real_query
|
||||
result = sanitize_query(query)
|
||||
assert result["was_sanitized"] is True
|
||||
assert result["method"] in ("question_extraction", "tail_sentence")
|
||||
|
||||
def test_2000_char_system_prompt_with_question(self):
|
||||
"""The exact scenario from Issue #333 — 2000 chars prepended."""
|
||||
system_prompt = "You are an AI assistant with access to tools. " * 45 # ~2000 chars
|
||||
real_query = "What is the status of the MemPalace project?"
|
||||
query = system_prompt + real_query
|
||||
result = sanitize_query(query)
|
||||
assert result["was_sanitized"] is True
|
||||
assert result["original_length"] > 2000
|
||||
assert result["clean_length"] <= MAX_QUERY_LENGTH
|
||||
assert result["method"] == "question_extraction"
|
||||
@@ -0,0 +1,266 @@
|
||||
"""Tests for mempalace.repair — scan, prune, and rebuild HNSW index."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
from mempalace import repair
|
||||
|
||||
|
||||
# ── _get_palace_path ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.repair.MempalaceConfig", create=True)
|
||||
def test_get_palace_path_from_config(mock_config_cls):
|
||||
mock_config_cls.return_value.palace_path = "/configured/palace"
|
||||
with patch.dict("sys.modules", {}):
|
||||
# Force reimport to pick up the mock
|
||||
result = repair._get_palace_path()
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_get_palace_path_fallback():
|
||||
with patch("mempalace.repair._get_palace_path") as mock_get:
|
||||
mock_get.return_value = os.path.join(os.path.expanduser("~"), ".mempalace", "palace")
|
||||
result = mock_get()
|
||||
assert ".mempalace" in result
|
||||
|
||||
|
||||
# ── _paginate_ids ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_paginate_ids_single_batch():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {"ids": ["id1", "id2", "id3"]}
|
||||
ids = repair._paginate_ids(col)
|
||||
assert ids == ["id1", "id2", "id3"]
|
||||
|
||||
|
||||
def test_paginate_ids_empty():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {"ids": []}
|
||||
ids = repair._paginate_ids(col)
|
||||
assert ids == []
|
||||
|
||||
|
||||
def test_paginate_ids_with_where():
|
||||
col = MagicMock()
|
||||
col.get.return_value = {"ids": ["id1"]}
|
||||
repair._paginate_ids(col, where={"wing": "test"})
|
||||
col.get.assert_called_with(where={"wing": "test"}, include=[], limit=1000, offset=0)
|
||||
|
||||
|
||||
def test_paginate_ids_offset_exception_fallback():
|
||||
col = MagicMock()
|
||||
# First call raises, fallback returns ids, second fallback returns empty
|
||||
col.get.side_effect = [
|
||||
Exception("offset bug"),
|
||||
{"ids": ["id1", "id2"]},
|
||||
Exception("offset bug"),
|
||||
{"ids": ["id1", "id2"]}, # same ids = no new = break
|
||||
]
|
||||
ids = repair._paginate_ids(col)
|
||||
assert "id1" in ids
|
||||
|
||||
|
||||
# ── scan_palace ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_scan_palace_no_ids(mock_chromadb, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 0
|
||||
mock_col.get.return_value = {"ids": []}
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
good, bad = repair.scan_palace(palace_path=str(tmp_path))
|
||||
assert good == set()
|
||||
assert bad == set()
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_scan_palace_all_good(mock_chromadb, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 2
|
||||
# _paginate_ids call
|
||||
mock_col.get.side_effect = [
|
||||
{"ids": ["id1", "id2"]}, # paginate
|
||||
{"ids": ["id1", "id2"]}, # probe batch — both returned
|
||||
]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
good, bad = repair.scan_palace(palace_path=str(tmp_path))
|
||||
assert "id1" in good
|
||||
assert "id2" in good
|
||||
assert len(bad) == 0
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_scan_palace_with_bad_ids(mock_chromadb, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 2
|
||||
|
||||
def get_side_effect(**kwargs):
|
||||
ids = kwargs.get("ids", None)
|
||||
if ids is None:
|
||||
# paginate call
|
||||
return {"ids": ["good1", "bad1"]}
|
||||
if "bad1" in ids and len(ids) == 1:
|
||||
raise Exception("corrupt")
|
||||
if "good1" in ids and len(ids) == 1:
|
||||
return {"ids": ["good1"]}
|
||||
# batch probe — raise to force per-id
|
||||
raise Exception("batch fail")
|
||||
|
||||
mock_col.get.side_effect = get_side_effect
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
good, bad = repair.scan_palace(palace_path=str(tmp_path))
|
||||
assert "good1" in good
|
||||
assert "bad1" in bad
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_scan_palace_with_wing_filter(mock_chromadb, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 1
|
||||
mock_col.get.side_effect = [
|
||||
{"ids": ["id1"]}, # paginate
|
||||
{"ids": ["id1"]}, # probe
|
||||
]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.scan_palace(palace_path=str(tmp_path), only_wing="test_wing")
|
||||
# Verify where filter was passed
|
||||
first_call = mock_col.get.call_args_list[0]
|
||||
assert first_call.kwargs.get("where") == {"wing": "test_wing"}
|
||||
|
||||
|
||||
# ── prune_corrupt ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_prune_corrupt_no_file(mock_chromadb, tmp_path):
|
||||
# Should print message and return without error
|
||||
repair.prune_corrupt(palace_path=str(tmp_path))
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_prune_corrupt_dry_run(mock_chromadb, tmp_path):
|
||||
bad_file = tmp_path / "corrupt_ids.txt"
|
||||
bad_file.write_text("bad1\nbad2\n")
|
||||
repair.prune_corrupt(palace_path=str(tmp_path), confirm=False)
|
||||
# No chromadb calls in dry run
|
||||
mock_chromadb.PersistentClient.assert_not_called()
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_prune_corrupt_confirmed(mock_chromadb, tmp_path):
|
||||
bad_file = tmp_path / "corrupt_ids.txt"
|
||||
bad_file.write_text("bad1\nbad2\n")
|
||||
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.side_effect = [10, 8]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.prune_corrupt(palace_path=str(tmp_path), confirm=True)
|
||||
mock_col.delete.assert_called_once()
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_prune_corrupt_delete_failure_fallback(mock_chromadb, tmp_path):
|
||||
bad_file = tmp_path / "corrupt_ids.txt"
|
||||
bad_file.write_text("bad1\nbad2\n")
|
||||
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.side_effect = [10, 8]
|
||||
# Batch delete fails, per-id succeeds
|
||||
mock_col.delete.side_effect = [Exception("batch fail"), None, None]
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.prune_corrupt(palace_path=str(tmp_path), confirm=True)
|
||||
assert mock_col.delete.call_count == 3 # 1 batch + 2 individual
|
||||
|
||||
|
||||
# ── rebuild_index ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_rebuild_index_no_palace(mock_chromadb, tmp_path):
|
||||
nonexistent = str(tmp_path / "nope")
|
||||
repair.rebuild_index(palace_path=nonexistent)
|
||||
mock_chromadb.PersistentClient.assert_not_called()
|
||||
|
||||
|
||||
@patch("mempalace.repair.shutil")
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_rebuild_index_empty_palace(mock_chromadb, mock_shutil, tmp_path):
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 0
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.rebuild_index(palace_path=str(tmp_path))
|
||||
mock_client.delete_collection.assert_not_called()
|
||||
|
||||
|
||||
@patch("mempalace.repair.shutil")
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_rebuild_index_success(mock_chromadb, mock_shutil, tmp_path):
|
||||
# Create a fake sqlite file
|
||||
sqlite_path = tmp_path / "chroma.sqlite3"
|
||||
sqlite_path.write_text("fake")
|
||||
|
||||
mock_col = MagicMock()
|
||||
mock_col.count.return_value = 2
|
||||
mock_col.get.return_value = {
|
||||
"ids": ["id1", "id2"],
|
||||
"documents": ["doc1", "doc2"],
|
||||
"metadatas": [{"wing": "a"}, {"wing": "b"}],
|
||||
}
|
||||
|
||||
mock_new_col = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.return_value = mock_col
|
||||
mock_client.create_collection.return_value = mock_new_col
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.rebuild_index(palace_path=str(tmp_path))
|
||||
|
||||
# Verify: backed up sqlite only (not copytree)
|
||||
mock_shutil.copy2.assert_called_once()
|
||||
assert "chroma.sqlite3" in str(mock_shutil.copy2.call_args)
|
||||
|
||||
# Verify: deleted and recreated with cosine
|
||||
mock_client.delete_collection.assert_called_once_with("mempalace_drawers")
|
||||
mock_client.create_collection.assert_called_once_with(
|
||||
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
# Verify: used upsert not add
|
||||
mock_new_col.upsert.assert_called_once()
|
||||
mock_new_col.add.assert_not_called()
|
||||
|
||||
|
||||
@patch("mempalace.repair.shutil")
|
||||
@patch("mempalace.repair.chromadb")
|
||||
def test_rebuild_index_error_reading(mock_chromadb, mock_shutil, tmp_path):
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_collection.side_effect = Exception("corrupt")
|
||||
mock_chromadb.PersistentClient.return_value = mock_client
|
||||
|
||||
repair.rebuild_index(palace_path=str(tmp_path))
|
||||
mock_client.delete_collection.assert_not_called()
|
||||
Reference in New Issue
Block a user