345 lines
12 KiB
Python
345 lines
12 KiB
Python
"""
|
|
Hook logic for MemPalace — Python implementation of session-start, stop, and precompact hooks.
|
|
|
|
Reads JSON from stdin, outputs JSON to stdout.
|
|
Supported hooks: session-start, stop, precompact
|
|
Supported harnesses: claude-code, codex (extensible to cursor, gemini, etc.)
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
SAVE_INTERVAL = 15
|
|
STATE_DIR = Path.home() / ".mempalace" / "hook_state"
|
|
|
|
STOP_BLOCK_REASON = (
|
|
"AUTO-SAVE checkpoint (MemPalace). Save this session's key content:\n"
|
|
"1. mempalace_diary_write — AAAK-compressed session summary\n"
|
|
"2. mempalace_add_drawer — verbatim quotes, decisions, code snippets\n"
|
|
"3. mempalace_kg_add — entity relationships (optional)\n"
|
|
"Do NOT write to Claude Code's native auto-memory (.md files). "
|
|
"Continue conversation after saving."
|
|
)
|
|
|
|
PRECOMPACT_BLOCK_REASON = (
|
|
"COMPACTION IMMINENT (MemPalace). Save ALL session content before context is lost:\n"
|
|
"1. mempalace_diary_write — thorough AAAK-compressed session summary\n"
|
|
"2. mempalace_add_drawer — ALL verbatim quotes, decisions, code, context\n"
|
|
"3. mempalace_kg_add — entity relationships (optional)\n"
|
|
"Be thorough \u2014 after compaction, detailed context will be lost. "
|
|
"Do NOT write to Claude Code's native auto-memory (.md files). "
|
|
"Save everything to MemPalace, then allow compaction to proceed."
|
|
)
|
|
|
|
|
|
def _sanitize_session_id(session_id: str) -> str:
|
|
"""Only allow alnum, dash, underscore to prevent path traversal."""
|
|
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "", session_id)
|
|
return sanitized or "unknown"
|
|
|
|
|
|
def _validate_transcript_path(transcript_path: str) -> Path:
|
|
"""Validate and resolve a transcript path, rejecting paths outside expected roots.
|
|
|
|
Returns a resolved Path if valid, or None if the path should be rejected.
|
|
Accepted paths must:
|
|
- Have a .jsonl or .json extension
|
|
- Not contain '..' after resolution (path traversal prevention)
|
|
"""
|
|
if not transcript_path:
|
|
return None
|
|
path = Path(transcript_path).expanduser().resolve()
|
|
if path.suffix not in (".jsonl", ".json"):
|
|
return None
|
|
# Reject if the original input contained '..' traversal components
|
|
if ".." in Path(transcript_path).parts:
|
|
return None
|
|
return path
|
|
|
|
|
|
def _count_human_messages(transcript_path: str) -> int:
|
|
"""Count human messages in a JSONL transcript, skipping command-messages."""
|
|
path = _validate_transcript_path(transcript_path)
|
|
if path is None:
|
|
if transcript_path:
|
|
_log(f"WARNING: transcript_path rejected by validator: {transcript_path!r}")
|
|
return 0
|
|
if not path.is_file():
|
|
return 0
|
|
count = 0
|
|
try:
|
|
with open(path, encoding="utf-8", errors="replace") as f:
|
|
for line in f:
|
|
try:
|
|
entry = json.loads(line)
|
|
msg = entry.get("message", {})
|
|
if isinstance(msg, dict) and msg.get("role") == "user":
|
|
content = msg.get("content", "")
|
|
if isinstance(content, str):
|
|
if "<command-message>" in content:
|
|
continue
|
|
elif isinstance(content, list):
|
|
text = " ".join(
|
|
b.get("text", "") for b in content if isinstance(b, dict)
|
|
)
|
|
if "<command-message>" in text:
|
|
continue
|
|
count += 1
|
|
# Also handle Codex CLI transcript format
|
|
# {"type": "event_msg", "payload": {"type": "user_message", "message": "..."}}
|
|
elif entry.get("type") == "event_msg":
|
|
payload = entry.get("payload", {})
|
|
if isinstance(payload, dict) and payload.get("type") == "user_message":
|
|
msg_text = payload.get("message", "")
|
|
if isinstance(msg_text, str) and "<command-message>" not in msg_text:
|
|
count += 1
|
|
except (json.JSONDecodeError, AttributeError):
|
|
pass
|
|
except OSError:
|
|
return 0
|
|
return count
|
|
|
|
|
|
_state_dir_initialized = False
|
|
|
|
|
|
def _log(message: str):
|
|
"""Append to hook state log file."""
|
|
global _state_dir_initialized
|
|
try:
|
|
if not _state_dir_initialized:
|
|
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
STATE_DIR.chmod(0o700)
|
|
except (OSError, NotImplementedError):
|
|
pass
|
|
_state_dir_initialized = True
|
|
log_path = STATE_DIR / "hook.log"
|
|
is_new = not log_path.exists()
|
|
timestamp = datetime.now().strftime("%H:%M:%S")
|
|
with open(log_path, "a") as f:
|
|
f.write(f"[{timestamp}] {message}\n")
|
|
if is_new:
|
|
try:
|
|
log_path.chmod(0o600)
|
|
except (OSError, NotImplementedError):
|
|
pass
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def _output(data: dict):
|
|
"""Print JSON to stdout without importing modules that may redirect streams.
|
|
|
|
If mempalace.mcp_server is already loaded, reuse its saved real stdout fd.
|
|
Otherwise, write directly to fd 1 so hook responses still go to stdout even
|
|
if sys.stdout has been redirected elsewhere.
|
|
"""
|
|
payload = (json.dumps(data, indent=2, ensure_ascii=False) + "\n").encode("utf-8")
|
|
|
|
real_stdout_fd: int | None = None
|
|
mcp_mod = sys.modules.get("mempalace.mcp_server") or sys.modules.get(
|
|
f"{__package__}.mcp_server" if __package__ else "mcp_server"
|
|
)
|
|
if mcp_mod is not None:
|
|
real_stdout_fd = getattr(mcp_mod, "_REAL_STDOUT_FD", None)
|
|
|
|
fd = real_stdout_fd if real_stdout_fd is not None else 1
|
|
offset = 0
|
|
try:
|
|
while offset < len(payload):
|
|
try:
|
|
offset += os.write(fd, payload[offset:])
|
|
except InterruptedError:
|
|
continue
|
|
return
|
|
except OSError:
|
|
pass
|
|
|
|
sys.stdout.buffer.write(payload)
|
|
sys.stdout.buffer.flush()
|
|
|
|
|
|
def _get_mine_dir(transcript_path: str = "") -> str:
|
|
"""Determine directory to mine from MEMPAL_DIR or transcript path."""
|
|
mempal_dir = os.environ.get("MEMPAL_DIR", "")
|
|
if mempal_dir and os.path.isdir(mempal_dir):
|
|
return mempal_dir
|
|
if transcript_path:
|
|
path = Path(transcript_path).expanduser()
|
|
if path.is_file():
|
|
return str(path.parent)
|
|
return ""
|
|
|
|
|
|
def _maybe_auto_ingest(transcript_path: str = ""):
|
|
"""Run mempalace mine in background if a mine directory is available."""
|
|
mine_dir = _get_mine_dir(transcript_path)
|
|
if not mine_dir:
|
|
return
|
|
try:
|
|
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
log_path = STATE_DIR / "hook.log"
|
|
with open(log_path, "a") as log_f:
|
|
subprocess.Popen(
|
|
[sys.executable, "-m", "mempalace", "mine", mine_dir],
|
|
stdout=log_f,
|
|
stderr=log_f,
|
|
)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def _mine_sync(transcript_path: str = ""):
|
|
"""Run mempalace mine synchronously (for precompact -- data must land first)."""
|
|
mine_dir = _get_mine_dir(transcript_path)
|
|
if not mine_dir:
|
|
return
|
|
try:
|
|
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
log_path = STATE_DIR / "hook.log"
|
|
with open(log_path, "a") as log_f:
|
|
subprocess.run(
|
|
[sys.executable, "-m", "mempalace", "mine", mine_dir],
|
|
stdout=log_f,
|
|
stderr=log_f,
|
|
timeout=60,
|
|
)
|
|
except (OSError, subprocess.TimeoutExpired):
|
|
pass
|
|
|
|
|
|
SUPPORTED_HARNESSES = {"claude-code", "codex"}
|
|
|
|
|
|
def _parse_harness_input(data: dict, harness: str) -> dict:
|
|
"""Parse stdin JSON according to the harness type."""
|
|
if harness not in SUPPORTED_HARNESSES:
|
|
print(f"Unknown harness: {harness}", file=sys.stderr)
|
|
sys.exit(1)
|
|
return {
|
|
"session_id": _sanitize_session_id(str(data.get("session_id", "unknown"))),
|
|
"stop_hook_active": data.get("stop_hook_active", False),
|
|
"transcript_path": str(data.get("transcript_path", "")),
|
|
}
|
|
|
|
|
|
def hook_stop(data: dict, harness: str):
|
|
"""Stop hook: block every N messages for auto-save."""
|
|
parsed = _parse_harness_input(data, harness)
|
|
session_id = parsed["session_id"]
|
|
stop_hook_active = parsed["stop_hook_active"]
|
|
transcript_path = parsed["transcript_path"]
|
|
|
|
# If already in a block-mode save cycle, let through (infinite-loop prevention).
|
|
# Silent mode saves directly without returning {"decision":"block"}, so there's
|
|
# no loop to prevent — and Claude Code's plugin dispatch sets this flag on every
|
|
# fire after the first, which would otherwise suppress all subsequent auto-saves.
|
|
if str(stop_hook_active).lower() in ("true", "1", "yes"):
|
|
# Safe default: assume silent mode on any config-read failure so saves
|
|
# proceed rather than being silently dropped. Silent mode is the default
|
|
# (v3.3.0+), so if we can't read config, behave as if it's still on.
|
|
silent_guard = True
|
|
try:
|
|
from .config import MempalaceConfig
|
|
except ImportError as exc:
|
|
_log(
|
|
f"WARNING: could not import MempalaceConfig for stop guard: {exc}; defaulting to silent mode"
|
|
)
|
|
else:
|
|
try:
|
|
silent_guard = MempalaceConfig().hook_silent_save
|
|
except AttributeError as exc:
|
|
_log(f"WARNING: could not read hook_silent_save: {exc}; defaulting to silent mode")
|
|
if not silent_guard:
|
|
_output({})
|
|
return
|
|
|
|
# Count human messages
|
|
exchange_count = _count_human_messages(transcript_path)
|
|
|
|
# Track last save point
|
|
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
last_save_file = STATE_DIR / f"{session_id}_last_save"
|
|
last_save = 0
|
|
if last_save_file.is_file():
|
|
try:
|
|
last_save = int(last_save_file.read_text().strip())
|
|
except (ValueError, OSError):
|
|
last_save = 0
|
|
|
|
since_last = exchange_count - last_save
|
|
|
|
_log(f"Session {session_id}: {exchange_count} exchanges, {since_last} since last save")
|
|
|
|
if since_last >= SAVE_INTERVAL and exchange_count > 0:
|
|
# Update last save point
|
|
try:
|
|
last_save_file.write_text(str(exchange_count), encoding="utf-8")
|
|
except OSError:
|
|
pass
|
|
|
|
_log(f"TRIGGERING SAVE at exchange {exchange_count}")
|
|
|
|
# Optional: auto-ingest if MEMPAL_DIR is set
|
|
_maybe_auto_ingest(transcript_path)
|
|
|
|
_output({"decision": "block", "reason": STOP_BLOCK_REASON})
|
|
else:
|
|
_output({})
|
|
|
|
|
|
def hook_session_start(data: dict, harness: str):
|
|
"""Session start hook: initialize session tracking state."""
|
|
parsed = _parse_harness_input(data, harness)
|
|
session_id = parsed["session_id"]
|
|
|
|
_log(f"SESSION START for session {session_id}")
|
|
|
|
# Initialize session state directory
|
|
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Pass through — no blocking on session start
|
|
_output({})
|
|
|
|
|
|
def hook_precompact(data: dict, harness: str):
|
|
"""Precompact hook: mine transcript synchronously, then allow compaction."""
|
|
parsed = _parse_harness_input(data, harness)
|
|
session_id = parsed["session_id"]
|
|
transcript_path = parsed["transcript_path"]
|
|
|
|
_log(f"PRE-COMPACT triggered for session {session_id}")
|
|
|
|
# Mine synchronously so data lands before compaction proceeds
|
|
_mine_sync(transcript_path)
|
|
|
|
_output({})
|
|
|
|
|
|
def run_hook(hook_name: str, harness: str):
|
|
"""Main entry point: read stdin JSON, dispatch to hook handler."""
|
|
try:
|
|
data = json.load(sys.stdin)
|
|
except (json.JSONDecodeError, EOFError):
|
|
_log("WARNING: Failed to parse stdin JSON, proceeding with empty data")
|
|
data = {}
|
|
|
|
hooks = {
|
|
"session-start": hook_session_start,
|
|
"stop": hook_stop,
|
|
"precompact": hook_precompact,
|
|
}
|
|
|
|
handler = hooks.get(hook_name)
|
|
if handler is None:
|
|
print(f"Unknown hook: {hook_name}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
handler(data, harness)
|