diff --git a/tests/test_closets.py b/tests/test_closets.py index b365102..57c989d 100644 --- a/tests/test_closets.py +++ b/tests/test_closets.py @@ -21,6 +21,13 @@ from mempalace.palace import ( ) from mempalace.miner import _extract_entities_for_metadata from mempalace.searcher import _bm25_score, _hybrid_rank +from mempalace.palace_graph import ( + create_tunnel, + list_tunnels, + delete_tunnel, + follow_tunnels, + _TUNNEL_FILE, +) # ── mine_lock ──────────────────────────────────────────────────────────── @@ -199,3 +206,57 @@ class TestDiaryIngest: ingest_diaries(diary_dir, palace_dir, force=True) result = ingest_diaries(diary_dir, palace_dir) # second run, no force assert result["days_updated"] == 0 + + +# ── tunnels ────────────────────────────────────────────────────────────── + + +class TestTunnels: + def setup_method(self): + # Use temp tunnel file + self._orig = _TUNNEL_FILE + import mempalace.palace_graph as pg + self._tmpdir = tempfile.mkdtemp() + pg._TUNNEL_FILE = os.path.join(self._tmpdir, "tunnels.json") + + def teardown_method(self): + import mempalace.palace_graph as pg + pg._TUNNEL_FILE = self._orig + + def test_create_tunnel(self): + t = create_tunnel("wing_api", "auth", "wing_db", "users", label="auth uses users table") + assert t["id"] + assert t["source"]["wing"] == "wing_api" + assert t["target"]["wing"] == "wing_db" + assert t["label"] == "auth uses users table" + + def test_list_tunnels(self): + create_tunnel("wing_a", "room1", "wing_b", "room2") + create_tunnel("wing_a", "room3", "wing_c", "room4") + all_t = list_tunnels() + assert len(all_t) == 2 + filtered = list_tunnels("wing_a") + assert len(filtered) == 2 + filtered_c = list_tunnels("wing_c") + assert len(filtered_c) == 1 + + def test_delete_tunnel(self): + t = create_tunnel("wing_x", "r1", "wing_y", "r2") + delete_tunnel(t["id"]) + assert len(list_tunnels()) == 0 + + def test_dedup_same_endpoints(self): + create_tunnel("wing_a", "r1", "wing_b", "r2", label="first") + create_tunnel("wing_a", "r1", "wing_b", "r2", label="updated") + tunnels = list_tunnels() + assert len(tunnels) == 1 + assert tunnels[0]["label"] == "updated" + + def test_follow_tunnels(self): + create_tunnel("wing_api", "auth", "wing_db", "users") + create_tunnel("wing_api", "auth", "wing_frontend", "login") + connections = follow_tunnels("wing_api", "auth") + assert len(connections) == 2 + wings = {c["connected_wing"] for c in connections} + assert "wing_db" in wings + assert "wing_frontend" in wings