Enhanced Session Management better guardrails

This commit is contained in:
Taylor Wilsdon
2025-08-03 15:51:04 -04:00
parent ff9b7ecd07
commit 71e2f1ba3e
4 changed files with 493 additions and 71 deletions

View File

@@ -241,6 +241,16 @@ async def oauth2_callback(request: Request) -> HTMLResponse:
# Store Google credentials in OAuth 2.1 session store
try:
store = get_oauth21_session_store()
# Try to get MCP session ID from request for binding
mcp_session_id = None
try:
if hasattr(request, 'state') and hasattr(request.state, 'session_id'):
mcp_session_id = request.state.session_id
logger.info(f"OAuth callback: Found MCP session ID for binding: {mcp_session_id}")
except Exception as e:
logger.debug(f"OAuth callback: Could not get MCP session ID: {e}")
store.store_session(
user_email=verified_user_id,
access_token=credentials.token,
@@ -251,8 +261,9 @@ async def oauth2_callback(request: Request) -> HTMLResponse:
scopes=credentials.scopes,
expiry=credentials.expiry,
session_id=f"google-{state}", # Use state as a pseudo session ID
mcp_session_id=mcp_session_id, # Bind to MCP session if available
)
logger.info(f"Stored Google credentials in OAuth 2.1 session store for {verified_user_id}")
logger.info(f"Stored Google credentials in OAuth 2.1 session store for {verified_user_id} (mcp: {mcp_session_id})")
except Exception as e:
logger.error(f"Failed to store Google credentials in OAuth 2.1 store: {e}")
@@ -534,6 +545,29 @@ async def proxy_token_exchange(request: Request):
# Get form data
body = await request.body()
content_type = request.headers.get("content-type", "application/x-www-form-urlencoded")
# Parse form data to add missing client credentials
from urllib.parse import parse_qs, urlencode
if content_type and "application/x-www-form-urlencoded" in content_type:
form_data = parse_qs(body.decode('utf-8'))
# Check if client_id is missing (public client)
if 'client_id' not in form_data or not form_data['client_id'][0]:
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
if client_id:
form_data['client_id'] = [client_id]
logger.debug(f"Added missing client_id to token request")
# Check if client_secret is missing (public client using PKCE)
if 'client_secret' not in form_data:
client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
if client_secret:
form_data['client_secret'] = [client_secret]
logger.debug(f"Added missing client_secret to token request")
# Reconstruct body with added credentials
body = urlencode(form_data, doseq=True).encode('utf-8')
# Forward request to Google
async with aiohttp.ClientSession() as session:
@@ -570,37 +604,47 @@ async def proxy_token_exchange(request: Request):
issuer="https://accounts.google.com"
)
user_email = id_token_claims.get("email")
if user_email:
# Try to get FastMCP session ID from request context for binding
mcp_session_id = None
try:
# Check if this is a streamable HTTP request with session
if hasattr(request, 'state') and hasattr(request.state, 'session_id'):
mcp_session_id = request.state.session_id
logger.info(f"Found MCP session ID for binding: {mcp_session_id}")
except Exception as e:
logger.debug(f"Could not get MCP session ID: {e}")
# Store the token session with MCP session binding
session_id = store_token_session(response_data, user_email, mcp_session_id)
logger.info(f"Stored OAuth session for {user_email} (session: {session_id}, mcp: {mcp_session_id})")
# Also create and store Google credentials
expiry = None
if "expires_in" in response_data:
# Google auth library expects timezone-naive datetime
expiry = datetime.utcnow() + timedelta(seconds=response_data["expires_in"])
credentials = Credentials(
token=response_data["access_token"],
refresh_token=response_data.get("refresh_token"),
token_uri="https://oauth2.googleapis.com/token",
client_id=os.getenv("GOOGLE_OAUTH_CLIENT_ID"),
client_secret=os.getenv("GOOGLE_OAUTH_CLIENT_SECRET"),
scopes=response_data.get("scope", "").split() if response_data.get("scope") else None,
expiry=expiry
)
# Save credentials to file for legacy auth
save_credentials_to_file(user_email, credentials)
logger.info(f"Saved Google credentials for {user_email}")
except jwt.ExpiredSignatureError:
logger.error("ID token has expired - cannot extract user email")
except jwt.InvalidTokenError as e:
logger.error(f"Invalid ID token - cannot extract user email: {e}")
except Exception as e:
logger.error(f"Failed to verify ID token: {e}")
# Fallback to unverified decode for backwards compatibility (with warning)
logger.warning("Using unverified ID token decode as fallback - this should be fixed")
id_token_claims = jwt.decode(response_data["id_token"], options={"verify_signature": False})
user_email = id_token_claims.get("email")
if user_email:
# Store the token session
session_id = store_token_session(response_data, user_email)
logger.info(f"Stored OAuth session for {user_email} (session: {session_id})")
# Also create and store Google credentials
expiry = None
if "expires_in" in response_data:
# Google auth library expects timezone-naive datetime
expiry = datetime.utcnow() + timedelta(seconds=response_data["expires_in"])
credentials = Credentials(
token=response_data["access_token"],
refresh_token=response_data.get("refresh_token"),
token_uri="https://oauth2.googleapis.com/token",
client_id=os.getenv("GOOGLE_OAUTH_CLIENT_ID"),
client_secret=os.getenv("GOOGLE_OAUTH_CLIENT_SECRET"),
scopes=response_data.get("scope", "").split() if response_data.get("scope") else None,
expiry=expiry
)
# Save credentials to file for legacy auth
save_credentials_to_file(user_email, credentials)
logger.info(f"Saved Google credentials for {user_email}")
logger.error(f"Failed to verify ID token - cannot extract user email: {e}")
except Exception as e:
logger.error(f"Failed to store OAuth session: {e}")