Merge branch 'main' into fix/query-sanitizer-prompt-contamination

This commit is contained in:
Ben Sigman
2026-04-10 22:39:31 -07:00
committed by GitHub
32 changed files with 2380 additions and 231 deletions
+236 -44
View File
@@ -24,8 +24,9 @@ import json
import logging
import hashlib
from datetime import datetime
from pathlib import Path
from .config import MempalaceConfig
from .config import MempalaceConfig, sanitize_name, sanitize_content
from .version import __version__
from .query_sanitizer import sanitize_query
from .searcher import search_memories
@@ -67,16 +68,60 @@ _client_cache = None
_collection_cache = None
# ==================== WRITE-AHEAD LOG ====================
# Every write operation is logged to a JSONL file before execution.
# This provides an audit trail for detecting memory poisoning and
# enables review/rollback of writes from external or untrusted sources.
_WAL_DIR = Path(os.path.expanduser("~/.mempalace/wal"))
_WAL_DIR.mkdir(parents=True, exist_ok=True)
try:
_WAL_DIR.chmod(0o700)
except (OSError, NotImplementedError):
pass
_WAL_FILE = _WAL_DIR / "write_log.jsonl"
def _wal_log(operation: str, params: dict, result: dict = None):
"""Append a write operation to the write-ahead log."""
entry = {
"timestamp": datetime.now().isoformat(),
"operation": operation,
"params": params,
"result": result,
}
try:
with open(_WAL_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, default=str) + "\n")
try:
_WAL_FILE.chmod(0o600)
except (OSError, NotImplementedError):
pass
except Exception as e:
logger.error(f"WAL write failed: {e}")
_client_cache = None
_collection_cache = None
def _get_client():
"""Return a singleton ChromaDB PersistentClient."""
global _client_cache
if _client_cache is None:
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
return _client_cache
def _get_collection(create=False):
"""Return the ChromaDB collection, caching the client between calls."""
global _client_cache, _collection_cache
global _collection_cache
try:
if _client_cache is None:
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
client = _get_client()
if create:
_collection_cache = _client_cache.get_or_create_collection(_config.collection_name)
_collection_cache = client.get_or_create_collection(_config.collection_name)
elif _collection_cache is None:
_collection_cache = _client_cache.get_collection(_config.collection_name)
_collection_cache = client.get_collection(_config.collection_name)
return _collection_cache
except Exception:
return None
@@ -99,16 +144,25 @@ def tool_status():
count = col.count()
wings = {}
rooms = {}
try:
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
for m in all_meta:
w = m.get("wing", "unknown")
r = m.get("room", "unknown")
wings[w] = wings.get(w, 0) + 1
rooms[r] = rooms.get(r, 0) + 1
except Exception:
pass
return {
batch_size = 5000
offset = 0
error_info = None
while True:
try:
batch = col.get(include=["metadatas"], limit=batch_size, offset=offset)
rows = batch["metadatas"]
for m in rows:
w = m.get("wing", "unknown")
r = m.get("room", "unknown")
wings[w] = wings.get(w, 0) + 1
rooms[r] = rooms.get(r, 0) + 1
offset += len(rows)
if len(rows) < batch_size:
break
except Exception as e:
error_info = f"Partial result, failed at offset {offset}: {str(e)}"
break
result = {
"total_drawers": count,
"wings": wings,
"rooms": rooms,
@@ -116,6 +170,10 @@ def tool_status():
"protocol": PALACE_PROTOCOL,
"aaak_dialect": AAAK_SPEC,
}
if error_info:
result["error"] = error_info
result["partial"] = True
return result
# ── AAAK Dialect Spec ─────────────────────────────────────────────────────────
@@ -156,13 +214,28 @@ def tool_list_wings():
if not col:
return _no_palace()
wings = {}
batch_size = 5000
offset = 0
try:
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
for m in all_meta:
w = m.get("wing", "unknown")
wings[w] = wings.get(w, 0) + 1
except Exception:
pass
col.count() # verify collection is accessible
except Exception as e:
return {"wings": {}, "error": str(e)}
while True:
try:
batch = col.get(include=["metadatas"], limit=batch_size, offset=offset)
rows = batch["metadatas"]
for m in rows:
w = m.get("wing", "unknown")
wings[w] = wings.get(w, 0) + 1
offset += len(rows)
if len(rows) < batch_size:
break
except Exception as e:
return {
"wings": wings,
"error": f"Partial result, failed at offset {offset}: {str(e)}",
"partial": True,
}
return {"wings": wings}
@@ -171,16 +244,33 @@ def tool_list_rooms(wing: str = None):
if not col:
return _no_palace()
rooms = {}
batch_size = 5000
offset = 0
where = {"wing": wing} if wing else None
try:
kwargs = {"include": ["metadatas"], "limit": 10000}
if wing:
kwargs["where"] = {"wing": wing}
all_meta = col.get(**kwargs)["metadatas"]
for m in all_meta:
r = m.get("room", "unknown")
rooms[r] = rooms.get(r, 0) + 1
except Exception:
pass
col.count() # verify collection is accessible
except Exception as e:
return {"wing": wing or "all", "rooms": {}, "error": str(e)}
while True:
try:
kwargs = {"include": ["metadatas"], "limit": batch_size, "offset": offset}
if where:
kwargs["where"] = where
batch = col.get(**kwargs)
rows = batch["metadatas"]
for m in rows:
r = m.get("room", "unknown")
rooms[r] = rooms.get(r, 0) + 1
offset += len(rows)
if len(rows) < batch_size:
break
except Exception as e:
return {
"wing": wing or "all",
"rooms": rooms,
"error": f"Partial result, failed at offset {offset}: {str(e)}",
"partial": True,
}
return {"wing": wing or "all", "rooms": rooms}
@@ -189,16 +279,31 @@ def tool_get_taxonomy():
if not col:
return _no_palace()
taxonomy = {}
batch_size = 5000
offset = 0
try:
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
for m in all_meta:
w = m.get("wing", "unknown")
r = m.get("room", "unknown")
if w not in taxonomy:
taxonomy[w] = {}
taxonomy[w][r] = taxonomy[w].get(r, 0) + 1
except Exception:
pass
col.count() # verify collection is accessible
except Exception as e:
return {"taxonomy": {}, "error": str(e)}
while True:
try:
batch = col.get(include=["metadatas"], limit=batch_size, offset=offset)
rows = batch["metadatas"]
for m in rows:
w = m.get("wing", "unknown")
r = m.get("room", "unknown")
if w not in taxonomy:
taxonomy[w] = {}
taxonomy[w][r] = taxonomy[w].get(r, 0) + 1
offset += len(rows)
if len(rows) < batch_size:
break
except Exception as e:
return {
"taxonomy": taxonomy,
"error": f"Partial result, failed at offset {offset}: {str(e)}",
"partial": True,
}
return {"taxonomy": taxonomy}
@@ -299,11 +404,30 @@ def tool_add_drawer(
wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp"
):
"""File verbatim content into a wing/room. Checks for duplicates first."""
try:
wing = sanitize_name(wing, "wing")
room = sanitize_name(room, "room")
content = sanitize_content(content)
except ValueError as e:
return {"success": False, "error": str(e)}
col = _get_collection(create=True)
if not col:
return _no_palace()
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}"
drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((wing + room + content[:100]).encode()).hexdigest()[:24]}"
_wal_log(
"add_drawer",
{
"drawer_id": drawer_id,
"wing": wing,
"room": room,
"added_by": added_by,
"content_length": len(content),
"content_preview": content[:200],
},
)
# Idempotency: if the deterministic ID already exists, return success as a no-op.
try:
@@ -342,6 +466,19 @@ def tool_delete_drawer(drawer_id: str):
existing = col.get(ids=[drawer_id])
if not existing["ids"]:
return {"success": False, "error": f"Drawer not found: {drawer_id}"}
# Log the deletion with the content being removed for audit trail
deleted_content = existing.get("documents", [""])[0] if existing.get("documents") else ""
deleted_meta = existing.get("metadatas", [{}])[0] if existing.get("metadatas") else {}
_wal_log(
"delete_drawer",
{
"drawer_id": drawer_id,
"deleted_meta": deleted_meta,
"content_preview": deleted_content[:200],
},
)
try:
col.delete(ids=[drawer_id])
logger.info(f"Deleted drawer: {drawer_id}")
@@ -363,6 +500,23 @@ def tool_kg_add(
subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None
):
"""Add a relationship to the knowledge graph."""
try:
subject = sanitize_name(subject, "subject")
predicate = sanitize_name(predicate, "predicate")
object = sanitize_name(object, "object")
except ValueError as e:
return {"success": False, "error": str(e)}
_wal_log(
"kg_add",
{
"subject": subject,
"predicate": predicate,
"object": object,
"valid_from": valid_from,
"source_closet": source_closet,
},
)
triple_id = _kg.add_triple(
subject, predicate, object, valid_from=valid_from, source_closet=source_closet
)
@@ -371,6 +525,10 @@ def tool_kg_add(
def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None):
"""Mark a fact as no longer true (set end date)."""
_wal_log(
"kg_invalidate",
{"subject": subject, "predicate": predicate, "object": object, "ended": ended},
)
_kg.invalidate(subject, predicate, object, ended=ended)
return {
"success": True,
@@ -401,6 +559,12 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
This is the agent's personal journal — observations, thoughts,
what it worked on, what it noticed, what it thinks matters.
"""
try:
agent_name = sanitize_name(agent_name, "agent_name")
entry = sanitize_content(entry)
except ValueError as e:
return {"success": False, "error": str(e)}
wing = f"wing_{agent_name.lower().replace(' ', '_')}"
room = "diary"
col = _get_collection(create=True)
@@ -408,9 +572,23 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
return _no_palace()
now = datetime.now()
entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.md5(entry[:50].encode()).hexdigest()[:8]}"
entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.sha256(entry[:50].encode()).hexdigest()[:12]}"
_wal_log(
"diary_write",
{
"agent_name": agent_name,
"topic": topic,
"entry_id": entry_id,
"entry_preview": entry[:200],
},
)
try:
# TODO: Future versions should expand AAAK before embedding to improve
# semantic search quality. For now, store raw AAAK in metadata so it's
# preserved, and keep the document as-is for embedding (even though
# compressed AAAK degrades embedding quality).
col.add(
ids=[entry_id],
documents=[entry],
@@ -744,17 +922,31 @@ TOOLS = {
}
SUPPORTED_PROTOCOL_VERSIONS = [
"2025-11-25",
"2025-06-18",
"2025-03-26",
"2024-11-05",
]
def handle_request(request):
method = request.get("method", "")
params = request.get("params", {})
req_id = request.get("id")
if method == "initialize":
client_version = params.get("protocolVersion", SUPPORTED_PROTOCOL_VERSIONS[-1])
negotiated = (
client_version
if client_version in SUPPORTED_PROTOCOL_VERSIONS
else SUPPORTED_PROTOCOL_VERSIONS[0]
)
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"protocolVersion": "2024-11-05",
"protocolVersion": negotiated,
"capabilities": {"tools": {}},
"serverInfo": {"name": "mempalace", "version": __version__},
},
@@ -774,7 +966,7 @@ def handle_request(request):
}
elif method == "tools/call":
tool_name = params.get("name")
tool_args = params.get("arguments", {})
tool_args = params.get("arguments") or {}
if tool_name not in TOOLS:
return {
"jsonrpc": "2.0",