fix: address code review — restore mtime check, bound metadata reads, harden security

Review fixes (from Sage's review):
- Restore mtime check in file_already_mined (check_mtime=True for miner)
- Restore limit=10000 on MCP metadata fetches to prevent OOM on large palaces
- Apply _SAFE_NAME_RE regex in sanitize_name (was dead code)
- Drop raw_aaak metadata duplication in diary_write
- chmod 0o700 on WAL dir, 0o600 on WAL file
- Add check_same_thread=False on KnowledgeGraph SQLite connection
- Remove __del__ (unreliable) and dead PRAGMA foreign_keys=ON
This commit is contained in:
bensig
2026-04-09 08:52:24 -07:00
parent 0717caea5c
commit c2308a1e36
5 changed files with 35 additions and 14 deletions
+4
View File
@@ -40,6 +40,10 @@ def sanitize_name(value: str, field_name: str = "name") -> str:
if "\x00" in value: if "\x00" in value:
raise ValueError(f"{field_name} contains null bytes") raise ValueError(f"{field_name} contains null bytes")
# Enforce safe character set
if not _SAFE_NAME_RE.match(value):
raise ValueError(f"{field_name} contains invalid characters")
return value return value
+1 -5
View File
@@ -57,7 +57,6 @@ class KnowledgeGraph:
conn = self._conn() conn = self._conn()
conn.executescript(""" conn.executescript("""
PRAGMA journal_mode=WAL; PRAGMA journal_mode=WAL;
PRAGMA foreign_keys=ON;
CREATE TABLE IF NOT EXISTS entities ( CREATE TABLE IF NOT EXISTS entities (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
@@ -91,7 +90,7 @@ class KnowledgeGraph:
def _conn(self): def _conn(self):
if self._connection is None: if self._connection is None:
self._connection = sqlite3.connect(self.db_path, timeout=10) self._connection = sqlite3.connect(self.db_path, timeout=10, check_same_thread=False)
self._connection.execute("PRAGMA journal_mode=WAL") self._connection.execute("PRAGMA journal_mode=WAL")
self._connection.row_factory = sqlite3.Row self._connection.row_factory = sqlite3.Row
return self._connection return self._connection
@@ -102,9 +101,6 @@ class KnowledgeGraph:
self._connection.close() self._connection.close()
self._connection = None self._connection = None
def __del__(self):
self.close()
def _entity_id(self, name: str) -> str: def _entity_id(self, name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "") return name.lower().replace(" ", "_").replace("'", "")
+12 -5
View File
@@ -74,6 +74,10 @@ _collection_cache = None
_WAL_DIR = Path(os.path.expanduser("~/.mempalace/wal")) _WAL_DIR = Path(os.path.expanduser("~/.mempalace/wal"))
_WAL_DIR.mkdir(parents=True, exist_ok=True) _WAL_DIR.mkdir(parents=True, exist_ok=True)
try:
_WAL_DIR.chmod(0o700)
except (OSError, NotImplementedError):
pass
_WAL_FILE = _WAL_DIR / "write_log.jsonl" _WAL_FILE = _WAL_DIR / "write_log.jsonl"
@@ -88,6 +92,10 @@ def _wal_log(operation: str, params: dict, result: dict = None):
try: try:
with open(_WAL_FILE, "a", encoding="utf-8") as f: with open(_WAL_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, default=str) + "\n") f.write(json.dumps(entry, default=str) + "\n")
try:
_WAL_FILE.chmod(0o600)
except (OSError, NotImplementedError):
pass
except Exception as e: except Exception as e:
logger.error(f"WAL write failed: {e}") logger.error(f"WAL write failed: {e}")
@@ -136,7 +144,7 @@ def tool_status():
wings = {} wings = {}
rooms = {} rooms = {}
try: try:
all_meta = col.get(include=["metadatas"])["metadatas"] all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
for m in all_meta: for m in all_meta:
w = m.get("wing", "unknown") w = m.get("wing", "unknown")
r = m.get("room", "unknown") r = m.get("room", "unknown")
@@ -193,7 +201,7 @@ def tool_list_wings():
return _no_palace() return _no_palace()
wings = {} wings = {}
try: try:
all_meta = col.get(include=["metadatas"])["metadatas"] all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
for m in all_meta: for m in all_meta:
w = m.get("wing", "unknown") w = m.get("wing", "unknown")
wings[w] = wings.get(w, 0) + 1 wings[w] = wings.get(w, 0) + 1
@@ -208,7 +216,7 @@ def tool_list_rooms(wing: str = None):
return _no_palace() return _no_palace()
rooms = {} rooms = {}
try: try:
kwargs = {"include": ["metadatas"]} kwargs = {"include": ["metadatas"], "limit": 10000}
if wing: if wing:
kwargs["where"] = {"wing": wing} kwargs["where"] = {"wing": wing}
all_meta = col.get(**kwargs)["metadatas"] all_meta = col.get(**kwargs)["metadatas"]
@@ -226,7 +234,7 @@ def tool_get_taxonomy():
return _no_palace() return _no_palace()
taxonomy = {} taxonomy = {}
try: try:
all_meta = col.get(include=["metadatas"])["metadatas"] all_meta = col.get(include=["metadatas"], limit=10000)["metadatas"]
for m in all_meta: for m in all_meta:
w = m.get("wing", "unknown") w = m.get("wing", "unknown")
r = m.get("room", "unknown") r = m.get("room", "unknown")
@@ -517,7 +525,6 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
"agent": agent_name, "agent": agent_name,
"filed_at": now.isoformat(), "filed_at": now.isoformat(),
"date": now.strftime("%Y-%m-%d"), "date": now.strftime("%Y-%m-%d"),
"raw_aaak": entry,
} }
], ],
) )
+1 -1
View File
@@ -417,7 +417,7 @@ def process_file(
# Skip if already filed # Skip if already filed
source_file = str(filepath) source_file = str(filepath)
if not dry_run and file_already_mined(collection, source_file): if not dry_run and file_already_mined(collection, source_file, check_mtime=True):
return 0, None return 0, None
try: try:
+17 -3
View File
@@ -48,10 +48,24 @@ def get_collection(palace_path: str, collection_name: str = "mempalace_drawers")
return client.create_collection(collection_name) return client.create_collection(collection_name)
def file_already_mined(collection, source_file: str) -> bool: def file_already_mined(collection, source_file: str, check_mtime: bool = False) -> bool:
"""Check if a file has already been filed in the palace.""" """Check if a file has already been filed in the palace.
When check_mtime=True (used by project miner), returns False if the file
has been modified since it was last mined, so it gets re-mined.
When check_mtime=False (used by convo miner), just checks existence.
"""
try: try:
results = collection.get(where={"source_file": source_file}, limit=1) results = collection.get(where={"source_file": source_file}, limit=1)
return len(results.get("ids", [])) > 0 if not results.get("ids"):
return False
if check_mtime:
stored_meta = results.get("metadatas", [{}])[0]
stored_mtime = stored_meta.get("source_mtime")
if stored_mtime is None:
return False
current_mtime = os.path.getmtime(source_file)
return float(stored_mtime) == current_mtime
return True
except Exception: except Exception:
return False return False