diff --git a/tests/test_hooks_cli.py b/tests/test_hooks_cli.py
index 8eeffed..d6951e2 100644
--- a/tests/test_hooks_cli.py
+++ b/tests/test_hooks_cli.py
@@ -42,29 +42,43 @@ def _write_transcript(path: Path, entries: list[dict]):
def test_count_human_messages_basic(tmp_path):
transcript = tmp_path / "t.jsonl"
- _write_transcript(transcript, [
- {"message": {"role": "user", "content": "hello"}},
- {"message": {"role": "assistant", "content": "hi"}},
- {"message": {"role": "user", "content": "bye"}},
- ])
+ _write_transcript(
+ transcript,
+ [
+ {"message": {"role": "user", "content": "hello"}},
+ {"message": {"role": "assistant", "content": "hi"}},
+ {"message": {"role": "user", "content": "bye"}},
+ ],
+ )
assert _count_human_messages(str(transcript)) == 2
def test_count_skips_command_messages(tmp_path):
transcript = tmp_path / "t.jsonl"
- _write_transcript(transcript, [
- {"message": {"role": "user", "content": "status"}},
- {"message": {"role": "user", "content": "real question"}},
- ])
+ _write_transcript(
+ transcript,
+ [
+ {"message": {"role": "user", "content": "status"}},
+ {"message": {"role": "user", "content": "real question"}},
+ ],
+ )
assert _count_human_messages(str(transcript)) == 1
def test_count_handles_list_content(tmp_path):
transcript = tmp_path / "t.jsonl"
- _write_transcript(transcript, [
- {"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}},
- {"message": {"role": "user", "content": [{"type": "text", "text": "x"}]}},
- ])
+ _write_transcript(
+ transcript,
+ [
+ {"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}},
+ {
+ "message": {
+ "role": "user",
+ "content": [{"type": "text", "text": "x"}],
+ }
+ },
+ ],
+ )
assert _count_human_messages(str(transcript)) == 1
@@ -90,6 +104,7 @@ def test_count_malformed_json_lines(tmp_path):
def _capture_hook_output(hook_fn, data, harness="claude-code", state_dir=None):
"""Run a hook and capture its JSON stdout output."""
import io
+
buf = io.StringIO()
patches = [patch("mempalace.hooks_cli._output", side_effect=lambda d: buf.write(json.dumps(d)))]
if state_dir:
@@ -123,10 +138,10 @@ def test_stop_hook_passthrough_when_active_string(tmp_path):
def test_stop_hook_passthrough_below_interval(tmp_path):
transcript = tmp_path / "t.jsonl"
- _write_transcript(transcript, [
- {"message": {"role": "user", "content": f"msg {i}"}}
- for i in range(SAVE_INTERVAL - 1)
- ])
+ _write_transcript(
+ transcript,
+ [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL - 1)],
+ )
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
@@ -137,10 +152,10 @@ def test_stop_hook_passthrough_below_interval(tmp_path):
def test_stop_hook_blocks_at_interval(tmp_path):
transcript = tmp_path / "t.jsonl"
- _write_transcript(transcript, [
- {"message": {"role": "user", "content": f"msg {i}"}}
- for i in range(SAVE_INTERVAL)
- ])
+ _write_transcript(
+ transcript,
+ [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
+ )
result = _capture_hook_output(
hook_stop,
{"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)},
@@ -152,10 +167,10 @@ def test_stop_hook_blocks_at_interval(tmp_path):
def test_stop_hook_tracks_save_point(tmp_path):
transcript = tmp_path / "t.jsonl"
- _write_transcript(transcript, [
- {"message": {"role": "user", "content": f"msg {i}"}}
- for i in range(SAVE_INTERVAL)
- ])
+ _write_transcript(
+ transcript,
+ [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)],
+ )
data = {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}
# First call blocks
diff --git a/tests/test_knowledge_graph.py b/tests/test_knowledge_graph.py
index 535eace..d7d9838 100644
--- a/tests/test_knowledge_graph.py
+++ b/tests/test_knowledge_graph.py
@@ -6,7 +6,6 @@ timeline, stats, and edge cases (duplicate triples, ID collisions).
"""
-
class TestEntityOperations:
def test_add_entity(self, kg):
eid = kg.add_entity("Alice", entity_type="person")
@@ -125,6 +124,7 @@ class TestWALMode:
conn.close()
assert mode == "wal"
+
class TestStats:
def test_stats_empty(self, kg):
stats = kg.stats()