diff --git a/tests/test_entity_detector.py b/tests/test_entity_detector.py index b3d378b..0fad76b 100644 --- a/tests/test_entity_detector.py +++ b/tests/test_entity_detector.py @@ -40,10 +40,7 @@ def test_extract_candidates_requires_min_frequency(): def test_extract_candidates_finds_multi_word_names(): # Multi-word names need 3+ occurrences and no stopwords - text = ( - "Claude Code is great. Claude Code rocks. " - "Claude Code works. Claude Code rules." - ) + text = "Claude Code is great. Claude Code rocks. Claude Code works. Claude Code rules." result = extract_candidates(text) assert "Claude Code" in result diff --git a/tests/test_general_extractor.py b/tests/test_general_extractor.py index 6a68fc3..0f5d46c 100644 --- a/tests/test_general_extractor.py +++ b/tests/test_general_extractor.py @@ -226,7 +226,13 @@ def test_split_into_segments_single_block(): def test_all_markers_has_five_types(): - assert set(ALL_MARKERS.keys()) == {"decision", "preference", "milestone", "problem", "emotional"} + assert set(ALL_MARKERS.keys()) == { + "decision", + "preference", + "milestone", + "problem", + "emotional", + } # ── POSITIVE_WORDS / NEGATIVE_WORDS ──────────────────────────────────── diff --git a/tests/test_hooks_cli.py b/tests/test_hooks_cli.py index 0ff1fb3..5a1870e 100644 --- a/tests/test_hooks_cli.py +++ b/tests/test_hooks_cli.py @@ -49,29 +49,43 @@ def _write_transcript(path: Path, entries: list[dict]): def test_count_human_messages_basic(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": "hello"}}, - {"message": {"role": "assistant", "content": "hi"}}, - {"message": {"role": "user", "content": "bye"}}, - ]) + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": "hello"}}, + {"message": {"role": "assistant", "content": "hi"}}, + {"message": {"role": "user", "content": "bye"}}, + ], + ) assert _count_human_messages(str(transcript)) == 2 def test_count_skips_command_messages(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": "status"}}, - {"message": {"role": "user", "content": "real question"}}, - ]) + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": "status"}}, + {"message": {"role": "user", "content": "real question"}}, + ], + ) assert _count_human_messages(str(transcript)) == 1 def test_count_handles_list_content(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}}, - {"message": {"role": "user", "content": [{"type": "text", "text": "x"}]}}, - ]) + _write_transcript( + transcript, + [ + {"message": {"role": "user", "content": [{"type": "text", "text": "hello"}]}}, + { + "message": { + "role": "user", + "content": [{"type": "text", "text": "x"}], + } + }, + ], + ) assert _count_human_messages(str(transcript)) == 1 @@ -97,6 +111,7 @@ def test_count_malformed_json_lines(tmp_path): def _capture_hook_output(hook_fn, data, harness="claude-code", state_dir=None): """Run a hook and capture its JSON stdout output.""" import io + buf = io.StringIO() patches = [patch("mempalace.hooks_cli._output", side_effect=lambda d: buf.write(json.dumps(d)))] if state_dir: @@ -130,10 +145,10 @@ def test_stop_hook_passthrough_when_active_string(tmp_path): def test_stop_hook_passthrough_below_interval(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(SAVE_INTERVAL - 1) - ]) + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL - 1)], + ) result = _capture_hook_output( hook_stop, {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, @@ -144,10 +159,10 @@ def test_stop_hook_passthrough_below_interval(tmp_path): def test_stop_hook_blocks_at_interval(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(SAVE_INTERVAL) - ]) + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) result = _capture_hook_output( hook_stop, {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, @@ -159,10 +174,10 @@ def test_stop_hook_blocks_at_interval(tmp_path): def test_stop_hook_tracks_save_point(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(SAVE_INTERVAL) - ]) + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) data = {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)} # First call blocks @@ -274,10 +289,10 @@ def test_parse_harness_input_valid(): def test_stop_hook_oserror_on_last_save_read(tmp_path): """When last_save_file has invalid content, falls back to 0.""" transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(SAVE_INTERVAL) - ]) + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) # Write invalid content to last save file (tmp_path / "test_last_save").write_text("not_a_number") result = _capture_hook_output( @@ -291,10 +306,10 @@ def test_stop_hook_oserror_on_last_save_read(tmp_path): def test_stop_hook_oserror_on_write(tmp_path): """When write to last_save_file fails, hook still outputs correctly.""" transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(SAVE_INTERVAL) - ]) + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) def bad_write_text(*args, **kwargs): raise OSError("disk full") @@ -303,7 +318,11 @@ def test_stop_hook_oserror_on_write(tmp_path): with patch.object(Path, "write_text", bad_write_text): result = _capture_hook_output( hook_stop, - {"session_id": "test", "stop_hook_active": False, "transcript_path": str(transcript)}, + { + "session_id": "test", + "stop_hook_active": False, + "transcript_path": str(transcript), + }, state_dir=tmp_path, ) assert result["decision"] == "block" @@ -356,15 +375,16 @@ def test_run_hook_dispatches_session_start(tmp_path): def test_run_hook_dispatches_stop(tmp_path): transcript = tmp_path / "t.jsonl" - _write_transcript(transcript, [ - {"message": {"role": "user", "content": f"msg {i}"}} - for i in range(3) - ]) - stdin_data = json.dumps({ - "session_id": "run-test", - "stop_hook_active": False, - "transcript_path": str(transcript), - }) + _write_transcript( + transcript, [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(3)] + ) + stdin_data = json.dumps( + { + "session_id": "run-test", + "stop_hook_active": False, + "transcript_path": str(transcript), + } + ) with patch("sys.stdin", io.StringIO(stdin_data)): with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): with patch("mempalace.hooks_cli._output") as mock_output: diff --git a/tests/test_normalize.py b/tests/test_normalize.py index d613e58..fc50251 100644 --- a/tests/test_normalize.py +++ b/tests/test_normalize.py @@ -158,12 +158,8 @@ def test_claude_code_jsonl_non_dict_entries(): def test_codex_jsonl_valid(): lines = [ json.dumps({"type": "session_meta", "payload": {}}), - json.dumps( - {"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}} - ), - json.dumps( - {"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}} - ), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), ] result = _try_codex_jsonl("\n".join(lines)) assert result is not None @@ -173,12 +169,8 @@ def test_codex_jsonl_valid(): def test_codex_jsonl_no_session_meta(): """Without session_meta, codex parser returns None.""" lines = [ - json.dumps( - {"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}} - ), - json.dumps( - {"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}} - ), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), ] result = _try_codex_jsonl("\n".join(lines)) assert result is None @@ -199,15 +191,9 @@ def test_codex_jsonl_skips_non_event_msg(): def test_codex_jsonl_non_string_message(): lines = [ json.dumps({"type": "session_meta"}), - json.dumps( - {"type": "event_msg", "payload": {"type": "user_message", "message": 123}} - ), - json.dumps( - {"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}} - ), - json.dumps( - {"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}} - ), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": 123}}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), ] result = _try_codex_jsonl("\n".join(lines)) assert result is not None @@ -216,15 +202,9 @@ def test_codex_jsonl_non_string_message(): def test_codex_jsonl_empty_text_skipped(): lines = [ json.dumps({"type": "session_meta"}), - json.dumps( - {"type": "event_msg", "payload": {"type": "user_message", "message": " "}} - ), - json.dumps( - {"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}} - ), - json.dumps( - {"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}} - ), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": " "}}), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), ] result = _try_codex_jsonl("\n".join(lines)) assert result is not None @@ -234,12 +214,8 @@ def test_codex_jsonl_payload_not_dict(): lines = [ json.dumps({"type": "session_meta"}), json.dumps({"type": "event_msg", "payload": "not a dict"}), - json.dumps( - {"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}} - ), - json.dumps( - {"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}} - ), + json.dumps({"type": "event_msg", "payload": {"type": "user_message", "message": "Q"}}), + json.dumps({"type": "event_msg", "payload": {"type": "agent_message", "message": "A"}}), ] result = _try_codex_jsonl("\n".join(lines)) assert result is not None diff --git a/tests/test_palace_graph.py b/tests/test_palace_graph.py index 7875e98..ddda272 100644 --- a/tests/test_palace_graph.py +++ b/tests/test_palace_graph.py @@ -15,8 +15,8 @@ def _make_fake_collection(metadatas, ids=None): col.count.return_value = len(metadatas) def fake_get(limit=1000, offset=0, include=None): - batch_meta = metadatas[offset:offset + limit] - batch_ids = ids[offset:offset + limit] + batch_meta = metadatas[offset : offset + limit] + batch_ids = ids[offset : offset + limit] return {"ids": batch_ids, "metadatas": batch_meta} col.get.side_effect = fake_get @@ -51,20 +51,34 @@ class TestBuildGraph: assert edges == [] def test_single_wing_no_edges(self): - col = _make_fake_collection([ - {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, - {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-02"}, - ]) + col = _make_fake_collection( + [ + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-02"}, + ] + ) nodes, edges = build_graph(col=col) assert "auth" in nodes assert nodes["auth"]["count"] == 2 assert edges == [] def test_multi_wing_creates_edges(self): - col = _make_fake_collection([ - {"room": "chromadb", "wing": "wing_code", "hall": "databases", "date": "2026-01-01"}, - {"room": "chromadb", "wing": "wing_project", "hall": "databases", "date": "2026-01-02"}, - ]) + col = _make_fake_collection( + [ + { + "room": "chromadb", + "wing": "wing_code", + "hall": "databases", + "date": "2026-01-01", + }, + { + "room": "chromadb", + "wing": "wing_project", + "hall": "databases", + "date": "2026-01-02", + }, + ] + ) nodes, edges = build_graph(col=col) assert "chromadb" in nodes assert len(edges) == 1 @@ -73,24 +87,30 @@ class TestBuildGraph: assert edges[0]["hall"] == "databases" def test_general_room_excluded(self): - col = _make_fake_collection([ - {"room": "general", "wing": "wing_code", "hall": "misc", "date": ""}, - ]) + col = _make_fake_collection( + [ + {"room": "general", "wing": "wing_code", "hall": "misc", "date": ""}, + ] + ) nodes, edges = build_graph(col=col) assert "general" not in nodes def test_missing_wing_excluded(self): - col = _make_fake_collection([ - {"room": "orphan", "wing": "", "hall": "misc", "date": ""}, - ]) + col = _make_fake_collection( + [ + {"room": "orphan", "wing": "", "hall": "misc", "date": ""}, + ] + ) nodes, edges = build_graph(col=col) assert "orphan" not in nodes def test_dates_capped_at_five(self): - col = _make_fake_collection([ - {"room": "busy", "wing": "w", "hall": "h", "date": f"2026-01-{i:02d}"} - for i in range(1, 10) - ]) + col = _make_fake_collection( + [ + {"room": "busy", "wing": "w", "hall": "h", "date": f"2026-01-{i:02d}"} + for i in range(1, 10) + ] + ) nodes, _ = build_graph(col=col) assert len(nodes["busy"]["dates"]) <= 5 @@ -100,11 +120,13 @@ class TestBuildGraph: class TestTraverse: def _build_col(self): - return _make_fake_collection([ - {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, - {"room": "login", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, - {"room": "deploy", "wing": "wing_ops", "hall": "infra", "date": "2026-01-01"}, - ]) + return _make_fake_collection( + [ + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "login", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + {"room": "deploy", "wing": "wing_ops", "hall": "infra", "date": "2026-01-01"}, + ] + ) def test_traverse_known_room(self): col = self._build_col() @@ -135,11 +157,13 @@ class TestTraverse: class TestFindTunnels: def _build_tunnel_col(self): - return _make_fake_collection([ - {"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"}, - {"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"}, - {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, - ]) + return _make_fake_collection( + [ + {"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"}, + {"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + ] + ) def test_find_all_tunnels(self): col = self._build_tunnel_col() @@ -176,11 +200,13 @@ class TestGraphStats: assert stats["total_edges"] == 0 def test_stats_with_data(self): - col = _make_fake_collection([ - {"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"}, - {"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"}, - {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, - ]) + col = _make_fake_collection( + [ + {"room": "chromadb", "wing": "wing_code", "hall": "db", "date": "2026-01-01"}, + {"room": "chromadb", "wing": "wing_project", "hall": "db", "date": "2026-01-02"}, + {"room": "auth", "wing": "wing_code", "hall": "security", "date": "2026-01-01"}, + ] + ) stats = graph_stats(col=col) assert stats["total_rooms"] == 2 assert stats["tunnel_rooms"] == 1 diff --git a/tests/test_spellcheck.py b/tests/test_spellcheck.py index 50da5f0..f2c7484 100644 --- a/tests/test_spellcheck.py +++ b/tests/test_spellcheck.py @@ -97,6 +97,7 @@ def test_spellcheck_user_text_passthrough_no_autocorrect(): def test_spellcheck_user_text_with_speller(): """When a speller is available, it corrects words.""" + def fake_speller(word): corrections = {"knoe": "know", "befor": "before"} return corrections.get(word, word) @@ -111,6 +112,7 @@ def test_spellcheck_user_text_with_speller(): def test_spellcheck_preserves_technical_terms(): """Technical terms should never be touched even with a speller.""" + def fake_speller(word): return "WRONG"