Merge pull request #387 from milla-jovovich/ben/security-hardening
security: harden inputs, fix shell injection, optimize DB access
This commit is contained in:
@@ -64,13 +64,20 @@ MEMPAL_DIR=""
|
|||||||
# Read JSON input from stdin
|
# Read JSON input from stdin
|
||||||
INPUT=$(cat)
|
INPUT=$(cat)
|
||||||
|
|
||||||
# Parse fields from Claude Code's JSON
|
# Parse all fields in a single Python call (3x faster than separate invocations)
|
||||||
SESSION_ID=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('session_id','unknown'))" 2>/dev/null)
|
eval $(echo "$INPUT" | python3 -c "
|
||||||
# Sanitize SESSION_ID to prevent path traversal (only allow alnum, dash, underscore)
|
import sys, json
|
||||||
SESSION_ID=$(echo "$SESSION_ID" | tr -cd 'a-zA-Z0-9_-')
|
data = json.load(sys.stdin)
|
||||||
[ -z "$SESSION_ID" ] && SESSION_ID="unknown"
|
sid = data.get('session_id', 'unknown')
|
||||||
STOP_HOOK_ACTIVE=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('stop_hook_active', False))" 2>/dev/null)
|
sha = data.get('stop_hook_active', False)
|
||||||
TRANSCRIPT_PATH=$(echo "$INPUT" | python3 -c "import sys,json; print(json.load(sys.stdin).get('transcript_path',''))" 2>/dev/null)
|
tp = data.get('transcript_path', '')
|
||||||
|
# Shell-safe output — only allow alphanumeric, underscore, hyphen, slash, dot, tilde
|
||||||
|
import re
|
||||||
|
safe = lambda s: re.sub(r'[^a-zA-Z0-9_/.\-~]', '', str(s))
|
||||||
|
print(f'SESSION_ID=\"{safe(sid)}\"')
|
||||||
|
print(f'STOP_HOOK_ACTIVE=\"{sha}\"')
|
||||||
|
print(f'TRANSCRIPT_PATH=\"{safe(tp)}\"')
|
||||||
|
" 2>/dev/null)
|
||||||
|
|
||||||
# Expand ~ in path
|
# Expand ~ in path
|
||||||
TRANSCRIPT_PATH="${TRANSCRIPT_PATH/#\~/$HOME}"
|
TRANSCRIPT_PATH="${TRANSCRIPT_PATH/#\~/$HOME}"
|
||||||
@@ -83,6 +90,7 @@ if [ "$STOP_HOOK_ACTIVE" = "True" ] || [ "$STOP_HOOK_ACTIVE" = "true" ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Count human messages in the JSONL transcript
|
# Count human messages in the JSONL transcript
|
||||||
|
# SECURITY: Pass transcript path as sys.argv to avoid shell injection via crafted paths
|
||||||
if [ -f "$TRANSCRIPT_PATH" ]; then
|
if [ -f "$TRANSCRIPT_PATH" ]; then
|
||||||
EXCHANGE_COUNT=$(python3 - "$TRANSCRIPT_PATH" <<'PYEOF'
|
EXCHANGE_COUNT=$(python3 - "$TRANSCRIPT_PATH" <<'PYEOF'
|
||||||
import json, sys
|
import json, sys
|
||||||
@@ -94,7 +102,6 @@ with open(sys.argv[1]) as f:
|
|||||||
msg = entry.get('message', {})
|
msg = entry.get('message', {})
|
||||||
if isinstance(msg, dict) and msg.get('role') == 'user':
|
if isinstance(msg, dict) and msg.get('role') == 'user':
|
||||||
content = msg.get('content', '')
|
content = msg.get('content', '')
|
||||||
# Skip system/command messages — only count real human input
|
|
||||||
if isinstance(content, str) and '<command-message>' in content:
|
if isinstance(content, str) and '<command-message>' in content:
|
||||||
continue
|
continue
|
||||||
count += 1
|
count += 1
|
||||||
|
|||||||
@@ -6,8 +6,58 @@ Priority: env vars > config file (~/.mempalace/config.json) > defaults
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
# ── Input validation ──────────────────────────────────────────────────────────
|
||||||
|
# Shared sanitizers for wing/room/entity names. Prevents path traversal,
|
||||||
|
# excessively long strings, and special characters that could cause issues
|
||||||
|
# in file paths, SQLite, or ChromaDB metadata.
|
||||||
|
|
||||||
|
MAX_NAME_LENGTH = 128
|
||||||
|
_SAFE_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_ .'-]{0,126}[a-zA-Z0-9]?$")
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_name(value: str, field_name: str = "name") -> str:
|
||||||
|
"""Validate and sanitize a wing/room/entity name.
|
||||||
|
|
||||||
|
Raises ValueError if the name is invalid.
|
||||||
|
"""
|
||||||
|
if not isinstance(value, str) or not value.strip():
|
||||||
|
raise ValueError(f"{field_name} must be a non-empty string")
|
||||||
|
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
if len(value) > MAX_NAME_LENGTH:
|
||||||
|
raise ValueError(f"{field_name} exceeds maximum length of {MAX_NAME_LENGTH} characters")
|
||||||
|
|
||||||
|
# Block path traversal
|
||||||
|
if ".." in value or "/" in value or "\\" in value:
|
||||||
|
raise ValueError(f"{field_name} contains invalid path characters")
|
||||||
|
|
||||||
|
# Block null bytes
|
||||||
|
if "\x00" in value:
|
||||||
|
raise ValueError(f"{field_name} contains null bytes")
|
||||||
|
|
||||||
|
# Enforce safe character set
|
||||||
|
if not _SAFE_NAME_RE.match(value):
|
||||||
|
raise ValueError(f"{field_name} contains invalid characters")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_content(value: str, max_length: int = 100_000) -> str:
|
||||||
|
"""Validate drawer/diary content length."""
|
||||||
|
if not isinstance(value, str) or not value.strip():
|
||||||
|
raise ValueError("content must be a non-empty string")
|
||||||
|
if len(value) > max_length:
|
||||||
|
raise ValueError(f"content exceeds maximum length of {max_length} characters")
|
||||||
|
if "\x00" in value:
|
||||||
|
raise ValueError("content contains null bytes")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace")
|
DEFAULT_PALACE_PATH = os.path.expanduser("~/.mempalace/palace")
|
||||||
DEFAULT_COLLECTION_NAME = "mempalace_drawers"
|
DEFAULT_COLLECTION_NAME = "mempalace_drawers"
|
||||||
|
|
||||||
@@ -126,6 +176,11 @@ class MempalaceConfig:
|
|||||||
def init(self):
|
def init(self):
|
||||||
"""Create config directory and write default config.json if it doesn't exist."""
|
"""Create config directory and write default config.json if it doesn't exist."""
|
||||||
self._config_dir.mkdir(parents=True, exist_ok=True)
|
self._config_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Restrict directory permissions to owner only (Unix)
|
||||||
|
try:
|
||||||
|
self._config_dir.chmod(0o700)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass # Windows doesn't support Unix permissions
|
||||||
if not self._config_file.exists():
|
if not self._config_file.exists():
|
||||||
default_config = {
|
default_config = {
|
||||||
"palace_path": DEFAULT_PALACE_PATH,
|
"palace_path": DEFAULT_PALACE_PATH,
|
||||||
@@ -135,6 +190,11 @@ class MempalaceConfig:
|
|||||||
}
|
}
|
||||||
with open(self._config_file, "w") as f:
|
with open(self._config_file, "w") as f:
|
||||||
json.dump(default_config, f, indent=2)
|
json.dump(default_config, f, indent=2)
|
||||||
|
# Restrict config file to owner read/write only
|
||||||
|
try:
|
||||||
|
self._config_file.chmod(0o600)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass
|
||||||
return self._config_file
|
return self._config_file
|
||||||
|
|
||||||
def save_people_map(self, people_map):
|
def save_people_map(self, people_map):
|
||||||
|
|||||||
+11
-35
@@ -15,9 +15,8 @@ from pathlib import Path
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import chromadb
|
|
||||||
|
|
||||||
from .normalize import normalize
|
from .normalize import normalize
|
||||||
|
from .palace import SKIP_DIRS, get_collection, file_already_mined
|
||||||
|
|
||||||
|
|
||||||
# File types that might contain conversations
|
# File types that might contain conversations
|
||||||
@@ -28,22 +27,8 @@ CONVO_EXTENSIONS = {
|
|||||||
".jsonl",
|
".jsonl",
|
||||||
}
|
}
|
||||||
|
|
||||||
SKIP_DIRS = {
|
|
||||||
".git",
|
|
||||||
"node_modules",
|
|
||||||
"__pycache__",
|
|
||||||
".venv",
|
|
||||||
"venv",
|
|
||||||
"env",
|
|
||||||
"dist",
|
|
||||||
"build",
|
|
||||||
".next",
|
|
||||||
".mempalace",
|
|
||||||
"tool-results",
|
|
||||||
"memory",
|
|
||||||
}
|
|
||||||
|
|
||||||
MIN_CHUNK_SIZE = 30
|
MIN_CHUNK_SIZE = 30
|
||||||
|
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -211,23 +196,6 @@ def detect_convo_room(content: str) -> str:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def get_collection(palace_path: str):
|
|
||||||
os.makedirs(palace_path, exist_ok=True)
|
|
||||||
client = chromadb.PersistentClient(path=palace_path)
|
|
||||||
try:
|
|
||||||
return client.get_collection("mempalace_drawers")
|
|
||||||
except Exception:
|
|
||||||
return client.create_collection("mempalace_drawers")
|
|
||||||
|
|
||||||
|
|
||||||
def file_already_mined(collection, source_file: str) -> bool:
|
|
||||||
try:
|
|
||||||
results = collection.get(where={"source_file": source_file}, limit=1)
|
|
||||||
return len(results.get("ids", [])) > 0
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# SCAN FOR CONVERSATION FILES
|
# SCAN FOR CONVERSATION FILES
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -244,6 +212,14 @@ def scan_convos(convo_dir: str) -> list:
|
|||||||
continue
|
continue
|
||||||
filepath = Path(root) / filename
|
filepath = Path(root) / filename
|
||||||
if filepath.suffix.lower() in CONVO_EXTENSIONS:
|
if filepath.suffix.lower() in CONVO_EXTENSIONS:
|
||||||
|
# Skip symlinks and oversized files
|
||||||
|
if filepath.is_symlink():
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if filepath.stat().st_size > MAX_FILE_SIZE:
|
||||||
|
continue
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
files.append(filepath)
|
files.append(filepath)
|
||||||
return files
|
return files
|
||||||
|
|
||||||
@@ -356,7 +332,7 @@ def mine_convos(
|
|||||||
chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room
|
chunk_room = chunk.get("memory_type", room) if extract_mode == "general" else room
|
||||||
if extract_mode == "general":
|
if extract_mode == "general":
|
||||||
room_counts[chunk_room] += 1
|
room_counts[chunk_room] += 1
|
||||||
drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.md5((source_file + str(chunk['chunk_index'])).encode(), usedforsecurity=False).hexdigest()[:16]}"
|
drawer_id = f"drawer_{wing}_{chunk_room}_{hashlib.sha256((source_file + str(chunk['chunk_index'])).encode()).hexdigest()[:24]}"
|
||||||
try:
|
try:
|
||||||
collection.add(
|
collection.add(
|
||||||
documents=[chunk["content"]],
|
documents=[chunk["content"]],
|
||||||
|
|||||||
@@ -50,11 +50,14 @@ class KnowledgeGraph:
|
|||||||
def __init__(self, db_path: str = None):
|
def __init__(self, db_path: str = None):
|
||||||
self.db_path = db_path or DEFAULT_KG_PATH
|
self.db_path = db_path or DEFAULT_KG_PATH
|
||||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._connection = None
|
||||||
self._init_db()
|
self._init_db()
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self):
|
||||||
conn = self._conn()
|
conn = self._conn()
|
||||||
conn.executescript("""
|
conn.executescript("""
|
||||||
|
PRAGMA journal_mode=WAL;
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS entities (
|
CREATE TABLE IF NOT EXISTS entities (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
@@ -84,12 +87,19 @@ class KnowledgeGraph:
|
|||||||
CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to);
|
CREATE INDEX IF NOT EXISTS idx_triples_valid ON triples(valid_from, valid_to);
|
||||||
""")
|
""")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def _conn(self):
|
def _conn(self):
|
||||||
conn = sqlite3.connect(self.db_path, timeout=10)
|
if self._connection is None:
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
self._connection = sqlite3.connect(self.db_path, timeout=10, check_same_thread=False)
|
||||||
return conn
|
self._connection.execute("PRAGMA journal_mode=WAL")
|
||||||
|
self._connection.row_factory = sqlite3.Row
|
||||||
|
return self._connection
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the database connection."""
|
||||||
|
if self._connection is not None:
|
||||||
|
self._connection.close()
|
||||||
|
self._connection = None
|
||||||
|
|
||||||
def _entity_id(self, name: str) -> str:
|
def _entity_id(self, name: str) -> str:
|
||||||
return name.lower().replace(" ", "_").replace("'", "")
|
return name.lower().replace(" ", "_").replace("'", "")
|
||||||
@@ -101,12 +111,11 @@ class KnowledgeGraph:
|
|||||||
eid = self._entity_id(name)
|
eid = self._entity_id(name)
|
||||||
props = json.dumps(properties or {})
|
props = json.dumps(properties or {})
|
||||||
conn = self._conn()
|
conn = self._conn()
|
||||||
conn.execute(
|
with conn:
|
||||||
"INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)",
|
conn.execute(
|
||||||
(eid, name, entity_type, props),
|
"INSERT OR REPLACE INTO entities (id, name, type, properties) VALUES (?, ?, ?, ?)",
|
||||||
)
|
(eid, name, entity_type, props),
|
||||||
conn.commit()
|
)
|
||||||
conn.close()
|
|
||||||
return eid
|
return eid
|
||||||
|
|
||||||
def add_triple(
|
def add_triple(
|
||||||
@@ -134,38 +143,38 @@ class KnowledgeGraph:
|
|||||||
|
|
||||||
# Auto-create entities if they don't exist
|
# Auto-create entities if they don't exist
|
||||||
conn = self._conn()
|
conn = self._conn()
|
||||||
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject))
|
with conn:
|
||||||
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj))
|
conn.execute(
|
||||||
|
"INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (sub_id, subject)
|
||||||
|
)
|
||||||
|
conn.execute("INSERT OR IGNORE INTO entities (id, name) VALUES (?, ?)", (obj_id, obj))
|
||||||
|
|
||||||
# Check for existing identical triple
|
# Check for existing identical triple
|
||||||
existing = conn.execute(
|
existing = conn.execute(
|
||||||
"SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
|
"SELECT id FROM triples WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
|
||||||
(sub_id, pred, obj_id),
|
(sub_id, pred, obj_id),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
conn.close()
|
return existing["id"] # Already exists and still valid
|
||||||
return existing[0] # Already exists and still valid
|
|
||||||
|
|
||||||
triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.md5(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:8]}"
|
triple_id = f"t_{sub_id}_{pred}_{obj_id}_{hashlib.sha256(f'{valid_from}{datetime.now().isoformat()}'.encode()).hexdigest()[:12]}"
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file)
|
"""INSERT INTO triples (id, subject, predicate, object, valid_from, valid_to, confidence, source_closet, source_file)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||||
(
|
(
|
||||||
triple_id,
|
triple_id,
|
||||||
sub_id,
|
sub_id,
|
||||||
pred,
|
pred,
|
||||||
obj_id,
|
obj_id,
|
||||||
valid_from,
|
valid_from,
|
||||||
valid_to,
|
valid_to,
|
||||||
confidence,
|
confidence,
|
||||||
source_closet,
|
source_closet,
|
||||||
source_file,
|
source_file,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return triple_id
|
return triple_id
|
||||||
|
|
||||||
def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None):
|
def invalidate(self, subject: str, predicate: str, obj: str, ended: str = None):
|
||||||
@@ -176,12 +185,11 @@ class KnowledgeGraph:
|
|||||||
ended = ended or date.today().isoformat()
|
ended = ended or date.today().isoformat()
|
||||||
|
|
||||||
conn = self._conn()
|
conn = self._conn()
|
||||||
conn.execute(
|
with conn:
|
||||||
"UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
|
conn.execute(
|
||||||
(ended, sub_id, pred, obj_id),
|
"UPDATE triples SET valid_to=? WHERE subject=? AND predicate=? AND object=? AND valid_to IS NULL",
|
||||||
)
|
(ended, sub_id, pred, obj_id),
|
||||||
conn.commit()
|
)
|
||||||
conn.close()
|
|
||||||
|
|
||||||
# ── Query operations ──────────────────────────────────────────────────
|
# ── Query operations ──────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -208,13 +216,13 @@ class KnowledgeGraph:
|
|||||||
{
|
{
|
||||||
"direction": "outgoing",
|
"direction": "outgoing",
|
||||||
"subject": name,
|
"subject": name,
|
||||||
"predicate": row[2],
|
"predicate": row["predicate"],
|
||||||
"object": row[10], # obj_name
|
"object": row["obj_name"],
|
||||||
"valid_from": row[4],
|
"valid_from": row["valid_from"],
|
||||||
"valid_to": row[5],
|
"valid_to": row["valid_to"],
|
||||||
"confidence": row[6],
|
"confidence": row["confidence"],
|
||||||
"source_closet": row[7],
|
"source_closet": row["source_closet"],
|
||||||
"current": row[5] is None,
|
"current": row["valid_to"] is None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -228,18 +236,17 @@ class KnowledgeGraph:
|
|||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"direction": "incoming",
|
"direction": "incoming",
|
||||||
"subject": row[10], # sub_name
|
"subject": row["sub_name"],
|
||||||
"predicate": row[2],
|
"predicate": row["predicate"],
|
||||||
"object": name,
|
"object": name,
|
||||||
"valid_from": row[4],
|
"valid_from": row["valid_from"],
|
||||||
"valid_to": row[5],
|
"valid_to": row["valid_to"],
|
||||||
"confidence": row[6],
|
"confidence": row["confidence"],
|
||||||
"source_closet": row[7],
|
"source_closet": row["source_closet"],
|
||||||
"current": row[5] is None,
|
"current": row["valid_to"] is None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
conn.close()
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def query_relationship(self, predicate: str, as_of: str = None):
|
def query_relationship(self, predicate: str, as_of: str = None):
|
||||||
@@ -262,15 +269,14 @@ class KnowledgeGraph:
|
|||||||
for row in conn.execute(query, params).fetchall():
|
for row in conn.execute(query, params).fetchall():
|
||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"subject": row[10],
|
"subject": row["sub_name"],
|
||||||
"predicate": pred,
|
"predicate": pred,
|
||||||
"object": row[11],
|
"object": row["obj_name"],
|
||||||
"valid_from": row[4],
|
"valid_from": row["valid_from"],
|
||||||
"valid_to": row[5],
|
"valid_to": row["valid_to"],
|
||||||
"current": row[5] is None,
|
"current": row["valid_to"] is None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
conn.close()
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def timeline(self, entity_name: str = None):
|
def timeline(self, entity_name: str = None):
|
||||||
@@ -300,15 +306,14 @@ class KnowledgeGraph:
|
|||||||
LIMIT 100
|
LIMIT 100
|
||||||
""").fetchall()
|
""").fetchall()
|
||||||
|
|
||||||
conn.close()
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"subject": r[10],
|
"subject": r["sub_name"],
|
||||||
"predicate": r[2],
|
"predicate": r["predicate"],
|
||||||
"object": r[11],
|
"object": r["obj_name"],
|
||||||
"valid_from": r[4],
|
"valid_from": r["valid_from"],
|
||||||
"valid_to": r[5],
|
"valid_to": r["valid_to"],
|
||||||
"current": r[5] is None,
|
"current": r["valid_to"] is None,
|
||||||
}
|
}
|
||||||
for r in rows
|
for r in rows
|
||||||
]
|
]
|
||||||
@@ -317,17 +322,18 @@ class KnowledgeGraph:
|
|||||||
|
|
||||||
def stats(self):
|
def stats(self):
|
||||||
conn = self._conn()
|
conn = self._conn()
|
||||||
entities = conn.execute("SELECT COUNT(*) FROM entities").fetchone()[0]
|
entities = conn.execute("SELECT COUNT(*) as cnt FROM entities").fetchone()["cnt"]
|
||||||
triples = conn.execute("SELECT COUNT(*) FROM triples").fetchone()[0]
|
triples = conn.execute("SELECT COUNT(*) as cnt FROM triples").fetchone()["cnt"]
|
||||||
current = conn.execute("SELECT COUNT(*) FROM triples WHERE valid_to IS NULL").fetchone()[0]
|
current = conn.execute(
|
||||||
|
"SELECT COUNT(*) as cnt FROM triples WHERE valid_to IS NULL"
|
||||||
|
).fetchone()["cnt"]
|
||||||
expired = triples - current
|
expired = triples - current
|
||||||
predicates = [
|
predicates = [
|
||||||
r[0]
|
r["predicate"]
|
||||||
for r in conn.execute(
|
for r in conn.execute(
|
||||||
"SELECT DISTINCT predicate FROM triples ORDER BY predicate"
|
"SELECT DISTINCT predicate FROM triples ORDER BY predicate"
|
||||||
).fetchall()
|
).fetchall()
|
||||||
]
|
]
|
||||||
conn.close()
|
|
||||||
return {
|
return {
|
||||||
"entities": entities,
|
"entities": entities,
|
||||||
"triples": triples,
|
"triples": triples,
|
||||||
|
|||||||
+126
-8
@@ -24,8 +24,9 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import hashlib
|
import hashlib
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from .config import MempalaceConfig
|
from .config import MempalaceConfig, sanitize_name, sanitize_content
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
from .searcher import search_memories
|
from .searcher import search_memories
|
||||||
from .palace_graph import traverse, find_tunnels, graph_stats
|
from .palace_graph import traverse, find_tunnels, graph_stats
|
||||||
@@ -66,16 +67,60 @@ _client_cache = None
|
|||||||
_collection_cache = None
|
_collection_cache = None
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== WRITE-AHEAD LOG ====================
|
||||||
|
# Every write operation is logged to a JSONL file before execution.
|
||||||
|
# This provides an audit trail for detecting memory poisoning and
|
||||||
|
# enables review/rollback of writes from external or untrusted sources.
|
||||||
|
|
||||||
|
_WAL_DIR = Path(os.path.expanduser("~/.mempalace/wal"))
|
||||||
|
_WAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
_WAL_DIR.chmod(0o700)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass
|
||||||
|
_WAL_FILE = _WAL_DIR / "write_log.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def _wal_log(operation: str, params: dict, result: dict = None):
|
||||||
|
"""Append a write operation to the write-ahead log."""
|
||||||
|
entry = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"operation": operation,
|
||||||
|
"params": params,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
with open(_WAL_FILE, "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(entry, default=str) + "\n")
|
||||||
|
try:
|
||||||
|
_WAL_FILE.chmod(0o600)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WAL write failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
_client_cache = None
|
||||||
|
_collection_cache = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client():
|
||||||
|
"""Return a singleton ChromaDB PersistentClient."""
|
||||||
|
global _client_cache
|
||||||
|
if _client_cache is None:
|
||||||
|
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
|
||||||
|
return _client_cache
|
||||||
|
|
||||||
|
|
||||||
def _get_collection(create=False):
|
def _get_collection(create=False):
|
||||||
"""Return the ChromaDB collection, caching the client between calls."""
|
"""Return the ChromaDB collection, caching the client between calls."""
|
||||||
global _client_cache, _collection_cache
|
global _collection_cache
|
||||||
try:
|
try:
|
||||||
if _client_cache is None:
|
client = _get_client()
|
||||||
_client_cache = chromadb.PersistentClient(path=_config.palace_path)
|
|
||||||
if create:
|
if create:
|
||||||
_collection_cache = _client_cache.get_or_create_collection(_config.collection_name)
|
_collection_cache = client.get_or_create_collection(_config.collection_name)
|
||||||
elif _collection_cache is None:
|
elif _collection_cache is None:
|
||||||
_collection_cache = _client_cache.get_collection(_config.collection_name)
|
_collection_cache = client.get_collection(_config.collection_name)
|
||||||
return _collection_cache
|
return _collection_cache
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
@@ -282,11 +327,30 @@ def tool_add_drawer(
|
|||||||
wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp"
|
wing: str, room: str, content: str, source_file: str = None, added_by: str = "mcp"
|
||||||
):
|
):
|
||||||
"""File verbatim content into a wing/room. Checks for duplicates first."""
|
"""File verbatim content into a wing/room. Checks for duplicates first."""
|
||||||
|
try:
|
||||||
|
wing = sanitize_name(wing, "wing")
|
||||||
|
room = sanitize_name(room, "room")
|
||||||
|
content = sanitize_content(content)
|
||||||
|
except ValueError as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
col = _get_collection(create=True)
|
col = _get_collection(create=True)
|
||||||
if not col:
|
if not col:
|
||||||
return _no_palace()
|
return _no_palace()
|
||||||
|
|
||||||
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5(content.encode()).hexdigest()[:16]}"
|
drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((wing + room + content[:100]).encode()).hexdigest()[:24]}"
|
||||||
|
|
||||||
|
_wal_log(
|
||||||
|
"add_drawer",
|
||||||
|
{
|
||||||
|
"drawer_id": drawer_id,
|
||||||
|
"wing": wing,
|
||||||
|
"room": room,
|
||||||
|
"added_by": added_by,
|
||||||
|
"content_length": len(content),
|
||||||
|
"content_preview": content[:200],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Idempotency: if the deterministic ID already exists, return success as a no-op.
|
# Idempotency: if the deterministic ID already exists, return success as a no-op.
|
||||||
try:
|
try:
|
||||||
@@ -325,6 +389,19 @@ def tool_delete_drawer(drawer_id: str):
|
|||||||
existing = col.get(ids=[drawer_id])
|
existing = col.get(ids=[drawer_id])
|
||||||
if not existing["ids"]:
|
if not existing["ids"]:
|
||||||
return {"success": False, "error": f"Drawer not found: {drawer_id}"}
|
return {"success": False, "error": f"Drawer not found: {drawer_id}"}
|
||||||
|
|
||||||
|
# Log the deletion with the content being removed for audit trail
|
||||||
|
deleted_content = existing.get("documents", [""])[0] if existing.get("documents") else ""
|
||||||
|
deleted_meta = existing.get("metadatas", [{}])[0] if existing.get("metadatas") else {}
|
||||||
|
_wal_log(
|
||||||
|
"delete_drawer",
|
||||||
|
{
|
||||||
|
"drawer_id": drawer_id,
|
||||||
|
"deleted_meta": deleted_meta,
|
||||||
|
"content_preview": deleted_content[:200],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
col.delete(ids=[drawer_id])
|
col.delete(ids=[drawer_id])
|
||||||
logger.info(f"Deleted drawer: {drawer_id}")
|
logger.info(f"Deleted drawer: {drawer_id}")
|
||||||
@@ -346,6 +423,23 @@ def tool_kg_add(
|
|||||||
subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None
|
subject: str, predicate: str, object: str, valid_from: str = None, source_closet: str = None
|
||||||
):
|
):
|
||||||
"""Add a relationship to the knowledge graph."""
|
"""Add a relationship to the knowledge graph."""
|
||||||
|
try:
|
||||||
|
subject = sanitize_name(subject, "subject")
|
||||||
|
predicate = sanitize_name(predicate, "predicate")
|
||||||
|
object = sanitize_name(object, "object")
|
||||||
|
except ValueError as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
_wal_log(
|
||||||
|
"kg_add",
|
||||||
|
{
|
||||||
|
"subject": subject,
|
||||||
|
"predicate": predicate,
|
||||||
|
"object": object,
|
||||||
|
"valid_from": valid_from,
|
||||||
|
"source_closet": source_closet,
|
||||||
|
},
|
||||||
|
)
|
||||||
triple_id = _kg.add_triple(
|
triple_id = _kg.add_triple(
|
||||||
subject, predicate, object, valid_from=valid_from, source_closet=source_closet
|
subject, predicate, object, valid_from=valid_from, source_closet=source_closet
|
||||||
)
|
)
|
||||||
@@ -354,6 +448,10 @@ def tool_kg_add(
|
|||||||
|
|
||||||
def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None):
|
def tool_kg_invalidate(subject: str, predicate: str, object: str, ended: str = None):
|
||||||
"""Mark a fact as no longer true (set end date)."""
|
"""Mark a fact as no longer true (set end date)."""
|
||||||
|
_wal_log(
|
||||||
|
"kg_invalidate",
|
||||||
|
{"subject": subject, "predicate": predicate, "object": object, "ended": ended},
|
||||||
|
)
|
||||||
_kg.invalidate(subject, predicate, object, ended=ended)
|
_kg.invalidate(subject, predicate, object, ended=ended)
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -384,6 +482,12 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
|
|||||||
This is the agent's personal journal — observations, thoughts,
|
This is the agent's personal journal — observations, thoughts,
|
||||||
what it worked on, what it noticed, what it thinks matters.
|
what it worked on, what it noticed, what it thinks matters.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
agent_name = sanitize_name(agent_name, "agent_name")
|
||||||
|
entry = sanitize_content(entry)
|
||||||
|
except ValueError as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
wing = f"wing_{agent_name.lower().replace(' ', '_')}"
|
wing = f"wing_{agent_name.lower().replace(' ', '_')}"
|
||||||
room = "diary"
|
room = "diary"
|
||||||
col = _get_collection(create=True)
|
col = _get_collection(create=True)
|
||||||
@@ -391,9 +495,23 @@ def tool_diary_write(agent_name: str, entry: str, topic: str = "general"):
|
|||||||
return _no_palace()
|
return _no_palace()
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.md5(entry[:50].encode()).hexdigest()[:8]}"
|
entry_id = f"diary_{wing}_{now.strftime('%Y%m%d_%H%M%S')}_{hashlib.sha256(entry[:50].encode()).hexdigest()[:12]}"
|
||||||
|
|
||||||
|
_wal_log(
|
||||||
|
"diary_write",
|
||||||
|
{
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"topic": topic,
|
||||||
|
"entry_id": entry_id,
|
||||||
|
"entry_preview": entry[:200],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# TODO: Future versions should expand AAAK before embedding to improve
|
||||||
|
# semantic search quality. For now, store raw AAAK in metadata so it's
|
||||||
|
# preserved, and keep the document as-is for embedding (even though
|
||||||
|
# compressed AAAK degrades embedding quality).
|
||||||
col.add(
|
col.add(
|
||||||
ids=[entry_id],
|
ids=[entry_id],
|
||||||
documents=[entry],
|
documents=[entry],
|
||||||
|
|||||||
+14
-58
@@ -17,6 +17,8 @@ from collections import defaultdict
|
|||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
|
|
||||||
|
from .palace import SKIP_DIRS, get_collection, file_already_mined
|
||||||
|
|
||||||
READABLE_EXTENSIONS = {
|
READABLE_EXTENSIONS = {
|
||||||
".txt",
|
".txt",
|
||||||
".md",
|
".md",
|
||||||
@@ -40,32 +42,6 @@ READABLE_EXTENSIONS = {
|
|||||||
".toml",
|
".toml",
|
||||||
}
|
}
|
||||||
|
|
||||||
SKIP_DIRS = {
|
|
||||||
".git",
|
|
||||||
"node_modules",
|
|
||||||
"__pycache__",
|
|
||||||
".venv",
|
|
||||||
"venv",
|
|
||||||
"env",
|
|
||||||
"dist",
|
|
||||||
"build",
|
|
||||||
".next",
|
|
||||||
"coverage",
|
|
||||||
".mempalace",
|
|
||||||
".ruff_cache",
|
|
||||||
".mypy_cache",
|
|
||||||
".pytest_cache",
|
|
||||||
".cache",
|
|
||||||
".tox",
|
|
||||||
".nox",
|
|
||||||
".idea",
|
|
||||||
".vscode",
|
|
||||||
".ipynb_checkpoints",
|
|
||||||
".eggs",
|
|
||||||
"htmlcov",
|
|
||||||
"target",
|
|
||||||
}
|
|
||||||
|
|
||||||
SKIP_FILENAMES = {
|
SKIP_FILENAMES = {
|
||||||
"mempalace.yaml",
|
"mempalace.yaml",
|
||||||
"mempalace.yml",
|
"mempalace.yml",
|
||||||
@@ -78,6 +54,7 @@ SKIP_FILENAMES = {
|
|||||||
CHUNK_SIZE = 800 # chars per drawer
|
CHUNK_SIZE = 800 # chars per drawer
|
||||||
CHUNK_OVERLAP = 100 # overlap between chunks
|
CHUNK_OVERLAP = 100 # overlap between chunks
|
||||||
MIN_CHUNK_SIZE = 50 # skip tiny chunks
|
MIN_CHUNK_SIZE = 50 # skip tiny chunks
|
||||||
|
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB — skip files larger than this
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -393,41 +370,11 @@ def chunk_text(content: str, source_file: str) -> list:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def get_collection(palace_path: str):
|
|
||||||
os.makedirs(palace_path, exist_ok=True)
|
|
||||||
client = chromadb.PersistentClient(path=palace_path)
|
|
||||||
try:
|
|
||||||
return client.get_collection("mempalace_drawers")
|
|
||||||
except Exception:
|
|
||||||
return client.create_collection("mempalace_drawers")
|
|
||||||
|
|
||||||
|
|
||||||
def file_already_mined(collection, source_file: str) -> bool:
|
|
||||||
"""Fast check: has this file been filed before and is unchanged?
|
|
||||||
|
|
||||||
Compares the stored mtime in drawer metadata against the file's current
|
|
||||||
mtime. Returns False (needs re-mining) when the file has been modified
|
|
||||||
since it was last mined, or when no mtime was stored.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
results = collection.get(where={"source_file": source_file}, limit=1)
|
|
||||||
if not results.get("ids"):
|
|
||||||
return False
|
|
||||||
stored_meta = results["metadatas"][0] if results.get("metadatas") else {}
|
|
||||||
stored_mtime = stored_meta.get("source_mtime")
|
|
||||||
if stored_mtime is None:
|
|
||||||
return False
|
|
||||||
current_mtime = os.path.getmtime(source_file)
|
|
||||||
return float(stored_mtime) == current_mtime
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def add_drawer(
|
def add_drawer(
|
||||||
collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str
|
collection, wing: str, room: str, content: str, source_file: str, chunk_index: int, agent: str
|
||||||
):
|
):
|
||||||
"""Add one drawer to the palace."""
|
"""Add one drawer to the palace."""
|
||||||
drawer_id = f"drawer_{wing}_{room}_{hashlib.md5((source_file + str(chunk_index)).encode(), usedforsecurity=False).hexdigest()[:16]}"
|
drawer_id = f"drawer_{wing}_{room}_{hashlib.sha256((source_file + str(chunk_index)).encode()).hexdigest()[:24]}"
|
||||||
try:
|
try:
|
||||||
metadata = {
|
metadata = {
|
||||||
"wing": wing,
|
"wing": wing,
|
||||||
@@ -470,7 +417,7 @@ def process_file(
|
|||||||
|
|
||||||
# Skip if already filed
|
# Skip if already filed
|
||||||
source_file = str(filepath)
|
source_file = str(filepath)
|
||||||
if not dry_run and file_already_mined(collection, source_file):
|
if not dry_run and file_already_mined(collection, source_file, check_mtime=True):
|
||||||
return 0, None
|
return 0, None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -562,6 +509,15 @@ def scan_project(
|
|||||||
if respect_gitignore and active_matchers and not force_include:
|
if respect_gitignore and active_matchers and not force_include:
|
||||||
if is_gitignored(filepath, active_matchers, is_dir=False):
|
if is_gitignored(filepath, active_matchers, is_dir=False):
|
||||||
continue
|
continue
|
||||||
|
# Skip symlinks — prevents following links to /dev/urandom, etc.
|
||||||
|
if filepath.is_symlink():
|
||||||
|
continue
|
||||||
|
# Skip files exceeding size limit
|
||||||
|
try:
|
||||||
|
if filepath.stat().st_size > MAX_FILE_SIZE:
|
||||||
|
continue
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
files.append(filepath)
|
files.append(filepath)
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
"""
|
||||||
|
palace.py — Shared palace operations.
|
||||||
|
|
||||||
|
Consolidates ChromaDB access patterns used by both miners and the MCP server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
SKIP_DIRS = {
|
||||||
|
".git",
|
||||||
|
"node_modules",
|
||||||
|
"__pycache__",
|
||||||
|
".venv",
|
||||||
|
"venv",
|
||||||
|
"env",
|
||||||
|
"dist",
|
||||||
|
"build",
|
||||||
|
".next",
|
||||||
|
"coverage",
|
||||||
|
".mempalace",
|
||||||
|
".ruff_cache",
|
||||||
|
".mypy_cache",
|
||||||
|
".pytest_cache",
|
||||||
|
".cache",
|
||||||
|
".tox",
|
||||||
|
".nox",
|
||||||
|
".idea",
|
||||||
|
".vscode",
|
||||||
|
".ipynb_checkpoints",
|
||||||
|
".eggs",
|
||||||
|
"htmlcov",
|
||||||
|
"target",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_collection(palace_path: str, collection_name: str = "mempalace_drawers"):
|
||||||
|
"""Get or create the palace ChromaDB collection."""
|
||||||
|
os.makedirs(palace_path, exist_ok=True)
|
||||||
|
try:
|
||||||
|
os.chmod(palace_path, 0o700)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass
|
||||||
|
client = chromadb.PersistentClient(path=palace_path)
|
||||||
|
try:
|
||||||
|
return client.get_collection(collection_name)
|
||||||
|
except Exception:
|
||||||
|
return client.create_collection(collection_name)
|
||||||
|
|
||||||
|
|
||||||
|
def file_already_mined(collection, source_file: str, check_mtime: bool = False) -> bool:
|
||||||
|
"""Check if a file has already been filed in the palace.
|
||||||
|
|
||||||
|
When check_mtime=True (used by project miner), returns False if the file
|
||||||
|
has been modified since it was last mined, so it gets re-mined.
|
||||||
|
When check_mtime=False (used by convo miner), just checks existence.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
results = collection.get(where={"source_file": source_file}, limit=1)
|
||||||
|
if not results.get("ids"):
|
||||||
|
return False
|
||||||
|
if check_mtime:
|
||||||
|
stored_meta = results.get("metadatas", [{}])[0]
|
||||||
|
stored_mtime = stored_meta.get("source_mtime")
|
||||||
|
if stored_mtime is None:
|
||||||
|
return False
|
||||||
|
current_mtime = os.path.getmtime(source_file)
|
||||||
|
return float(stored_mtime) == current_mtime
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
+1
-1
@@ -26,7 +26,7 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"chromadb>=0.5.0,<0.7",
|
"chromadb>=0.5.0,<0.7",
|
||||||
"pyyaml>=6.0",
|
"pyyaml>=6.0,<7",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from mempalace.miner import mine, scan_project
|
from mempalace.miner import mine, scan_project
|
||||||
|
from mempalace.palace import file_already_mined
|
||||||
|
|
||||||
|
|
||||||
def write_file(path: Path, content: str):
|
def write_file(path: Path, content: str):
|
||||||
@@ -206,3 +208,54 @@ def test_scan_project_skip_dirs_still_apply_without_override():
|
|||||||
assert scanned_files(project_root, respect_gitignore=False) == ["main.py"]
|
assert scanned_files(project_root, respect_gitignore=False) == ["main.py"]
|
||||||
finally:
|
finally:
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_already_mined_check_mtime():
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
|
palace_path = os.path.join(tmpdir, "palace")
|
||||||
|
os.makedirs(palace_path)
|
||||||
|
client = chromadb.PersistentClient(path=palace_path)
|
||||||
|
col = client.get_or_create_collection("mempalace_drawers")
|
||||||
|
|
||||||
|
test_file = os.path.join(tmpdir, "test.txt")
|
||||||
|
with open(test_file, "w") as f:
|
||||||
|
f.write("hello world")
|
||||||
|
|
||||||
|
mtime = os.path.getmtime(test_file)
|
||||||
|
|
||||||
|
# Not mined yet
|
||||||
|
assert file_already_mined(col, test_file) is False
|
||||||
|
assert file_already_mined(col, test_file, check_mtime=True) is False
|
||||||
|
|
||||||
|
# Add it with mtime
|
||||||
|
col.add(
|
||||||
|
ids=["d1"],
|
||||||
|
documents=["hello world"],
|
||||||
|
metadatas=[{"source_file": test_file, "source_mtime": str(mtime)}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Already mined (no mtime check)
|
||||||
|
assert file_already_mined(col, test_file) is True
|
||||||
|
# Already mined (mtime matches)
|
||||||
|
assert file_already_mined(col, test_file, check_mtime=True) is True
|
||||||
|
|
||||||
|
# Modify file so mtime changes
|
||||||
|
time.sleep(0.1)
|
||||||
|
with open(test_file, "w") as f:
|
||||||
|
f.write("modified content")
|
||||||
|
|
||||||
|
# Still mined without mtime check
|
||||||
|
assert file_already_mined(col, test_file) is True
|
||||||
|
# Needs re-mining with mtime check
|
||||||
|
assert file_already_mined(col, test_file, check_mtime=True) is False
|
||||||
|
|
||||||
|
# Record with no mtime stored should return False for check_mtime
|
||||||
|
col.add(
|
||||||
|
ids=["d2"],
|
||||||
|
documents=["other"],
|
||||||
|
metadatas=[{"source_file": "/fake/no_mtime.txt"}],
|
||||||
|
)
|
||||||
|
assert file_already_mined(col, "/fake/no_mtime.txt", check_mtime=True) is False
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|||||||
Reference in New Issue
Block a user