session binding and legacy compatibility
This commit is contained in:
@@ -154,8 +154,99 @@ class OAuth21SessionStore:
|
||||
self._sessions: Dict[str, Dict[str, Any]] = {}
|
||||
self._mcp_session_mapping: Dict[str, str] = {} # Maps FastMCP session ID -> user email
|
||||
self._session_auth_binding: Dict[str, str] = {} # Maps session ID -> authenticated user email (immutable)
|
||||
self._oauth_states: Dict[str, Dict[str, Any]] = {}
|
||||
self._lock = RLock()
|
||||
|
||||
def _cleanup_expired_oauth_states_locked(self):
|
||||
"""Remove expired OAuth state entries. Caller must hold lock."""
|
||||
now = datetime.utcnow()
|
||||
expired_states = [
|
||||
state
|
||||
for state, data in self._oauth_states.items()
|
||||
if data.get("expires_at") and data["expires_at"] <= now
|
||||
]
|
||||
for state in expired_states:
|
||||
del self._oauth_states[state]
|
||||
logger.debug(
|
||||
"Removed expired OAuth state: %s",
|
||||
state[:8] if len(state) > 8 else state,
|
||||
)
|
||||
|
||||
def store_oauth_state(
|
||||
self,
|
||||
state: str,
|
||||
session_id: Optional[str] = None,
|
||||
expires_in_seconds: int = 600,
|
||||
) -> None:
|
||||
"""Persist an OAuth state value for later validation."""
|
||||
if not state:
|
||||
raise ValueError("OAuth state must be provided")
|
||||
if expires_in_seconds < 0:
|
||||
raise ValueError("expires_in_seconds must be non-negative")
|
||||
|
||||
with self._lock:
|
||||
self._cleanup_expired_oauth_states_locked()
|
||||
now = datetime.utcnow()
|
||||
expiry = now + timedelta(seconds=expires_in_seconds)
|
||||
self._oauth_states[state] = {
|
||||
"session_id": session_id,
|
||||
"expires_at": expiry,
|
||||
"created_at": now,
|
||||
}
|
||||
logger.debug(
|
||||
"Stored OAuth state %s (expires at %s)",
|
||||
state[:8] if len(state) > 8 else state,
|
||||
expiry.isoformat(),
|
||||
)
|
||||
|
||||
def validate_and_consume_oauth_state(
|
||||
self,
|
||||
state: str,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate that a state value exists and consume it.
|
||||
|
||||
Args:
|
||||
state: The OAuth state returned by Google.
|
||||
session_id: Optional session identifier that initiated the flow.
|
||||
|
||||
Returns:
|
||||
Metadata associated with the state.
|
||||
|
||||
Raises:
|
||||
ValueError: If the state is missing, expired, or does not match the session.
|
||||
"""
|
||||
if not state:
|
||||
raise ValueError("Missing OAuth state parameter")
|
||||
|
||||
with self._lock:
|
||||
self._cleanup_expired_oauth_states_locked()
|
||||
state_info = self._oauth_states.get(state)
|
||||
|
||||
if not state_info:
|
||||
logger.error("SECURITY: OAuth callback received unknown or expired state")
|
||||
raise ValueError("Invalid or expired OAuth state parameter")
|
||||
|
||||
bound_session = state_info.get("session_id")
|
||||
if bound_session and session_id and bound_session != session_id:
|
||||
# Consume the state to prevent replay attempts
|
||||
del self._oauth_states[state]
|
||||
logger.error(
|
||||
"SECURITY: OAuth state session mismatch (expected %s, got %s)",
|
||||
bound_session,
|
||||
session_id,
|
||||
)
|
||||
raise ValueError("OAuth state does not match the initiating session")
|
||||
|
||||
# State is valid – consume it to prevent reuse
|
||||
del self._oauth_states[state]
|
||||
logger.debug(
|
||||
"Validated OAuth state %s",
|
||||
state[:8] if len(state) > 8 else state,
|
||||
)
|
||||
return state_info
|
||||
|
||||
def store_session(
|
||||
self,
|
||||
user_email: str,
|
||||
@@ -427,6 +518,13 @@ class OAuth21SessionStore:
|
||||
with self._lock:
|
||||
return mcp_session_id in self._mcp_session_mapping
|
||||
|
||||
def get_single_user_email(self) -> Optional[str]:
|
||||
"""Return the sole authenticated user email when exactly one session exists."""
|
||||
with self._lock:
|
||||
if len(self._sessions) == 1:
|
||||
return next(iter(self._sessions))
|
||||
return None
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get store statistics."""
|
||||
with self._lock:
|
||||
@@ -568,4 +666,4 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store token session: {e}")
|
||||
return ""
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user