fix: harden palace security checks

Agent-Logs-Url: https://github.com/MemPalace/mempalace/sessions/775f2fc4-3051-462e-8586-6d694b55da0d

Co-authored-by: igorls <4753812+igorls@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-04-12 22:21:42 +00:00
committed by Igor Lins e Silva
parent bb577bb41f
commit c478dfa173
8 changed files with 238 additions and 15 deletions
+17 -3
View File
@@ -156,7 +156,7 @@ def cmd_migrate(args):
from .migrate import migrate
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
migrate(palace_path=palace_path, dry_run=args.dry_run)
migrate(palace_path=palace_path, dry_run=args.dry_run, confirm=getattr(args, "yes", False))
def cmd_status(args):
@@ -170,12 +170,19 @@ def cmd_repair(args):
"""Rebuild palace vector index from SQLite metadata."""
import chromadb
import shutil
from .migrate import confirm_destructive_action, has_palace_database
palace_path = os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
palace_path = os.path.abspath(
os.path.expanduser(args.palace) if args.palace else MempalaceConfig().palace_path
)
db_path = os.path.join(palace_path, "chroma.sqlite3")
if not os.path.isdir(palace_path):
print(f"\n No palace found at {palace_path}")
return
if not has_palace_database(palace_path):
print(f"\n No palace database found at {db_path}")
return
print(f"\n{'=' * 55}")
print(" MemPalace Repair")
@@ -197,6 +204,9 @@ def cmd_repair(args):
print(" Nothing to repair.")
return
if not confirm_destructive_action("Repair", palace_path, assume_yes=getattr(args, "yes", False)):
return
# Extract all drawers in batches
print("\n Extracting drawers...")
batch_size = 5000
@@ -216,6 +226,9 @@ def cmd_repair(args):
palace_path = palace_path.rstrip(os.sep)
backup_path = palace_path + ".backup"
if os.path.exists(backup_path):
if not has_palace_database(backup_path):
print(f" Refusing to delete non-palace backup path: {backup_path}")
return
shutil.rmtree(backup_path)
print(f" Backing up to {backup_path}...")
shutil.copytree(palace_path, backup_path)
@@ -532,7 +545,7 @@ def main():
sub.add_parser(
"repair",
help="Rebuild palace vector index from stored data (fixes segfaults after corruption)",
)
).add_argument("--yes", action="store_true", help="Skip confirmation for destructive changes")
# mcp
sub.add_parser(
@@ -551,6 +564,7 @@ def main():
action="store_true",
help="Show what would be migrated without changing anything",
)
p_migrate.add_argument("--yes", action="store_true", help="Skip confirmation for destructive changes")
sub.add_parser("status", help="Show what's been filed")
+31 -3
View File
@@ -94,7 +94,9 @@ else:
pass
# Keys whose values should be redacted in WAL entries to avoid logging sensitive content
_WAL_REDACT_KEYS = frozenset({"content_preview", "entry_preview"})
_WAL_REDACT_KEYS = frozenset(
{"content", "content_preview", "document", "entry", "entry_preview", "query", "text"}
)
def _wal_log(operation: str, params: dict, result: dict = None):
@@ -212,6 +214,13 @@ def _get_cached_metadata(col, where=None):
return result
def _sanitize_optional_name(value: str = None, field_name: str = "name") -> str:
"""Validate optional wing/room-style filters."""
if value is None:
return None
return sanitize_name(value, field_name)
# ==================== READ TOOLS ====================
@@ -296,6 +305,10 @@ def tool_list_wings():
def tool_list_rooms(wing: str = None):
try:
wing = _sanitize_optional_name(wing, "wing")
except ValueError as e:
return {"error": str(e)}
col = _get_collection()
if not col:
return _no_palace()
@@ -345,6 +358,11 @@ def tool_search(
context: str = None,
):
limit = max(1, min(limit, _MAX_RESULTS))
try:
wing = _sanitize_optional_name(wing, "wing")
room = _sanitize_optional_name(room, "room")
except ValueError as e:
return {"error": str(e)}
# Backwards compat: accept old name
# Backwards compat: convert old similarity scale (higher=stricter) to
# distance scale (lower=stricter). Similarity 0.8 → distance 0.2.
@@ -425,6 +443,11 @@ def tool_traverse_graph(start_room: str, max_hops: int = 2):
def tool_find_tunnels(wing_a: str = None, wing_b: str = None):
"""Find rooms that bridge two wings — the hallways connecting domains."""
try:
wing_a = _sanitize_optional_name(wing_a, "wing_a")
wing_b = _sanitize_optional_name(wing_b, "wing_b")
except ValueError as e:
return {"error": str(e)}
col = _get_collection()
if not col:
return _no_palace()
@@ -559,6 +582,11 @@ def tool_list_drawers(wing: str = None, room: str = None, limit: int = 20, offse
"""List drawers with pagination. Optional wing/room filter."""
limit = max(1, min(limit, _MAX_RESULTS))
offset = max(0, offset)
try:
wing = _sanitize_optional_name(wing, "wing")
room = _sanitize_optional_name(room, "room")
except ValueError as e:
return {"error": str(e)}
col = _get_collection()
if not col:
return _no_palace()
@@ -1098,8 +1126,8 @@ TOOLS = {
"properties": {
"query": {
"type": "string",
"description": "Short search query ONLY — keywords or a question. Max 200 chars recommended.",
"maxLength": 500,
"description": "Short search query ONLY — keywords or a question. Max 250 chars.",
"maxLength": 250,
},
"limit": {
"type": "integer",
+30 -3
View File
@@ -104,14 +104,38 @@ def detect_chromadb_version(db_path: str) -> str:
conn.close()
def migrate(palace_path: str, dry_run: bool = False):
def has_palace_database(path: str) -> bool:
"""Return True when path looks like a MemPalace ChromaDB directory."""
return os.path.isfile(os.path.join(path, "chroma.sqlite3"))
def confirm_destructive_action(action: str, palace_path: str, assume_yes: bool = False) -> bool:
"""Require confirmation before destructive palace operations."""
if assume_yes:
return True
print(f"\n {action} will replace data in: {palace_path}")
print(" A backup will be created first, but the original directory will be deleted.")
try:
answer = input(" Continue? [y/N]: ").strip().lower()
except EOFError:
print(" Aborted. Re-run with --yes to confirm destructive changes.")
return False
if answer not in {"y", "yes"}:
print(" Aborted.")
return False
return True
def migrate(palace_path: str, dry_run: bool = False, confirm: bool = False):
"""Migrate a palace to the currently installed ChromaDB version."""
import chromadb
palace_path = os.path.expanduser(palace_path)
palace_path = os.path.abspath(os.path.expanduser(palace_path))
db_path = os.path.join(palace_path, "chroma.sqlite3")
if not os.path.isfile(db_path):
if not os.path.isdir(palace_path) or not has_palace_database(palace_path):
print(f"\n No palace database found at {db_path}")
return False
@@ -166,6 +190,9 @@ def migrate(palace_path: str, dry_run: bool = False):
print(f" Would migrate {len(drawers)} drawers.")
return True
if not confirm_destructive_action("Migration", palace_path, assume_yes=confirm):
return False
# Backup the old palace
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"{palace_path}.pre-migrate.{timestamp}"
+17 -5
View File
@@ -24,7 +24,7 @@ import logging
logger = logging.getLogger("mempalace_mcp")
# --- Constants ---
MAX_QUERY_LENGTH = 500 # Above this, system prompt almost certainly dominates
MAX_QUERY_LENGTH = 250 # Above this, prompt contamination increasingly dominates
SAFE_QUERY_LENGTH = 200 # Below this, query is almost certainly clean
MIN_QUERY_LENGTH = 10 # Extracted result shorter than this = extraction failed
@@ -67,6 +67,20 @@ def sanitize_query(raw_query: str) -> dict:
raw_query = raw_query.strip()
original_length = len(raw_query)
def _trim_candidate(candidate: str) -> str:
candidate = candidate.strip().strip("\"'")
if len(candidate) <= MAX_QUERY_LENGTH:
return candidate
nested_fragments = [
frag.strip().strip("\"'") for frag in _SENTENCE_SPLIT.split(candidate) if frag.strip()
]
for frag in reversed(nested_fragments):
if MIN_QUERY_LENGTH <= len(frag) <= MAX_QUERY_LENGTH:
return frag
return candidate[-MAX_QUERY_LENGTH:].strip()
# --- Step 1: Short query passthrough ---
if original_length <= SAFE_QUERY_LENGTH:
return {
@@ -106,7 +120,7 @@ def sanitize_query(raw_query: str) -> dict:
if len(candidate) >= MIN_QUERY_LENGTH:
# Apply length guard
if len(candidate) > MAX_QUERY_LENGTH:
candidate = candidate[-MAX_QUERY_LENGTH:]
candidate = _trim_candidate(candidate)
logger.warning(
"Query sanitized: %d%d chars (method=question_extraction)",
original_length,
@@ -126,9 +140,7 @@ def sanitize_query(raw_query: str) -> dict:
for seg in reversed(all_segments):
seg = seg.strip()
if len(seg) >= MIN_QUERY_LENGTH:
candidate = seg
if len(candidate) > MAX_QUERY_LENGTH:
candidate = candidate[-MAX_QUERY_LENGTH:]
candidate = _trim_candidate(seg)
logger.warning(
"Query sanitized: %d%d chars (method=tail_sentence)",
original_length,
+40 -1
View File
@@ -423,10 +423,24 @@ def test_cmd_repair_no_palace(mock_config_cls, tmp_path, capsys):
assert "No palace found" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_requires_palace_database(mock_config_cls, tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
with patch.dict("sys.modules", {"chromadb": mock_chromadb}):
cmd_repair(args)
out = capsys.readouterr().out
assert "No palace database found" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
(palace_dir / "chroma.sqlite3").write_text("db")
mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
@@ -443,6 +457,7 @@ def test_cmd_repair_error_reading(mock_config_cls, tmp_path, capsys):
def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
(palace_dir / "chroma.sqlite3").write_text("db")
mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
@@ -461,8 +476,9 @@ def test_cmd_repair_zero_drawers(mock_config_cls, tmp_path, capsys):
def test_cmd_repair_success(mock_config_cls, tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
(palace_dir / "chroma.sqlite3").write_text("db")
mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None)
args = argparse.Namespace(palace=None, yes=True)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.count.return_value = 2
@@ -483,6 +499,29 @@ def test_cmd_repair_success(mock_config_cls, tmp_path, capsys):
assert "2 drawers rebuilt" in out
@patch("mempalace.cli.MempalaceConfig")
def test_cmd_repair_aborts_without_confirmation(mock_config_cls, tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
(palace_dir / "chroma.sqlite3").write_text("db")
mock_config_cls.return_value.palace_path = str(palace_dir)
args = argparse.Namespace(palace=None)
mock_chromadb = MagicMock()
mock_col = MagicMock()
mock_col.count.return_value = 1
mock_client = MagicMock()
mock_client.get_collection.return_value = mock_col
mock_chromadb.PersistentClient.return_value = mock_client
with (
patch.dict("sys.modules", {"chromadb": mock_chromadb}),
patch("builtins.input", return_value="n"),
):
cmd_repair(args)
out = capsys.readouterr().out
assert "Aborted." in out
mock_client.create_collection.assert_not_called()
# ── cmd_compress ───────────────────────────────────────────────────────
+55
View File
@@ -8,6 +8,8 @@ via monkeypatch to avoid touching real data.
import json
import pytest
def _patch_mcp_server(monkeypatch, config, kg):
"""Patch the mcp_server module globals to use test fixtures."""
@@ -311,6 +313,59 @@ class TestSearchTool:
result_loose = tool_search(query="JWT", max_distance=0.01, min_similarity=999.0)
assert len(result_strict["results"]) <= len(result_loose["results"])
def test_list_rooms_rejects_invalid_wing(self, monkeypatch, config, kg):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
result = mcp_server.tool_list_rooms(wing="../etc/passwd")
assert "error" in result
def test_search_rejects_invalid_room(self, monkeypatch, config, kg):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "search_memories", lambda *args, **kwargs: pytest.fail())
result = mcp_server.tool_search(query="JWT", room="../backend")
assert "error" in result
def test_list_drawers_rejects_invalid_wing(self, monkeypatch, config, kg):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
result = mcp_server.tool_list_drawers(wing="../notes")
assert "error" in result
def test_find_tunnels_rejects_invalid_wing(self, monkeypatch, config, kg):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
monkeypatch.setattr(mcp_server, "_get_collection", lambda *args, **kwargs: pytest.fail())
result = mcp_server.tool_find_tunnels(wing_a="../project")
assert "error" in result
def test_wal_redacts_sensitive_fields(self, monkeypatch, config, kg, tmp_path):
_patch_mcp_server(monkeypatch, config, kg)
from mempalace import mcp_server
wal_file = tmp_path / "write_log.jsonl"
monkeypatch.setattr(mcp_server, "_WAL_FILE", wal_file)
mcp_server._wal_log(
"test",
{"content": "secret note", "query": "private search", "safe": "ok"},
)
entry = json.loads(wal_file.read_text().strip())
assert entry["params"]["content"].startswith("[REDACTED")
assert entry["params"]["query"].startswith("[REDACTED")
assert entry["params"]["safe"] == "ok"
# ── Write Tools ─────────────────────────────────────────────────────────
+45
View File
@@ -0,0 +1,45 @@
"""Tests for destructive-operation safety in mempalace.migrate."""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from mempalace.migrate import migrate
def test_migrate_requires_palace_database(tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
result = migrate(str(palace_dir))
out = capsys.readouterr().out
assert result is False
assert "No palace database found" in out
def test_migrate_aborts_without_confirmation(tmp_path, capsys):
palace_dir = tmp_path / "palace"
palace_dir.mkdir()
(palace_dir / "chroma.sqlite3").write_text("db")
mock_chromadb = SimpleNamespace(__version__="0.6.0", PersistentClient=MagicMock())
mock_chromadb.PersistentClient.side_effect = Exception("unreadable")
with (
patch.dict("sys.modules", {"chromadb": mock_chromadb}),
patch("mempalace.migrate.detect_chromadb_version", return_value="0.5.x"),
patch(
"mempalace.migrate.extract_drawers_from_sqlite",
return_value=[{"id": "id1", "document": "doc", "metadata": {"wing": "w", "room": "r"}}],
),
patch("builtins.input", return_value="n"),
patch("mempalace.migrate.shutil.copytree") as mock_copytree,
patch("mempalace.migrate.shutil.rmtree") as mock_rmtree,
):
result = migrate(str(palace_dir))
out = capsys.readouterr().out
assert result is False
assert "Aborted." in out
mock_copytree.assert_not_called()
mock_rmtree.assert_not_called()
+3
View File
@@ -123,6 +123,9 @@ class TestTailTruncation:
class TestLengthGuards:
"""Verify output length constraints."""
def test_max_query_length_reduced(self):
assert MAX_QUERY_LENGTH == 250
def test_output_never_exceeds_max(self):
# Very long question sentence
long_question = "a" * 1000 + "?"