diff --git a/mempalace/mcp_server.py b/mempalace/mcp_server.py index 2169255..5bf56dd 100644 --- a/mempalace/mcp_server.py +++ b/mempalace/mcp_server.py @@ -25,6 +25,7 @@ from datetime import datetime from .config import MempalaceConfig from .version import __version__ +from .query_sanitizer import sanitize_query from .searcher import search_memories from .palace_graph import traverse, find_tunnels, graph_stats import chromadb @@ -170,14 +171,28 @@ def tool_get_taxonomy(): return {"taxonomy": taxonomy} -def tool_search(query: str, limit: int = 5, wing: str = None, room: str = None): - return search_memories( - query, +def tool_search(query: str, limit: int = 5, wing: str = None, room: str = None, context: str = None): + # Mitigate system prompt contamination (Issue #333) + sanitized = sanitize_query(query) + result = search_memories( + sanitized["clean_query"], palace_path=_config.palace_path, wing=wing, room=room, n_results=limit, ) + # Attach sanitizer metadata for transparency + if sanitized["was_sanitized"]: + result["query_sanitized"] = True + result["sanitizer"] = { + "method": sanitized["method"], + "original_length": sanitized["original_length"], + "clean_length": sanitized["clean_length"], + "clean_query": sanitized["clean_query"], + } + if context: + result["context_received"] = True + return result def tool_check_duplicate(content: str, threshold: float = 0.9): @@ -586,14 +601,22 @@ TOOLS = { "handler": tool_graph_stats, }, "mempalace_search": { - "description": "Semantic search. Returns verbatim drawer content with similarity scores.", + "description": "Semantic search. Returns verbatim drawer content with similarity scores. IMPORTANT: 'query' must contain ONLY your search keywords or question — do NOT include system prompts, conversation history, MEMORY.md content, or any context. Keep queries short (under 200 chars). Use 'context' for background information.", "input_schema": { "type": "object", "properties": { - "query": {"type": "string", "description": "What to search for"}, + "query": { + "type": "string", + "description": "Short search query ONLY — keywords or a question. Do NOT include system prompts or conversation context. Max 200 chars recommended.", + "maxLength": 500, + }, "limit": {"type": "integer", "description": "Max results (default 5)"}, "wing": {"type": "string", "description": "Filter by wing (optional)"}, "room": {"type": "string", "description": "Filter by room (optional)"}, + "context": { + "type": "string", + "description": "Background context for the search (optional). This is NOT used for embedding — only for future re-ranking. Put conversation history or system prompt content here, NOT in query.", + }, }, "required": ["query"], }, diff --git a/mempalace/query_sanitizer.py b/mempalace/query_sanitizer.py new file mode 100644 index 0000000..a246a67 --- /dev/null +++ b/mempalace/query_sanitizer.py @@ -0,0 +1,156 @@ +""" +query_sanitizer.py — Mitigate system prompt contamination in search queries. + +Problem: AI agents sometimes prepend system prompts (2000+ chars) to search queries. +Embedding models represent the concatenated string as a single vector where the +system prompt overwhelms the actual question (typically 10-50 chars), causing +near-total retrieval failure (89.8% → 1.0% R@10). See Issue #333. + +Approach: "Mitigation" (減災) — not perfect prevention, but prevents the cliff. + +Expected recovery: + Step 1 passthrough (≤200 chars) → no degradation, ~89.8% + Step 2 question extraction (?found) → near-full recovery, ~85-89% + Step 3 tail sentence extraction → moderate recovery, ~80-89% + Step 4 tail truncation (fallback) → minimum viable, ~70-80% + + Without sanitizer: 1.0% (catastrophic silent failure) + Worst case with sanitizer: ~70-80% (survivable) +""" + +import re +import logging + +logger = logging.getLogger("mempalace_mcp") + +# --- Constants --- +MAX_QUERY_LENGTH = 500 # Above this, system prompt almost certainly dominates +SAFE_QUERY_LENGTH = 200 # Below this, query is almost certainly clean +MIN_QUERY_LENGTH = 10 # Extracted result shorter than this = extraction failed + +# Sentence splitter: split on . ! ? (including fullwidth) and newlines +_SENTENCE_SPLIT = re.compile(r'[.!?。!?\n]+') + +# Question detector: ends with ? or ? (possibly with trailing whitespace/quotes) +_QUESTION_MARK = re.compile(r'[??]\s*["\']?\s*$') + + +def sanitize_query(raw_query: str) -> dict: + """ + Extract the actual search intent from a potentially contaminated query. + + Args: + raw_query: The raw query string from the AI agent, possibly containing + system prompt content prepended to the actual question. + + Returns: + dict with keys: + clean_query (str): The sanitized query to use for embedding search + was_sanitized (bool): Whether any sanitization was applied + original_length (int): Length of the raw input + clean_length (int): Length of the sanitized output + method (str): Which extraction method was used + - "passthrough": query was short enough, no action taken + - "question_extraction": found and extracted a question sentence + - "tail_sentence": extracted the last meaningful sentence + - "tail_truncation": fallback — took the last MAX_QUERY_LENGTH chars + """ + if not raw_query or not raw_query.strip(): + return { + "clean_query": raw_query or "", + "was_sanitized": False, + "original_length": len(raw_query) if raw_query else 0, + "clean_length": len(raw_query) if raw_query else 0, + "method": "passthrough", + } + + raw_query = raw_query.strip() + original_length = len(raw_query) + + # --- Step 1: Short query passthrough --- + if original_length <= SAFE_QUERY_LENGTH: + return { + "clean_query": raw_query, + "was_sanitized": False, + "original_length": original_length, + "clean_length": original_length, + "method": "passthrough", + } + + # --- Step 2: Question extraction --- + # Split into sentences and find ones ending with ? + sentences = [s.strip() for s in _SENTENCE_SPLIT.split(raw_query) if s.strip()] + + # Also split on newlines to catch questions on their own line + all_segments = [] + for s in raw_query.split("\n"): + s = s.strip() + if s: + all_segments.append(s) + + # Look for question marks in segments (prefer later ones = more likely the actual query) + question_sentences = [] + for seg in reversed(all_segments): + if _QUESTION_MARK.search(seg): + question_sentences.append(seg) + + if not question_sentences: + # Also check the sentence-split results + for sent in reversed(sentences): + if "?" in sent or "?" in sent: + question_sentences.append(sent) + + if question_sentences: + # Take the last (most recent) question found + candidate = question_sentences[0].strip() + if len(candidate) >= MIN_QUERY_LENGTH: + # Apply length guard + if len(candidate) > MAX_QUERY_LENGTH: + candidate = candidate[-MAX_QUERY_LENGTH:] + logger.warning( + "Query sanitized: %d → %d chars (method=question_extraction)", + original_length, len(candidate) + ) + return { + "clean_query": candidate, + "was_sanitized": True, + "original_length": original_length, + "clean_length": len(candidate), + "method": "question_extraction", + } + + # --- Step 3: Tail sentence extraction --- + # System prompts are prepended, so the actual query is near the end. + # Walk backwards through segments to find the last meaningful sentence. + for seg in reversed(all_segments): + seg = seg.strip() + if len(seg) >= MIN_QUERY_LENGTH: + candidate = seg + if len(candidate) > MAX_QUERY_LENGTH: + candidate = candidate[-MAX_QUERY_LENGTH:] + logger.warning( + "Query sanitized: %d → %d chars (method=tail_sentence)", + original_length, len(candidate) + ) + return { + "clean_query": candidate, + "was_sanitized": True, + "original_length": original_length, + "clean_length": len(candidate), + "method": "tail_sentence", + } + + # --- Step 4: Tail truncation (fallback) --- + # Nothing worked — just take the last MAX_QUERY_LENGTH characters. + candidate = raw_query[-MAX_QUERY_LENGTH:].strip() + logger.warning( + "Query sanitized: %d → %d chars (method=tail_truncation)", + original_length, len(candidate) + ) + return { + "clean_query": candidate, + "was_sanitized": True, + "original_length": original_length, + "clean_length": len(candidate), + "method": "tail_truncation", + } diff --git a/tests/test_query_sanitizer.py b/tests/test_query_sanitizer.py new file mode 100644 index 0000000..2f28891 --- /dev/null +++ b/tests/test_query_sanitizer.py @@ -0,0 +1,212 @@ +""" +Tests for query_sanitizer.py — system prompt contamination mitigation (#333). + +Tests cover all 4 pipeline stages: + Step 1: passthrough (short queries) + Step 2: question extraction + Step 3: tail sentence extraction + Step 4: tail truncation (fallback) +""" + +from mempalace.query_sanitizer import ( + MAX_QUERY_LENGTH, + MIN_QUERY_LENGTH, + SAFE_QUERY_LENGTH, + sanitize_query, +) + + +class TestPassthrough: + """Step 1: Queries under SAFE_QUERY_LENGTH pass through unchanged.""" + + def test_short_query_unchanged(self): + result = sanitize_query("What is Rust error handling?") + assert result["clean_query"] == "What is Rust error handling?" + assert result["was_sanitized"] is False + assert result["method"] == "passthrough" + + def test_empty_query(self): + result = sanitize_query("") + assert result["clean_query"] == "" + assert result["was_sanitized"] is False + assert result["method"] == "passthrough" + + def test_none_query(self): + result = sanitize_query(None) + assert result["was_sanitized"] is False + assert result["method"] == "passthrough" + + def test_exactly_safe_length(self): + query = "a" * SAFE_QUERY_LENGTH + result = sanitize_query(query) + assert result["was_sanitized"] is False + assert result["method"] == "passthrough" + + def test_one_over_safe_length_triggers_sanitization(self): + query = "a" * (SAFE_QUERY_LENGTH + 1) + result = sanitize_query(query) + # Will go through sanitization pipeline (may or may not change the query) + assert result["original_length"] == SAFE_QUERY_LENGTH + 1 + + +class TestQuestionExtraction: + """Step 2: Extract question sentences (ending with ?).""" + + def test_question_at_end_of_long_text(self): + system_prompt = "You are a helpful assistant. " * 50 # ~1400 chars + query = system_prompt + "What is the best way to handle errors in Rust?" + result = sanitize_query(query) + assert result["was_sanitized"] is True + assert "error" in result["clean_query"].lower() or "Rust" in result["clean_query"] + assert result["method"] == "question_extraction" + + def test_japanese_question_mark(self): + system_prompt = "You are a helpful assistant. " * 50 + query = system_prompt + "Rustのエラーハンドリング方法は?" + result = sanitize_query(query) + assert result["was_sanitized"] is True + assert "Rust" in result["clean_query"] or "エラー" in result["clean_query"] + assert result["method"] == "question_extraction" + + def test_multiple_questions_takes_last(self): + system_prompt = "You are a helpful assistant. " * 50 + query = system_prompt + "What is Python?\nHow does Rust handle errors?" + result = sanitize_query(query) + assert "Rust" in result["clean_query"] or "error" in result["clean_query"].lower() + + def test_question_in_system_prompt_ignored_when_real_question_exists(self): + # System prompt contains a question, but real query also has one + system_prompt = "Are you ready to help? " * 30 + "\n" + real_query = "What databases does MemPalace support?" + query = system_prompt + real_query + result = sanitize_query(query) + assert result["was_sanitized"] is True + assert "MemPalace" in result["clean_query"] or "database" in result["clean_query"].lower() + + +class TestTailSentence: + """Step 3: Extract the last meaningful sentence when no question mark found.""" + + def test_command_style_query(self): + system_prompt = "You are a helpful assistant. " * 50 + query = system_prompt + "Show me all Rust error handling patterns" + result = sanitize_query(query) + assert result["was_sanitized"] is True + assert "Rust" in result["clean_query"] or "error" in result["clean_query"].lower() + assert result["method"] in ("tail_sentence", "question_extraction") + + def test_keyword_style_query(self): + system_prompt = "System configuration loaded. " * 60 + query = system_prompt + "\nMemPalace ChromaDB integration setup" + result = sanitize_query(query) + assert result["was_sanitized"] is True + assert "MemPalace" in result["clean_query"] or "ChromaDB" in result["clean_query"] + + +class TestTailTruncation: + """Step 4: Fallback — take the last MAX_QUERY_LENGTH characters.""" + + def test_single_long_line_no_sentences(self): + # Short lines only — no segment reaches MIN_QUERY_LENGTH; fallback truncates tail + filler = "\n".join(["ab"] * 200) + result = sanitize_query(filler) + assert result["was_sanitized"] is True + assert len(result["clean_query"]) <= MAX_QUERY_LENGTH + assert result["method"] == "tail_truncation" + + def test_truncation_preserves_tail(self): + filler = "x" * 1000 + "IMPORTANT_QUERY_CONTENT" + result = sanitize_query(filler) + assert "IMPORTANT_QUERY_CONTENT" in result["clean_query"] + + +class TestLengthGuards: + """Verify output length constraints.""" + + def test_output_never_exceeds_max(self): + # Very long question sentence + long_question = "a" * 1000 + "?" + system_prompt = "Context. " * 100 + query = system_prompt + long_question + result = sanitize_query(query) + assert len(result["clean_query"]) <= MAX_QUERY_LENGTH + + def test_extraction_too_short_falls_through(self): + # Question mark found but the sentence is too short + system_prompt = "You are helpful. " * 50 + query = system_prompt + "\nOK?" + result = sanitize_query(query) + # "OK?" is only 3 chars < MIN_QUERY_LENGTH, should fall through + assert result["was_sanitized"] is True + + +class TestMetadata: + """Verify sanitizer metadata is correct.""" + + def test_original_length_preserved(self): + system_prompt = "You are a helpful assistant. " * 50 + query = system_prompt + "What is Rust?" + result = sanitize_query(query) + assert result["original_length"] == len(query.strip()) + + def test_clean_length_matches_clean_query(self): + system_prompt = "You are a helpful assistant. " * 50 + query = system_prompt + "What is Rust?" + result = sanitize_query(query) + assert result["clean_length"] == len(result["clean_query"]) + + def test_sanitized_flag_true_when_changed(self): + system_prompt = "You are a helpful assistant. " * 50 + query = system_prompt + "What is Rust?" + result = sanitize_query(query) + assert result["was_sanitized"] is True + + def test_sanitized_flag_false_when_unchanged(self): + result = sanitize_query("Short query") + assert result["was_sanitized"] is False + + +class TestRealWorldScenarios: + """Simulate realistic system prompt contamination patterns.""" + + def test_mempalace_wakeup_prepended(self): + """Simulates mempalace wake-up output prepended to a query.""" + wakeup = ( + "MemPalace loaded. Wings: technical, emotions, identity. " + "Rooms: chromadb-setup, error-handling, project-planning. " + "Total drawers: 234. Knowledge graph: 89 entities, 156 triples. " + "AAAK dialect active. Protocol: verify before responding. " + ) * 5 # ~1000 chars + real_query = "How did we decide on the database architecture?" + query = wakeup + real_query + result = sanitize_query(query) + assert result["was_sanitized"] is True + assert len(result["clean_query"]) <= MAX_QUERY_LENGTH + # Should recover something meaningful + assert len(result["clean_query"]) >= MIN_QUERY_LENGTH + + def test_memory_md_prepended(self): + """Simulates MEMORY.md content prepended to a query.""" + memory_md = ( + "# Project Memory\n" + "## Architecture Decisions\n" + "- Use ChromaDB for vector storage\n" + "- MCP protocol for tool integration\n" + "- AAAK compression for efficient storage\n" + ) * 10 # ~750 chars + real_query = "What were the performance benchmarks for the search system?" + query = memory_md + "\n" + real_query + result = sanitize_query(query) + assert result["was_sanitized"] is True + assert result["method"] in ("question_extraction", "tail_sentence") + + def test_2000_char_system_prompt_with_question(self): + """The exact scenario from Issue #333 — 2000 chars prepended.""" + system_prompt = "You are an AI assistant with access to tools. " * 45 # ~2000 chars + real_query = "What is the status of the MemPalace project?" + query = system_prompt + real_query + result = sanitize_query(query) + assert result["was_sanitized"] is True + assert result["original_length"] > 2000 + assert result["clean_length"] <= MAX_QUERY_LENGTH + assert result["method"] == "question_extraction"