193 lines
5.7 KiB
Python
193 lines
5.7 KiB
Python
|
|
import contextlib
|
||
|
|
import json
|
||
|
|
from pathlib import Path
|
||
|
|
from unittest.mock import patch
|
||
|
|
|
||
|
|
from mempalace.hooks_cli import (
|
||
|
|
SAVE_INTERVAL,
|
||
|
|
STOP_BLOCK_REASON,
|
||
|
|
PRECOMPACT_BLOCK_REASON,
|
||
|
|
_count_human_messages,
|
||
|
|
_sanitize_session_id,
|
||
|
|
hook_stop,
|
||
|
|
hook_session_start,
|
||
|
|
hook_precompact,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# --- _sanitize_session_id ---
|
||
|
|
|
||
|
|
|
||
|
|
def test_sanitize_normal_id():
|
||
|
|
assert _sanitize_session_id("abc-123_XYZ") == "abc-123_XYZ"
|
||
|
|
|
||
|
|
|
||
|
|
def test_sanitize_strips_dangerous_chars():
|
||
|
|
assert _sanitize_session_id("../../etc/passwd") == "etcpasswd"
|
||
|
|
|
||
|
|
|
||
|
|
def test_sanitize_empty_returns_unknown():
|
||
|
|
assert _sanitize_session_id("") == "unknown"
|
||
|
|
assert _sanitize_session_id("!!!") == "unknown"
|
||
|
|
|
||
|
|
|
||
|
|
# --- _count_human_messages ---
|
||
|
|
|
||
|
|
|
||
|
|
def _write_transcript(path: Path, entries: list[dict]):
|
||
|
|
with open(path, "w", encoding="utf-8") as f:
|
||
|
|
for entry in entries:
|
||
|
|
f.write(json.dumps(entry) + "\n")
|
||
|
|
|
||
|
|
|
||
|
|
def test_count_human_messages_basic(tmp_path):
|
||
|
|
transcript = tmp_path / "t.jsonl"
|
||
|
|
_write_transcript(transcript, [
|
||
|
|
{"message": {"role": "user", "content": "hello"}},
|
||
|
|
{"message": {"role": "assistant", "content": "hi"}},
|
||
|
|
{"message": {"role": "user", "content": "bye"}},
|
||
|
|
])
|
||
|
|
assert _count_human_messages(str(transcript)) == 2
|
||
|
|
|
||
|
|
|
||
|
|
def test_count_skips_command_messages(tmp_path):
|
||
|
|
transcript = tmp_path / "t.jsonl"
|
||
|
|
_write_transcript(transcript, [
|
||
|
|
{"message": {"role": "user", "content": "<command-message>status</command-message>"}},
|
||
|
|
{"message": {"role": "user", "content": "real question"}},
|
||
|
|
])
|
||
|
|
assert _count_human_messages(str(transcript)) == 1
|
||
|
|
|
||
|
|
|
||
|
|
def test_count_handles_list_content(tmp_path):
|
||
|
|
transcript = tmp_path / "t.jsonl"
|
||
|
|
_write_transcript(transcript, [
|
||
|
|
{"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}},
|
||
|
|
{"message": {"role": "user", "content": [{"type": "text", "text": "<command-message>x</command-message>"}]}},
|
||
|
|
])
|
||
|
|
assert _count_human_messages(str(transcript)) == 1
|
||
|
|
|
||
|
|
|
||
|
|
def test_count_missing_file():
|
||
|
|
assert _count_human_messages("/nonexistent/path.jsonl") == 0
|
||
|
|
|
||
|
|
|
||
|
|
def test_count_empty_file(tmp_path):
|
||
|
|
transcript = tmp_path / "t.jsonl"
|
||
|
|
transcript.write_text("")
|
||
|
|
assert _count_human_messages(str(transcript)) == 0
|
||
|
|
|
||
|
|
|
||
|
|
def test_count_malformed_json_lines(tmp_path):
|
||
|
|
transcript = tmp_path / "t.jsonl"
|
||
|
|
transcript.write_text('not json\n{"message": {"role": "user", "content": "ok"}}\n')
|
||
|
|
assert _count_human_messages(str(transcript)) == 1
|
||
|
|
|
||
|
|
|
||
|
|
# --- hook_stop ---
|
||
|
|
|
||
|
|
|
||
|
|
def _capture_hook_output(hook_fn, data, harness="claude-code", state_dir=None):
|
||
|
|
"""Run a hook and capture its JSON stdout output."""
|
||
|
|
import io
|
||
|
|
buf = io.StringIO()
|
||
|
|
patches = [patch("mempalace.hooks_cli._output", side_effect=lambda d: buf.write(json.dumps(d)))]
|
||
|
|
if state_dir:
|
||
|
|
patches.append(patch("mempalace.hooks_cli.STATE_DIR", state_dir))
|
||
|
|
with contextlib.ExitStack() as stack:
|
||
|
|
for p in patches:
|
||
|
|
stack.enter_context(p)
|
||
|
|
hook_fn(data, harness)
|
||
|
|
return json.loads(buf.getvalue())
|
||
|
|
|
||
|
|
|
||
|
|
def test_stop_hook_passthrough_when_active(tmp_path):
|
||
|
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||
|
|
result = _capture_hook_output(
|
||
|
|
hook_stop,
|
||
|
|
{"session_id": "test", "stop_hook_active": True, "transcript_path": ""},
|
||
|
|
state_dir=tmp_path,
|
||
|
|
)
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
|
||
|
|
def test_stop_hook_passthrough_when_active_string(tmp_path):
|
||
|
|
with patch("mempalace.hooks_cli.STATE_DIR", tmp_path):
|
||
|
|
result = _capture_hook_output(
|
||
|
|
hook_stop,
|
||
|
|
{"session_id": "test", "stop_hook_active": "true", "transcript_path": ""},
|
||
|
|
state_dir=tmp_path,
|
||
|
|
)
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
|
||
|
|
def test_stop_hook_passthrough_below_interval(tmp_path):
|
||
|
|
transcript = tmp_path / "t.jsonl"
|
||
|
|
_write_transcript(transcript, [
|
||
|
|
{"message": {"role": "user", "content": f"msg {i}"}}
|
||
|
|
for i in range(SAVE_INTERVAL - 1)
|
||
|
|
])
|
||
|
|
result = _capture_hook_output(
|
||
|
|
hook_stop,
|
||
|
|
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
|
||
|
|
state_dir=tmp_path,
|
||
|
|
)
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
|
||
|
|
def test_stop_hook_blocks_at_interval(tmp_path):
|
||
|
|
transcript = tmp_path / "t.jsonl"
|
||
|
|
_write_transcript(transcript, [
|
||
|
|
{"message": {"role": "user", "content": f"msg {i}"}}
|
||
|
|
for i in range(SAVE_INTERVAL)
|
||
|
|
])
|
||
|
|
result = _capture_hook_output(
|
||
|
|
hook_stop,
|
||
|
|
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
|
||
|
|
state_dir=tmp_path,
|
||
|
|
)
|
||
|
|
assert result["decision"] == "block"
|
||
|
|
assert result["reason"] == STOP_BLOCK_REASON
|
||
|
|
|
||
|
|
|
||
|
|
def test_stop_hook_tracks_save_point(tmp_path):
|
||
|
|
transcript = tmp_path / "t.jsonl"
|
||
|
|
_write_transcript(transcript, [
|
||
|
|
{"message": {"role": "user", "content": f"msg {i}"}}
|
||
|
|
for i in range(SAVE_INTERVAL)
|
||
|
|
])
|
||
|
|
data = {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}
|
||
|
|
|
||
|
|
# First call blocks
|
||
|
|
result = _capture_hook_output(hook_stop, data, state_dir=tmp_path)
|
||
|
|
assert result["decision"] == "block"
|
||
|
|
|
||
|
|
# Second call with same count passes through (already saved)
|
||
|
|
result = _capture_hook_output(hook_stop, data, state_dir=tmp_path)
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
|
||
|
|
# --- hook_session_start ---
|
||
|
|
|
||
|
|
|
||
|
|
def test_session_start_passes_through(tmp_path):
|
||
|
|
result = _capture_hook_output(
|
||
|
|
hook_session_start,
|
||
|
|
{"session_id": "test"},
|
||
|
|
state_dir=tmp_path,
|
||
|
|
)
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
|
||
|
|
# --- hook_precompact ---
|
||
|
|
|
||
|
|
|
||
|
|
def test_precompact_always_blocks(tmp_path):
|
||
|
|
result = _capture_hook_output(
|
||
|
|
hook_precompact,
|
||
|
|
{"session_id": "test"},
|
||
|
|
state_dir=tmp_path,
|
||
|
|
)
|
||
|
|
assert result["decision"] == "block"
|
||
|
|
assert result["reason"] == PRECOMPACT_BLOCK_REASON
|