Merge pull request #739 from MemPalace/copilot/fix-unauthorized-data-deletion
Harden palace deletion, WAL redaction, and MCP search input handling
This commit is contained in:
+25
-4
@@ -156,7 +156,7 @@ def cmd_migrate(args):
|
|||||||
from .migrate import migrate
|
from .migrate import migrate
|
||||||
|
|
||||||
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
|
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
|
||||||
migrate(palace_path=palace_path, dry_run=args.dry_run)
|
migrate(palace_path=palace_path, dry_run=args.dry_run, confirm=getattr(args, "yes", False))
|
||||||
|
|
||||||
|
|
||||||
def cmd_status(args):
|
def cmd_status(args):
|
||||||
@@ -170,12 +170,19 @@ def cmd_repair(args):
|
|||||||
"""Rebuild palace vector index from SQLite metadata."""
|
"""Rebuild palace vector index from SQLite metadata."""
|
||||||
import chromadb
|
import chromadb
|
||||||
import shutil
|
import shutil
|
||||||
|
from .migrate import confirm_destructive_action, contains_palace_database
|
||||||
|
|
||||||
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
|
palace_path = os.path.abspath(
|
||||||
|
os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
|
||||||
|
)
|
||||||
|
db_path = os.path.join(palace_path, "chroma.sqlite3")
|
||||||
|
|
||||||
if not os.path.isdir(palace_path):
|
if not os.path.isdir(palace_path):
|
||||||
print(f"\n No palace found at {palace_path}")
|
print(f"\n No palace found at {palace_path}")
|
||||||
return
|
return
|
||||||
|
if not contains_palace_database(palace_path):
|
||||||
|
print(f"\n No palace database found at {db_path}")
|
||||||
|
return
|
||||||
|
|
||||||
print(f"\n{'=' * 55}")
|
print(f"\n{'=' * 55}")
|
||||||
print(" MemPalace Repair")
|
print(" MemPalace Repair")
|
||||||
@@ -197,6 +204,11 @@ def cmd_repair(args):
|
|||||||
print(" Nothing to repair.")
|
print(" Nothing to repair.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not confirm_destructive_action(
|
||||||
|
"Repair", palace_path, assume_yes=getattr(args, "yes", False)
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
# Extract all drawers in batches
|
# Extract all drawers in batches
|
||||||
print("\n Extracting drawers...")
|
print("\n Extracting drawers...")
|
||||||
batch_size = 5000
|
batch_size = 5000
|
||||||
@@ -213,9 +225,15 @@ def cmd_repair(args):
|
|||||||
print(f" Extracted {len(all_ids)} drawers")
|
print(f" Extracted {len(all_ids)} drawers")
|
||||||
|
|
||||||
# Backup and rebuild
|
# Backup and rebuild
|
||||||
palace_path = palace_path.rstrip(os.sep)
|
palace_path = os.path.normpath(palace_path)
|
||||||
backup_path = palace_path + ".backup"
|
backup_path = palace_path + ".backup"
|
||||||
if os.path.exists(backup_path):
|
if os.path.exists(backup_path):
|
||||||
|
if not contains_palace_database(backup_path):
|
||||||
|
print(
|
||||||
|
" Backup validation failed: backup path exists but does not contain chroma.sqlite3. "
|
||||||
|
f"Please remove or rename: {backup_path}"
|
||||||
|
)
|
||||||
|
return
|
||||||
shutil.rmtree(backup_path)
|
shutil.rmtree(backup_path)
|
||||||
print(f" Backing up to {backup_path}...")
|
print(f" Backing up to {backup_path}...")
|
||||||
shutil.copytree(palace_path, backup_path)
|
shutil.copytree(palace_path, backup_path)
|
||||||
@@ -532,7 +550,7 @@ def main():
|
|||||||
sub.add_parser(
|
sub.add_parser(
|
||||||
"repair",
|
"repair",
|
||||||
help="Rebuild palace vector index from stored data (fixes segfaults after corruption)",
|
help="Rebuild palace vector index from stored data (fixes segfaults after corruption)",
|
||||||
)
|
).add_argument("--yes", action="store_true", help="Skip confirmation for destructive changes")
|
||||||
|
|
||||||
# mcp
|
# mcp
|
||||||
sub.add_parser(
|
sub.add_parser(
|
||||||
@@ -551,6 +569,9 @@ def main():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Show what would be migrated without changing anything",
|
help="Show what would be migrated without changing anything",
|
||||||
)
|
)
|
||||||
|
p_migrate.add_argument(
|
||||||
|
"--yes", action="store_true", help="Skip confirmation for destructive changes"
|
||||||
|
)
|
||||||
|
|
||||||
sub.add_parser("status", help="Show what's been filed")
|
sub.add_parser("status", help="Show what's been filed")
|
||||||
|
|
||||||
|
|||||||
+31
-3
@@ -94,7 +94,9 @@ else:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Keys whose values should be redacted in WAL entries to avoid logging sensitive content
|
# Keys whose values should be redacted in WAL entries to avoid logging sensitive content
|
||||||
_WAL_REDACT_KEYS = frozenset({"content_preview", "entry_preview"})
|
_WAL_REDACT_KEYS = frozenset(
|
||||||
|
{"content", "content_preview", "document", "entry", "entry_preview", "query", "text"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _wal_log(operation: str, params: dict, result: dict = None):
|
def _wal_log(operation: str, params: dict, result: dict = None):
|
||||||
@@ -212,6 +214,13 @@ def _get_cached_metadata(col, where=None):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_optional_name(value: str = None, field_name: str = "name") -> str:
|
||||||
|
"""Validate optional wing/room-style filters."""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return sanitize_name(value, field_name)
|
||||||
|
|
||||||
|
|
||||||
# ==================== READ TOOLS ====================
|
# ==================== READ TOOLS ====================
|
||||||
|
|
||||||
|
|
||||||
@@ -296,6 +305,10 @@ def tool_list_wings():
|
|||||||
|
|
||||||
|
|
||||||
def tool_list_rooms(wing: str = None):
|
def tool_list_rooms(wing: str = None):
|
||||||
|
try:
|
||||||
|
wing = _sanitize_optional_name(wing, "wing")
|
||||||
|
except ValueError as e:
|
||||||
|
return {"error": str(e)}
|
||||||
col = _get_collection()
|
col = _get_collection()
|
||||||
if not col:
|
if not col:
|
||||||
return _no_palace()
|
return _no_palace()
|
||||||
@@ -345,6 +358,11 @@ def tool_search(
|
|||||||
context: str = None,
|
context: str = None,
|
||||||
):
|
):
|
||||||
limit = max(1, min(limit, _MAX_RESULTS))
|
limit = max(1, min(limit, _MAX_RESULTS))
|
||||||
|
try:
|
||||||
|
wing = _sanitize_optional_name(wing, "wing")
|
||||||
|
room = _sanitize_optional_name(room, "room")
|
||||||
|
except ValueError as e:
|
||||||
|
return {"error": str(e)}
|
||||||
# Backwards compat: accept old name
|
# Backwards compat: accept old name
|
||||||
# Backwards compat: convert old similarity scale (higher=stricter) to
|
# Backwards compat: convert old similarity scale (higher=stricter) to
|
||||||
# distance scale (lower=stricter). Similarity 0.8 → distance 0.2.
|
# distance scale (lower=stricter). Similarity 0.8 → distance 0.2.
|
||||||
@@ -425,6 +443,11 @@ def tool_traverse_graph(start_room: str, max_hops: int = 2):
|
|||||||
|
|
||||||
def tool_find_tunnels(wing_a: str = None, wing_b: str = None):
|
def tool_find_tunnels(wing_a: str = None, wing_b: str = None):
|
||||||
"""Find rooms that bridge two wings — the hallways connecting domains."""
|
"""Find rooms that bridge two wings — the hallways connecting domains."""
|
||||||
|
try:
|
||||||
|
wing_a = _sanitize_optional_name(wing_a, "wing_a")
|
||||||
|
wing_b = _sanitize_optional_name(wing_b, "wing_b")
|
||||||
|
except ValueError as e:
|
||||||
|
return {"error": str(e)}
|
||||||
col = _get_collection()
|
col = _get_collection()
|
||||||
if not col:
|
if not col:
|
||||||
return _no_palace()
|
return _no_palace()
|
||||||
@@ -559,6 +582,11 @@ def tool_list_drawers(wing: str = None, room: str = None, limit: int = 20, offse
|
|||||||
"""List drawers with pagination. Optional wing/room filter."""
|
"""List drawers with pagination. Optional wing/room filter."""
|
||||||
limit = max(1, min(limit, _MAX_RESULTS))
|
limit = max(1, min(limit, _MAX_RESULTS))
|
||||||
offset = max(0, offset)
|
offset = max(0, offset)
|
||||||
|
try:
|
||||||
|
wing = _sanitize_optional_name(wing, "wing")
|
||||||
|
room = _sanitize_optional_name(room, "room")
|
||||||
|
except ValueError as e:
|
||||||
|
return {"error": str(e)}
|
||||||
col = _get_collection()
|
col = _get_collection()
|
||||||
if not col:
|
if not col:
|
||||||
return _no_palace()
|
return _no_palace()
|
||||||
@@ -1098,8 +1126,8 @@ TOOLS = {
|
|||||||
"properties": {
|
"properties": {
|
||||||
"query": {
|
"query": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Short search query ONLY — keywords or a question. Max 200 chars recommended.",
|
"description": "Short search query ONLY — keywords or a question. Max 250 chars.",
|
||||||
"maxLength": 500,
|
"maxLength": 250,
|
||||||
},
|
},
|
||||||
"limit": {
|
"limit": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
|||||||
+32
-3
@@ -104,14 +104,40 @@ def detect_chromadb_version(db_path: str) -> str:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def migrate(palace_path: str, dry_run: bool = False):
|
def contains_palace_database(path: str) -> bool:
|
||||||
|
"""Return True when path looks like a MemPalace ChromaDB directory."""
|
||||||
|
return os.path.isfile(os.path.join(path, "chroma.sqlite3"))
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_destructive_action(
|
||||||
|
operation_name: str, palace_path: str, assume_yes: bool = False
|
||||||
|
) -> bool:
|
||||||
|
"""Require confirmation before destructive palace operations."""
|
||||||
|
if assume_yes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
print(f"\n {operation_name} will replace data in: {palace_path}")
|
||||||
|
print(" A backup will be created first, then the palace will be rebuilt.")
|
||||||
|
try:
|
||||||
|
answer = input(" Continue? [y/N]: ").strip().lower()
|
||||||
|
except EOFError:
|
||||||
|
print(" Aborted. Re-run with --yes to confirm destructive changes.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if answer not in {"y", "yes"}:
|
||||||
|
print(" Aborted.")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False):
|
||||||
"""Migrate a palace to the currently installed ChromaDB version."""
|
"""Migrate a palace to the currently installed ChromaDB version."""
|
||||||
import chromadb
|
import chromadb
|
||||||
|
|
||||||
palace_path = os.path.expanduser(palace_path)
|
palace_path = os.path.abspath(os.path.expanduser(palace_path))
|
||||||
db_path = os.path.join(palace_path, "chroma.sqlite3")
|
db_path = os.path.join(palace_path, "chroma.sqlite3")
|
||||||
|
|
||||||
if not os.path.isfile(db_path):
|
if not os.path.isdir(palace_path) or not contains_palace_database(palace_path):
|
||||||
print(f"\n No palace database found at {db_path}")
|
print(f"\n No palace database found at {db_path}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -166,6 +192,9 @@ def migrate(palace_path: str, dry_run: bool = False):
|
|||||||
print(f" Would migrate {len(drawers)} drawers.")
|
print(f" Would migrate {len(drawers)} drawers.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if not confirm_destructive_action("Migration", palace_path, assume_yes=confirm):
|
||||||
|
return False
|
||||||
|
|
||||||
# Backup the old palace
|
# Backup the old palace
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
backup_path = f"{palace_path}.pre-migrate.{timestamp}"
|
backup_path = f"{palace_path}.pre-migrate.{timestamp}"
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ import logging
|
|||||||
logger = logging.getLogger("mempalace_mcp")
|
logger = logging.getLogger("mempalace_mcp")
|
||||||
|
|
||||||
# --- Constants ---
|
# --- Constants ---
|
||||||
MAX_QUERY_LENGTH = 500 # Above this, system prompt almost certainly dominates
|
MAX_QUERY_LENGTH = 250 # Above this, prompt contamination increasingly dominates
|
||||||
SAFE_QUERY_LENGTH = 200 # Below this, query is almost certainly clean
|
SAFE_QUERY_LENGTH = 200 # Below this, query is almost certainly clean
|
||||||
MIN_QUERY_LENGTH = 10 # Extracted result shorter than this = extraction failed
|
MIN_QUERY_LENGTH = 10 # Extracted result shorter than this = extraction failed
|
||||||
|
QUOTE_CHARS = {"'", '"'}
|
||||||
|
|
||||||
# Sentence splitter: split on . ! ? (including fullwidth) and newlines
|
# Sentence splitter: split on . ! ? (including fullwidth) and newlines
|
||||||
_SENTENCE_SPLIT = re.compile(r"[.!?。!?\n]+")
|
_SENTENCE_SPLIT = re.compile(r"[.!?。!?\n]+")
|
||||||
@@ -67,6 +68,36 @@ def sanitize_query(raw_query: str) -> dict:
|
|||||||
raw_query = raw_query.strip()
|
raw_query = raw_query.strip()
|
||||||
original_length = len(raw_query)
|
original_length = len(raw_query)
|
||||||
|
|
||||||
|
def _strip_wrapping_quotes(candidate: str) -> str:
|
||||||
|
candidate = candidate.strip()
|
||||||
|
while (
|
||||||
|
len(candidate) >= 2 and candidate[:1] in QUOTE_CHARS and candidate[:1] == candidate[-1:]
|
||||||
|
):
|
||||||
|
candidate = candidate[1:-1].strip()
|
||||||
|
if not candidate:
|
||||||
|
return ""
|
||||||
|
if candidate[:1] in QUOTE_CHARS:
|
||||||
|
candidate = candidate[1:].strip()
|
||||||
|
if candidate[-1:] in QUOTE_CHARS:
|
||||||
|
candidate = candidate[:-1].strip()
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
def _trim_candidate(candidate: str) -> str:
|
||||||
|
candidate = _strip_wrapping_quotes(candidate)
|
||||||
|
if len(candidate) <= MAX_QUERY_LENGTH:
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
nested_fragments = [
|
||||||
|
_strip_wrapping_quotes(frag)
|
||||||
|
for frag in _SENTENCE_SPLIT.split(candidate)
|
||||||
|
if frag.strip()
|
||||||
|
]
|
||||||
|
for frag in reversed(nested_fragments):
|
||||||
|
if MIN_QUERY_LENGTH <= len(frag) <= MAX_QUERY_LENGTH:
|
||||||
|
return frag
|
||||||
|
|
||||||
|
return candidate[-MAX_QUERY_LENGTH:].strip()
|
||||||
|
|
||||||
# --- Step 1: Short query passthrough ---
|
# --- Step 1: Short query passthrough ---
|
||||||
if original_length <= SAFE_QUERY_LENGTH:
|
if original_length <= SAFE_QUERY_LENGTH:
|
||||||
return {
|
return {
|
||||||
@@ -106,7 +137,7 @@ def sanitize_query(raw_query: str) -> dict:
|
|||||||
if len(candidate) >= MIN_QUERY_LENGTH:
|
if len(candidate) >= MIN_QUERY_LENGTH:
|
||||||
# Apply length guard
|
# Apply length guard
|
||||||
if len(candidate) > MAX_QUERY_LENGTH:
|
if len(candidate) > MAX_QUERY_LENGTH:
|
||||||
candidate = candidate[-MAX_QUERY_LENGTH:]
|
candidate = _trim_candidate(candidate)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Query sanitized: %d → %d chars (method=question_extraction)",
|
"Query sanitized: %d → %d chars (method=question_extraction)",
|
||||||
original_length,
|
original_length,
|
||||||
@@ -126,9 +157,9 @@ def sanitize_query(raw_query: str) -> dict:
|
|||||||
for seg in reversed(all_segments):
|
for seg in reversed(all_segments):
|
||||||
seg = seg.strip()
|
seg = seg.strip()
|
||||||
if len(seg) >= MIN_QUERY_LENGTH:
|
if len(seg) >= MIN_QUERY_LENGTH:
|
||||||
candidate = seg
|
candidate = _trim_candidate(seg)
|
||||||
if len(candidate) > MAX_QUERY_LENGTH:
|
if len(candidate) < MIN_QUERY_LENGTH:
|
||||||
candidate = candidate[-MAX_QUERY_LENGTH:]
|
continue
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Query sanitized: %d → %d chars (method=tail_sentence)",
|
"Query sanitized: %d → %d chars (method=tail_sentence)",
|
||||||
original_length,
|
original_length,
|
||||||
|
|||||||
+40
-1
@@ -423,10 +423,24 @@ def test_cmd_repair_no_palace(mock_config_cls, tmp_path, capsys):
|
|||||||
assert "No palace found" in out
|
assert "No palace found" in out
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_repair_requires_palace_database(mock_config_cls, tmp_path, capsys):
|
||||||
|
palace_dir = tmp_path / "palace"
|
||||||
|
palace_dir.mkdir()
|
||||||
|
mock_config_cls.return_value.palace_path = str(palace_dir)
|
||||||
|
args = argparse.Namespace(palace=None)
|
||||||
|
mock_chromadb = MagicMock()
|
||||||
|
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
|
||||||
|
cmd_repair(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "No palace database found" in out
|
||||||
|
|
||||||
|
|
||||||
@patch("mempalace.cli.MempalaceConfig")
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys):
|
def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys):
|
||||||
palace_dir = tmp_path / "palace"
|
palace_dir = tmp_path / "palace"
|
||||||
palace_dir.mkdir()
|
palace_dir.mkdir()
|
||||||
|
(palace_dir / "chroma.sqlite3").write_text("db")
|
||||||
mock_config_cls.return_value.palace_path = str(palace_dir)
|
mock_config_cls.return_value.palace_path = str(palace_dir)
|
||||||
args = argparse.Namespace(palace=None)
|
args = argparse.Namespace(palace=None)
|
||||||
mock_chromadb = MagicMock()
|
mock_chromadb = MagicMock()
|
||||||
@@ -443,6 +457,7 @@ def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys):
|
|||||||
def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys):
|
def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys):
|
||||||
palace_dir = tmp_path / "palace"
|
palace_dir = tmp_path / "palace"
|
||||||
palace_dir.mkdir()
|
palace_dir.mkdir()
|
||||||
|
(palace_dir / "chroma.sqlite3").write_text("db")
|
||||||
mock_config_cls.return_value.palace_path = str(palace_dir)
|
mock_config_cls.return_value.palace_path = str(palace_dir)
|
||||||
args = argparse.Namespace(palace=None)
|
args = argparse.Namespace(palace=None)
|
||||||
mock_chromadb = MagicMock()
|
mock_chromadb = MagicMock()
|
||||||
@@ -461,8 +476,9 @@ def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys):
|
|||||||
def test_cmd_repair_success(mock_config_cls, tmp_path, capsys):
|
def test_cmd_repair_success(mock_config_cls, tmp_path, capsys):
|
||||||
palace_dir = tmp_path / "palace"
|
palace_dir = tmp_path / "palace"
|
||||||
palace_dir.mkdir()
|
palace_dir.mkdir()
|
||||||
|
(palace_dir / "chroma.sqlite3").write_text("db")
|
||||||
mock_config_cls.return_value.palace_path = str(palace_dir)
|
mock_config_cls.return_value.palace_path = str(palace_dir)
|
||||||
args = argparse.Namespace(palace=None)
|
args = argparse.Namespace(palace=None, yes=True)
|
||||||
mock_chromadb = MagicMock()
|
mock_chromadb = MagicMock()
|
||||||
mock_col = MagicMock()
|
mock_col = MagicMock()
|
||||||
mock_col.count.return_value = 2
|
mock_col.count.return_value = 2
|
||||||
@@ -483,6 +499,29 @@ def test_cmd_repair_success(mock_config_cls, tmp_path, capsys):
|
|||||||
assert "2 drawers rebuilt" in out
|
assert "2 drawers rebuilt" in out
|
||||||
|
|
||||||
|
|
||||||
|
@patch("mempalace.cli.MempalaceConfig")
|
||||||
|
def test_cmd_repair_aborts_without_confirmation(mock_config_cls, tmp_path, capsys):
|
||||||
|
palace_dir = tmp_path / "palace"
|
||||||
|
palace_dir.mkdir()
|
||||||
|
(palace_dir / "chroma.sqlite3").write_text("db")
|
||||||
|
mock_config_cls.return_value.palace_path = str(palace_dir)
|
||||||
|
args = argparse.Namespace(palace=None)
|
||||||
|
mock_chromadb = MagicMock()
|
||||||
|
mock_col = MagicMock()
|
||||||
|
mock_col.count.return_value = 1
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_collection.return_value = mock_col
|
||||||
|
mock_chromadb.PersistentClient.return_value = mock_client
|
||||||
|
with (
|
||||||
|
patch.dict("sys.modules", {"chromadb": mock_chromadb}),
|
||||||
|
patch("builtins.input", return_value="n"),
|
||||||
|
):
|
||||||
|
cmd_repair(args)
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "Aborted." in out
|
||||||
|
mock_client.create_collection.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
# ── cmd_compress ───────────────────────────────────────────────────────
|
# ── cmd_compress ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ via monkeypatch to avoid touching real data.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def _patch_mcp_server(monkeypatch, config, kg):
|
def _patch_mcp_server(monkeypatch, config, kg):
|
||||||
"""Patch the mcp_server module globals to use test fixtures."""
|
"""Patch the mcp_server module globals to use test fixtures."""
|
||||||
@@ -311,6 +313,59 @@ class TestSearchTool:
|
|||||||
result_loose = tool_search(query="JWT", max_distance=0.01, min_similarity=999.0)
|
result_loose = tool_search(query="JWT", max_distance=0.01, min_similarity=999.0)
|
||||||
assert len(result_strict["results"]) <= len(result_loose["results"])
|
assert len(result_strict["results"]) <= len(result_loose["results"])
|
||||||
|
|
||||||
|
def test_list_rooms_rejects_invalid_wing(self, monkeypatch, config, kg):
|
||||||
|
_patch_mcp_server(monkeypatch, config, kg)
|
||||||
|
from mempalace import mcp_server
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
|
||||||
|
|
||||||
|
result = mcp_server.tool_list_rooms(wing="../etc/passwd")
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
def test_search_rejects_invalid_room(self, monkeypatch, config, kg):
|
||||||
|
_patch_mcp_server(monkeypatch, config, kg)
|
||||||
|
from mempalace import mcp_server
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_server, "search_memories", lambda *args, **kwargs: pytest.fail())
|
||||||
|
|
||||||
|
result = mcp_server.tool_search(query="JWT", room="../backend")
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
def test_list_drawers_rejects_invalid_wing(self, monkeypatch, config, kg):
|
||||||
|
_patch_mcp_server(monkeypatch, config, kg)
|
||||||
|
from mempalace import mcp_server
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
|
||||||
|
|
||||||
|
result = mcp_server.tool_list_drawers(wing="../notes")
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
def test_find_tunnels_rejects_invalid_wing(self, monkeypatch, config, kg):
|
||||||
|
_patch_mcp_server(monkeypatch, config, kg)
|
||||||
|
from mempalace import mcp_server
|
||||||
|
|
||||||
|
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
|
||||||
|
|
||||||
|
result = mcp_server.tool_find_tunnels(wing_a="../project")
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
def test_wal_redacts_sensitive_fields(self, monkeypatch, config, kg, tmp_path):
|
||||||
|
_patch_mcp_server(monkeypatch, config, kg)
|
||||||
|
from mempalace import mcp_server
|
||||||
|
|
||||||
|
wal_file = tmp_path / "write_log.jsonl"
|
||||||
|
monkeypatch.setattr(mcp_server, "_WAL_FILE", wal_file)
|
||||||
|
|
||||||
|
mcp_server._wal_log(
|
||||||
|
"test",
|
||||||
|
{"content": "secret note", "query": "private search", "safe": "ok"},
|
||||||
|
)
|
||||||
|
|
||||||
|
entry = json.loads(wal_file.read_text().strip())
|
||||||
|
assert entry["params"]["content"].startswith("[REDACTED")
|
||||||
|
assert entry["params"]["query"].startswith("[REDACTED")
|
||||||
|
assert entry["params"]["safe"] == "ok"
|
||||||
|
|
||||||
|
|
||||||
# ── Write Tools ─────────────────────────────────────────────────────────
|
# ── Write Tools ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
"""Tests for destructive-operation safety in mempalace.migrate."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from mempalace.migrate import migrate
|
||||||
|
|
||||||
|
|
||||||
|
def test_migrate_requires_palace_database(tmp_path, capsys):
|
||||||
|
palace_dir = tmp_path / "palace"
|
||||||
|
palace_dir.mkdir()
|
||||||
|
|
||||||
|
result = migrate(str(palace_dir))
|
||||||
|
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert result is False
|
||||||
|
assert "No palace database found" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_migrate_aborts_without_confirmation(tmp_path, capsys):
|
||||||
|
palace_dir = tmp_path / "palace"
|
||||||
|
palace_dir.mkdir()
|
||||||
|
# Presence of chroma.sqlite3 is the safety gate; validity is mocked below.
|
||||||
|
(palace_dir / "chroma.sqlite3").write_text("db")
|
||||||
|
|
||||||
|
mock_chromadb = SimpleNamespace(
|
||||||
|
__version__="0.6.0",
|
||||||
|
PersistentClient=MagicMock(side_effect=Exception("unreadable")),
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.dict("sys.modules", {"chromadb": mock_chromadb}),
|
||||||
|
patch("mempalace.migrate.detect_chromadb_version", return_value="0.5.x"),
|
||||||
|
patch(
|
||||||
|
"mempalace.migrate.extract_drawers_from_sqlite",
|
||||||
|
return_value=[{"id": "id1", "document": "doc", "metadata": {"wing": "w", "room": "r"}}],
|
||||||
|
),
|
||||||
|
patch("builtins.input", return_value="n"),
|
||||||
|
patch("mempalace.migrate.shutil.copytree") as mock_copytree,
|
||||||
|
patch("mempalace.migrate.shutil.rmtree") as mock_rmtree,
|
||||||
|
):
|
||||||
|
result = migrate(str(palace_dir))
|
||||||
|
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert result is False
|
||||||
|
assert "Aborted." in out
|
||||||
|
mock_copytree.assert_not_called()
|
||||||
|
mock_rmtree.assert_not_called()
|
||||||
@@ -102,6 +102,21 @@ class TestTailSentence:
|
|||||||
assert result["was_sanitized"] is True
|
assert result["was_sanitized"] is True
|
||||||
assert "MemPalace" in result["clean_query"] or "ChromaDB" in result["clean_query"]
|
assert "MemPalace" in result["clean_query"] or "ChromaDB" in result["clean_query"]
|
||||||
|
|
||||||
|
def test_long_candidate_uses_last_sentence_fragment(self):
|
||||||
|
query = ("Prompt sentence. " * 30) + "Final search intent for architecture migration"
|
||||||
|
result = sanitize_query(query)
|
||||||
|
assert result["method"] == "tail_sentence"
|
||||||
|
assert result["clean_query"] == "Final search intent for architecture migration"
|
||||||
|
|
||||||
|
def test_long_candidate_strips_wrapping_quotes(self):
|
||||||
|
query = ("Prefix text " * 30) + '\n"' + ("x" * 260) + '"'
|
||||||
|
result = sanitize_query(query)
|
||||||
|
assert result["method"] == "tail_sentence"
|
||||||
|
assert result["clean_query"] == "x" * MAX_QUERY_LENGTH
|
||||||
|
assert not result["clean_query"].startswith('"')
|
||||||
|
assert not result["clean_query"].endswith('"')
|
||||||
|
assert len(result["clean_query"]) <= MAX_QUERY_LENGTH
|
||||||
|
|
||||||
|
|
||||||
class TestTailTruncation:
|
class TestTailTruncation:
|
||||||
"""Step 4: Fallback — take the last MAX_QUERY_LENGTH characters."""
|
"""Step 4: Fallback — take the last MAX_QUERY_LENGTH characters."""
|
||||||
@@ -119,10 +134,19 @@ class TestTailTruncation:
|
|||||||
result = sanitize_query(filler)
|
result = sanitize_query(filler)
|
||||||
assert "IMPORTANT_QUERY_CONTENT" in result["clean_query"]
|
assert "IMPORTANT_QUERY_CONTENT" in result["clean_query"]
|
||||||
|
|
||||||
|
def test_tail_sentence_fallback_preserves_tail_without_delimiters(self):
|
||||||
|
filler = ("x" * 260) + "IMPORTANT_QUERY_CONTENT"
|
||||||
|
result = sanitize_query(filler)
|
||||||
|
assert result["method"] == "tail_sentence"
|
||||||
|
assert "IMPORTANT_QUERY_CONTENT" in result["clean_query"]
|
||||||
|
|
||||||
|
|
||||||
class TestLengthGuards:
|
class TestLengthGuards:
|
||||||
"""Verify output length constraints."""
|
"""Verify output length constraints."""
|
||||||
|
|
||||||
|
def test_max_query_length_reduced(self):
|
||||||
|
assert MAX_QUERY_LENGTH == 250
|
||||||
|
|
||||||
def test_output_never_exceeds_max(self):
|
def test_output_never_exceeds_max(self):
|
||||||
# Very long question sentence
|
# Very long question sentence
|
||||||
long_question = "a" * 1000 + "?"
|
long_question = "a" * 1000 + "?"
|
||||||
|
|||||||
Reference in New Issue
Block a user