Merge pull request #387 from milla-jovovich/ben/security-hardening

security: harden inputs, fix shell injection, optimize DB access
This commit is contained in:
Ben Sigman
2026-04-09 09:13:09 -07:00
committed by GitHub
9 changed files with 434 additions and 187 deletions
+15 -8
View File
@@ -64,13 +64,20 @@ MEMPAL_DIR=""
# Read JSON input from stdin # Read JSON input from stdin
INPUT=$(cat) INPUT=$(cat)
# Parse fields from Claude Code's JSON # Parse all fields in a single Python call (3x faster than separate invocations)
SESSION_ID=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('session_id','unknown'))" 2>/dev/null) eval $(echo "$INPUT" | python3 -c "
# Sanitize SESSION_ID to prevent path traversal (only allow alnum, dash, underscore) import sys, json
SESSION_ID=$(echo "$SESSION_ID" | tr -cd 'a-zA-Z0-9_-') data = json.load(sys.stdin)
[ -z "$SESSION_ID" ] && SESSION_ID="unknown" sid = data.get('session_id', 'unknown')
STOP_HOOK_ACTIVE=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('stop_hook_active', False))" 2>/dev/null) sha = data.get('stop_hook_active', False)
TRANSCRIPT_PATH=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('transcript_path',''))" 2>/dev/null) tp = data.get('transcript_path', '')
# Shell-safe output — only allow alphanumeric, underscore, hyphen, slash, dot, tilde
import re
safe = lambda s: re.sub(r'[^a-zA-Z0-9_/.\-~]', '', str(s))
print(f'SESSION_ID=\"{safe(sid)}\"')
print(f'STOP_HOOK_ACTIVE=\"{sha}\"')
print(f'TRANSCRIPT_PATH=\"{safe(tp)}\"')
" 2>/dev/null)
# Expand ~ in path # Expand ~ in path
TRANSCRIPT_PATH="${TRANSCRIPT_PATH/#\~/$HOME}" TRANSCRIPT_PATH="${TRANSCRIPT_PATH/#\~/$HOME}"
@@ -83,6 +90,7 @@ if [ "$STOP_HOOK_ACTIVE" = "True" ] || [ "$STOP_HOOK_ACTIVE" = "true" ]; then
fi fi
# Count human messages in the JSONL transcript # Count human messages in the JSONL transcript
# SECURITY: Pass transcript path as sys.argv to avoid shell injection via crafted paths
if [ -f "$TRANSCRIPT_PATH" ]; then if [ -f "$TRANSCRIPT_PATH" ]; then
EXCHANGE_COUNT=$(python3 - "$TRANSCRIPT_PATH" <<'PYEOF' EXCHANGE_COUNT=$(python3 - "$TRANSCRIPT_PATH" <<'PYEOF'
import json, sys import json, sys
@@ -94,7 +102,6 @@ with open(sys.argv[1]) as f:
msg = entry.get('message', {}) msg = entry.get('message', {})
if isinstance(msg, dict) and msg.get('role') == 'user': if isinstance(msg, dict) and msg.get('role') == 'user':
content = msg.get('content', '') content = msg.get('content', '')
# Skip system/command messages — only count real human input
if isinstance(content, str) and '<command-message>' in content: if isinstance(content, str) and '<command-message>' in content:
continue continue
count += 1 count += 1
+60
View File
@@ -6,8 +6,58 @@ Priority: env vars > config file (~/.mempalace/config.json) > defaults
import json import json
import os import os
import re
from pathlib import Path from pathlib import Path
# ── Input validation ──────────────────────────────────────────────────────────
# Shared sanitizers for wing/room/entity names. Prevents path traversal,
# excessively long strings, and special characters that could cause issues
# in file paths, SQLite, or ChromaDB metadata.
MAX_NAME_LENGTH = 128
_SAFE_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_ .'-]{0,126}[a-zA-Z0-9]?$")
def sanitize_name(value: str, field_name: str = "name") -> str:
"""Validate and sanitize a wing/room/entity name.
Raises ValueError if the name is invalid.
"""
if not isinstance(value, str) or not value.strip():
raise ValueError(f"{field_name} must be a non-empty string")
value = value.strip()
if len(value) > MAX_NAME_LENGTH:
raise ValueError(f"{field_name} exceeds maximum length of {MAX_NAME_LENGTH} characters")
# Block path traversal
if ".." in value or "/" in value or "\\" in value:
raise ValueError(f"{field_name} contains invalid path characters")
# Block null bytes
if "\x00" in value:
raise ValueError(f"{field_name} contains null bytes")
# Enforce safe character set
if not _SAFE_NAME_RE.match(value):
raise ValueError(f"{field_name} contains invalid characters")
return value
def sanitize_content(value: str, max_length: int = 100_000) -> str:
"""Validate drawer/diary content length."""
if not isinstance(value, str) or not value.strip():
raise ValueError("content must be a non-empty string")
if len(value) > max_length:
raise ValueError(f"content exceeds maximum length of {max_length} characters")
if "\x00" in value:
raise ValueError("content contains null bytes")
return value
DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace") DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace")
DEFAULT_COLLECTION_NAME = "mempalace_drawers" DEFAULT_COLLECTION_NAME = "mempalace_drawers"
@@ -126,6 +176,11 @@ class MempalaceConfig:
def init(self): def init(self):
"""Create config directory and write default config.json if it doesn't exist.""" """Create config directory and write default config.json if it doesn't exist."""
self._config_dir.mkdir(parents=True, exist_ok=True) self._config_dir.mkdir(parents=True, exist_ok=True)
# Restrict directory permissions to owner only (Unix)
try:
self._config_dir.chmod(0o700)
except (OSError, NotImplementedError):
pass # Windows doesn't support Unix permissions
if not self._config_file.exists(): if not self._config_file.exists():
default_config = { default_config = {
"palace_path": DEFAULT_PALACE_PATH, "palace_path": DEFAULT_PALACE_PATH,
@@ -135,6 +190,11 @@ class MempalaceConfig:
} }
with open(self._config_file, "w") as f: with open(self._config_file, "w") as f:
json.dump(default_config, f, indent=2) json.dump(default_config, f, indent=2)
# Restrict config file to owner read/write only
try:
self._config_file.chmod(0o600)
except (OSError, NotImplementedError):
pass
return self._config_file return self._config_file
def save_people_map(self, people_map): def save_people_map(self, people_map):
+11 -35
View File
@@ -15,9 +15,8 @@ from pathlib import Path
from datetime import datetime from datetime import datetime
from collections import defaultdict from collections import defaultdict
import chromadb
from .normalize import normalize from .normalize import normalize
from .palace import SKIP_DIRS, get_collection, file_already_mined
# File types that might contain conversations # File types that might contain conversations
@@ -28,22 +27,8 @@ CONVO_EXTENSIONS = {
".jsonl", ".jsonl",
} }
SKIP_DIRS = {
".git",
"node_modules",
"__pycache__",
".venv",
"venv",
"env",
"dist",
"build",
".next",
".mempalace",
"tool-results",
"memory",
}
MIN_CHUNK_SIZE = 30 MIN_CHUNK_SIZE = 30
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this
# ============================================================================= # =============================================================================
@@ -211,23 +196,6 @@ def detect_convo_room(content: str) -> str:
# ============================================================================= # =============================================================================
def get_collection(palace_path: str):
os.makedirs(palace_path, exist_ok=True)
client = chromadb.PersistentClient(path=palace_path)
try:
return client.get_collection("mempalace_drawers")
except Exception:
return client.create_collection("mempalace_drawers")
def file_already_mined(collection, source_file: str) -> bool:
try:
results = collection.get(where={"source_file": source_file}, limit=1)
return len(results.get("ids", [])) > 0
except Exception:
return False
# ============================================================================= # =============================================================================
# SCAN FOR CONVERSATION FILES # SCAN FOR CONVERSATION FILES
# ============================================================================= # =============================================================================
@@ -244,6 +212,14 @@ def scan_convos(convo_dir: str) -> list:
continue continue
filepath = Path(root) / filename filepath = Path(root) / filename
if filepath.suffix.lower() in CONVO_EXTENSIONS: if filepath.suffix.lower() in CONVO_EXTENSIONS:
# Skip symlinks and oversized files
if filepath.is_symlink():
continue
try:
if filepath.stat().st_size > MAX_FILE_SIZE:
continue
except OSError:
continue
files.append(filepath) files.append(filepath)
return files return files
@@ -356,7 +332,7 @@ def mine_convos(
chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room
if extract_mode == "general": if extract_mode == "general":
room_counts[chunk_room] += 1 room_counts[chunk_room] += 1
drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.md5((source_file + str(chunk['chunk_index'])).encode(), usedforsecurity=False).hexdigest()[:16]}" drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}"
try: try:
collection.add( collection.add(
documents=[chunk["content"]], documents=[chunk["content"]],
+83 -77
View File
@@ -50,11 +50,14 @@ class KnowledgeGraph:
def __init__(self, db_path: str = None): def __init__(self, db_path: str = None):
self.db_path = db_path or DEFAULT_KG_PATH self.db_path = db_path or DEFAULT_KG_PATH
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._connection = None
self._init_db() self._init_db()
def _init_db(self): def _init_db(self):
conn = self._conn() conn = self._conn()
conn.executescript(""" conn.executescript("""
PRAGMA journal_mode=WAL;
CREATE TABLE IF NOT EXISTS entities ( CREATE TABLE IF NOT EXISTS entities (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
name TEXT NOT NULL, name TEXT NOT NULL,
@@ -84,12 +87,19 @@ class KnowledgeGraph:
CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to); CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to);
""") """)
conn.commit() conn.commit()
conn.close()
def _conn(self): def _conn(self):
conn = sqlite3.connect(self.db_path, timeout=10) if self._connection is None:
conn.execute("PRAGMA journal_mode=WAL") self._connection = sqlite3.connect(self.db_path, timeout=10, check_same_thread=False)
return conn self._connection.execute("PRAGMA journal_mode=WAL")
self._connection.row_factory = sqlite3.Row
return self._connection
def close(self):
"""Close the database connection."""
if self._connection is not None:
self._connection.close()
self._connection = None
def _entity_id(self, name: str) -> str: def _entity_id(self, name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "") return name.lower().replace(" ", "_").replace("'", "")
@@ -101,12 +111,11 @@ class KnowledgeGraph:
eid = self._entity_id(name) eid = self._entity_id(name)
props = json.dumps(properties or {}) props = json.dumps(properties or {})
conn = self._conn() conn = self._conn()
conn.execute( with conn:
"INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)", conn.execute(
(eid, name, entity_type, props), "INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)",
) (eid, name, entity_type, props),
conn.commit() )
conn.close()
return eid return eid
def add_triple( def add_triple(
@@ -134,38 +143,38 @@ class KnowledgeGraph:
# Auto-create entities if they don't exist # Auto-create entities if they don't exist
conn = self._conn() conn = self._conn()
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject)) with conn:
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj)) conn.execute(
"INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject)
)
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj))
# Check for existing identical triple # Check for existing identical triple
existing = conn.execute( existing = conn.execute(
"SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", "SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
(sub_id, pred, obj_id), (sub_id, pred, obj_id),
).fetchone() ).fetchone()
if existing: if existing:
conn.close() return existing["id"] # Already exists and still valid
return existing[0] # Already exists and still valid
triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.md5(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:8]}" triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.sha256(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:12]}"
conn.execute( conn.execute(
"""INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file) """INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
( (
triple_id, triple_id,
sub_id, sub_id,
pred, pred,
obj_id, obj_id,
valid_from, valid_from,
valid_to, valid_to,
confidence, confidence,
source_closet, source_closet,
source_file, source_file,
), ),
) )
conn.commit()
conn.close()
return triple_id return triple_id
def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None): def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None):
@@ -176,12 +185,11 @@ class KnowledgeGraph:
ended = ended or date.today().isoformat() ended = ended or date.today().isoformat()
conn = self._conn() conn = self._conn()
conn.execute( with conn:
"UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", conn.execute(
(ended, sub_id, pred, obj_id), "UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
) (ended, sub_id, pred, obj_id),
conn.commit() )
conn.close()
# ── Query operations ────────────────────────────────────────────────── # ── Query operations ──────────────────────────────────────────────────
@@ -208,13 +216,13 @@ class KnowledgeGraph:
{ {
"direction": "outgoing", "direction": "outgoing",
"subject": name, "subject": name,
"predicate": row[2], "predicate": row["predicate"],
"object": row[10], # obj_name "object": row["obj_name"],
"valid_from": row[4], "valid_from": row["valid_from"],
"valid_to": row[5], "valid_to": row["valid_to"],
"confidence": row[6], "confidence": row["confidence"],
"source_closet": row[7], "source_closet": row["source_closet"],
"current": row[5] is None, "current": row["valid_to"] is None,
} }
) )
@@ -228,18 +236,17 @@ class KnowledgeGraph:
results.append( results.append(
{ {
"direction": "incoming", "direction": "incoming",
"subject": row[10], # sub_name "subject": row["sub_name"],
"predicate": row[2], "predicate": row["predicate"],
"object": name, "object": name,
"valid_from": row[4], "valid_from": row["valid_from"],
"valid_to": row[5], "valid_to": row["valid_to"],
"confidence": row[6], "confidence": row["confidence"],
"source_closet": row[7], "source_closet": row["source_closet"],
"current": row[5] is None, "current": row["valid_to"] is None,
} }
) )
conn.close()
return results return results
def query_relationship(self, predicate: str, as_of: str = None): def query_relationship(self, predicate: str, as_of: str = None):
@@ -262,15 +269,14 @@ class KnowledgeGraph:
for row in conn.execute(query, params).fetchall(): for row in conn.execute(query, params).fetchall():
results.append( results.append(
{ {
"subject": row[10], "subject": row["sub_name"],
"predicate": pred, "predicate": pred,
"object": row[11], "object": row["obj_name"],
"valid_from": row[4], "valid_from": row["valid_from"],
"valid_to": row[5], "valid_to": row["valid_to"],
"current": row[5] is None, "current": row["valid_to"] is None,
} }
) )
conn.close()
return results return results
def timeline(self, entity_name: str = None): def timeline(self, entity_name: str = None):
@@ -300,15 +306,14 @@ class KnowledgeGraph:
LIMIT 100 LIMIT 100
""").fetchall() """).fetchall()
conn.close()
return [ return [
{ {
"subject": r[10], "subject": r["sub_name"],
"predicate": r[2], "predicate": r["predicate"],
"object": r[11], "object": r["obj_name"],
"valid_from": r[4], "valid_from": r["valid_from"],
"valid_to": r[5], "valid_to": r["valid_to"],
"current": r[5] is None, "current": r["valid_to"] is None,
} }
for r in rows for r in rows
] ]
@@ -317,17 +322,18 @@ class KnowledgeGraph:
def stats(self): def stats(self):
conn = self._conn() conn = self._conn()
entities = conn.execute("SELECT COUNT(*) FROM entities").fetchone()[0] entities = conn.execute("SELECT COUNT(*) as cnt FROM entities").fetchone()["cnt"]
triples = conn.execute("SELECT COUNT(*) FROM triples").fetchone()[0] triples = conn.execute("SELECT COUNT(*) as cnt FROM triples").fetchone()["cnt"]
current = conn.execute("SELECT COUNT(*) FROM triples WHERE valid_to IS NULL").fetchone()[0] current = conn.execute(
"SELECT COUNT(*) as cnt FROM triples WHERE valid_to IS NULL"
).fetchone()["cnt"]
expired = triples - current expired = triples - current
predicates = [ predicates = [
r[0] r["predicate"]
for r in conn.execute( for r in conn.execute(
"SELECT DISTINCT predicate FROM triples ORDER BY predicate" "SELECT DISTINCT predicate FROM triples ORDER BY predicate"
).fetchall() ).fetchall()
] ]
conn.close()
return { return {
"entities": entities, "entities": entities,
"triples": triples, "triples": triples,
+126 -8
View File
@@ -24,8 +24,9 @@ import json
import logging import logging
import hashlib import hashlib
from datetime import datetime from datetime import datetime
from pathlib import Path
from .config import MempalaceConfig from .config import MempalaceConfig, sanitize_name, sanitize_content
from .version import __version__ from .version import __version__
from .searcher import search_memories from .searcher import search_memories
from .palace_graph import traverse, find_tunnels, graph_stats from .palace_graph import traverse, find_tunnels, graph_stats
@@ -66,16 +67,60 @@ _client_cache = None
_collection_cache = None _collection_cache = None
# ==================== WRITE-AHEAD LOG ====================
# Every write operation is logged to a JSONL file before execution.
# This provides an audit trail for detecting memory poisoning and
# enables review/rollback of writes from external or untrusted sources.
_WAL_DIR = Path(os.path.expanduser("~/.mempalace/wal"))
_WAL_DIR.mkdir(parents=True, exist_ok=True)
try:
_WAL_DIR.chmod(0o700)
except (OSError, NotImplementedError):
pass
_WAL_FILE = _WAL_DIR / "write_log.jsonl"
def _wal_log(operation: str, params: dict, result: dict = None):
"""Append a write operation to the write-ahead log."""
entry = {
"timestamp": datetime.now().isoformat(),
"operation": operation,
"params": params,
"result": result,
}
try:
with open(_WAL_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, default=str) + "\n")
try:
_WAL_FILE.chmod(0o600)
except (OSError, NotImplementedError):
pass
except Exception as e:
logger.error(f"WAL write failed: {e}")
_client_cache = None
_collection_cache = None
def _get_client():
"""Return a singleton ChromaDB PersistentClient."""
global _client_cache
if _client_cache is None:
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
return _client_cache
def _get_collection(create=False): def _get_collection(create=False):
"""Return the ChromaDB collection, caching the client between calls.""" """Return the ChromaDB collection, caching the client between calls."""
global _client_cache, _collection_cache global _collection_cache
try: try:
if _client_cache is None: client = _get_client()
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
if create: if create:
_collection_cache = _client_cache.get_or_create_collection(_config.collection_name) _collection_cache = client.get_or_create_collection(_config.collection_name)
elif _collection_cache is None: elif _collection_cache is None:
_collection_cache = _client_cache.get_collection(_config.collection_name) _collection_cache = client.get_collection(_config.collection_name)
return _collection_cache return _collection_cache
except Exception: except Exception:
return None return None
@@ -282,11 +327,30 @@ def tool_add_drawer(
wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp" wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp"
): ):
"""File verbatim content into a wing/room. Checks for duplicates first.""" """File verbatim content into a wing/room. Checks for duplicates first."""
try:
wing = sanitize_name(wing, "wing")
room = sanitize_name(room, "room")
content = sanitize_content(content)
except ValueError as e:
return {"success": False, "error": str(e)}
col = _get_collection(create=True) col = _get_collection(create=True)
if not col: if not col:
return _no_palace() return _no_palace()
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}" drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((wing + room + content[:100]).encode()).hexdigest()[:24]}"
_wal_log(
"add_drawer",
{
"drawer_id": drawer_id,
"wing": wing,
"room": room,
"added_by": added_by,
"content_length": len(content),
"content_preview": content[:200],
},
)
# Idempotency: if the deterministic ID already exists, return success as a no-op. # Idempotency: if the deterministic ID already exists, return success as a no-op.
try: try:
@@ -325,6 +389,19 @@ def tool_delete_drawer(drawer_id: str):
existing = col.get(ids=[drawer_id]) existing = col.get(ids=[drawer_id])
if not existing["ids"]: if not existing["ids"]:
return {"success": False, "error": f"Drawer not found: {drawer_id}"} return {"success": False, "error": f"Drawer not found: {drawer_id}"}
# Log the deletion with the content being removed for audit trail
deleted_content = existing.get("documents", [""])[0] if existing.get("documents") else ""
deleted_meta = existing.get("metadatas", [{}])[0] if existing.get("metadatas") else {}
_wal_log(
"delete_drawer",
{
"drawer_id": drawer_id,
"deleted_meta": deleted_meta,
"content_preview": deleted_content[:200],
},
)
try: try:
col.delete(ids=[drawer_id]) col.delete(ids=[drawer_id])
logger.info(f"Deleted drawer: {drawer_id}") logger.info(f"Deleted drawer: {drawer_id}")
@@ -346,6 +423,23 @@ def tool_kg_add(
subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None
): ):
"""Add a relationship to the knowledge graph.""" """Add a relationship to the knowledge graph."""
try:
subject = sanitize_name(subject, "subject")
predicate = sanitize_name(predicate, "predicate")
object = sanitize_name(object, "object")
except ValueError as e:
return {"success": False, "error": str(e)}
_wal_log(
"kg_add",
{
"subject": subject,
"predicate": predicate,
"object": object,
"valid_from": valid_from,
"source_closet": source_closet,
},
)
triple_id = _kg.add_triple( triple_id = _kg.add_triple(
subject, predicate, object, valid_from=valid_from, source_closet=source_closet subject, predicate, object, valid_from=valid_from, source_closet=source_closet
) )
@@ -354,6 +448,10 @@ def tool_kg_add(
def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None): def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None):
"""Mark a fact as no longer true (set end date).""" """Mark a fact as no longer true (set end date)."""
_wal_log(
"kg_invalidate",
{"subject": subject, "predicate": predicate, "object": object, "ended": ended},
)
_kg.invalidate(subject, predicate, object, ended=ended) _kg.invalidate(subject, predicate, object, ended=ended)
return { return {
"success": True, "success": True,
@@ -384,6 +482,12 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
This is the agent's personal journal — observations, thoughts, This is the agent's personal journal — observations, thoughts,
what it worked on, what it noticed, what it thinks matters. what it worked on, what it noticed, what it thinks matters.
""" """
try:
agent_name = sanitize_name(agent_name, "agent_name")
entry = sanitize_content(entry)
except ValueError as e:
return {"success": False, "error": str(e)}
wing = f"wing_{agent_name.lower().replace(' ', '_')}" wing = f"wing_{agent_name.lower().replace(' ', '_')}"
room = "diary" room = "diary"
col = _get_collection(create=True) col = _get_collection(create=True)
@@ -391,9 +495,23 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
return _no_palace() return _no_palace()
now = datetime.now() now = datetime.now()
entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.md5(entry[:50].encode()).hexdigest()[:8]}" entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.sha256(entry[:50].encode()).hexdigest()[:12]}"
_wal_log(
"diary_write",
{
"agent_name": agent_name,
"topic": topic,
"entry_id": entry_id,
"entry_preview": entry[:200],
},
)
try: try:
# TODO: Future versions should expand AAAK before embedding to improve
# semantic search quality. For now, store raw AAAK in metadata so it's
# preserved, and keep the document as-is for embedding (even though
# compressed AAAK degrades embedding quality).
col.add( col.add(
ids=[entry_id], ids=[entry_id],
documents=[entry], documents=[entry],
+14 -58
View File
@@ -17,6 +17,8 @@ from collections import defaultdict
import chromadb import chromadb
from .palace import SKIP_DIRS, get_collection, file_already_mined
READABLE_EXTENSIONS = { READABLE_EXTENSIONS = {
".txt", ".txt",
".md", ".md",
@@ -40,32 +42,6 @@ READABLE_EXTENSIONS = {
".toml", ".toml",
} }
SKIP_DIRS = {
".git",
"node_modules",
"__pycache__",
".venv",
"venv",
"env",
"dist",
"build",
".next",
"coverage",
".mempalace",
".ruff_cache",
".mypy_cache",
".pytest_cache",
".cache",
".tox",
".nox",
".idea",
".vscode",
".ipynb_checkpoints",
".eggs",
"htmlcov",
"target",
}
SKIP_FILENAMES = { SKIP_FILENAMES = {
"mempalace.yaml", "mempalace.yaml",
"mempalace.yml", "mempalace.yml",
@@ -78,6 +54,7 @@ SKIP_FILENAMES = {
CHUNK_SIZE = 800 # chars per drawer CHUNK_SIZE = 800 # chars per drawer
CHUNK_OVERLAP = 100 # overlap between chunks CHUNK_OVERLAP = 100 # overlap between chunks
MIN_CHUNK_SIZE = 50 # skip tiny chunks MIN_CHUNK_SIZE = 50 # skip tiny chunks
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this
# ============================================================================= # =============================================================================
@@ -393,41 +370,11 @@ def chunk_text(content: str, source_file: str) -> list:
# ============================================================================= # =============================================================================
def get_collection(palace_path: str):
os.makedirs(palace_path, exist_ok=True)
client = chromadb.PersistentClient(path=palace_path)
try:
return client.get_collection("mempalace_drawers")
except Exception:
return client.create_collection("mempalace_drawers")
def file_already_mined(collection, source_file: str) -> bool:
"""Fast check: has this file been filed before and is unchanged?
Compares the stored mtime in drawer metadata against the file's current
mtime. Returns False (needs re-mining) when the file has been modified
since it was last mined, or when no mtime was stored.
"""
try:
results = collection.get(where={"source_file": source_file}, limit=1)
if not results.get("ids"):
return False
stored_meta = results["metadatas"][0] if results.get("metadatas") else {}
stored_mtime = stored_meta.get("source_mtime")
if stored_mtime is None:
return False
current_mtime = os.path.getmtime(source_file)
return float(stored_mtime) == current_mtime
except Exception:
return False
def add_drawer( def add_drawer(
collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str
): ):
"""Add one drawer to the palace.""" """Add one drawer to the palace."""
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}" drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((source_file + str(chunk_index)).encode()).hexdigest()[:24]}"
try: try:
metadata = { metadata = {
"wing": wing, "wing": wing,
@@ -470,7 +417,7 @@ def process_file(
# Skip if already filed # Skip if already filed
source_file = str(filepath) source_file = str(filepath)
if not dry_run and file_already_mined(collection, source_file): if not dry_run and file_already_mined(collection, source_file, check_mtime=True):
return 0, None return 0, None
try: try:
@@ -562,6 +509,15 @@ def scan_project(
if respect_gitignore and active_matchers and not force_include: if respect_gitignore and active_matchers and not force_include:
if is_gitignored(filepath, active_matchers, is_dir=False): if is_gitignored(filepath, active_matchers, is_dir=False):
continue continue
# Skip symlinks — prevents following links to /dev/urandom, etc.
if filepath.is_symlink():
continue
# Skip files exceeding size limit
try:
if filepath.stat().st_size > MAX_FILE_SIZE:
continue
except OSError:
continue
files.append(filepath) files.append(filepath)
return files return files
+71
View File
@@ -0,0 +1,71 @@
"""
palace.py — Shared palace operations.
Consolidates ChromaDB access patterns used by both miners and the MCP server.
"""
import os
import chromadb
SKIP_DIRS = {
".git",
"node_modules",
"__pycache__",
".venv",
"venv",
"env",
"dist",
"build",
".next",
"coverage",
".mempalace",
".ruff_cache",
".mypy_cache",
".pytest_cache",
".cache",
".tox",
".nox",
".idea",
".vscode",
".ipynb_checkpoints",
".eggs",
"htmlcov",
"target",
}
def get_collection(palace_path: str, collection_name: str = "mempalace_drawers"):
"""Get or create the palace ChromaDB collection."""
os.makedirs(palace_path, exist_ok=True)
try:
os.chmod(palace_path, 0o700)
except (OSError, NotImplementedError):
pass
client = chromadb.PersistentClient(path=palace_path)
try:
return client.get_collection(collection_name)
except Exception:
return client.create_collection(collection_name)
def file_already_mined(collection, source_file: str, check_mtime: bool = False) -> bool:
"""Check if a file has already been filed in the palace.
When check_mtime=True (used by project miner), returns False if the file
has been modified since it was last mined, so it gets re-mined.
When check_mtime=False (used by convo miner), just checks existence.
"""
try:
results = collection.get(where={"source_file": source_file}, limit=1)
if not results.get("ids"):
return False
if check_mtime:
stored_meta = results.get("metadatas", [{}])[0]
stored_mtime = stored_meta.get("source_mtime")
if stored_mtime is None:
return False
current_mtime = os.path.getmtime(source_file)
return float(stored_mtime) == current_mtime
return True
except Exception:
return False
+1 -1
View File
@@ -26,7 +26,7 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"chromadb>=0.5.0,<0.7", "chromadb>=0.5.0,<0.7",
"pyyaml>=6.0", "pyyaml>=6.0,<7",
] ]
[project.urls] [project.urls]
+53
View File
@@ -1,12 +1,14 @@
import os import os
import shutil import shutil
import tempfile import tempfile
import time
from pathlib import Path from pathlib import Path
import chromadb import chromadb
import yaml import yaml
from mempalace.miner import mine, scan_project from mempalace.miner import mine, scan_project
from mempalace.palace import file_already_mined
def write_file(path: Path, content: str): def write_file(path: Path, content: str):
@@ -206,3 +208,54 @@ def test_scan_project_skip_dirs_still_apply_without_override():
assert scanned_files(project_root, respect_gitignore=False) == ["main.py"] assert scanned_files(project_root, respect_gitignore=False) == ["main.py"]
finally: finally:
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
def test_file_already_mined_check_mtime():
tmpdir = tempfile.mkdtemp()
try:
palace_path = os.path.join(tmpdir, "palace")
os.makedirs(palace_path)
client = chromadb.PersistentClient(path=palace_path)
col = client.get_or_create_collection("mempalace_drawers")
test_file = os.path.join(tmpdir, "test.txt")
with open(test_file, "w") as f:
f.write("hello world")
mtime = os.path.getmtime(test_file)
# Not mined yet
assert file_already_mined(col, test_file) is False
assert file_already_mined(col, test_file, check_mtime=True) is False
# Add it with mtime
col.add(
ids=["d1"],
documents=["hello world"],
metadatas=[{"source_file": test_file, "source_mtime": str(mtime)}],
)
# Already mined (no mtime check)
assert file_already_mined(col, test_file) is True
# Already mined (mtime matches)
assert file_already_mined(col, test_file, check_mtime=True) is True
# Modify file so mtime changes
time.sleep(0.1)
with open(test_file, "w") as f:
f.write("modified content")
# Still mined without mtime check
assert file_already_mined(col, test_file) is True
# Needs re-mining with mtime check
assert file_already_mined(col, test_file, check_mtime=True) is False
# Record with no mtime stored should return False for check_mtime
col.add(
ids=["d2"],
documents=["other"],
metadatas=[{"source_file": "/fake/no_mtime.txt"}],
)
assert file_already_mined(col, "/fake/no_mtime.txt", check_mtime=True) is False
finally:
shutil.rmtree(tmpdir)