84e2aa16e4
build_graph() scans every drawer's metadata in 1000-item batches on every call — O(n) per graph build with no caching. At 50K+ drawers this costs several seconds per MCP tool call (traverse, find_tunnels, graph_stats all call build_graph on every invocation). Add a module-level cache (nodes + edges + timestamp) with a 60-second TTL. Cache is invalidated via invalidate_graph_cache(), exported for write operations to call. Tests updated with setup_method cache resets and two new tests verifying cache hit and invalidation behaviour.
487 lines
16 KiB
Python
487 lines
16 KiB
Python
"""
|
|
palace_graph.py — Graph traversal layer for MemPalace
|
|
======================================================
|
|
|
|
Builds a navigable graph from the palace structure:
|
|
- Nodes = rooms (named ideas)
|
|
- Edges = shared rooms across wings (tunnels)
|
|
- Edge types = halls (the corridors)
|
|
|
|
Enables queries like:
|
|
"Start at chromadb-setup in wing_code, walk to wing_myproject"
|
|
"Find all rooms connected to riley-college-apps"
|
|
"What topics bridge wing_hardware and wing_myproject?"
|
|
|
|
No external graph DB needed — built from ChromaDB metadata.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import time
|
|
from collections import Counter, defaultdict
|
|
from datetime import datetime, timezone
|
|
|
|
from .config import MempalaceConfig
|
|
from .palace import get_collection as _get_palace_collection
|
|
from .palace import mine_lock
|
|
|
|
# Module-level graph cache — mirrors _metadata_cache pattern in mcp_server.py
|
|
_graph_cache_nodes = None
|
|
_graph_cache_edges = None
|
|
_graph_cache_time = 0.0
|
|
_GRAPH_CACHE_TTL = 60.0 # seconds — graph changes less often than metadata
|
|
|
|
|
|
def invalidate_graph_cache():
|
|
"""Clear the graph cache. Called from mcp_server.py on writes."""
|
|
global _graph_cache_nodes, _graph_cache_edges, _graph_cache_time
|
|
_graph_cache_nodes = None
|
|
_graph_cache_edges = None
|
|
_graph_cache_time = 0.0
|
|
|
|
|
|
def _get_collection(config=None):
|
|
config = config or MempalaceConfig()
|
|
try:
|
|
return _get_palace_collection(
|
|
config.palace_path,
|
|
collection_name=config.collection_name,
|
|
create=False,
|
|
)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def build_graph(col=None, config=None):
|
|
"""
|
|
Build the palace graph from ChromaDB metadata.
|
|
|
|
Returns cached result if fresh (within TTL). Cache is invalidated
|
|
on writes via invalidate_graph_cache().
|
|
|
|
Returns:
|
|
nodes: dict of {room: {wings: set, halls: set, count: int}}
|
|
edges: list of {room, wing_a, wing_b, hall} — one per tunnel crossing
|
|
"""
|
|
global _graph_cache_nodes, _graph_cache_edges, _graph_cache_time
|
|
now = time.time()
|
|
if _graph_cache_nodes is not None and (now - _graph_cache_time) < _GRAPH_CACHE_TTL:
|
|
return _graph_cache_nodes, _graph_cache_edges
|
|
|
|
if col is None:
|
|
col = _get_collection(config)
|
|
if not col:
|
|
return {}, []
|
|
|
|
total = col.count()
|
|
room_data = defaultdict(lambda: {"wings": set(), "halls": set(), "count": 0, "dates": set()})
|
|
|
|
offset = 0
|
|
while offset < total:
|
|
batch = col.get(limit=1000, offset=offset, include=["metadatas"])
|
|
for meta in batch["metadatas"]:
|
|
room = meta.get("room", "")
|
|
wing = meta.get("wing", "")
|
|
hall = meta.get("hall", "")
|
|
date = meta.get("date", "")
|
|
if room and room != "general" and wing:
|
|
room_data[room]["wings"].add(wing)
|
|
if hall:
|
|
room_data[room]["halls"].add(hall)
|
|
if date:
|
|
room_data[room]["dates"].add(date)
|
|
room_data[room]["count"] += 1
|
|
if not batch["ids"]:
|
|
break
|
|
offset += len(batch["ids"])
|
|
|
|
# Build edges from rooms that span multiple wings
|
|
edges = []
|
|
for room, data in room_data.items():
|
|
wings = sorted(data["wings"])
|
|
if len(wings) >= 2:
|
|
for i, wa in enumerate(wings):
|
|
for wb in wings[i + 1 :]:
|
|
for hall in data["halls"]:
|
|
edges.append(
|
|
{
|
|
"room": room,
|
|
"wing_a": wa,
|
|
"wing_b": wb,
|
|
"hall": hall,
|
|
"count": data["count"],
|
|
}
|
|
)
|
|
|
|
# Convert sets to lists for JSON serialization
|
|
nodes = {}
|
|
for room, data in room_data.items():
|
|
nodes[room] = {
|
|
"wings": sorted(data["wings"]),
|
|
"halls": sorted(data["halls"]),
|
|
"count": data["count"],
|
|
"dates": sorted(data["dates"])[-5:] if data["dates"] else [],
|
|
}
|
|
|
|
_graph_cache_nodes = nodes
|
|
_graph_cache_edges = edges
|
|
_graph_cache_time = time.time()
|
|
|
|
return nodes, edges
|
|
|
|
|
|
def traverse(start_room: str, col=None, config=None, max_hops: int = 2):
|
|
"""
|
|
Walk the graph from a starting room. Find connected rooms
|
|
through shared wings.
|
|
|
|
Returns list of paths: [{room, wing, hall, hop_distance}]
|
|
"""
|
|
nodes, edges = build_graph(col, config)
|
|
|
|
if start_room not in nodes:
|
|
return {
|
|
"error": f"Room '{start_room}' not found",
|
|
"suggestions": _fuzzy_match(start_room, nodes),
|
|
}
|
|
|
|
start = nodes[start_room]
|
|
visited = {start_room}
|
|
results = [
|
|
{
|
|
"room": start_room,
|
|
"wings": start["wings"],
|
|
"halls": start["halls"],
|
|
"count": start["count"],
|
|
"hop": 0,
|
|
}
|
|
]
|
|
|
|
# BFS traversal
|
|
frontier = [(start_room, 0)]
|
|
while frontier:
|
|
current_room, depth = frontier.pop(0)
|
|
if depth >= max_hops:
|
|
continue
|
|
|
|
current = nodes.get(current_room, {})
|
|
current_wings = set(current.get("wings", []))
|
|
|
|
# Find all rooms that share a wing with current room
|
|
for room, data in nodes.items():
|
|
if room in visited:
|
|
continue
|
|
shared_wings = current_wings & set(data["wings"])
|
|
if shared_wings:
|
|
visited.add(room)
|
|
results.append(
|
|
{
|
|
"room": room,
|
|
"wings": data["wings"],
|
|
"halls": data["halls"],
|
|
"count": data["count"],
|
|
"hop": depth + 1,
|
|
"connected_via": sorted(shared_wings),
|
|
}
|
|
)
|
|
if depth + 1 < max_hops:
|
|
frontier.append((room, depth + 1))
|
|
|
|
# Sort by relevance (hop distance, then count)
|
|
results.sort(key=lambda x: (x["hop"], -x["count"]))
|
|
return results[:50] # cap results
|
|
|
|
|
|
def find_tunnels(wing_a: str = None, wing_b: str = None, col=None, config=None):
|
|
"""
|
|
Find rooms that connect two wings (or all tunnel rooms if no wings specified).
|
|
These are the "hallways" — same named idea appearing in multiple domains.
|
|
"""
|
|
nodes, edges = build_graph(col, config)
|
|
|
|
tunnels = []
|
|
for room, data in nodes.items():
|
|
wings = data["wings"]
|
|
if len(wings) < 2:
|
|
continue
|
|
|
|
if wing_a and wing_a not in wings:
|
|
continue
|
|
if wing_b and wing_b not in wings:
|
|
continue
|
|
|
|
tunnels.append(
|
|
{
|
|
"room": room,
|
|
"wings": wings,
|
|
"halls": data["halls"],
|
|
"count": data["count"],
|
|
"recent": data["dates"][-1] if data["dates"] else "",
|
|
}
|
|
)
|
|
|
|
tunnels.sort(key=lambda x: -x["count"])
|
|
return tunnels[:50]
|
|
|
|
|
|
def graph_stats(col=None, config=None):
|
|
"""Summary statistics about the palace graph."""
|
|
nodes, edges = build_graph(col, config)
|
|
|
|
tunnel_rooms = sum(1 for n in nodes.values() if len(n["wings"]) >= 2)
|
|
wing_counts = Counter()
|
|
for data in nodes.values():
|
|
for w in data["wings"]:
|
|
wing_counts[w] += 1
|
|
|
|
return {
|
|
"total_rooms": len(nodes),
|
|
"tunnel_rooms": tunnel_rooms,
|
|
"total_edges": len(edges),
|
|
"rooms_per_wing": dict(wing_counts.most_common()),
|
|
"top_tunnels": [
|
|
{"room": r, "wings": d["wings"], "count": d["count"]}
|
|
for r, d in sorted(nodes.items(), key=lambda x: -len(x[1]["wings"]))[:10]
|
|
if len(d["wings"]) >= 2
|
|
],
|
|
}
|
|
|
|
|
|
def _fuzzy_match(query: str, nodes: dict, n: int = 5):
|
|
"""Find rooms that approximately match a query string."""
|
|
query_lower = query.lower()
|
|
scored = []
|
|
for room in nodes:
|
|
# Simple substring matching
|
|
if query_lower in room:
|
|
scored.append((room, 1.0))
|
|
elif any(word in room for word in query_lower.split("-")):
|
|
scored.append((room, 0.5))
|
|
scored.sort(key=lambda x: -x[1])
|
|
return [r for r, _ in scored[:n]]
|
|
|
|
|
|
# =============================================================================
|
|
# EXPLICIT TUNNELS — agent-created cross-wing links
|
|
# =============================================================================
|
|
# Passive tunnels are discovered from shared room names across wings.
|
|
# Explicit tunnels are created by agents when they notice a connection
|
|
# between two specific drawers or rooms in different wings/projects.
|
|
#
|
|
# Stored as a JSON file at ~/.mempalace/tunnels.json so they persist
|
|
# across palace rebuilds (not in ChromaDB which can be recreated).
|
|
|
|
|
|
_TUNNEL_FILE = os.path.join(os.path.expanduser("~"), ".mempalace", "tunnels.json")
|
|
|
|
|
|
def _load_tunnels():
|
|
"""Load explicit tunnels from disk.
|
|
|
|
Returns an empty list if the file is missing or corrupt (e.g. truncated
|
|
by a crash mid-write on a system that lacks atomic-rename semantics).
|
|
"""
|
|
if not os.path.exists(_TUNNEL_FILE):
|
|
return []
|
|
try:
|
|
with open(_TUNNEL_FILE, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
except Exception:
|
|
return []
|
|
return data if isinstance(data, list) else []
|
|
|
|
|
|
def _save_tunnels(tunnels):
|
|
"""Persist explicit tunnels atomically.
|
|
|
|
Writes to ``tunnels.json.tmp`` then ``os.replace``s it into place, so
|
|
a crash mid-write can never leave a partial/empty tunnels.json that
|
|
silently wipes every tunnel on next read.
|
|
"""
|
|
os.makedirs(os.path.dirname(_TUNNEL_FILE), exist_ok=True)
|
|
tmp_path = _TUNNEL_FILE + ".tmp"
|
|
with open(tmp_path, "w", encoding="utf-8") as f:
|
|
json.dump(tunnels, f, indent=2)
|
|
f.flush()
|
|
try:
|
|
os.fsync(f.fileno())
|
|
except OSError:
|
|
# Not all filesystems (or Windows file handles) support fsync — tolerate.
|
|
pass
|
|
os.replace(tmp_path, _TUNNEL_FILE)
|
|
|
|
|
|
def _endpoint_key(wing: str, room: str) -> str:
|
|
return f"{wing}/{room}"
|
|
|
|
|
|
def _canonical_tunnel_id(
|
|
source_wing: str, source_room: str, target_wing: str, target_room: str
|
|
) -> str:
|
|
"""Compute a symmetric tunnel ID.
|
|
|
|
Tunnels are conceptually undirected — "auth relates to users" is the
|
|
same connection as "users relates to auth". Sort the two endpoints
|
|
before hashing so ``create_tunnel(A, B)`` and ``create_tunnel(B, A)``
|
|
resolve to the same ID and dedup into one record.
|
|
"""
|
|
src = _endpoint_key(source_wing, source_room)
|
|
tgt = _endpoint_key(target_wing, target_room)
|
|
a, b = sorted((src, tgt))
|
|
return hashlib.sha256(f"{a}↔{b}".encode()).hexdigest()[:16]
|
|
|
|
|
|
def _require_name(value: str, field: str) -> str:
|
|
"""Reject empty / non-string endpoint identifiers."""
|
|
if not isinstance(value, str) or not value.strip():
|
|
raise ValueError(f"{field} must be a non-empty string")
|
|
return value.strip()
|
|
|
|
|
|
def create_tunnel(
|
|
source_wing: str,
|
|
source_room: str,
|
|
target_wing: str,
|
|
target_room: str,
|
|
label: str = "",
|
|
source_drawer_id: str = None,
|
|
target_drawer_id: str = None,
|
|
):
|
|
"""Create an explicit (symmetric) tunnel between two locations in the palace.
|
|
|
|
Tunnels are undirected: ``create_tunnel(A, B)`` and ``create_tunnel(B, A)``
|
|
resolve to the same canonical ID. A second call with the same endpoints
|
|
updates the stored label (and drawer IDs, if provided) rather than
|
|
creating a duplicate.
|
|
|
|
The ``source`` / ``target`` fields on the returned dict preserve the
|
|
argument order the caller used, so callers can display it directionally
|
|
if they like. The ID and dedup are symmetric.
|
|
|
|
Args:
|
|
source_wing: Wing of the source (e.g., "project_api").
|
|
source_room: Room in the source wing.
|
|
target_wing: Wing of the target (e.g., "project_database").
|
|
target_room: Room in the target wing.
|
|
label: Description of the connection.
|
|
source_drawer_id: Optional specific drawer ID.
|
|
target_drawer_id: Optional specific drawer ID.
|
|
|
|
Returns:
|
|
The stored tunnel dict.
|
|
|
|
Raises:
|
|
ValueError: if any wing or room is empty or non-string.
|
|
"""
|
|
source_wing = _require_name(source_wing, "source_wing")
|
|
source_room = _require_name(source_room, "source_room")
|
|
target_wing = _require_name(target_wing, "target_wing")
|
|
target_room = _require_name(target_room, "target_room")
|
|
|
|
tunnel_id = _canonical_tunnel_id(source_wing, source_room, target_wing, target_room)
|
|
|
|
tunnel = {
|
|
"id": tunnel_id,
|
|
"source": {"wing": source_wing, "room": source_room},
|
|
"target": {"wing": target_wing, "room": target_room},
|
|
"label": label,
|
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
if source_drawer_id:
|
|
tunnel["source"]["drawer_id"] = source_drawer_id
|
|
if target_drawer_id:
|
|
tunnel["target"]["drawer_id"] = target_drawer_id
|
|
|
|
# Serialize the load → mutate → save cycle. Without this, two concurrent
|
|
# create_tunnel calls can both read the same snapshot and the later
|
|
# writer silently drops the earlier writer's tunnel.
|
|
with mine_lock(_TUNNEL_FILE):
|
|
tunnels = _load_tunnels()
|
|
for existing in tunnels:
|
|
if existing.get("id") == tunnel_id:
|
|
# Preserve original creation timestamp on label updates.
|
|
tunnel["created_at"] = existing.get("created_at", tunnel["created_at"])
|
|
tunnel["updated_at"] = datetime.now(timezone.utc).isoformat()
|
|
existing.clear()
|
|
existing.update(tunnel)
|
|
_save_tunnels(tunnels)
|
|
return existing
|
|
tunnels.append(tunnel)
|
|
_save_tunnels(tunnels)
|
|
return tunnel
|
|
|
|
|
|
def list_tunnels(wing: str = None):
|
|
"""List all explicit tunnels, optionally filtered by wing.
|
|
|
|
Returns tunnels where ``wing`` appears as either source or target
|
|
(tunnels are symmetric, so either endpoint is a valid filter match).
|
|
"""
|
|
tunnels = _load_tunnels()
|
|
if wing:
|
|
tunnels = [t for t in tunnels if t["source"]["wing"] == wing or t["target"]["wing"] == wing]
|
|
return tunnels
|
|
|
|
|
|
def delete_tunnel(tunnel_id: str):
|
|
"""Delete an explicit tunnel by ID. Returns ``{"deleted": <id>}``."""
|
|
with mine_lock(_TUNNEL_FILE):
|
|
tunnels = _load_tunnels()
|
|
tunnels = [t for t in tunnels if t.get("id") != tunnel_id]
|
|
_save_tunnels(tunnels)
|
|
return {"deleted": tunnel_id}
|
|
|
|
|
|
def follow_tunnels(wing: str, room: str, col=None, config=None):
|
|
"""Follow explicit tunnels from a room — returns connected drawers.
|
|
|
|
Given a location (wing/room), finds all tunnels leading from or to it,
|
|
and optionally fetches the connected drawer content.
|
|
"""
|
|
tunnels = _load_tunnels()
|
|
connections = []
|
|
|
|
for t in tunnels:
|
|
src = t["source"]
|
|
tgt = t["target"]
|
|
|
|
if src["wing"] == wing and src["room"] == room:
|
|
connections.append(
|
|
{
|
|
"direction": "outgoing",
|
|
"connected_wing": tgt["wing"],
|
|
"connected_room": tgt["room"],
|
|
"label": t.get("label", ""),
|
|
"drawer_id": tgt.get("drawer_id"),
|
|
"tunnel_id": t["id"],
|
|
}
|
|
)
|
|
elif tgt["wing"] == wing and tgt["room"] == room:
|
|
connections.append(
|
|
{
|
|
"direction": "incoming",
|
|
"connected_wing": src["wing"],
|
|
"connected_room": src["room"],
|
|
"label": t.get("label", ""),
|
|
"drawer_id": src.get("drawer_id"),
|
|
"tunnel_id": t["id"],
|
|
}
|
|
)
|
|
|
|
# If we have a collection, fetch drawer content for connected items
|
|
if col and connections:
|
|
drawer_ids = [c["drawer_id"] for c in connections if c.get("drawer_id")]
|
|
if drawer_ids:
|
|
try:
|
|
results = col.get(ids=drawer_ids, include=["documents", "metadatas"])
|
|
drawer_map = dict(zip(results["ids"], results["documents"]))
|
|
for c in connections:
|
|
did = c.get("drawer_id")
|
|
if did and did in drawer_map:
|
|
c["drawer_preview"] = drawer_map[did][:300]
|
|
except Exception:
|
|
pass
|
|
|
|
return connections
|