security: harden inputs, fix shell injection, optimize DB access

- Fix command injection in hook script (pass paths via sys.argv)
- Add sanitize_name/sanitize_content validators in config.py
- Add 10MB file size guard + symlink skip in miners
- Fix SQLite connection leak in knowledge_graph.py (reuse connection)
- Use `with conn:` for proper transaction handling
- Consolidate shared palace operations into palace.py
- Add write-ahead log for audit trail on writes/deletes
- Add metadata cache with 30s TTL for status/taxonomy calls
- Upgrade md5 → sha256 for drawer/triple IDs
- Harden file permissions (0o700/0o600)
- Pin chromadb>=0.5.0,<0.7

Based on PR #252 by @anthonyonazure with lint fixes applied.

Co-Authored-By: anthonyonazure <anthonyonazure@users.noreply.github.com>
This commit is contained in:
bensig
2026-04-09 08:06:30 -07:00
parent 963c04cf45
commit 1d19dfc9d5
8 changed files with 389 additions and 203 deletions
+15 -8
View File
@@ -64,13 +64,20 @@ MEMPAL_DIR=""
# Read JSON input from stdin # Read JSON input from stdin
INPUT=$(cat) INPUT=$(cat)
# Parse fields from Claude Code's JSON # Parse all fields in a single Python call (3x faster than separate invocations)
SESSION_ID=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('session_id','unknown'))" 2>/dev/null) eval $(echo "$INPUT" | python3 -c "
# Sanitize SESSION_ID to prevent path traversal (only allow alnum, dash, underscore) import sys, json
SESSION_ID=$(echo "$SESSION_ID" | tr -cd 'a-zA-Z0-9_-') data = json.load(sys.stdin)
[ -z "$SESSION_ID" ] && SESSION_ID="unknown" sid = data.get('session_id', 'unknown')
STOP_HOOK_ACTIVE=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('stop_hook_active', False))" 2>/dev/null) sha = data.get('stop_hook_active', False)
TRANSCRIPT_PATH=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('transcript_path',''))" 2>/dev/null) tp = data.get('transcript_path', '')
# Shell-safe output — only allow alphanumeric, underscore, hyphen, slash, dot, tilde
import re
safe = lambda s: re.sub(r'[^a-zA-Z0-9_/.\-~]', '', str(s))
print(f'SESSION_ID=\"{safe(sid)}\"')
print(f'STOP_HOOK_ACTIVE=\"{sha}\"')
print(f'TRANSCRIPT_PATH=\"{safe(tp)}\"')
" 2>/dev/null)
# Expand ~ in path # Expand ~ in path
TRANSCRIPT_PATH="${TRANSCRIPT_PATH/#\~/$HOME}" TRANSCRIPT_PATH="${TRANSCRIPT_PATH/#\~/$HOME}"
@@ -83,6 +90,7 @@ if [ "$STOP_HOOK_ACTIVE" = "True" ] || [ "$STOP_HOOK_ACTIVE" = "true" ]; then
fi fi
# Count human messages in the JSONL transcript # Count human messages in the JSONL transcript
# SECURITY: Pass transcript path as sys.argv to avoid shell injection via crafted paths
if [ -f "$TRANSCRIPT_PATH" ]; then if [ -f "$TRANSCRIPT_PATH" ]; then
EXCHANGE_COUNT=$(python3 - "$TRANSCRIPT_PATH" <<'PYEOF' EXCHANGE_COUNT=$(python3 - "$TRANSCRIPT_PATH" <<'PYEOF'
import json, sys import json, sys
@@ -94,7 +102,6 @@ with open(sys.argv[1]) as f:
msg = entry.get('message', {}) msg = entry.get('message', {})
if isinstance(msg, dict) and msg.get('role') == 'user': if isinstance(msg, dict) and msg.get('role') == 'user':
content = msg.get('content', '') content = msg.get('content', '')
# Skip system/command messages — only count real human input
if isinstance(content, str) and '<command-message>' in content: if isinstance(content, str) and '<command-message>' in content:
continue continue
count += 1 count += 1
+56
View File
@@ -6,8 +6,54 @@ Priority: env vars > config file (~/.mempalace/config.json) > defaults
import json import json
import os import os
import re
from pathlib import Path from pathlib import Path
# ── Input validation ──────────────────────────────────────────────────────────
# Shared sanitizers for wing/room/entity names. Prevents path traversal,
# excessively long strings, and special characters that could cause issues
# in file paths, SQLite, or ChromaDB metadata.
MAX_NAME_LENGTH = 128
_SAFE_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_ .'-]{0,126}[a-zA-Z0-9]?$")
def sanitize_name(value: str, field_name: str = "name") -> str:
"""Validate and sanitize a wing/room/entity name.
Raises ValueError if the name is invalid.
"""
if not isinstance(value, str) or not value.strip():
raise ValueError(f"{field_name} must be a non-empty string")
value = value.strip()
if len(value) > MAX_NAME_LENGTH:
raise ValueError(f"{field_name} exceeds maximum length of {MAX_NAME_LENGTH} characters")
# Block path traversal
if ".." in value or "/" in value or "\\" in value:
raise ValueError(f"{field_name} contains invalid path characters")
# Block null bytes
if "\x00" in value:
raise ValueError(f"{field_name} contains null bytes")
return value
def sanitize_content(value: str, max_length: int = 100_000) -> str:
"""Validate drawer/diary content length."""
if not isinstance(value, str) or not value.strip():
raise ValueError("content must be a non-empty string")
if len(value) > max_length:
raise ValueError(f"content exceeds maximum length of {max_length} characters")
if "\x00" in value:
raise ValueError("content contains null bytes")
return value
DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace") DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace")
DEFAULT_COLLECTION_NAME = "mempalace_drawers" DEFAULT_COLLECTION_NAME = "mempalace_drawers"
@@ -126,6 +172,11 @@ class MempalaceConfig:
def init(self): def init(self):
"""Create config directory and write default config.json if it doesn't exist.""" """Create config directory and write default config.json if it doesn't exist."""
self._config_dir.mkdir(parents=True, exist_ok=True) self._config_dir.mkdir(parents=True, exist_ok=True)
# Restrict directory permissions to owner only (Unix)
try:
self._config_dir.chmod(0o700)
except (OSError, NotImplementedError):
pass # Windows doesn't support Unix permissions
if not self._config_file.exists(): if not self._config_file.exists():
default_config = { default_config = {
"palace_path": DEFAULT_PALACE_PATH, "palace_path": DEFAULT_PALACE_PATH,
@@ -135,6 +186,11 @@ class MempalaceConfig:
} }
with open(self._config_file, "w") as f: with open(self._config_file, "w") as f:
json.dump(default_config, f, indent=2) json.dump(default_config, f, indent=2)
# Restrict config file to owner read/write only
try:
self._config_file.chmod(0o600)
except (OSError, NotImplementedError):
pass
return self._config_file return self._config_file
def save_people_map(self, people_map): def save_people_map(self, people_map):
+11 -35
View File
@@ -15,9 +15,8 @@ from pathlib import Path
from datetime import datetime from datetime import datetime
from collections import defaultdict from collections import defaultdict
import chromadb
from .normalize import normalize from .normalize import normalize
from .palace import SKIP_DIRS, get_collection, file_already_mined
# File types that might contain conversations # File types that might contain conversations
@@ -28,22 +27,8 @@ CONVO_EXTENSIONS = {
".jsonl", ".jsonl",
} }
SKIP_DIRS = {
".git",
"node_modules",
"__pycache__",
".venv",
"venv",
"env",
"dist",
"build",
".next",
".mempalace",
"tool-results",
"memory",
}
MIN_CHUNK_SIZE = 30 MIN_CHUNK_SIZE = 30
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this
# ============================================================================= # =============================================================================
@@ -211,23 +196,6 @@ def detect_convo_room(content: str) -> str:
# ============================================================================= # =============================================================================
def get_collection(palace_path: str):
os.makedirs(palace_path, exist_ok=True)
client = chromadb.PersistentClient(path=palace_path)
try:
return client.get_collection("mempalace_drawers")
except Exception:
return client.create_collection("mempalace_drawers")
def file_already_mined(collection, source_file: str) -> bool:
try:
results = collection.get(where={"source_file": source_file}, limit=1)
return len(results.get("ids", [])) > 0
except Exception:
return False
# ============================================================================= # =============================================================================
# SCAN FOR CONVERSATION FILES # SCAN FOR CONVERSATION FILES
# ============================================================================= # =============================================================================
@@ -244,6 +212,14 @@ def scan_convos(convo_dir: str) -> list:
continue continue
filepath = Path(root) / filename filepath = Path(root) / filename
if filepath.suffix.lower() in CONVO_EXTENSIONS: if filepath.suffix.lower() in CONVO_EXTENSIONS:
# Skip symlinks and oversized files
if filepath.is_symlink():
continue
try:
if filepath.stat().st_size > MAX_FILE_SIZE:
continue
except OSError:
continue
files.append(filepath) files.append(filepath)
return files return files
@@ -356,7 +332,7 @@ def mine_convos(
chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room
if extract_mode == "general": if extract_mode == "general":
room_counts[chunk_room] += 1 room_counts[chunk_room] += 1
drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.md5((source_file + str(chunk['chunk_index'])).encode(), usedforsecurity=False).hexdigest()[:16]}" drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}"
try: try:
collection.add( collection.add(
documents=[chunk["content"]], documents=[chunk["content"]],
+87 -77
View File
@@ -50,11 +50,15 @@ class KnowledgeGraph:
def __init__(self, db_path: str = None): def __init__(self, db_path: str = None):
self.db_path = db_path or DEFAULT_KG_PATH self.db_path = db_path or DEFAULT_KG_PATH
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._connection = None
self._init_db() self._init_db()
def _init_db(self): def _init_db(self):
conn = self._conn() conn = self._conn()
conn.executescript(""" conn.executescript("""
PRAGMA journal_mode=WAL;
PRAGMA foreign_keys=ON;
CREATE TABLE IF NOT EXISTS entities ( CREATE TABLE IF NOT EXISTS entities (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
name TEXT NOT NULL, name TEXT NOT NULL,
@@ -84,12 +88,22 @@ class KnowledgeGraph:
CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to); CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to);
""") """)
conn.commit() conn.commit()
conn.close()
def _conn(self): def _conn(self):
conn = sqlite3.connect(self.db_path, timeout=10) if self._connection is None:
conn.execute("PRAGMA journal_mode=WAL") self._connection = sqlite3.connect(self.db_path, timeout=10)
return conn self._connection.execute("PRAGMA journal_mode=WAL")
self._connection.row_factory = sqlite3.Row
return self._connection
def close(self):
"""Close the database connection."""
if self._connection is not None:
self._connection.close()
self._connection = None
def __del__(self):
self.close()
def _entity_id(self, name: str) -> str: def _entity_id(self, name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "") return name.lower().replace(" ", "_").replace("'", "")
@@ -101,12 +115,11 @@ class KnowledgeGraph:
eid = self._entity_id(name) eid = self._entity_id(name)
props = json.dumps(properties or {}) props = json.dumps(properties or {})
conn = self._conn() conn = self._conn()
conn.execute( with conn:
"INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)", conn.execute(
(eid, name, entity_type, props), "INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)",
) (eid, name, entity_type, props),
conn.commit() )
conn.close()
return eid return eid
def add_triple( def add_triple(
@@ -134,38 +147,38 @@ class KnowledgeGraph:
# Auto-create entities if they don't exist # Auto-create entities if they don't exist
conn = self._conn() conn = self._conn()
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject)) with conn:
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj)) conn.execute(
"INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject)
)
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj))
# Check for existing identical triple # Check for existing identical triple
existing = conn.execute( existing = conn.execute(
"SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", "SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
(sub_id, pred, obj_id), (sub_id, pred, obj_id),
).fetchone() ).fetchone()
if existing: if existing:
conn.close() return existing["id"] # Already exists and still valid
return existing[0] # Already exists and still valid
triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.md5(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:8]}" triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.sha256(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:12]}"
conn.execute( conn.execute(
"""INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file) """INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
( (
triple_id, triple_id,
sub_id, sub_id,
pred, pred,
obj_id, obj_id,
valid_from, valid_from,
valid_to, valid_to,
confidence, confidence,
source_closet, source_closet,
source_file, source_file,
), ),
) )
conn.commit()
conn.close()
return triple_id return triple_id
def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None): def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None):
@@ -176,12 +189,11 @@ class KnowledgeGraph:
ended = ended or date.today().isoformat() ended = ended or date.today().isoformat()
conn = self._conn() conn = self._conn()
conn.execute( with conn:
"UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", conn.execute(
(ended, sub_id, pred, obj_id), "UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
) (ended, sub_id, pred, obj_id),
conn.commit() )
conn.close()
# ── Query operations ────────────────────────────────────────────────── # ── Query operations ──────────────────────────────────────────────────
@@ -208,13 +220,13 @@ class KnowledgeGraph:
{ {
"direction": "outgoing", "direction": "outgoing",
"subject": name, "subject": name,
"predicate": row[2], "predicate": row["predicate"],
"object": row[10], # obj_name "object": row["obj_name"],
"valid_from": row[4], "valid_from": row["valid_from"],
"valid_to": row[5], "valid_to": row["valid_to"],
"confidence": row[6], "confidence": row["confidence"],
"source_closet": row[7], "source_closet": row["source_closet"],
"current": row[5] is None, "current": row["valid_to"] is None,
} }
) )
@@ -228,18 +240,17 @@ class KnowledgeGraph:
results.append( results.append(
{ {
"direction": "incoming", "direction": "incoming",
"subject": row[10], # sub_name "subject": row["sub_name"],
"predicate": row[2], "predicate": row["predicate"],
"object": name, "object": name,
"valid_from": row[4], "valid_from": row["valid_from"],
"valid_to": row[5], "valid_to": row["valid_to"],
"confidence": row[6], "confidence": row["confidence"],
"source_closet": row[7], "source_closet": row["source_closet"],
"current": row[5] is None, "current": row["valid_to"] is None,
} }
) )
conn.close()
return results return results
def query_relationship(self, predicate: str, as_of: str = None): def query_relationship(self, predicate: str, as_of: str = None):
@@ -262,15 +273,14 @@ class KnowledgeGraph:
for row in conn.execute(query, params).fetchall(): for row in conn.execute(query, params).fetchall():
results.append( results.append(
{ {
"subject": row[10], "subject": row["sub_name"],
"predicate": pred, "predicate": pred,
"object": row[11], "object": row["obj_name"],
"valid_from": row[4], "valid_from": row["valid_from"],
"valid_to": row[5], "valid_to": row["valid_to"],
"current": row[5] is None, "current": row["valid_to"] is None,
} }
) )
conn.close()
return results return results
def timeline(self, entity_name: str = None): def timeline(self, entity_name: str = None):
@@ -300,15 +310,14 @@ class KnowledgeGraph:
LIMIT 100 LIMIT 100
""").fetchall() """).fetchall()
conn.close()
return [ return [
{ {
"subject": r[10], "subject": r["sub_name"],
"predicate": r[2], "predicate": r["predicate"],
"object": r[11], "object": r["obj_name"],
"valid_from": r[4], "valid_from": r["valid_from"],
"valid_to": r[5], "valid_to": r["valid_to"],
"current": r[5] is None, "current": r["valid_to"] is None,
} }
for r in rows for r in rows
] ]
@@ -317,17 +326,18 @@ class KnowledgeGraph:
def stats(self): def stats(self):
conn = self._conn() conn = self._conn()
entities = conn.execute("SELECT COUNT(*) FROM entities").fetchone()[0] entities = conn.execute("SELECT COUNT(*) as cnt FROM entities").fetchone()["cnt"]
triples = conn.execute("SELECT COUNT(*) FROM triples").fetchone()[0] triples = conn.execute("SELECT COUNT(*) as cnt FROM triples").fetchone()["cnt"]
current = conn.execute("SELECT COUNT(*) FROM triples WHERE valid_to IS NULL").fetchone()[0] current = conn.execute(
"SELECT COUNT(*) as cnt FROM triples WHERE valid_to IS NULL"
).fetchone()["cnt"]
expired = triples - current expired = triples - current
predicates = [ predicates = [
r[0] r["predicate"]
for r in conn.execute( for r in conn.execute(
"SELECT DISTINCT predicate FROM triples ORDER BY predicate" "SELECT DISTINCT predicate FROM triples ORDER BY predicate"
).fetchall() ).fetchall()
] ]
conn.close()
return { return {
"entities": entities, "entities": entities,
"triples": triples, "triples": triples,
+161 -25
View File
@@ -23,9 +23,11 @@ import sys
import json import json
import logging import logging
import hashlib import hashlib
import time
from datetime import datetime from datetime import datetime
from pathlib import Path
from .config import MempalaceConfig from .config import MempalaceConfig, sanitize_name, sanitize_content
from .version import __version__ from .version import __version__
from .searcher import search_memories from .searcher import search_memories
from .palace_graph import traverse, find_tunnels, graph_stats from .palace_graph import traverse, find_tunnels, graph_stats
@@ -66,12 +68,64 @@ _client_cache = None
_collection_cache = None _collection_cache = None
# ==================== WRITE-AHEAD LOG ====================
# Every write operation is logged to a JSONL file before execution.
# This provides an audit trail for detecting memory poisoning and
# enables review/rollback of writes from external or untrusted sources.
_WAL_DIR = Path(os.path.expanduser("~/.mempalace/wal"))
_WAL_DIR.mkdir(parents=True, exist_ok=True)
_WAL_FILE = _WAL_DIR / "write_log.jsonl"
def _wal_log(operation: str, params: dict, result: dict = None):
"""Append a write operation to the write-ahead log."""
entry = {
"timestamp": datetime.now().isoformat(),
"operation": operation,
"params": params,
"result": result,
}
try:
with open(_WAL_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, default=str) + "\n")
except Exception as e:
logger.error(f"WAL write failed: {e}")
_client = None
def _get_client():
"""Return a singleton ChromaDB PersistentClient."""
global _client
if _client is None:
_client = chromadb.PersistentClient(path=_config.palace_path)
return _client
_meta_cache = {"data": None, "timestamp": 0, "ttl": 30} # 30 second TTL
def _get_cached_metadata():
"""Return all record metadatas with a time-based cache to avoid repeated full scans."""
now = time.time()
if _meta_cache["data"] is not None and (now - _meta_cache["timestamp"]) < _meta_cache["ttl"]:
return _meta_cache["data"]
col = _get_collection()
if not col:
return None
all_meta = col.get(include=["metadatas"])["metadatas"]
_meta_cache["data"] = all_meta
_meta_cache["timestamp"] = now
return all_meta
def _get_collection(create=False): def _get_collection(create=False):
"""Return the ChromaDB collection, caching the client between calls.""" """Return the ChromaDB collection, caching the client between calls."""
global _client_cache, _collection_cache global _client_cache, _collection_cache
try: try:
if _client_cache is None: _get_client()
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
if create: if create:
_collection_cache = _client_cache.get_or_create_collection(_config.collection_name) _collection_cache = _client_cache.get_or_create_collection(_config.collection_name)
elif _collection_cache is None: elif _collection_cache is None:
@@ -99,12 +153,13 @@ def tool_status():
wings = {} wings = {}
rooms = {} rooms = {}
try: try:
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"] all_meta = _get_cached_metadata()
for m in all_meta: if all_meta:
w = m.get("wing", "unknown") for m in all_meta:
r = m.get("room", "unknown") w = m.get("wing", "unknown")
wings[w] = wings.get(w, 0) + 1 r = m.get("room", "unknown")
rooms[r] = rooms.get(r, 0) + 1 wings[w] = wings.get(w, 0) + 1
rooms[r] = rooms.get(r, 0) + 1
except Exception: except Exception:
pass pass
return { return {
@@ -156,10 +211,11 @@ def tool_list_wings():
return _no_palace() return _no_palace()
wings = {} wings = {}
try: try:
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"] all_meta = _get_cached_metadata()
for m in all_meta: if all_meta:
w = m.get("wing", "unknown") for m in all_meta:
wings[w] = wings.get(w, 0) + 1 w = m.get("wing", "unknown")
wings[w] = wings.get(w, 0) + 1
except Exception: except Exception:
pass pass
return {"wings": wings} return {"wings": wings}
@@ -171,10 +227,12 @@ def tool_list_rooms(wing: str = None):
return _no_palace() return _no_palace()
rooms = {} rooms = {}
try: try:
kwargs = {"include": ["metadatas"], "limit": 10000}
if wing: if wing:
kwargs["where"] = {"wing": wing} # Filtered query — cannot use the full metadata cache
all_meta = col.get(**kwargs)["metadatas"] all_meta = col.get(include=["metadatas"], where={"wing": wing})["metadatas"]
else:
# No filter — use the cached metadata
all_meta = _get_cached_metadata() or []
for m in all_meta: for m in all_meta:
r = m.get("room", "unknown") r = m.get("room", "unknown")
rooms[r] = rooms.get(r, 0) + 1 rooms[r] = rooms.get(r, 0) + 1
@@ -189,13 +247,14 @@ def tool_get_taxonomy():
return _no_palace() return _no_palace()
taxonomy = {} taxonomy = {}
try: try:
all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"] all_meta = _get_cached_metadata()
for m in all_meta: if all_meta:
w = m.get("wing", "unknown") for m in all_meta:
r = m.get("room", "unknown") w = m.get("wing", "unknown")
if w not in taxonomy: r = m.get("room", "unknown")
taxonomy[w] = {} if w not in taxonomy:
taxonomy[w][r] = taxonomy[w].get(r, 0) + 1 taxonomy[w] = {}
taxonomy[w][r] = taxonomy[w].get(r, 0) + 1
except Exception: except Exception:
pass pass
return {"taxonomy": taxonomy} return {"taxonomy": taxonomy}
@@ -282,11 +341,30 @@ def tool_add_drawer(
wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp" wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp"
): ):
"""File verbatim content into a wing/room. Checks for duplicates first.""" """File verbatim content into a wing/room. Checks for duplicates first."""
try:
wing = sanitize_name(wing, "wing")
room = sanitize_name(room, "room")
content = sanitize_content(content)
except ValueError as e:
return {"success": False, "error": str(e)}
col = _get_collection(create=True) col = _get_collection(create=True)
if not col: if not col:
return _no_palace() return _no_palace()
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}" drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((content[:100] + datetime.now().isoformat()).encode()).hexdigest()[:24]}"
_wal_log(
"add_drawer",
{
"drawer_id": drawer_id,
"wing": wing,
"room": room,
"added_by": added_by,
"content_length": len(content),
"content_preview": content[:200],
},
)
# Idempotency: if the deterministic ID already exists, return success as a no-op. # Idempotency: if the deterministic ID already exists, return success as a no-op.
try: try:
@@ -311,6 +389,7 @@ def tool_add_drawer(
} }
], ],
) )
_meta_cache["data"] = None # Invalidate metadata cache
logger.info(f"Filed drawer: {drawer_id}{wing}/{room}") logger.info(f"Filed drawer: {drawer_id}{wing}/{room}")
return {"success": True, "drawer_id": drawer_id, "wing": wing, "room": room} return {"success": True, "drawer_id": drawer_id, "wing": wing, "room": room}
except Exception as e: except Exception as e:
@@ -325,8 +404,22 @@ def tool_delete_drawer(drawer_id: str):
existing = col.get(ids=[drawer_id]) existing = col.get(ids=[drawer_id])
if not existing["ids"]: if not existing["ids"]:
return {"success": False, "error": f"Drawer not found: {drawer_id}"} return {"success": False, "error": f"Drawer not found: {drawer_id}"}
# Log the deletion with the content being removed for audit trail
deleted_content = existing.get("documents", [""])[0] if existing.get("documents") else ""
deleted_meta = existing.get("metadatas", [{}])[0] if existing.get("metadatas") else {}
_wal_log(
"delete_drawer",
{
"drawer_id": drawer_id,
"deleted_meta": deleted_meta,
"content_preview": deleted_content[:200],
},
)
try: try:
col.delete(ids=[drawer_id]) col.delete(ids=[drawer_id])
_meta_cache["data"] = None # Invalidate metadata cache
logger.info(f"Deleted drawer: {drawer_id}") logger.info(f"Deleted drawer: {drawer_id}")
return {"success": True, "drawer_id": drawer_id} return {"success": True, "drawer_id": drawer_id}
except Exception as e: except Exception as e:
@@ -346,6 +439,23 @@ def tool_kg_add(
subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None
): ):
"""Add a relationship to the knowledge graph.""" """Add a relationship to the knowledge graph."""
try:
subject = sanitize_name(subject, "subject")
predicate = sanitize_name(predicate, "predicate")
object = sanitize_name(object, "object")
except ValueError as e:
return {"success": False, "error": str(e)}
_wal_log(
"kg_add",
{
"subject": subject,
"predicate": predicate,
"object": object,
"valid_from": valid_from,
"source_closet": source_closet,
},
)
triple_id = _kg.add_triple( triple_id = _kg.add_triple(
subject, predicate, object, valid_from=valid_from, source_closet=source_closet subject, predicate, object, valid_from=valid_from, source_closet=source_closet
) )
@@ -354,6 +464,10 @@ def tool_kg_add(
def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None): def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None):
"""Mark a fact as no longer true (set end date).""" """Mark a fact as no longer true (set end date)."""
_wal_log(
"kg_invalidate",
{"subject": subject, "predicate": predicate, "object": object, "ended": ended},
)
_kg.invalidate(subject, predicate, object, ended=ended) _kg.invalidate(subject, predicate, object, ended=ended)
return { return {
"success": True, "success": True,
@@ -384,6 +498,12 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
This is the agent's personal journal — observations, thoughts, This is the agent's personal journal — observations, thoughts,
what it worked on, what it noticed, what it thinks matters. what it worked on, what it noticed, what it thinks matters.
""" """
try:
agent_name = sanitize_name(agent_name, "agent_name")
entry = sanitize_content(entry)
except ValueError as e:
return {"success": False, "error": str(e)}
wing = f"wing_{agent_name.lower().replace(' ', '_')}" wing = f"wing_{agent_name.lower().replace(' ', '_')}"
room = "diary" room = "diary"
col = _get_collection(create=True) col = _get_collection(create=True)
@@ -391,9 +511,23 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
return _no_palace() return _no_palace()
now = datetime.now() now = datetime.now()
entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.md5(entry[:50].encode()).hexdigest()[:8]}" entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.sha256(entry[:50].encode()).hexdigest()[:12]}"
_wal_log(
"diary_write",
{
"agent_name": agent_name,
"topic": topic,
"entry_id": entry_id,
"entry_preview": entry[:200],
},
)
try: try:
# TODO: Future versions should expand AAAK before embedding to improve
# semantic search quality. For now, store raw AAAK in metadata so it's
# preserved, and keep the document as-is for embedding (even though
# compressed AAAK degrades embedding quality).
col.add( col.add(
ids=[entry_id], ids=[entry_id],
documents=[entry], documents=[entry],
@@ -407,9 +541,11 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
"agent": agent_name, "agent": agent_name,
"filed_at": now.isoformat(), "filed_at": now.isoformat(),
"date": now.strftime("%Y-%m-%d"), "date": now.strftime("%Y-%m-%d"),
"raw_aaak": entry,
} }
], ],
) )
_meta_cache["data"] = None # Invalidate metadata cache
logger.info(f"Diary entry: {entry_id}{wing}/diary/{topic}") logger.info(f"Diary entry: {entry_id}{wing}/diary/{topic}")
return { return {
"success": True, "success": True,
+13 -57
View File
@@ -17,6 +17,8 @@ from collections import defaultdict
import chromadb import chromadb
from .palace import SKIP_DIRS, get_collection, file_already_mined
READABLE_EXTENSIONS = { READABLE_EXTENSIONS = {
".txt", ".txt",
".md", ".md",
@@ -40,32 +42,6 @@ READABLE_EXTENSIONS = {
".toml", ".toml",
} }
SKIP_DIRS = {
".git",
"node_modules",
"__pycache__",
".venv",
"venv",
"env",
"dist",
"build",
".next",
"coverage",
".mempalace",
".ruff_cache",
".mypy_cache",
".pytest_cache",
".cache",
".tox",
".nox",
".idea",
".vscode",
".ipynb_checkpoints",
".eggs",
"htmlcov",
"target",
}
SKIP_FILENAMES = { SKIP_FILENAMES = {
"mempalace.yaml", "mempalace.yaml",
"mempalace.yml", "mempalace.yml",
@@ -78,6 +54,7 @@ SKIP_FILENAMES = {
CHUNK_SIZE = 800 # chars per drawer CHUNK_SIZE = 800 # chars per drawer
CHUNK_OVERLAP = 100 # overlap between chunks CHUNK_OVERLAP = 100 # overlap between chunks
MIN_CHUNK_SIZE = 50 # skip tiny chunks MIN_CHUNK_SIZE = 50 # skip tiny chunks
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this
# ============================================================================= # =============================================================================
@@ -393,41 +370,11 @@ def chunk_text(content: str, source_file: str) -> list:
# ============================================================================= # =============================================================================
def get_collection(palace_path: str):
os.makedirs(palace_path, exist_ok=True)
client = chromadb.PersistentClient(path=palace_path)
try:
return client.get_collection("mempalace_drawers")
except Exception:
return client.create_collection("mempalace_drawers")
def file_already_mined(collection, source_file: str) -> bool:
"""Fast check: has this file been filed before and is unchanged?
Compares the stored mtime in drawer metadata against the file's current
mtime. Returns False (needs re-mining) when the file has been modified
since it was last mined, or when no mtime was stored.
"""
try:
results = collection.get(where={"source_file": source_file}, limit=1)
if not results.get("ids"):
return False
stored_meta = results["metadatas"][0] if results.get("metadatas") else {}
stored_mtime = stored_meta.get("source_mtime")
if stored_mtime is None:
return False
current_mtime = os.path.getmtime(source_file)
return float(stored_mtime) == current_mtime
except Exception:
return False
def add_drawer( def add_drawer(
collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str
): ):
"""Add one drawer to the palace.""" """Add one drawer to the palace."""
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}" drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((source_file + str(chunk_index)).encode()).hexdigest()[:24]}"
try: try:
metadata = { metadata = {
"wing": wing, "wing": wing,
@@ -562,6 +509,15 @@ def scan_project(
if respect_gitignore and active_matchers and not force_include: if respect_gitignore and active_matchers and not force_include:
if is_gitignored(filepath, active_matchers, is_dir=False): if is_gitignored(filepath, active_matchers, is_dir=False):
continue continue
# Skip symlinks — prevents following links to /dev/urandom, etc.
if filepath.is_symlink():
continue
# Skip files exceeding size limit
try:
if filepath.stat().st_size > MAX_FILE_SIZE:
continue
except OSError:
continue
files.append(filepath) files.append(filepath)
return files return files
+45
View File
@@ -0,0 +1,45 @@
"""
palace.py — Shared palace operations.
Consolidates ChromaDB access patterns used by both miners and the MCP server.
"""
import os
import chromadb
SKIP_DIRS = {
".git",
"node_modules",
"__pycache__",
".venv",
"venv",
"env",
"dist",
"build",
".next",
"coverage",
".mempalace",
}
def get_collection(palace_path: str, collection_name: str = "mempalace_drawers"):
"""Get or create the palace ChromaDB collection."""
os.makedirs(palace_path, exist_ok=True)
try:
os.chmod(palace_path, 0o700)
except (OSError, NotImplementedError):
pass
client = chromadb.PersistentClient(path=palace_path)
try:
return client.get_collection(collection_name)
except Exception:
return client.create_collection(collection_name)
def file_already_mined(collection, source_file: str) -> bool:
"""Check if a file has already been filed in the palace."""
try:
results = collection.get(where={"source_file": source_file}, limit=1)
return len(results.get("ids", [])) > 0
except Exception:
return False
+1 -1
View File
@@ -26,7 +26,7 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"chromadb>=0.5.0,<0.7", "chromadb>=0.5.0,<0.7",
"pyyaml>=6.0", "pyyaml>=6.0,<7",
] ]
[project.urls] [project.urls]