From 58eca5075a37ebbe8fe2790ac629e5dbc2670cef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arturo=20Dom=C3=ADnguez?= Date: Sat, 11 Apr 2026 21:44:17 -0600 Subject: [PATCH] Security hardening: consistent input validation, argument whitelisting, concurrency safety, and WAL fixes (#647) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses findings from security audit (ref #401): inconsistent sanitization across MCP tools, unfiltered argument dispatch allowing audit trail spoofing, SQLite concurrency issues, WAL permission race condition, sensitive content in logs, and internal error leakage to MCP callers. Co-authored-by: Israel Domínguez Co-authored-by: Claude Opus 4.6 (1M context) --- mempalace/knowledge_graph.py | 172 ++++++++++++++++++----------------- mempalace/mcp_server.py | 81 ++++++++++++++--- 2 files changed, 156 insertions(+), 97 deletions(-) diff --git a/mempalace/knowledge_graph.py b/mempalace/knowledge_graph.py index b094f06..6ede054 100644 --- a/mempalace/knowledge_graph.py +++ b/mempalace/knowledge_graph.py @@ -39,6 +39,7 @@ import hashlib import json import os import sqlite3 +import threading from datetime import date, datetime from pathlib import Path @@ -51,6 +52,7 @@ class KnowledgeGraph: self.db_path = db_path or DEFAULT_KG_PATH Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) self._connection = None + self._lock = threading.Lock() self._init_db() def _init_db(self): @@ -110,12 +112,13 @@ class KnowledgeGraph: """Add or update an entity node.""" eid = self._entity_id(name) props = json.dumps(properties or {}) - conn = self._conn() - with conn: - conn.execute( - "INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)", - (eid, name, entity_type, props), - ) + with self._lock: + conn = self._conn() + with conn: + conn.execute( + "INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)", + (eid, name, entity_type, props), + ) return eid def add_triple( @@ -142,39 +145,42 @@ class KnowledgeGraph: pred = predicate.lower().replace(" ", "_") # Auto-create entities if they don't exist - conn = self._conn() - with conn: - conn.execute( - "INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject) - ) - conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj)) + with self._lock: + conn = self._conn() + with conn: + conn.execute( + "INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject) + ) + conn.execute( + "INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj) + ) - # Check for existing identical triple - existing = conn.execute( - "SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", - (sub_id, pred, obj_id), - ).fetchone() + # Check for existing identical triple + existing = conn.execute( + "SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", + (sub_id, pred, obj_id), + ).fetchone() - if existing: - return existing["id"] # Already exists and still valid + if existing: + return existing["id"] # Already exists and still valid - triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.sha256(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:12]}" + triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.sha256(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:12]}" - conn.execute( - """INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - triple_id, - sub_id, - pred, - obj_id, - valid_from, - valid_to, - confidence, - source_closet, - source_file, - ), - ) + conn.execute( + """INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + triple_id, + sub_id, + pred, + obj_id, + valid_from, + valid_to, + confidence, + source_closet, + source_file, + ), + ) return triple_id def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None): @@ -184,12 +190,13 @@ class KnowledgeGraph: pred = predicate.lower().replace(" ", "_") ended = ended or date.today().isoformat() - conn = self._conn() - with conn: - conn.execute( - "UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", - (ended, sub_id, pred, obj_id), - ) + with self._lock: + conn = self._conn() + with conn: + conn.execute( + "UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL", + (ended, sub_id, pred, obj_id), + ) # ── Query operations ────────────────────────────────────────────────── @@ -201,51 +208,52 @@ class KnowledgeGraph: as_of: date string — only return facts valid at that time """ eid = self._entity_id(name) - conn = self._conn() results = [] + with self._lock: + conn = self._conn() - if direction in ("outgoing", "both"): - query = "SELECT t.*, e.name as obj_name FROM triples t JOIN entities e ON t.object = e.id WHERE t.subject = ?" - params = [eid] - if as_of: - query += " AND (t.valid_from IS NULL OR t.valid_from <= ?) AND (t.valid_to IS NULL OR t.valid_to >= ?)" - params.extend([as_of, as_of]) - for row in conn.execute(query, params).fetchall(): - results.append( - { - "direction": "outgoing", - "subject": name, - "predicate": row["predicate"], - "object": row["obj_name"], - "valid_from": row["valid_from"], - "valid_to": row["valid_to"], - "confidence": row["confidence"], - "source_closet": row["source_closet"], - "current": row["valid_to"] is None, - } - ) + if direction in ("outgoing", "both"): + query = "SELECT t.*, e.name as obj_name FROM triples t JOIN entities e ON t.object = e.id WHERE t.subject = ?" + params = [eid] + if as_of: + query += " AND (t.valid_from IS NULL OR t.valid_from <= ?) AND (t.valid_to IS NULL OR t.valid_to >= ?)" + params.extend([as_of, as_of]) + for row in conn.execute(query, params).fetchall(): + results.append( + { + "direction": "outgoing", + "subject": name, + "predicate": row["predicate"], + "object": row["obj_name"], + "valid_from": row["valid_from"], + "valid_to": row["valid_to"], + "confidence": row["confidence"], + "source_closet": row["source_closet"], + "current": row["valid_to"] is None, + } + ) - if direction in ("incoming", "both"): - query = "SELECT t.*, e.name as sub_name FROM triples t JOIN entities e ON t.subject = e.id WHERE t.object = ?" - params = [eid] - if as_of: - query += " AND (t.valid_from IS NULL OR t.valid_from <= ?) AND (t.valid_to IS NULL OR t.valid_to >= ?)" - params.extend([as_of, as_of]) - for row in conn.execute(query, params).fetchall(): - results.append( - { - "direction": "incoming", - "subject": row["sub_name"], - "predicate": row["predicate"], - "object": name, - "valid_from": row["valid_from"], - "valid_to": row["valid_to"], - "confidence": row["confidence"], - "source_closet": row["source_closet"], - "current": row["valid_to"] is None, - } - ) + if direction in ("incoming", "both"): + query = "SELECT t.*, e.name as sub_name FROM triples t JOIN entities e ON t.subject = e.id WHERE t.object = ?" + params = [eid] + if as_of: + query += " AND (t.valid_from IS NULL OR t.valid_from <= ?) AND (t.valid_to IS NULL OR t.valid_to >= ?)" + params.extend([as_of, as_of]) + for row in conn.execute(query, params).fetchall(): + results.append( + { + "direction": "incoming", + "subject": row["sub_name"], + "predicate": row["predicate"], + "object": name, + "valid_from": row["valid_from"], + "valid_to": row["valid_to"], + "confidence": row["confidence"], + "source_closet": row["source_closet"], + "current": row["valid_to"] is None, + } + ) return results diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index bf8a281..4771b63 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -79,23 +79,38 @@ try: except (OSError, NotImplementedError): pass _WAL_FILE = _WAL_DIR / "write_log.jsonl" +# Pre-create WAL file with restricted permissions to avoid race condition +if not _WAL_FILE.exists(): + _WAL_FILE.touch(mode=0o600) +else: + try: + _WAL_FILE.chmod(0o600) + except (OSError, NotImplementedError): + pass + +# Keys whose values should be redacted in WAL entries to avoid logging sensitive content +_WAL_REDACT_KEYS = frozenset({"content_preview", "entry_preview"}) def _wal_log(operation: str, params: dict, result: dict = None): """Append a write operation to the write-ahead log.""" + # Redact sensitive content from params before logging + safe_params = {} + for k, v in params.items(): + if k in _WAL_REDACT_KEYS: + safe_params[k] = f"[REDACTED {len(v)} chars]" if isinstance(v, str) else "[REDACTED]" + else: + safe_params[k] = v entry = { "timestamp": datetime.now().isoformat(), "operation": operation, - "params": params, + "params": safe_params, "result": result, } try: - with open(_WAL_FILE, "a", encoding="utf-8") as f: + fd = os.open(str(_WAL_FILE), os.O_WRONLY | os.O_APPEND | os.O_CREAT, 0o600) + with os.fdopen(fd, "a", encoding="utf-8") as f: f.write(json.dumps(entry, default=str) + "\n") - try: - _WAL_FILE.chmod(0o600) - except (OSError, NotImplementedError): - pass except Exception as e: logger.error(f"WAL write failed: {e}") @@ -298,6 +313,7 @@ def tool_get_taxonomy(): def tool_search( query: str, limit: int = 5, wing: str = None, room: str = None, context: str = None ): + limit = max(1, min(limit, 50)) # Mitigate system prompt contamination (Issue #333) sanitized = sanitize_query(query) result = search_memories( @@ -352,8 +368,9 @@ def tool_check_duplicate(content: str, threshold: float = 0.9): "is_duplicate": len(duplicates) > 0, "matches": duplicates, } - except Exception as e: - return {"error": str(e)} + except Exception: + logger.exception("check_duplicate failed") + return {"error": "Duplicate check failed"} def tool_get_aaak_spec(): @@ -363,6 +380,7 @@ def tool_get_aaak_spec(): def tool_traverse_graph(start_room: str, max_hops: int = 2): """Walk the palace graph from a room. Find connected ideas across wings.""" + max_hops = max(1, min(max_hops, 10)) col = _get_collection() if not col: return _no_palace() @@ -480,6 +498,12 @@ def tool_delete_drawer(drawer_id: str): def tool_kg_query(entity: str, as_of: str = None, direction: str = "both"): """Query the knowledge graph for an entity's relationships.""" + try: + entity = sanitize_name(entity, "entity") + except ValueError as e: + return {"error": str(e)} + if direction not in ("outgoing", "incoming", "both"): + return {"error": "direction must be 'outgoing', 'incoming', or 'both'"} results = _kg.query_entity(entity, as_of=as_of, direction=direction) return {"entity": entity, "as_of": as_of, "facts": results, "count": len(results)} @@ -513,6 +537,12 @@ def tool_kg_add( def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None): """Mark a fact as no longer true (set end date).""" + try: + subject = sanitize_name(subject, "subject") + predicate = sanitize_name(predicate, "predicate") + object = sanitize_name(object, "object") + except ValueError as e: + return {"success": False, "error": str(e)} _wal_log( "kg_invalidate", {"subject": subject, "predicate": predicate, "object": object, "ended": ended}, @@ -527,6 +557,11 @@ def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = N def tool_kg_timeline(entity: str = None): """Get chronological timeline of facts, optionally for one entity.""" + if entity is not None: + try: + entity = sanitize_name(entity, "entity") + except ValueError as e: + return {"error": str(e)} results = _kg.timeline(entity) return {"entity": entity or "all", "timeline": results, "count": len(results)} @@ -610,6 +645,11 @@ def tool_diary_read(agent_name: str, last_n: int = 10): Read an agent's recent diary entries. Returns the last N entries in chronological order — the agent's personal journal. """ + try: + agent_name = sanitize_name(agent_name, "agent_name") + except ValueError as e: + return {"error": str(e)} + last_n = max(1, min(last_n, 100)) wing = f"wing_{agent_name.lower().replace(' ', '_')}" col = _get_collection() if not col: @@ -646,8 +686,9 @@ def tool_diary_read(agent_name: str, last_n: int = 10): "total": len(results["ids"]), "showing": len(entries), } - except Exception as e: - return {"error": str(e)} + except Exception: + logger.exception("diary_read failed") + return {"error": "Failed to read diary entries"} # ==================== MCP PROTOCOL ==================== @@ -963,17 +1004,27 @@ def handle_request(request): "id": req_id, "error": {"code": -32601, "message": f"Unknown tool: {tool_name}"}, } + # Whitelist arguments to declared schema properties only. + # Prevents callers from spoofing internal params like added_by/source_file. + schema_props = TOOLS[tool_name]["input_schema"].get("properties", {}) + tool_args = {k: v for k, v in tool_args.items() if k in schema_props} # Coerce argument types based on input_schema. # MCP JSON transport may deliver integers as floats or strings; # ChromaDB and Python slicing require native int. - schema_props = TOOLS[tool_name]["input_schema"].get("properties", {}) for key, value in list(tool_args.items()): prop_schema = schema_props.get(key, {}) declared_type = prop_schema.get("type") - if declared_type == "integer" and not isinstance(value, int): - tool_args[key] = int(value) - elif declared_type == "number" and not isinstance(value, (int, float)): - tool_args[key] = float(value) + try: + if declared_type == "integer" and not isinstance(value, int): + tool_args[key] = int(value) + elif declared_type == "number" and not isinstance(value, (int, float)): + tool_args[key] = float(value) + except (ValueError, TypeError): + return { + "jsonrpc": "2.0", + "id": req_id, + "error": {"code": -32602, "message": f"Invalid value for parameter '{key}'"}, + } try: result = TOOLS[tool_name]["handler"](**tool_args) return {