almost there, working out session persistence

This commit is contained in:
Taylor Wilsdon
2025-08-02 15:40:23 -04:00
parent 06ef1223dd
commit 9470a41dde
6 changed files with 210 additions and 49 deletions

View File

@@ -164,6 +164,9 @@ def load_credentials_from_file(
if creds_data.get("expiry"):
try:
expiry = datetime.fromisoformat(creds_data["expiry"])
# Ensure timezone-naive datetime for Google auth library compatibility
if expiry.tzinfo is not None:
expiry = expiry.replace(tzinfo=None)
except (ValueError, TypeError) as e:
logger.warning(
f"Could not parse expiry time for {user_google_email}: {e}"

View File

@@ -44,21 +44,46 @@ class MCPSessionMiddleware(BaseHTTPMiddleware):
headers = dict(request.headers)
session_id = extract_session_from_headers(headers)
# Try to get OAuth 2.1 auth context
# Try to get OAuth 2.1 auth context from FastMCP
auth_context = None
user_email = None
# Check for FastMCP auth context
if hasattr(request.state, "auth"):
auth_context = request.state.auth
# Extract user email from auth claims if available
if hasattr(auth_context, 'claims') and auth_context.claims:
user_email = auth_context.claims.get('email')
# Also check Authorization header for bearer tokens
auth_header = headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer ") and not user_email:
try:
import jwt
token = auth_header[7:] # Remove "Bearer " prefix
# Decode without verification to extract email
claims = jwt.decode(token, options={"verify_signature": False})
user_email = claims.get('email')
if user_email:
logger.debug(f"Extracted user email from JWT: {user_email}")
except:
pass
# Build session context
if session_id or auth_context:
if session_id or auth_context or user_email:
# Create session ID from user email if not provided
if not session_id and user_email:
session_id = f"google_{user_email}"
session_context = SessionContext(
session_id=session_id or (auth_context.session_id if auth_context else None),
user_id=auth_context.user_id if auth_context else None,
user_id=user_email or (auth_context.user_id if auth_context else None),
auth_context=auth_context,
request=request,
metadata={
"path": request.url.path,
"method": request.method,
"user_email": user_email,
}
)

View File

@@ -58,7 +58,8 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
# Calculate expiry
expiry = None
if "expires_in" in token_data:
expiry = datetime.now(timezone.utc) + timedelta(seconds=token_data["expires_in"])
# Google auth library expects timezone-naive datetime
expiry = datetime.utcnow() + timedelta(seconds=token_data["expires_in"])
credentials = Credentials(
token=token_data["access_token"],
@@ -76,7 +77,8 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
# Otherwise, create minimal credentials with just the access token
else:
# Assume token is valid for 1 hour (typical for Google tokens)
expiry = datetime.now(timezone.utc) + timedelta(hours=1)
# Google auth library expects timezone-naive datetime
expiry = datetime.utcnow() + timedelta(hours=1)
credentials = Credentials(
token=access_token,
@@ -129,7 +131,7 @@ def store_token_session(token_response: dict, user_email: str) -> str:
client_id=_auth_provider.client_id,
client_secret=_auth_provider.client_secret,
scopes=token_response.get("scope", "").split() if token_response.get("scope") else None,
expiry=datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600)),
expiry=datetime.utcnow() + timedelta(seconds=token_response.get("expires_in", 3600)),
session_id=session_id,
)

View File

@@ -54,6 +54,7 @@ def get_session_context() -> Optional[SessionContext]:
Returns:
The current session context or None
"""
print('called get_session_context')
return _current_session_context.get()

View File

@@ -1,8 +1,14 @@
import aiohttp
import logging
import jwt
import os
import time
from typing import Any, Optional
from importlib import metadata
from urllib.parse import urlencode
from datetime import datetime, timedelta
from fastapi import Header
from fastapi.responses import HTMLResponse
@@ -14,13 +20,15 @@ from starlette.responses import RedirectResponse
from starlette.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from auth.google_auth import handle_auth_callback, start_auth_flow, check_client_secrets
from auth.oauth21_session_store import get_oauth21_session_store
from auth.google_auth import handle_auth_callback, start_auth_flow, check_client_secrets, save_credentials_to_file
from google.oauth2.credentials import Credentials
from auth.oauth_callback_server import get_oauth_redirect_uri, ensure_oauth_callback_available
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
# FastMCP OAuth imports
from auth.fastmcp_google_auth import GoogleWorkspaceAuthProvider
from auth.oauth21_google_bridge import set_auth_provider
from auth.oauth21_google_bridge import set_auth_provider, store_token_session
# Import shared configuration
from auth.scopes import (
@@ -78,7 +86,7 @@ _current_transport_mode = "stdio" # Default to stdio
_auth_provider: Optional[GoogleWorkspaceAuthProvider] = None
# Create middleware configuration
from starlette.middleware import Middleware
from auth.mcp_session_middleware import MCPSessionMiddleware
cors_middleware = Middleware(
CORSMiddleware,
@@ -88,16 +96,20 @@ cors_middleware = Middleware(
allow_headers=["*"],
)
session_middleware = Middleware(MCPSessionMiddleware)
# Custom FastMCP that adds CORS to streamable HTTP
class CORSEnabledFastMCP(FastMCP):
def streamable_http_app(self) -> "Starlette":
"""Override to add CORS middleware to the app."""
"""Override to add CORS and session middleware to the app."""
app = super().streamable_http_app()
# Add CORS as the first middleware
app.user_middleware.insert(0, cors_middleware)
# Add session middleware first (to set context before other middleware)
app.user_middleware.insert(0, session_middleware)
# Add CORS as the second middleware
app.user_middleware.insert(1, cors_middleware)
# Rebuild middleware stack
app.middleware_stack = app.build_middleware_stack()
logger.info("Added CORS middleware to streamable HTTP app")
logger.info("Added session and CORS middleware to streamable HTTP app")
return app
# Initialize auth provider for HTTP transport
@@ -177,7 +189,6 @@ def get_auth_provider() -> Optional[GoogleWorkspaceAuthProvider]:
@server.custom_route("/health", methods=["GET"])
async def health_check(request: Request):
"""Health check endpoint for container orchestration."""
from fastapi.responses import JSONResponse
try:
version = metadata.version("workspace-mcp")
except metadata.PackageNotFoundError:
@@ -232,7 +243,6 @@ async def oauth2_callback(request: Request) -> HTMLResponse:
# Store Google credentials in OAuth 2.1 session store
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
store.store_session(
user_email=verified_user_id,
@@ -337,17 +347,46 @@ async def oauth_protected_resource(request: Request):
],
"bearer_methods_supported": ["header"],
"scopes_supported": [
# Base scopes
"openid",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
# Calendar scopes
"https://www.googleapis.com/auth/calendar",
"https://www.googleapis.com/auth/calendar.readonly",
"https://www.googleapis.com/auth/calendar.events",
# Drive scopes
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.file",
# Gmail scopes
"https://www.googleapis.com/auth/gmail.modify",
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
"https://www.googleapis.com/auth/gmail.compose",
"https://www.googleapis.com/auth/gmail.labels",
# Docs scopes
"https://www.googleapis.com/auth/documents",
"https://www.googleapis.com/auth/documents.readonly",
# Sheets scopes
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/spreadsheets.readonly",
# Slides scopes
"https://www.googleapis.com/auth/presentations",
"https://www.googleapis.com/auth/presentations.readonly",
# Chat scopes
"https://www.googleapis.com/auth/chat.spaces",
"https://www.googleapis.com/auth/forms",
"https://www.googleapis.com/auth/tasks"
"https://www.googleapis.com/auth/chat.messages",
"https://www.googleapis.com/auth/chat.messages.readonly",
# Forms scopes
"https://www.googleapis.com/auth/forms.body",
"https://www.googleapis.com/auth/forms.body.readonly",
"https://www.googleapis.com/auth/forms.responses.readonly",
# Tasks scopes
"https://www.googleapis.com/auth/tasks",
"https://www.googleapis.com/auth/tasks.readonly",
# Search scope
"https://www.googleapis.com/auth/cse"
],
"resource_documentation": "https://developers.google.com/workspace",
"client_registration_required": True,
@@ -464,7 +503,37 @@ async def oauth_client_config(request: Request):
],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"scope": "openid email profile https://www.googleapis.com/auth/calendar https://www.googleapis.com/auth/drive https://www.googleapis.com/auth/gmail.modify",
"scope": " ".join([
"openid",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/calendar",
"https://www.googleapis.com/auth/calendar.readonly",
"https://www.googleapis.com/auth/calendar.events",
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.file",
"https://www.googleapis.com/auth/gmail.modify",
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
"https://www.googleapis.com/auth/gmail.compose",
"https://www.googleapis.com/auth/gmail.labels",
"https://www.googleapis.com/auth/documents",
"https://www.googleapis.com/auth/documents.readonly",
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/spreadsheets.readonly",
"https://www.googleapis.com/auth/presentations",
"https://www.googleapis.com/auth/presentations.readonly",
"https://www.googleapis.com/auth/chat.spaces",
"https://www.googleapis.com/auth/chat.messages",
"https://www.googleapis.com/auth/chat.messages.readonly",
"https://www.googleapis.com/auth/forms.body",
"https://www.googleapis.com/auth/forms.body.readonly",
"https://www.googleapis.com/auth/forms.responses.readonly",
"https://www.googleapis.com/auth/tasks",
"https://www.googleapis.com/auth/tasks.readonly",
"https://www.googleapis.com/auth/cse"
]),
"token_endpoint_auth_method": "client_secret_basic",
"code_challenge_methods": ["S256"]
},
@@ -488,8 +557,6 @@ async def oauth_authorize(request: Request):
}
)
from urllib.parse import urlencode
# Get query parameters
params = dict(request.query_params)
@@ -526,9 +593,6 @@ async def proxy_token_exchange(request: Request):
"Access-Control-Allow-Headers": "Content-Type, Authorization"
}
)
import aiohttp
try:
# Get form data
body = await request.body()
@@ -545,7 +609,44 @@ async def proxy_token_exchange(request: Request):
if response.status != 200:
logger.error(f"Token exchange failed: {response.status} - {response_data}")
else:
logger.info("Token exchange successful")
logger.info(f"Token exchange successful")
# Store the token session for credential bridging
if "access_token" in response_data:
try:
# Extract user email from ID token if present
if "id_token" in response_data:
# Decode without verification (we trust Google's response)
id_token_claims = jwt.decode(response_data["id_token"], options={"verify_signature": False})
user_email = id_token_claims.get("email")
if user_email:
# Store the token session
session_id = store_token_session(response_data, user_email)
logger.info(f"Stored OAuth session for {user_email} (session: {session_id})")
# Also create and store Google credentials
expiry = None
if "expires_in" in response_data:
# Google auth library expects timezone-naive datetime
expiry = datetime.utcnow() + timedelta(seconds=response_data["expires_in"])
credentials = Credentials(
token=response_data["access_token"],
refresh_token=response_data.get("refresh_token"),
token_uri="https://oauth2.googleapis.com/token",
client_id=os.getenv("GOOGLE_OAUTH_CLIENT_ID"),
client_secret=os.getenv("GOOGLE_OAUTH_CLIENT_SECRET"),
scopes=response_data.get("scope", "").split() if response_data.get("scope") else None,
expiry=expiry
)
# Save credentials to file for legacy auth
save_credentials_to_file(user_email, credentials)
logger.info(f"Saved Google credentials for {user_email}")
except Exception as e:
logger.error(f"Failed to store OAuth session: {e}")
return JSONResponse(
status_code=response.status,
@@ -610,7 +711,6 @@ async def oauth_register(request: Request):
]
# Build the registration response with our pre-configured credentials
import time
response_data = {
"client_id": client_id,
"client_secret": client_secret,
@@ -619,7 +719,37 @@ async def oauth_register(request: Request):
"redirect_uris": redirect_uris,
"grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]),
"response_types": body.get("response_types", ["code"]),
"scope": body.get("scope", "openid email profile https://www.googleapis.com/auth/calendar https://www.googleapis.com/auth/drive https://www.googleapis.com/auth/gmail.modify"),
"scope": body.get("scope", " ".join([
"openid",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/calendar",
"https://www.googleapis.com/auth/calendar.readonly",
"https://www.googleapis.com/auth/calendar.events",
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.file",
"https://www.googleapis.com/auth/gmail.modify",
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
"https://www.googleapis.com/auth/gmail.compose",
"https://www.googleapis.com/auth/gmail.labels",
"https://www.googleapis.com/auth/documents",
"https://www.googleapis.com/auth/documents.readonly",
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/spreadsheets.readonly",
"https://www.googleapis.com/auth/presentations",
"https://www.googleapis.com/auth/presentations.readonly",
"https://www.googleapis.com/auth/chat.spaces",
"https://www.googleapis.com/auth/chat.messages",
"https://www.googleapis.com/auth/chat.messages.readonly",
"https://www.googleapis.com/auth/forms.body",
"https://www.googleapis.com/auth/forms.body.readonly",
"https://www.googleapis.com/auth/forms.responses.readonly",
"https://www.googleapis.com/auth/tasks",
"https://www.googleapis.com/auth/tasks.readonly",
"https://www.googleapis.com/auth/cse"
])),
"token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_basic"),
"code_challenge_methods": ["S256"],
# Additional OAuth 2.1 fields

View File

@@ -13,7 +13,7 @@ from core.server import server, set_transport_mode, initialize_auth, shutdown_au
from core.utils import check_credentials_directory_permissions
logging.basicConfig(
level=logging.INFO,
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)