almost there, working out session persistence

This commit is contained in:
Taylor Wilsdon
2025-08-02 15:40:23 -04:00
parent 06ef1223dd
commit 9470a41dde
6 changed files with 210 additions and 49 deletions

View File

@@ -1,8 +1,14 @@
import aiohttp
import logging
import jwt
import os
import time
from typing import Any, Optional
from importlib import metadata
from urllib.parse import urlencode
from datetime import datetime, timedelta
from fastapi import Header
from fastapi.responses import HTMLResponse
@@ -14,13 +20,15 @@ from starlette.responses import RedirectResponse
from starlette.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from auth.google_auth import handle_auth_callback, start_auth_flow, check_client_secrets
from auth.oauth21_session_store import get_oauth21_session_store
from auth.google_auth import handle_auth_callback, start_auth_flow, check_client_secrets, save_credentials_to_file
from google.oauth2.credentials import Credentials
from auth.oauth_callback_server import get_oauth_redirect_uri, ensure_oauth_callback_available
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
# FastMCP OAuth imports
from auth.fastmcp_google_auth import GoogleWorkspaceAuthProvider
from auth.oauth21_google_bridge import set_auth_provider
from auth.oauth21_google_bridge import set_auth_provider, store_token_session
# Import shared configuration
from auth.scopes import (
@@ -78,7 +86,7 @@ _current_transport_mode = "stdio" # Default to stdio
_auth_provider: Optional[GoogleWorkspaceAuthProvider] = None
# Create middleware configuration
from starlette.middleware import Middleware
from auth.mcp_session_middleware import MCPSessionMiddleware
cors_middleware = Middleware(
CORSMiddleware,
@@ -88,16 +96,20 @@ cors_middleware = Middleware(
allow_headers=["*"],
)
session_middleware = Middleware(MCPSessionMiddleware)
# Custom FastMCP that adds CORS to streamable HTTP
class CORSEnabledFastMCP(FastMCP):
def streamable_http_app(self) -> "Starlette":
"""Override to add CORS middleware to the app."""
"""Override to add CORS and session middleware to the app."""
app = super().streamable_http_app()
# Add CORS as the first middleware
app.user_middleware.insert(0, cors_middleware)
# Add session middleware first (to set context before other middleware)
app.user_middleware.insert(0, session_middleware)
# Add CORS as the second middleware
app.user_middleware.insert(1, cors_middleware)
# Rebuild middleware stack
app.middleware_stack = app.build_middleware_stack()
logger.info("Added CORS middleware to streamable HTTP app")
logger.info("Added session and CORS middleware to streamable HTTP app")
return app
# Initialize auth provider for HTTP transport
@@ -177,7 +189,6 @@ def get_auth_provider() -> Optional[GoogleWorkspaceAuthProvider]:
@server.custom_route("/health", methods=["GET"])
async def health_check(request: Request):
"""Health check endpoint for container orchestration."""
from fastapi.responses import JSONResponse
try:
version = metadata.version("workspace-mcp")
except metadata.PackageNotFoundError:
@@ -232,7 +243,6 @@ async def oauth2_callback(request: Request) -> HTMLResponse:
# Store Google credentials in OAuth 2.1 session store
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
store.store_session(
user_email=verified_user_id,
@@ -337,17 +347,46 @@ async def oauth_protected_resource(request: Request):
],
"bearer_methods_supported": ["header"],
"scopes_supported": [
# Base scopes
"openid",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
# Calendar scopes
"https://www.googleapis.com/auth/calendar",
"https://www.googleapis.com/auth/calendar.readonly",
"https://www.googleapis.com/auth/calendar.events",
# Drive scopes
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.file",
# Gmail scopes
"https://www.googleapis.com/auth/gmail.modify",
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
"https://www.googleapis.com/auth/gmail.compose",
"https://www.googleapis.com/auth/gmail.labels",
# Docs scopes
"https://www.googleapis.com/auth/documents",
"https://www.googleapis.com/auth/documents.readonly",
# Sheets scopes
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/spreadsheets.readonly",
# Slides scopes
"https://www.googleapis.com/auth/presentations",
"https://www.googleapis.com/auth/presentations.readonly",
# Chat scopes
"https://www.googleapis.com/auth/chat.spaces",
"https://www.googleapis.com/auth/forms",
"https://www.googleapis.com/auth/tasks"
"https://www.googleapis.com/auth/chat.messages",
"https://www.googleapis.com/auth/chat.messages.readonly",
# Forms scopes
"https://www.googleapis.com/auth/forms.body",
"https://www.googleapis.com/auth/forms.body.readonly",
"https://www.googleapis.com/auth/forms.responses.readonly",
# Tasks scopes
"https://www.googleapis.com/auth/tasks",
"https://www.googleapis.com/auth/tasks.readonly",
# Search scope
"https://www.googleapis.com/auth/cse"
],
"resource_documentation": "https://developers.google.com/workspace",
"client_registration_required": True,
@@ -464,7 +503,37 @@ async def oauth_client_config(request: Request):
],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"scope": "openid email profile https://www.googleapis.com/auth/calendar https://www.googleapis.com/auth/drive https://www.googleapis.com/auth/gmail.modify",
"scope": " ".join([
"openid",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/calendar",
"https://www.googleapis.com/auth/calendar.readonly",
"https://www.googleapis.com/auth/calendar.events",
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.file",
"https://www.googleapis.com/auth/gmail.modify",
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
"https://www.googleapis.com/auth/gmail.compose",
"https://www.googleapis.com/auth/gmail.labels",
"https://www.googleapis.com/auth/documents",
"https://www.googleapis.com/auth/documents.readonly",
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/spreadsheets.readonly",
"https://www.googleapis.com/auth/presentations",
"https://www.googleapis.com/auth/presentations.readonly",
"https://www.googleapis.com/auth/chat.spaces",
"https://www.googleapis.com/auth/chat.messages",
"https://www.googleapis.com/auth/chat.messages.readonly",
"https://www.googleapis.com/auth/forms.body",
"https://www.googleapis.com/auth/forms.body.readonly",
"https://www.googleapis.com/auth/forms.responses.readonly",
"https://www.googleapis.com/auth/tasks",
"https://www.googleapis.com/auth/tasks.readonly",
"https://www.googleapis.com/auth/cse"
]),
"token_endpoint_auth_method": "client_secret_basic",
"code_challenge_methods": ["S256"]
},
@@ -488,8 +557,6 @@ async def oauth_authorize(request: Request):
}
)
from urllib.parse import urlencode
# Get query parameters
params = dict(request.query_params)
@@ -526,9 +593,6 @@ async def proxy_token_exchange(request: Request):
"Access-Control-Allow-Headers": "Content-Type, Authorization"
}
)
import aiohttp
try:
# Get form data
body = await request.body()
@@ -545,7 +609,44 @@ async def proxy_token_exchange(request: Request):
if response.status != 200:
logger.error(f"Token exchange failed: {response.status} - {response_data}")
else:
logger.info("Token exchange successful")
logger.info(f"Token exchange successful")
# Store the token session for credential bridging
if "access_token" in response_data:
try:
# Extract user email from ID token if present
if "id_token" in response_data:
# Decode without verification (we trust Google's response)
id_token_claims = jwt.decode(response_data["id_token"], options={"verify_signature": False})
user_email = id_token_claims.get("email")
if user_email:
# Store the token session
session_id = store_token_session(response_data, user_email)
logger.info(f"Stored OAuth session for {user_email} (session: {session_id})")
# Also create and store Google credentials
expiry = None
if "expires_in" in response_data:
# Google auth library expects timezone-naive datetime
expiry = datetime.utcnow() + timedelta(seconds=response_data["expires_in"])
credentials = Credentials(
token=response_data["access_token"],
refresh_token=response_data.get("refresh_token"),
token_uri="https://oauth2.googleapis.com/token",
client_id=os.getenv("GOOGLE_OAUTH_CLIENT_ID"),
client_secret=os.getenv("GOOGLE_OAUTH_CLIENT_SECRET"),
scopes=response_data.get("scope", "").split() if response_data.get("scope") else None,
expiry=expiry
)
# Save credentials to file for legacy auth
save_credentials_to_file(user_email, credentials)
logger.info(f"Saved Google credentials for {user_email}")
except Exception as e:
logger.error(f"Failed to store OAuth session: {e}")
return JSONResponse(
status_code=response.status,
@@ -571,7 +672,7 @@ async def proxy_token_exchange(request: Request):
async def oauth_register(request: Request):
"""
Dynamic client registration workaround endpoint.
Google doesn't support OAuth 2.1 dynamic client registration, so this endpoint
accepts any registration request and returns our pre-configured Google OAuth
credentials, allowing standards-compliant clients to work seamlessly.
@@ -585,22 +686,22 @@ async def oauth_register(request: Request):
"Access-Control-Allow-Headers": "Content-Type, Authorization"
}
)
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
if not client_id or not client_secret:
return JSONResponse(
status_code=400,
content={"error": "invalid_request", "error_description": "OAuth not configured"},
headers={"Access-Control-Allow-Origin": "*"}
)
try:
# Parse the registration request
body = await request.json()
logger.info(f"Dynamic client registration request received: {body}")
# Extract redirect URIs from the request or use defaults
redirect_uris = body.get("redirect_uris", [])
if not redirect_uris:
@@ -608,9 +709,8 @@ async def oauth_register(request: Request):
f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/oauth2callback",
"http://localhost:5173/auth/callback"
]
# Build the registration response with our pre-configured credentials
import time
response_data = {
"client_id": client_id,
"client_secret": client_secret,
@@ -619,7 +719,37 @@ async def oauth_register(request: Request):
"redirect_uris": redirect_uris,
"grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]),
"response_types": body.get("response_types", ["code"]),
"scope": body.get("scope", "openid email profile https://www.googleapis.com/auth/calendar https://www.googleapis.com/auth/drive https://www.googleapis.com/auth/gmail.modify"),
"scope": body.get("scope", " ".join([
"openid",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/calendar",
"https://www.googleapis.com/auth/calendar.readonly",
"https://www.googleapis.com/auth/calendar.events",
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.file",
"https://www.googleapis.com/auth/gmail.modify",
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.send",
"https://www.googleapis.com/auth/gmail.compose",
"https://www.googleapis.com/auth/gmail.labels",
"https://www.googleapis.com/auth/documents",
"https://www.googleapis.com/auth/documents.readonly",
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/spreadsheets.readonly",
"https://www.googleapis.com/auth/presentations",
"https://www.googleapis.com/auth/presentations.readonly",
"https://www.googleapis.com/auth/chat.spaces",
"https://www.googleapis.com/auth/chat.messages",
"https://www.googleapis.com/auth/chat.messages.readonly",
"https://www.googleapis.com/auth/forms.body",
"https://www.googleapis.com/auth/forms.body.readonly",
"https://www.googleapis.com/auth/forms.responses.readonly",
"https://www.googleapis.com/auth/tasks",
"https://www.googleapis.com/auth/tasks.readonly",
"https://www.googleapis.com/auth/cse"
])),
"token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_basic"),
"code_challenge_methods": ["S256"],
# Additional OAuth 2.1 fields
@@ -627,9 +757,9 @@ async def oauth_register(request: Request):
"registration_access_token": "not-required", # We don't implement client management
"registration_client_uri": f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/oauth2/register/{client_id}"
}
logger.info("Dynamic client registration successful - returning pre-configured Google credentials")
return JSONResponse(
status_code=201,
content=response_data,
@@ -639,7 +769,7 @@ async def oauth_register(request: Request):
"Cache-Control": "no-store"
}
)
except Exception as e:
logger.error(f"Error in dynamic client registration: {e}")
return JSONResponse(