diff --git a/tests/test_hooks_cli.py b/tests/test_hooks_cli.py index 8eeffed..d6951e2 100644 --- a/tests/test_hooks_cli.py +++ b/tests/test_hooks_cli.py @@ -42,29 +42,43 @@ def _write_transcript(path: Path, entries: list[dict]): def test_count_human_messages_basic(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": "hello"}}, - {"message": {"role": "assistant", "content": "hi"}}, - {"message": {"role": "user", "content": "bye"}}, - ]) + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": "hello"}}, + {"message": {"role": "assistant", "content": "hi"}}, + {"message": {"role": "user", "content": "bye"}}, + ], + ) assert _count_human_messages(str(transcript)) == 2 def test_count_skips_command_messages(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": "status"}}, - {"message": {"role": "user", "content": "real question"}}, - ]) + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": "status"}}, + {"message": {"role": "user", "content": "real question"}}, + ], + ) assert _count_human_messages(str(transcript)) == 1 def test_count_handles_list_content(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}}, - {"message": {"role": "user", "content": [{"type": "text", "text": "x"}]}}, - ]) + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}}, + { + "message": { + "role": "user", + "content": [{"type": "text", "text": "x"}], + } + }, + ], + ) assert _count_human_messages(str(transcript)) == 1 @@ -90,6 +104,7 @@ def test_count_malformed_json_lines(tmp_path): def _capture_hook_output(hook_fn, data, harness="claude-code", state_dir=None): """Run a hook and capture its JSON stdout output.""" import io + buf = io.StringIO() patches = [patch("mempalace.hooks_cli._output", side_effect=lambda d: buf.write(json.dumps(d)))] if state_dir: @@ -123,10 +138,10 @@ def test_stop_hook_passthrough_when_active_string(tmp_path): def test_stop_hook_passthrough_below_interval(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(SAVE_INTERVAL - 1) - ]) + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL - 1)], + ) result = _capture_hook_output( hook_stop, {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, @@ -137,10 +152,10 @@ def test_stop_hook_passthrough_below_interval(tmp_path): def test_stop_hook_blocks_at_interval(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(SAVE_INTERVAL) - ]) + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) result = _capture_hook_output( hook_stop, {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, @@ -152,10 +167,10 @@ def test_stop_hook_blocks_at_interval(tmp_path): def test_stop_hook_tracks_save_point(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(SAVE_INTERVAL) - ]) + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) data = {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)} # First call blocks diff --git a/tests/test_knowledge_graph.py b/tests/test_knowledge_graph.py index 535eace..d7d9838 100644 --- a/tests/test_knowledge_graph.py +++ b/tests/test_knowledge_graph.py @@ -6,7 +6,6 @@ timeline, stats, and edge cases (duplicate triples, ID collisions). """ - class TestEntityOperations: def test_add_entity(self, kg): eid = kg.add_entity("Alice", entity_type="person") @@ -125,6 +124,7 @@ class TestWALMode: conn.close() assert mode == "wal" + class TestStats: def test_stats_empty(self, kg): stats = kg.stats()