Merge pull request #153 from taylorwilsdon/vscode_oauth_support

feat: Streamlined VSCode OAuth Support
This commit is contained in:
Taylor Wilsdon
2025-08-09 11:58:10 -04:00
committed by GitHub
14 changed files with 2292 additions and 1726 deletions

View File

@@ -51,6 +51,8 @@
A production-ready MCP server that integrates all major Google Workspace services with AI assistants. It supports both single-user operation and multi-user authentication via OAuth 2.1, making it a powerful backend for custom applications. Built with FastMCP for optimal performance, featuring advanced authentication handling, service caching, and streamlined development patterns. A production-ready MCP server that integrates all major Google Workspace services with AI assistants. It supports both single-user operation and multi-user authentication via OAuth 2.1, making it a powerful backend for custom applications. Built with FastMCP for optimal performance, featuring advanced authentication handling, service caching, and streamlined development patterns.
**🎉 Simplified Setup**: Now uses Google Desktop OAuth clients - no redirect URIs or port configuration needed!
## Features ## Features
- **🔐 Advanced OAuth 2.0 & OAuth 2.1**: Secure authentication with automatic token refresh, transport-aware callback handling, session management, centralized scope management, and OAuth 2.1 bearer token support for multi-user environments with innovative CORS proxy architecture - **🔐 Advanced OAuth 2.0 & OAuth 2.1**: Secure authentication with automatic token refresh, transport-aware callback handling, session management, centralized scope management, and OAuth 2.1 bearer token support for multi-user environments with innovative CORS proxy architecture
@@ -95,6 +97,8 @@ A production-ready MCP server that integrates all major Google Workspace service
| `GOOGLE_PSE_API_KEY` *(optional)* | API key for Google Custom Search - see [Custom Search Setup](#google-custom-search-setup) | | `GOOGLE_PSE_API_KEY` *(optional)* | API key for Google Custom Search - see [Custom Search Setup](#google-custom-search-setup) |
| `GOOGLE_PSE_ENGINE_ID` *(optional)* | Programmable Search Engine ID for Custom Search | | `GOOGLE_PSE_ENGINE_ID` *(optional)* | Programmable Search Engine ID for Custom Search |
| `MCP_ENABLE_OAUTH21` *(optional)* | Set to `true` to enable OAuth 2.1 support (requires streamable-http transport) | | `MCP_ENABLE_OAUTH21` *(optional)* | Set to `true` to enable OAuth 2.1 support (requires streamable-http transport) |
| `OAUTH_CUSTOM_REDIRECT_URIS` *(optional)* | Comma-separated list of additional redirect URIs |
| `OAUTH_ALLOWED_ORIGINS` *(optional)* | Comma-separated list of additional CORS origins |
| `OAUTHLIB_INSECURE_TRANSPORT=1` | Development only (allows `http://` redirect) | | `OAUTHLIB_INSECURE_TRANSPORT=1` | Development only (allows `http://` redirect) |
Claude Desktop stores these securely in the OS keychain; set them once in the extension pane. Claude Desktop stores these securely in the OS keychain; set them once in the extension pane.
@@ -114,12 +118,12 @@ Claude Desktop stores these securely in the OS keychain; set them once in the ex
### Configuration ### Configuration
1. **Google Cloud Setup**: 1. **Google Cloud Setup**:
- Create OAuth 2.0 credentials (web application) in [Google Cloud Console](https://console.cloud.google.com/) - Create OAuth 2.0 credentials in [Google Cloud Console](https://console.cloud.google.com/)
- Create a new project (or use an existing one) for your MCP server. - Create a new project (or use an existing one) for your MCP server.
- Navigate to APIs & Services → Credentials. - Navigate to APIs & Services → Credentials.
- Click Create Credentials → OAuth Client ID. - Click Create Credentials → OAuth Client ID.
- Choose Web Application as the application type. - **Choose Desktop Application as the application type** (simpler setup, no redirect URIs needed!)
- Add redirect URI: `http://localhost:8000/oauth2callback` - Download your credentials and note the Client ID and Client Secret
- **Enable APIs**: - **Enable APIs**:
- In the Google Cloud Console, go to APIs & Services → Library. - In the Google Cloud Console, go to APIs & Services → Library.
@@ -278,6 +282,29 @@ This architecture enables any OAuth 2.1 compliant client to authenticate users t
</details> </details>
**MCP Inspector**: No additional configuration needed with desktop OAuth client.
**Claude Code Inspector**: No additional configuration needed with desktop OAuth client.
### VS Code MCP Client Support
The server includes native support for VS Code's MCP client:
- **No Configuration Required**: Works out-of-the-box with VS Code's MCP extension
- **Standards Compliant**: Full OAuth 2.1 compliance with desktop OAuth clients
**VS Code mcp.json Configuration Example**:
```json
{
"servers": {
"google-workspace": {
"url": "http://localhost:8000/mcp/",
"type": "http"
}
}
}
```
### Connect to Claude Desktop ### Connect to Claude Desktop
The server supports two transport modes: The server supports two transport modes:
@@ -422,17 +449,18 @@ If you need to use HTTP mode with Claude Desktop:
### First-Time Authentication ### First-Time Authentication
The server features **transport-aware OAuth callback handling**: The server uses **Google Desktop OAuth** for simplified authentication:
- **Stdio Mode**: Automatically starts a minimal HTTP server on port 8000 for OAuth callbacks - **No redirect URIs needed**: Desktop OAuth clients handle authentication without complex callback URLs
- **HTTP Mode**: Uses the existing FastAPI server for OAuth callbacks - **Automatic flow**: The server manages the entire OAuth process transparently
- **Same OAuth Flow**: Both modes use `http://localhost:8000/oauth2callback` for consistency - **Transport-agnostic**: Works seamlessly in both stdio and HTTP modes
When calling a tool: When calling a tool:
1. Server returns authorization URL 1. Server returns authorization URL
2. Open URL in browser and authorize 2. Open URL in browser and authorize
3. Server handles OAuth callback automatically (on port 8000 in both modes) 3. Google provides an authorization code
4. Retry the original request 4. Paste the code when prompted (or it's handled automatically)
5. Server completes authentication and retries your request
--- ---

View File

@@ -24,6 +24,7 @@ from pydantic import AnyHttpUrl
try: try:
from fastmcp.server.auth import RemoteAuthProvider from fastmcp.server.auth import RemoteAuthProvider
from fastmcp.server.auth.providers.jwt import JWTVerifier from fastmcp.server.auth.providers.jwt import JWTVerifier
REMOTEAUTHPROVIDER_AVAILABLE = True REMOTEAUTHPROVIDER_AVAILABLE = True
except ImportError: except ImportError:
REMOTEAUTHPROVIDER_AVAILABLE = False REMOTEAUTHPROVIDER_AVAILABLE = False
@@ -35,9 +36,10 @@ except ImportError:
from auth.oauth_common_handlers import ( from auth.oauth_common_handlers import (
handle_oauth_authorize, handle_oauth_authorize,
handle_proxy_token_exchange, handle_proxy_token_exchange,
handle_oauth_protected_resource,
handle_oauth_authorization_server, handle_oauth_authorization_server,
handle_oauth_client_config, handle_oauth_client_config,
handle_oauth_register handle_oauth_register,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -45,12 +47,12 @@ logger = logging.getLogger(__name__)
class GoogleRemoteAuthProvider(RemoteAuthProvider): class GoogleRemoteAuthProvider(RemoteAuthProvider):
""" """
RemoteAuthProvider implementation for Google Workspace using FastMCP v2.11.1+. RemoteAuthProvider implementation for Google Workspace.
This provider extends RemoteAuthProvider to add: This provider extends RemoteAuthProvider to add:
- OAuth proxy endpoints for CORS workaround - OAuth proxy endpoints for CORS workaround
- Dynamic client registration support - Dynamic client registration support
- Enhanced session management with issuer tracking - Session management with issuer tracking
""" """
def __init__(self): def __init__(self):
@@ -65,51 +67,90 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
self.port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000))) self.port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
if not self.client_id: if not self.client_id:
logger.error("GOOGLE_OAUTH_CLIENT_ID not set - OAuth 2.1 authentication will not work") logger.error(
raise ValueError("GOOGLE_OAUTH_CLIENT_ID environment variable is required for OAuth 2.1 authentication") "GOOGLE_OAUTH_CLIENT_ID not set - OAuth 2.1 authentication will not work"
)
raise ValueError(
"GOOGLE_OAUTH_CLIENT_ID environment variable is required for OAuth 2.1 authentication"
)
# Configure JWT verifier for Google tokens # Configure JWT verifier for Google tokens
token_verifier = JWTVerifier( token_verifier = JWTVerifier(
jwks_uri="https://www.googleapis.com/oauth2/v3/certs", jwks_uri="https://www.googleapis.com/oauth2/v3/certs",
issuer="https://accounts.google.com", issuer="https://accounts.google.com",
audience=self.client_id, # Always use actual client_id audience=self.client_id, # Always use actual client_id
algorithm="RS256" algorithm="RS256",
) )
# Initialize RemoteAuthProvider with local server as the authorization server # Initialize RemoteAuthProvider with base URL (no /mcp/ suffix)
# This ensures OAuth discovery points to our proxy endpoints instead of Google directly # The /mcp/ resource URL is handled in the protected resource metadata endpoint
super().__init__( super().__init__(
token_verifier=token_verifier, token_verifier=token_verifier,
authorization_servers=[AnyHttpUrl(f"{self.base_url}:{self.port}")], authorization_servers=[AnyHttpUrl(f"{self.base_url}:{self.port}")],
resource_server_url=f"{self.base_url}:{self.port}" resource_server_url=f"{self.base_url}:{self.port}",
) )
logger.debug("GoogleRemoteAuthProvider initialized") logger.debug("GoogleRemoteAuthProvider")
def get_routes(self) -> List[Route]: def get_routes(self) -> List[Route]:
""" """
Add custom OAuth proxy endpoints to the standard protected resource routes. Add OAuth routes at canonical locations.
These endpoints work around Google's CORS restrictions and provide
dynamic client registration support.
""" """
# Get the standard OAuth protected resource routes from RemoteAuthProvider # Get the standard OAuth protected resource routes from RemoteAuthProvider
routes = super().get_routes() parent_routes = super().get_routes()
# Log what routes we're getting from the parent # Filter out the parent's oauth-protected-resource route since we're replacing it
logger.debug(f"Registered {len(routes)} OAuth routes from parent") routes = [
r
for r in parent_routes
if r.path != "/.well-known/oauth-protected-resource"
]
# Add our custom proxy endpoints using common handlers # Add our custom OAuth discovery endpoint that returns /mcp/ as the resource
routes.append(Route("/oauth2/authorize", handle_oauth_authorize, methods=["GET", "OPTIONS"])) routes.append(
Route(
"/.well-known/oauth-protected-resource",
handle_oauth_protected_resource,
methods=["GET", "OPTIONS"],
)
)
routes.append(Route("/oauth2/token", handle_proxy_token_exchange, methods=["POST", "OPTIONS"])) routes.append(
Route(
"/.well-known/oauth-authorization-server",
handle_oauth_authorization_server,
methods=["GET", "OPTIONS"],
)
)
routes.append(Route("/oauth2/register", handle_oauth_register, methods=["POST", "OPTIONS"])) routes.append(
Route(
"/.well-known/oauth-client",
handle_oauth_client_config,
methods=["GET", "OPTIONS"],
)
)
routes.append(Route("/.well-known/oauth-authorization-server", handle_oauth_authorization_server, methods=["GET", "OPTIONS"])) # Add OAuth flow endpoints
routes.append(
routes.append(Route("/.well-known/oauth-client", handle_oauth_client_config, methods=["GET", "OPTIONS"])) Route(
"/oauth2/authorize", handle_oauth_authorize, methods=["GET", "OPTIONS"]
)
)
routes.append(
Route(
"/oauth2/token",
handle_proxy_token_exchange,
methods=["POST", "OPTIONS"],
)
)
routes.append(
Route(
"/oauth2/register", handle_oauth_register, methods=["POST", "OPTIONS"]
)
)
logger.info(f"Registered {len(routes)} OAuth routes")
return routes return routes
async def verify_token(self, token: str) -> Optional[object]: async def verify_token(self, token: str) -> Optional[object]:
@@ -121,22 +162,30 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
""" """
# Check if this is a Google OAuth access token (starts with ya29.) # Check if this is a Google OAuth access token (starts with ya29.)
if token.startswith("ya29."): if token.startswith("ya29."):
logger.debug("Detected Google OAuth access token, using tokeninfo verification") logger.debug(
"Detected Google OAuth access token, using tokeninfo verification"
)
try: try:
# Verify the access token using Google's tokeninfo endpoint # Verify the access token using Google's tokeninfo endpoint
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={token}" url = (
f"https://oauth2.googleapis.com/tokeninfo?access_token={token}"
)
async with session.get(url) as response: async with session.get(url) as response:
if response.status != 200: if response.status != 200:
logger.error(f"Token verification failed: {response.status}") logger.error(
f"Token verification failed: {response.status}"
)
return None return None
token_info = await response.json() token_info = await response.json()
# Verify the token is for our client # Verify the token is for our client
if token_info.get("aud") != self.client_id: if token_info.get("aud") != self.client_id:
logger.error(f"Token audience mismatch: expected {self.client_id}, got {token_info.get('aud')}") logger.error(
f"Token audience mismatch: expected {self.client_id}, got {token_info.get('aud')}"
)
return None return None
# Check if token is expired # Check if token is expired
@@ -151,7 +200,9 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
# Calculate expires_at timestamp # Calculate expires_at timestamp
expires_in = int(token_info.get("expires_in", 0)) expires_in = int(token_info.get("expires_in", 0))
expires_at = int(time.time()) + expires_in if expires_in > 0 else 0 expires_at = (
int(time.time()) + expires_in if expires_in > 0 else 0
)
access_token = SimpleNamespace( access_token = SimpleNamespace(
claims={ claims={
@@ -166,12 +217,15 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
client_id=self.client_id, # Add client_id at top level client_id=self.client_id, # Add client_id at top level
# Add other required fields # Add other required fields
sub=token_info.get("sub", ""), sub=token_info.get("sub", ""),
email=token_info.get("email", "") email=token_info.get("email", ""),
) )
user_email = token_info.get("email") user_email = token_info.get("email")
if user_email: if user_email:
from auth.oauth21_session_store import get_oauth21_session_store from auth.oauth21_session_store import (
get_oauth21_session_store,
)
store = get_oauth21_session_store() store = get_oauth21_session_store()
session_id = f"google_{token_info.get('sub', 'unknown')}" session_id = f"google_{token_info.get('sub', 'unknown')}"
@@ -179,10 +233,13 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
mcp_session_id = None mcp_session_id = None
try: try:
from fastmcp.server.dependencies import get_context from fastmcp.server.dependencies import get_context
ctx = get_context() ctx = get_context()
if ctx and hasattr(ctx, 'session_id'): if ctx and hasattr(ctx, "session_id"):
mcp_session_id = ctx.session_id mcp_session_id = ctx.session_id
logger.debug(f"Binding MCP session {mcp_session_id} to user {user_email}") logger.debug(
f"Binding MCP session {mcp_session_id} to user {user_email}"
)
except Exception: except Exception:
pass pass
@@ -193,7 +250,7 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
scopes=access_token.scopes, scopes=access_token.scopes,
session_id=session_id, session_id=session_id,
mcp_session_id=mcp_session_id, mcp_session_id=mcp_session_id,
issuer="https://accounts.google.com" issuer="https://accounts.google.com",
) )
logger.info(f"Verified OAuth token: {user_email}") logger.info(f"Verified OAuth token: {user_email}")
@@ -214,6 +271,7 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
user_email = access_token.claims.get("email") user_email = access_token.claims.get("email")
if user_email: if user_email:
from auth.oauth21_session_store import get_oauth21_session_store from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store() store = get_oauth21_session_store()
session_id = f"google_{access_token.claims.get('sub', 'unknown')}" session_id = f"google_{access_token.claims.get('sub', 'unknown')}"
@@ -223,9 +281,11 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
access_token=token, access_token=token,
scopes=access_token.scopes or [], scopes=access_token.scopes or [],
session_id=session_id, session_id=session_id,
issuer="https://accounts.google.com" issuer="https://accounts.google.com",
) )
logger.debug(f"Successfully verified JWT token for user: {user_email}") logger.debug(
f"Successfully verified JWT token for user: {user_email}"
)
return access_token return access_token

View File

@@ -160,23 +160,6 @@ def set_auth_layer(auth_layer):
logger.info("set_auth_layer called - OAuth is now handled by FastMCP") logger.info("set_auth_layer called - OAuth is now handled by FastMCP")
_oauth21_enabled = False
def is_oauth21_enabled() -> bool:
"""
Check if the OAuth 2.1 authentication layer is active.
"""
global _oauth21_enabled
return _oauth21_enabled
def enable_oauth21():
"""
Enable the OAuth 2.1 authentication layer.
"""
global _oauth21_enabled
_oauth21_enabled = True
logger.debug("OAuth 2.1 authentication enabled")
async def get_legacy_auth_service( async def get_legacy_auth_service(
@@ -206,13 +189,16 @@ async def get_authenticated_google_service_oauth21(
tool_name: str, tool_name: str,
user_google_email: str, user_google_email: str,
required_scopes: list[str], required_scopes: list[str],
session_id: Optional[str] = None,
auth_token_email: Optional[str] = None,
allow_recent_auth: bool = False,
context: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None,
) -> Tuple[Any, str]: ) -> Tuple[Any, str]:
""" """
Enhanced version of get_authenticated_google_service that supports OAuth 2.1. Enhanced version of get_authenticated_google_service that supports OAuth 2.1.
This function checks for OAuth 2.1 session context and uses it if available, This function checks for OAuth 2.1 session context and uses it if available,
otherwise falls back to legacy authentication. otherwise falls back to legacy authentication based on configuration.
Args: Args:
service_name: Google service name service_name: Google service name
@@ -220,20 +206,32 @@ async def get_authenticated_google_service_oauth21(
tool_name: Tool name for logging tool_name: Tool name for logging
user_google_email: User's Google email user_google_email: User's Google email
required_scopes: Required OAuth scopes required_scopes: Required OAuth scopes
session_id: Optional OAuth session ID
auth_token_email: Optional authenticated user email from token
allow_recent_auth: Whether to allow recently authenticated sessions
context: Optional context containing session information context: Optional context containing session information
Returns: Returns:
Tuple of (service instance, actual user email) Tuple of (service instance, actual user email)
""" """
# Check if OAuth 2.1 is truly enabled
if not is_oauth21_enabled():
logger.debug(f"[{tool_name}] OAuth 2.1 disabled, using legacy authentication")
return await get_legacy_auth_service(
service_name=service_name,
version=version,
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=required_scopes,
)
builder = get_oauth21_service_builder() builder = get_oauth21_service_builder()
# FastMCP handles context now - extract any session info # FastMCP handles context now - extract any session info
session_id = None if not session_id and context:
auth_context = None
if context:
session_id = builder.extract_session_from_context(context) session_id = builder.extract_session_from_context(context)
auth_context = context.get("auth_context")
auth_context = context.get("auth_context") if context else None
return await builder.get_authenticated_service_with_session( return await builder.get_authenticated_service_with_session(
service_name=service_name, service_name=service_name,
@@ -244,3 +242,34 @@ async def get_authenticated_google_service_oauth21(
session_id=session_id, session_id=session_id,
auth_context=auth_context, auth_context=auth_context,
) )
async def get_authenticated_google_service_oauth21_v2(
request: "OAuth21ServiceRequest",
) -> Tuple[Any, str]:
"""
Enhanced version of get_authenticated_google_service that supports OAuth 2.1.
This version uses a parameter object to reduce function complexity and
improve maintainability. It's the recommended approach for new code.
Args:
request: OAuth21ServiceRequest object containing all parameters
Returns:
Tuple of (service instance, actual user email)
"""
# Delegate to the original function for now
# This provides a migration path while maintaining backward compatibility
return await get_authenticated_google_service_oauth21(
service_name=request.service_name,
version=request.version,
tool_name=request.tool_name,
user_google_email=request.user_google_email,
required_scopes=request.required_scopes,
session_id=request.session_id,
auth_token_email=request.auth_token_email,
allow_recent_auth=request.allow_recent_auth,
context=request.context,
)

View File

@@ -5,7 +5,6 @@ In streamable-http mode: Uses the existing FastAPI server
In stdio mode: Starts a minimal HTTP server just for OAuth callbacks In stdio mode: Starts a minimal HTTP server just for OAuth callbacks
""" """
import os
import asyncio import asyncio
import logging import logging
import threading import threading
@@ -20,7 +19,7 @@ from urllib.parse import urlparse
from auth.scopes import SCOPES from auth.scopes import SCOPES
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
from auth.google_auth import handle_auth_callback, check_client_secrets from auth.google_auth import handle_auth_callback, check_client_secrets
from core.config import get_oauth_redirect_uri from auth.oauth_config import get_oauth_redirect_uri
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -16,22 +16,24 @@ from google.oauth2.credentials import Credentials
from auth.oauth21_session_store import store_token_session from auth.oauth21_session_store import store_token_session
from auth.google_auth import save_credentials_to_file from auth.google_auth import save_credentials_to_file
from auth.scopes import get_current_scopes from auth.scopes import get_current_scopes
from core.config import WORKSPACE_MCP_BASE_URI, WORKSPACE_MCP_PORT, get_oauth_base_url from auth.oauth_config import get_oauth_config
from auth.oauth_error_handling import (
OAuthError, OAuthValidationError, OAuthConfigurationError,
create_oauth_error_response, validate_token_request,
validate_registration_request, get_development_cors_headers,
log_security_event
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def handle_oauth_authorize(request: Request): async def handle_oauth_authorize(request: Request):
"""Common handler for OAuth authorization proxy.""" """Common handler for OAuth authorization proxy."""
origin = request.headers.get("origin")
if request.method == "OPTIONS": if request.method == "OPTIONS":
return JSONResponse( cors_headers = get_development_cors_headers(origin)
content={}, return JSONResponse(content={}, headers=cors_headers)
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type"
}
)
# Get query parameters # Get query parameters
params = dict(request.query_params) params = dict(request.query_params)
@@ -55,35 +57,40 @@ async def handle_oauth_authorize(request: Request):
# Build Google authorization URL # Build Google authorization URL
google_auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(params) google_auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(params)
# Return redirect # Return redirect with development CORS headers if needed
cors_headers = get_development_cors_headers(origin)
return RedirectResponse( return RedirectResponse(
url=google_auth_url, url=google_auth_url,
status_code=302, status_code=302,
headers={ headers=cors_headers
"Access-Control-Allow-Origin": "*"
}
) )
async def handle_proxy_token_exchange(request: Request): async def handle_proxy_token_exchange(request: Request):
"""Common handler for OAuth token exchange proxy.""" """Common handler for OAuth token exchange proxy with comprehensive error handling."""
if request.method == "OPTIONS": origin = request.headers.get("origin")
return JSONResponse(
content={},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization"
}
)
try:
# Get form data
body = await request.body()
content_type = request.headers.get("content-type", "application/x-www-form-urlencoded")
# Parse form data to add missing client credentials if request.method == "OPTIONS":
cors_headers = get_development_cors_headers(origin)
return JSONResponse(content={}, headers=cors_headers)
try:
# Get form data with validation
try:
body = await request.body()
content_type = request.headers.get("content-type", "application/x-www-form-urlencoded")
except Exception as e:
raise OAuthValidationError(f"Failed to read request body: {e}")
# Parse and validate form data
if content_type and "application/x-www-form-urlencoded" in content_type: if content_type and "application/x-www-form-urlencoded" in content_type:
form_data = parse_qs(body.decode('utf-8')) try:
form_data = parse_qs(body.decode('utf-8'))
except Exception as e:
raise OAuthValidationError(f"Invalid form data: {e}")
# Convert to single values and validate
request_data = {k: v[0] if v else '' for k, v in form_data.items()}
validate_token_request(request_data)
# Check if client_id is missing (public client) # Check if client_id is missing (public client)
if 'client_id' not in form_data or not form_data['client_id'][0]: if 'client_id' not in form_data or not form_data['client_id'][0]:
@@ -186,43 +193,57 @@ async def handle_proxy_token_exchange(request: Request):
except Exception as e: except Exception as e:
logger.error(f"Failed to store OAuth session: {e}") logger.error(f"Failed to store OAuth session: {e}")
# Add development CORS headers
cors_headers = get_development_cors_headers(origin)
response_headers = {
"Content-Type": "application/json",
"Cache-Control": "no-store"
}
response_headers.update(cors_headers)
return JSONResponse( return JSONResponse(
status_code=response.status, status_code=response.status,
content=response_data, content=response_data,
headers={ headers=response_headers
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
"Cache-Control": "no-store"
}
) )
except OAuthError as e:
log_security_event("oauth_token_exchange_error", {
"error_code": e.error_code,
"description": e.description
}, request)
return create_oauth_error_response(e, origin)
except Exception as e: except Exception as e:
logger.error(f"Error in token proxy: {e}") logger.error(f"Unexpected error in token proxy: {e}", exc_info=True)
return JSONResponse( log_security_event("oauth_token_exchange_unexpected_error", {
status_code=500, "error": str(e)
content={"error": "server_error", "error_description": str(e)}, }, request)
headers={"Access-Control-Allow-Origin": "*"} error = OAuthConfigurationError("Internal server error")
) return create_oauth_error_response(error, origin)
async def handle_oauth_protected_resource(request: Request): async def handle_oauth_protected_resource(request: Request):
"""Common handler for OAuth protected resource metadata.""" """
if request.method == "OPTIONS": Handle OAuth protected resource metadata requests.
return JSONResponse( """
content={}, origin = request.headers.get("origin")
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type"
}
)
base_url = get_oauth_base_url() # Handle preflight
if request.method == "OPTIONS":
cors_headers = get_development_cors_headers(origin)
return JSONResponse(content={}, headers=cors_headers)
config = get_oauth_config()
base_url = config.get_oauth_base_url()
# For streamable-http transport, the MCP server runs at /mcp
# This is the actual resource being protected
resource_url = f"{base_url}/mcp"
# Build metadata response per RFC 9449
metadata = { metadata = {
"resource": base_url, "resource": resource_url, # The MCP server endpoint that needs protection
"authorization_servers": [ "authorization_servers": [base_url], # Our proxy acts as the auth server
base_url
],
"bearer_methods_supported": ["header"], "bearer_methods_supported": ["header"],
"scopes_supported": get_current_scopes(), "scopes_supported": get_current_scopes(),
"resource_documentation": "https://developers.google.com/workspace", "resource_documentation": "https://developers.google.com/workspace",
@@ -230,179 +251,143 @@ async def handle_oauth_protected_resource(request: Request):
"client_configuration_endpoint": f"{base_url}/.well-known/oauth-client", "client_configuration_endpoint": f"{base_url}/.well-known/oauth-client",
} }
# Log the response for debugging
logger.debug(f"Returning protected resource metadata: {metadata}")
# Add development CORS headers
cors_headers = get_development_cors_headers(origin)
response_headers = {
"Content-Type": "application/json; charset=utf-8",
"Cache-Control": "public, max-age=3600"
}
response_headers.update(cors_headers)
return JSONResponse( return JSONResponse(
content=metadata, content=metadata,
headers={ headers=response_headers
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*"
}
) )
async def handle_oauth_authorization_server(request: Request): async def handle_oauth_authorization_server(request: Request):
"""Common handler for OAuth authorization server metadata.""" """
Handle OAuth authorization server metadata.
"""
origin = request.headers.get("origin")
if request.method == "OPTIONS": if request.method == "OPTIONS":
return JSONResponse( cors_headers = get_development_cors_headers(origin)
content={}, return JSONResponse(content={}, headers=cors_headers)
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type"
}
)
# Get base URL once and reuse config = get_oauth_config()
base_url = get_oauth_base_url()
try: # Get authorization server metadata from centralized config
# Fetch metadata from Google # Pass scopes directly to keep all metadata generation in one place
async with aiohttp.ClientSession() as session: metadata = config.get_authorization_server_metadata(scopes=get_current_scopes())
url = "https://accounts.google.com/.well-known/openid-configuration"
async with session.get(url) as response:
if response.status == 200:
metadata = await response.json()
# Add OAuth 2.1 required fields logger.debug(f"Returning authorization server metadata: {metadata}")
metadata.setdefault("code_challenge_methods_supported", ["S256"])
metadata.setdefault("pkce_required", True)
# Override endpoints to use our proxies # Add development CORS headers
metadata["token_endpoint"] = f"{base_url}/oauth2/token" cors_headers = get_development_cors_headers(origin)
metadata["authorization_endpoint"] = f"{base_url}/oauth2/authorize" response_headers = {
metadata["enable_dynamic_registration"] = True "Content-Type": "application/json; charset=utf-8",
metadata["registration_endpoint"] = f"{base_url}/oauth2/register" "Cache-Control": "public, max-age=3600"
return JSONResponse( }
content=metadata, response_headers.update(cors_headers)
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*"
}
)
# Fallback metadata return JSONResponse(
return JSONResponse( content=metadata,
content={ headers=response_headers
"issuer": "https://accounts.google.com", )
"authorization_endpoint": f"{base_url}/oauth2/authorize",
"token_endpoint": f"{base_url}/oauth2/token",
"userinfo_endpoint": "https://www.googleapis.com/oauth2/v2/userinfo",
"revocation_endpoint": "https://oauth2.googleapis.com/revoke",
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
"response_types_supported": ["code"],
"code_challenge_methods_supported": ["S256"],
"pkce_required": True,
"grant_types_supported": ["authorization_code", "refresh_token"],
"scopes_supported": get_current_scopes(),
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"]
},
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*"
}
)
except Exception as e:
logger.error(f"Error fetching auth server metadata: {e}")
return JSONResponse(
status_code=500,
content={"error": "Failed to fetch authorization server metadata"},
headers={"Access-Control-Allow-Origin": "*"}
)
async def handle_oauth_client_config(request: Request): async def handle_oauth_client_config(request: Request):
"""Common handler for OAuth client configuration.""" """Common handler for OAuth client configuration."""
origin = request.headers.get("origin")
if request.method == "OPTIONS": if request.method == "OPTIONS":
return JSONResponse( cors_headers = get_development_cors_headers(origin)
content={}, return JSONResponse(content={}, headers=cors_headers)
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type"
}
)
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID") client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
if not client_id: if not client_id:
cors_headers = get_development_cors_headers(origin)
return JSONResponse( return JSONResponse(
status_code=404, status_code=404,
content={"error": "OAuth not configured"}, content={"error": "OAuth not configured"},
headers={"Access-Control-Allow-Origin": "*"} headers=cors_headers
) )
# Get OAuth configuration
config = get_oauth_config()
return JSONResponse( return JSONResponse(
content={ content={
"client_id": client_id, "client_id": client_id,
"client_name": "Google Workspace MCP Server", "client_name": "Google Workspace MCP Server",
"client_uri": f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}", "client_uri": config.base_url,
"redirect_uris": [ "redirect_uris": [
f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/oauth2callback", f"{config.base_url}/oauth2callback",
"http://localhost:5173/auth/callback" "http://localhost:5173/auth/callback"
], ],
"grant_types": ["authorization_code", "refresh_token"], "grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"], "response_types": ["code"],
"scope": " ".join(get_current_scopes()), "scope": " ".join(get_current_scopes()),
"token_endpoint_auth_method": "client_secret_basic", "token_endpoint_auth_method": "client_secret_basic",
"code_challenge_methods": ["S256"] "code_challenge_methods": config.supported_code_challenge_methods[:1] # Primary method only
}, },
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json; charset=utf-8",
"Access-Control-Allow-Origin": "*" "Cache-Control": "public, max-age=3600",
**get_development_cors_headers(origin)
} }
) )
async def handle_oauth_register(request: Request): async def handle_oauth_register(request: Request):
"""Common handler for OAuth dynamic client registration.""" """Common handler for OAuth dynamic client registration with comprehensive error handling."""
origin = request.headers.get("origin")
if request.method == "OPTIONS": if request.method == "OPTIONS":
return JSONResponse( cors_headers = get_development_cors_headers(origin)
content={}, return JSONResponse(content={}, headers=cors_headers)
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization"
}
)
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID") config = get_oauth_config()
client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
if not client_id or not client_secret: if not config.is_configured():
return JSONResponse( error = OAuthConfigurationError("OAuth client credentials not configured")
status_code=400, return create_oauth_error_response(error, origin)
content={"error": "invalid_request", "error_description": "OAuth not configured"},
headers={"Access-Control-Allow-Origin": "*"}
)
try: try:
# Parse the registration request # Parse and validate the registration request
body = await request.json() try:
logger.info(f"Dynamic client registration request received: {body}") body = await request.json()
except Exception as e:
raise OAuthValidationError(f"Invalid JSON in registration request: {e}")
validate_registration_request(body)
logger.info("Dynamic client registration request received")
# Extract redirect URIs from the request or use defaults # Extract redirect URIs from the request or use defaults
redirect_uris = body.get("redirect_uris", []) redirect_uris = body.get("redirect_uris", [])
if not redirect_uris: if not redirect_uris:
redirect_uris = [ redirect_uris = config.get_redirect_uris()
f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/oauth2callback",
"http://localhost:5173/auth/callback"
]
# Build the registration response with our pre-configured credentials # Build the registration response with our pre-configured credentials
response_data = { response_data = {
"client_id": client_id, "client_id": config.client_id,
"client_secret": client_secret, "client_secret": config.client_secret,
"client_name": body.get("client_name", "Google Workspace MCP Server"), "client_name": body.get("client_name", "Google Workspace MCP Server"),
"client_uri": body.get("client_uri", f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}"), "client_uri": body.get("client_uri", config.base_url),
"redirect_uris": redirect_uris, "redirect_uris": redirect_uris,
"grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]), "grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]),
"response_types": body.get("response_types", ["code"]), "response_types": body.get("response_types", ["code"]),
"scope": body.get("scope", " ".join(get_current_scopes())), "scope": body.get("scope", " ".join(get_current_scopes())),
"token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_basic"), "token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_basic"),
"code_challenge_methods": ["S256"], "code_challenge_methods": config.supported_code_challenge_methods,
# Additional OAuth 2.1 fields # Additional OAuth 2.1 fields
"client_id_issued_at": int(time.time()), "client_id_issued_at": int(time.time()),
"registration_access_token": "not-required", # We don't implement client management "registration_access_token": "not-required", # We don't implement client management
"registration_client_uri": f"{get_oauth_base_url()}/oauth2/register/{client_id}" "registration_client_uri": f"{config.get_oauth_base_url()}/oauth2/register/{config.client_id}"
} }
logger.info("Dynamic client registration successful - returning pre-configured Google credentials") logger.info("Dynamic client registration successful - returning pre-configured Google credentials")
@@ -412,15 +397,21 @@ async def handle_oauth_register(request: Request):
content=response_data, content=response_data,
headers={ headers={
"Content-Type": "application/json", "Content-Type": "application/json",
"Access-Control-Allow-Origin": "*", "Cache-Control": "no-store",
"Cache-Control": "no-store" **get_development_cors_headers(origin)
} }
) )
except OAuthError as e:
log_security_event("oauth_registration_error", {
"error_code": e.error_code,
"description": e.description
}, request)
return create_oauth_error_response(e, origin)
except Exception as e: except Exception as e:
logger.error(f"Error in dynamic client registration: {e}") logger.error(f"Unexpected error in client registration: {e}", exc_info=True)
return JSONResponse( log_security_event("oauth_registration_unexpected_error", {
status_code=400, "error": str(e)
content={"error": "invalid_request", "error_description": str(e)}, }, request)
headers={"Access-Control-Allow-Origin": "*"} error = OAuthConfigurationError("Internal server error")
) return create_oauth_error_response(error, origin)

319
auth/oauth_config.py Normal file
View File

@@ -0,0 +1,319 @@
"""
OAuth Configuration Management
This module centralizes OAuth-related configuration to eliminate hardcoded values
scattered throughout the codebase. It provides environment variable support and
sensible defaults for all OAuth-related settings.
Supports both OAuth 2.0 and OAuth 2.1 with automatic client capability detection.
"""
import os
from typing import List, Optional, Dict, Any
class OAuthConfig:
"""
Centralized OAuth configuration management.
This class eliminates the hardcoded configuration anti-pattern identified
in the challenge review by providing a single source of truth for all
OAuth-related configuration values.
"""
def __init__(self):
# Base server configuration
self.base_uri = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
self.port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", "8000")))
self.base_url = f"{self.base_uri}:{self.port}"
# OAuth client configuration
self.client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
self.client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
# OAuth 2.1 configuration
self.oauth21_enabled = os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true"
self.pkce_required = self.oauth21_enabled # PKCE is mandatory in OAuth 2.1
self.supported_code_challenge_methods = ["S256", "plain"] if not self.oauth21_enabled else ["S256"]
# Transport mode (will be set at runtime)
self._transport_mode = "stdio" # Default
# Redirect URI configuration
self.redirect_uri = self._get_redirect_uri()
def _get_redirect_uri(self) -> str:
"""
Get the OAuth redirect URI, supporting reverse proxy configurations.
Returns:
The configured redirect URI
"""
explicit_uri = os.getenv("GOOGLE_OAUTH_REDIRECT_URI")
if explicit_uri:
return explicit_uri
return f"{self.base_url}/oauth2callback"
def get_redirect_uris(self) -> List[str]:
"""
Get all valid OAuth redirect URIs.
Returns:
List of all supported redirect URIs
"""
uris = []
# Primary redirect URI
uris.append(self.redirect_uri)
# Custom redirect URIs from environment
custom_uris = os.getenv("OAUTH_CUSTOM_REDIRECT_URIS")
if custom_uris:
uris.extend([uri.strip() for uri in custom_uris.split(",")])
# Remove duplicates while preserving order
return list(dict.fromkeys(uris))
def get_allowed_origins(self) -> List[str]:
"""
Get allowed CORS origins for OAuth endpoints.
Returns:
List of allowed origins for CORS
"""
origins = []
# Server's own origin
origins.append(self.base_url)
# VS Code and development origins
origins.extend([
"vscode-webview://",
"https://vscode.dev",
"https://github.dev",
])
# Custom origins from environment
custom_origins = os.getenv("OAUTH_ALLOWED_ORIGINS")
if custom_origins:
origins.extend([origin.strip() for origin in custom_origins.split(",")])
return list(dict.fromkeys(origins))
def is_configured(self) -> bool:
"""
Check if OAuth is properly configured.
Returns:
True if OAuth client credentials are available
"""
return bool(self.client_id and self.client_secret)
def get_oauth_base_url(self) -> str:
"""
Get OAuth base URL for constructing OAuth endpoints.
Returns:
Base URL for OAuth endpoints
"""
return self.base_url
def validate_redirect_uri(self, uri: str) -> bool:
"""
Validate if a redirect URI is allowed.
Args:
uri: The redirect URI to validate
Returns:
True if the URI is allowed, False otherwise
"""
allowed_uris = self.get_redirect_uris()
return uri in allowed_uris
def get_environment_summary(self) -> dict:
"""
Get a summary of the current OAuth configuration.
Returns:
Dictionary with configuration summary (excluding secrets)
"""
return {
"base_url": self.base_url,
"redirect_uri": self.redirect_uri,
"client_configured": bool(self.client_id),
"oauth21_enabled": self.oauth21_enabled,
"pkce_required": self.pkce_required,
"transport_mode": self._transport_mode,
"total_redirect_uris": len(self.get_redirect_uris()),
"total_allowed_origins": len(self.get_allowed_origins()),
}
def set_transport_mode(self, mode: str) -> None:
"""
Set the current transport mode for OAuth callback handling.
Args:
mode: Transport mode ("stdio", "streamable-http", etc.)
"""
self._transport_mode = mode
def get_transport_mode(self) -> str:
"""
Get the current transport mode.
Returns:
Current transport mode
"""
return self._transport_mode
def is_oauth21_enabled(self) -> bool:
"""
Check if OAuth 2.1 mode is enabled.
Returns:
True if OAuth 2.1 is enabled
"""
return self.oauth21_enabled
def detect_oauth_version(self, request_params: Dict[str, Any]) -> str:
"""
Detect OAuth version based on request parameters.
This method implements a conservative detection strategy:
- Only returns "oauth21" when we have clear indicators
- Defaults to "oauth20" for backward compatibility
- Respects the global oauth21_enabled flag
Args:
request_params: Request parameters from authorization or token request
Returns:
"oauth21" or "oauth20" based on detection
"""
# If OAuth 2.1 is not enabled globally, always return OAuth 2.0
if not self.oauth21_enabled:
return "oauth20"
# Use the structured type for cleaner detection logic
from auth.oauth_types import OAuthVersionDetectionParams
params = OAuthVersionDetectionParams.from_request(request_params)
# Clear OAuth 2.1 indicator: PKCE is present
if params.has_pkce:
return "oauth21"
# For public clients in OAuth 2.1 mode, we require PKCE
# But since they didn't send PKCE, fall back to OAuth 2.0
# This ensures backward compatibility
# Default to OAuth 2.0 for maximum compatibility
return "oauth20"
def get_authorization_server_metadata(self, scopes: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Get OAuth authorization server metadata per RFC 8414.
Args:
scopes: Optional list of supported scopes to include in metadata
Returns:
Authorization server metadata dictionary
"""
metadata = {
"issuer": self.base_url,
"authorization_endpoint": f"{self.base_url}/oauth2/authorize",
"token_endpoint": f"{self.base_url}/oauth2/token",
"registration_endpoint": f"{self.base_url}/oauth2/register",
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
"response_types_supported": ["code", "token"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
"code_challenge_methods_supported": self.supported_code_challenge_methods,
}
# Include scopes if provided
if scopes is not None:
metadata["scopes_supported"] = scopes
# Add OAuth 2.1 specific metadata
if self.oauth21_enabled:
metadata["pkce_required"] = True
# OAuth 2.1 deprecates implicit flow
metadata["response_types_supported"] = ["code"]
# OAuth 2.1 requires exact redirect URI matching
metadata["require_exact_redirect_uri"] = True
return metadata
# Global configuration instance
_oauth_config = None
def get_oauth_config() -> OAuthConfig:
"""
Get the global OAuth configuration instance.
Returns:
The singleton OAuth configuration instance
"""
global _oauth_config
if _oauth_config is None:
_oauth_config = OAuthConfig()
return _oauth_config
def reload_oauth_config() -> OAuthConfig:
"""
Reload the OAuth configuration from environment variables.
This is useful for testing or when environment variables change.
Returns:
The reloaded OAuth configuration instance
"""
global _oauth_config
_oauth_config = OAuthConfig()
return _oauth_config
# Convenience functions for backward compatibility
def get_oauth_base_url() -> str:
"""Get OAuth base URL."""
return get_oauth_config().get_oauth_base_url()
def get_redirect_uris() -> List[str]:
"""Get all valid OAuth redirect URIs."""
return get_oauth_config().get_redirect_uris()
def get_allowed_origins() -> List[str]:
"""Get allowed CORS origins."""
return get_oauth_config().get_allowed_origins()
def is_oauth_configured() -> bool:
"""Check if OAuth is properly configured."""
return get_oauth_config().is_configured()
def set_transport_mode(mode: str) -> None:
"""Set the current transport mode."""
get_oauth_config().set_transport_mode(mode)
def get_transport_mode() -> str:
"""Get the current transport mode."""
return get_oauth_config().get_transport_mode()
def is_oauth21_enabled() -> bool:
"""Check if OAuth 2.1 is enabled."""
return get_oauth_config().is_oauth21_enabled()
def get_oauth_redirect_uri() -> str:
"""Get the primary OAuth redirect URI."""
return get_oauth_config().redirect_uri

View File

@@ -0,0 +1,321 @@
"""
OAuth Error Handling and Validation
This module provides comprehensive error handling and input validation for OAuth
endpoints, addressing the inconsistent error handling identified in the challenge review.
"""
import logging
from typing import Optional, Dict, Any, List
from starlette.responses import JSONResponse
from starlette.requests import Request
from urllib.parse import urlparse
import re
logger = logging.getLogger(__name__)
class OAuthError(Exception):
"""Base exception for OAuth-related errors."""
def __init__(self, error_code: str, description: str, status_code: int = 400):
self.error_code = error_code
self.description = description
self.status_code = status_code
super().__init__(f"{error_code}: {description}")
class OAuthValidationError(OAuthError):
"""Exception for OAuth validation errors."""
def __init__(self, description: str, field: Optional[str] = None):
error_code = "invalid_request"
if field:
description = f"Invalid {field}: {description}"
super().__init__(error_code, description, 400)
class OAuthConfigurationError(OAuthError):
"""Exception for OAuth configuration errors."""
def __init__(self, description: str):
super().__init__("server_error", description, 500)
def create_oauth_error_response(error: OAuthError, origin: Optional[str] = None) -> JSONResponse:
"""
Create a standardized OAuth error response.
Args:
error: The OAuth error to convert to a response
origin: Optional origin for development CORS headers
Returns:
JSONResponse with standardized error format
"""
headers = {
"Content-Type": "application/json",
"Cache-Control": "no-store"
}
# Add development CORS headers if needed
cors_headers = get_development_cors_headers(origin)
headers.update(cors_headers)
content = {
"error": error.error_code,
"error_description": error.description
}
logger.warning(f"OAuth error response: {error.error_code} - {error.description}")
return JSONResponse(
status_code=error.status_code,
content=content,
headers=headers
)
def validate_redirect_uri(uri: str) -> None:
"""
Validate an OAuth redirect URI.
Args:
uri: The redirect URI to validate
Raises:
OAuthValidationError: If the URI is invalid
"""
if not uri:
raise OAuthValidationError("Redirect URI is required", "redirect_uri")
try:
parsed = urlparse(uri)
except Exception:
raise OAuthValidationError("Malformed redirect URI", "redirect_uri")
# Basic URI validation
if not parsed.scheme or not parsed.netloc:
raise OAuthValidationError("Redirect URI must be absolute", "redirect_uri")
# Security checks
if parsed.scheme not in ["http", "https"]:
raise OAuthValidationError("Redirect URI must use HTTP or HTTPS", "redirect_uri")
# Additional security for production
if parsed.scheme == "http" and parsed.hostname not in ["localhost", "127.0.0.1"]:
logger.warning(f"Insecure redirect URI: {uri}")
def validate_client_id(client_id: str) -> None:
"""
Validate an OAuth client ID.
Args:
client_id: The client ID to validate
Raises:
OAuthValidationError: If the client ID is invalid
"""
if not client_id:
raise OAuthValidationError("Client ID is required", "client_id")
if len(client_id) < 10:
raise OAuthValidationError("Client ID is too short", "client_id")
# Basic format validation for Google client IDs
if not re.match(r'^[a-zA-Z0-9\-_.]+$', client_id):
raise OAuthValidationError("Client ID contains invalid characters", "client_id")
def validate_authorization_code(code: str) -> None:
"""
Validate an OAuth authorization code.
Args:
code: The authorization code to validate
Raises:
OAuthValidationError: If the code is invalid
"""
if not code:
raise OAuthValidationError("Authorization code is required", "code")
if len(code) < 10:
raise OAuthValidationError("Authorization code is too short", "code")
# Check for suspicious patterns
if any(char in code for char in [' ', '\n', '\t', '<', '>']):
raise OAuthValidationError("Authorization code contains invalid characters", "code")
def validate_scopes(scopes: List[str]) -> None:
"""
Validate OAuth scopes.
Args:
scopes: List of scopes to validate
Raises:
OAuthValidationError: If the scopes are invalid
"""
if not scopes:
return # Empty scopes list is acceptable
for scope in scopes:
if not scope:
raise OAuthValidationError("Empty scope is not allowed", "scope")
if len(scope) > 200:
raise OAuthValidationError("Scope is too long", "scope")
# Basic scope format validation
if not re.match(r'^[a-zA-Z0-9\-_.:/]+$', scope):
raise OAuthValidationError(f"Invalid scope format: {scope}", "scope")
def validate_token_request(request_data: Dict[str, Any]) -> None:
"""
Validate an OAuth token exchange request.
Args:
request_data: The token request data to validate
Raises:
OAuthValidationError: If the request is invalid
"""
grant_type = request_data.get("grant_type")
if not grant_type:
raise OAuthValidationError("Grant type is required", "grant_type")
if grant_type not in ["authorization_code", "refresh_token"]:
raise OAuthValidationError(f"Unsupported grant type: {grant_type}", "grant_type")
if grant_type == "authorization_code":
code = request_data.get("code")
validate_authorization_code(code)
redirect_uri = request_data.get("redirect_uri")
if redirect_uri:
validate_redirect_uri(redirect_uri)
client_id = request_data.get("client_id")
if client_id:
validate_client_id(client_id)
def validate_registration_request(request_data: Dict[str, Any]) -> None:
"""
Validate an OAuth client registration request.
Args:
request_data: The registration request data to validate
Raises:
OAuthValidationError: If the request is invalid
"""
# Validate redirect URIs if provided
redirect_uris = request_data.get("redirect_uris", [])
if redirect_uris:
if not isinstance(redirect_uris, list):
raise OAuthValidationError("redirect_uris must be an array", "redirect_uris")
for uri in redirect_uris:
validate_redirect_uri(uri)
# Validate grant types if provided
grant_types = request_data.get("grant_types", [])
if grant_types:
if not isinstance(grant_types, list):
raise OAuthValidationError("grant_types must be an array", "grant_types")
allowed_grant_types = ["authorization_code", "refresh_token"]
for grant_type in grant_types:
if grant_type not in allowed_grant_types:
raise OAuthValidationError(f"Unsupported grant type: {grant_type}", "grant_types")
# Validate response types if provided
response_types = request_data.get("response_types", [])
if response_types:
if not isinstance(response_types, list):
raise OAuthValidationError("response_types must be an array", "response_types")
allowed_response_types = ["code"]
for response_type in response_types:
if response_type not in allowed_response_types:
raise OAuthValidationError(f"Unsupported response type: {response_type}", "response_types")
def sanitize_user_input(value: str, max_length: int = 1000) -> str:
"""
Sanitize user input to prevent injection attacks.
Args:
value: The input value to sanitize
max_length: Maximum allowed length
Returns:
Sanitized input value
Raises:
OAuthValidationError: If the input is invalid
"""
if not isinstance(value, str):
raise OAuthValidationError("Input must be a string")
if len(value) > max_length:
raise OAuthValidationError(f"Input is too long (max {max_length} characters)")
# Remove potentially dangerous characters
sanitized = re.sub(r'[<>"\'\0\n\r\t]', '', value)
return sanitized.strip()
def log_security_event(event_type: str, details: Dict[str, Any], request: Optional[Request] = None) -> None:
"""
Log security-related events for monitoring.
Args:
event_type: Type of security event
details: Event details
request: Optional request object for context
"""
log_data = {
"event_type": event_type,
"details": details
}
if request:
log_data["request"] = {
"method": request.method,
"path": request.url.path,
"user_agent": request.headers.get("user-agent", "unknown"),
"origin": request.headers.get("origin", "unknown")
}
logger.warning(f"Security event: {log_data}")
def get_development_cors_headers(origin: Optional[str] = None) -> Dict[str, str]:
"""
Get minimal CORS headers for development scenarios only.
Only allows localhost origins for development tools and inspectors.
Args:
origin: The request origin (will be validated)
Returns:
CORS headers for localhost origins only, empty dict otherwise
"""
# Only allow localhost origins for development
if origin and (origin.startswith("http://localhost:") or origin.startswith("http://127.0.0.1:")):
return {
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
"Access-Control-Max-Age": "3600"
}
return {}

78
auth/oauth_types.py Normal file
View File

@@ -0,0 +1,78 @@
"""
Type definitions for OAuth authentication.
This module provides structured types for OAuth-related parameters,
improving code maintainability and type safety.
"""
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
@dataclass
class OAuth21ServiceRequest:
"""
Encapsulates parameters for OAuth 2.1 service authentication requests.
This parameter object pattern reduces function complexity and makes
it easier to extend authentication parameters in the future.
"""
service_name: str
version: str
tool_name: str
user_google_email: str
required_scopes: List[str]
session_id: Optional[str] = None
auth_token_email: Optional[str] = None
allow_recent_auth: bool = False
context: Optional[Dict[str, Any]] = None
def to_legacy_params(self) -> dict:
"""Convert to legacy parameter format for backward compatibility."""
return {
"service_name": self.service_name,
"version": self.version,
"tool_name": self.tool_name,
"user_google_email": self.user_google_email,
"required_scopes": self.required_scopes,
}
@dataclass
class OAuthVersionDetectionParams:
"""
Parameters used for OAuth version detection.
Encapsulates the various signals we use to determine
whether a client supports OAuth 2.1 or needs OAuth 2.0.
"""
client_id: Optional[str] = None
client_secret: Optional[str] = None
code_challenge: Optional[str] = None
code_challenge_method: Optional[str] = None
code_verifier: Optional[str] = None
authenticated_user: Optional[str] = None
session_id: Optional[str] = None
@classmethod
def from_request(cls, request_params: Dict[str, Any]) -> "OAuthVersionDetectionParams":
"""Create from raw request parameters."""
return cls(
client_id=request_params.get("client_id"),
client_secret=request_params.get("client_secret"),
code_challenge=request_params.get("code_challenge"),
code_challenge_method=request_params.get("code_challenge_method"),
code_verifier=request_params.get("code_verifier"),
authenticated_user=request_params.get("authenticated_user"),
session_id=request_params.get("session_id"),
)
@property
def has_pkce(self) -> bool:
"""Check if PKCE parameters are present."""
return bool(self.code_challenge or self.code_verifier)
@property
def is_public_client(self) -> bool:
"""Check if this appears to be a public client (no secret)."""
return bool(self.client_id and not self.client_secret)

View File

@@ -334,9 +334,31 @@ def require_google_service(
# Log authentication status # Log authentication status
logger.debug(f"[{tool_name}] Auth: {authenticated_user or 'none'} via {auth_method or 'none'} (session: {mcp_session_id[:8] if mcp_session_id else 'none'})") logger.debug(f"[{tool_name}] Auth: {authenticated_user or 'none'} via {auth_method or 'none'} (session: {mcp_session_id[:8] if mcp_session_id else 'none'})")
from auth.oauth21_integration import is_oauth21_enabled from auth.oauth_config import is_oauth21_enabled, get_oauth_config
# Smart OAuth version detection and fallback
use_oauth21 = False
oauth_version = "oauth20" # Default
if is_oauth21_enabled(): if is_oauth21_enabled():
# OAuth 2.1 is enabled globally, check client capabilities
# Try to detect from context if this is an OAuth 2.1 capable client
config = get_oauth_config()
# Build request params from context for version detection
request_params = {}
if authenticated_user:
request_params["authenticated_user"] = authenticated_user
if mcp_session_id:
request_params["session_id"] = mcp_session_id
# Detect OAuth version based on client capabilities
oauth_version = config.detect_oauth_version(request_params)
use_oauth21 = (oauth_version == "oauth21")
logger.debug(f"[{tool_name}] OAuth version detected: {oauth_version}, will use OAuth 2.1: {use_oauth21}")
if use_oauth21:
logger.debug(f"[{tool_name}] Using OAuth 2.1 flow") logger.debug(f"[{tool_name}] Using OAuth 2.1 flow")
# The downstream get_authenticated_google_service_oauth21 will handle # The downstream get_authenticated_google_service_oauth21 will handle
# whether the user's token is valid for the requested resource. # whether the user's token is valid for the requested resource.
@@ -352,8 +374,8 @@ def require_google_service(
allow_recent_auth=False, allow_recent_auth=False,
) )
else: else:
# If OAuth 2.1 is not enabled, always use the legacy authentication method. # Use legacy OAuth 2.0 authentication
logger.debug(f"[{tool_name}] Using legacy OAuth flow") logger.debug(f"[{tool_name}] Using legacy OAuth 2.0 flow")
service, actual_user_email = await get_authenticated_google_service( service, actual_user_email = await get_authenticated_google_service(
service_name=service_name, service_name=service_name,
version=service_version, version=service_version,
@@ -464,7 +486,7 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
logger.debug(f"[{tool_name}] Could not get FastMCP context: {e}") logger.debug(f"[{tool_name}] Could not get FastMCP context: {e}")
# Use the same logic as single service decorator # Use the same logic as single service decorator
from auth.oauth21_integration import is_oauth21_enabled from auth.oauth_config import is_oauth21_enabled
if is_oauth21_enabled(): if is_oauth21_enabled():
logger.debug(f"[{tool_name}] Attempting OAuth 2.1 authentication flow for {service_type}.") logger.debug(f"[{tool_name}] Attempting OAuth 2.1 authentication flow for {service_type}.")

View File

@@ -2,57 +2,34 @@
Shared configuration for Google Workspace MCP server. Shared configuration for Google Workspace MCP server.
This module holds configuration values that need to be shared across modules This module holds configuration values that need to be shared across modules
to avoid circular imports. to avoid circular imports.
NOTE: OAuth configuration has been moved to auth.oauth_config for centralization.
This module now imports from there for backward compatibility.
""" """
import os import os
from auth.oauth_config import (
get_oauth_base_url,
get_oauth_redirect_uri,
set_transport_mode,
get_transport_mode,
is_oauth21_enabled
)
# Server configuration # Server configuration
WORKSPACE_MCP_PORT = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000))) WORKSPACE_MCP_PORT = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
WORKSPACE_MCP_BASE_URI = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost") WORKSPACE_MCP_BASE_URI = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
# Disable USER_GOOGLE_EMAIL in OAuth 2.1 multi-user mode # Disable USER_GOOGLE_EMAIL in OAuth 2.1 multi-user mode
_oauth21_enabled = os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true" USER_GOOGLE_EMAIL = None if is_oauth21_enabled() else os.getenv("USER_GOOGLE_EMAIL", None)
USER_GOOGLE_EMAIL = None if _oauth21_enabled else os.getenv("USER_GOOGLE_EMAIL", None)
# Transport mode (will be set by main.py) # Re-export OAuth functions for backward compatibility
_current_transport_mode = "stdio" # Default to stdio __all__ = [
'WORKSPACE_MCP_PORT',
'WORKSPACE_MCP_BASE_URI',
def set_transport_mode(mode: str): 'USER_GOOGLE_EMAIL',
"""Set the current transport mode for OAuth callback handling.""" 'get_oauth_base_url',
global _current_transport_mode 'get_oauth_redirect_uri',
_current_transport_mode = mode 'set_transport_mode',
'get_transport_mode'
]
def get_transport_mode() -> str:
"""Get the current transport mode."""
return _current_transport_mode
# OAuth Configuration
# Determine base URL and redirect URI once at startup
_OAUTH_REDIRECT_URI = os.getenv("GOOGLE_OAUTH_REDIRECT_URI")
if _OAUTH_REDIRECT_URI:
# Extract base URL from the redirect URI (remove the /oauth2callback path)
_OAUTH_BASE_URL = _OAUTH_REDIRECT_URI.removesuffix("/oauth2callback")
else:
# Construct from base URI and port if not explicitly set
_OAUTH_BASE_URL = f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}"
_OAUTH_REDIRECT_URI = f"{_OAUTH_BASE_URL}/oauth2callback"
def get_oauth_base_url() -> str:
"""Get OAuth base URL for constructing OAuth endpoints.
Returns the base URL (without paths) for OAuth endpoints,
respecting GOOGLE_OAUTH_REDIRECT_URI for reverse proxy scenarios.
"""
return _OAUTH_BASE_URL
def get_oauth_redirect_uri() -> str:
"""Get OAuth redirect URI based on current configuration.
Returns the redirect URI configured at startup, either from
GOOGLE_OAUTH_REDIRECT_URI environment variable or constructed
from WORKSPACE_MCP_BASE_URI and WORKSPACE_MCP_PORT.
"""
return _OAUTH_REDIRECT_URI

View File

@@ -7,7 +7,6 @@ from fastapi.responses import HTMLResponse, JSONResponse
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.requests import Request from starlette.requests import Request
from starlette.middleware import Middleware from starlette.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -19,8 +18,6 @@ from auth.auth_info_middleware import AuthInfoMiddleware
from auth.fastmcp_google_auth import GoogleWorkspaceAuthProvider from auth.fastmcp_google_auth import GoogleWorkspaceAuthProvider
from auth.scopes import SCOPES from auth.scopes import SCOPES
from core.config import ( from core.config import (
WORKSPACE_MCP_PORT,
WORKSPACE_MCP_BASE_URI,
USER_GOOGLE_EMAIL, USER_GOOGLE_EMAIL,
get_transport_mode, get_transport_mode,
set_transport_mode as _set_transport_mode, set_transport_mode as _set_transport_mode,
@@ -41,31 +38,25 @@ logger = logging.getLogger(__name__)
_auth_provider: Optional[Union[GoogleWorkspaceAuthProvider, GoogleRemoteAuthProvider]] = None _auth_provider: Optional[Union[GoogleWorkspaceAuthProvider, GoogleRemoteAuthProvider]] = None
# --- Middleware Definitions --- # --- Middleware Definitions ---
cors_middleware = Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
session_middleware = Middleware(MCPSessionMiddleware) session_middleware = Middleware(MCPSessionMiddleware)
# Custom FastMCP that adds CORS to streamable HTTP # Custom FastMCP that adds secure middleware stack for OAuth 2.1
class CORSEnabledFastMCP(FastMCP): class SecureFastMCP(FastMCP):
def streamable_http_app(self) -> "Starlette": def streamable_http_app(self) -> "Starlette":
"""Override to add CORS and session middleware to the app.""" """Override to add secure middleware stack for OAuth 2.1."""
app = super().streamable_http_app() app = super().streamable_http_app()
# Add session middleware first (to set context before other middleware)
# Add middleware in order (first added = outermost layer)
# Session Management - extracts session info for MCP context
app.user_middleware.insert(0, session_middleware) app.user_middleware.insert(0, session_middleware)
# Add CORS as the second middleware
app.user_middleware.insert(1, cors_middleware)
# Rebuild middleware stack # Rebuild middleware stack
app.middleware_stack = app.build_middleware_stack() app.middleware_stack = app.build_middleware_stack()
logger.info("Added session and CORS middleware to streamable HTTP app") logger.info("Added middleware stack: Session Management")
return app return app
# --- Server Instance --- # --- Server Instance ---
server = CORSEnabledFastMCP( server = SecureFastMCP(
name="google_workspace", name="google_workspace",
auth=None, auth=None,
) )
@@ -86,32 +77,37 @@ def configure_server_for_http():
This must be called BEFORE server.run(). This must be called BEFORE server.run().
""" """
global _auth_provider global _auth_provider
transport_mode = get_transport_mode() transport_mode = get_transport_mode()
if transport_mode != "streamable-http": if transport_mode != "streamable-http":
return return
oauth21_enabled = os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true" # Use centralized OAuth configuration
from auth.oauth_config import get_oauth_config
config = get_oauth_config()
# Check if OAuth 2.1 is enabled via centralized config
oauth21_enabled = config.is_oauth21_enabled()
if oauth21_enabled: if oauth21_enabled:
if not os.getenv("GOOGLE_OAUTH_CLIENT_ID"): if not config.is_configured():
logger.warning("⚠️ OAuth 2.1 enabled but GOOGLE_OAUTH_CLIENT_ID not set") logger.warning("⚠️ OAuth 2.1 enabled but OAuth credentials not configured")
return return
if GOOGLE_REMOTE_AUTH_AVAILABLE: if GOOGLE_REMOTE_AUTH_AVAILABLE:
logger.info("🔐 OAuth 2.1 enabled") logger.info("🔐 OAuth 2.1 enabled with automatic OAuth 2.0 fallback for legacy clients")
try: try:
_auth_provider = GoogleRemoteAuthProvider() _auth_provider = GoogleRemoteAuthProvider()
server.auth = _auth_provider server.auth = _auth_provider
set_auth_provider(_auth_provider) set_auth_provider(_auth_provider)
from auth.oauth21_integration import enable_oauth21 logger.debug("OAuth 2.1 authentication enabled")
enable_oauth21()
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize GoogleRemoteAuthProvider: {e}", exc_info=True) logger.error(f"Failed to initialize GoogleRemoteAuthProvider: {e}", exc_info=True)
else: else:
logger.error("OAuth 2.1 is enabled, but GoogleRemoteAuthProvider is not available.") logger.error("OAuth 2.1 is enabled, but GoogleRemoteAuthProvider is not available.")
else: else:
logger.info("OAuth 2.1 is DISABLED. Server will use legacy tool-based authentication.") logger.info("OAuth 2.0 mode - Server will use legacy authentication.")
server.auth = None server.auth = None
def get_auth_provider() -> Optional[Union[GoogleWorkspaceAuthProvider, GoogleRemoteAuthProvider]]: def get_auth_provider() -> Optional[Union[GoogleWorkspaceAuthProvider, GoogleRemoteAuthProvider]]:

View File

@@ -1,256 +0,0 @@
#!/usr/bin/env python3
"""
Auto-installer for Google Workspace MCP in Claude Desktop
Enhanced version with OAuth configuration and installation options
"""
import json
import os
import platform
import sys
from pathlib import Path
from typing import Dict, Optional, Tuple
def get_claude_config_path() -> Path:
"""Get the Claude Desktop config file path for the current platform."""
system = platform.system()
if system == "Darwin": # macOS
return Path.home() / "Library/Application Support/Claude/claude_desktop_config.json"
elif system == "Windows":
appdata = os.environ.get("APPDATA")
if not appdata:
raise RuntimeError("APPDATA environment variable not found")
return Path(appdata) / "Claude/claude_desktop_config.json"
else:
raise RuntimeError(f"Unsupported platform: {system}")
def prompt_yes_no(question: str, default: bool = True) -> bool:
"""Prompt user for yes/no question."""
default_str = "Y/n" if default else "y/N"
while True:
response = input(f"{question} [{default_str}]: ").strip().lower()
if not response:
return default
if response in ['y', 'yes']:
return True
if response in ['n', 'no']:
return False
print("Please answer 'y' or 'n'")
def get_oauth_credentials() -> Tuple[Optional[Dict[str, str]], Optional[str]]:
"""Get OAuth credentials from user."""
print("\n🔑 OAuth Credentials Setup")
print("You need Google OAuth 2.0 credentials to use this server.")
print("\nYou can provide credentials in two ways:")
print("1. Environment variables (recommended for production)")
print("2. Client secrets JSON file")
use_env = prompt_yes_no("\nDo you want to use environment variables?", default=True)
env_vars = {}
client_secret_path = None
if use_env:
print("\n📝 Enter your OAuth credentials:")
client_id = input("Client ID (ends with .apps.googleusercontent.com): ").strip()
client_secret = input("Client Secret: ").strip()
if not client_id or not client_secret:
print("❌ Both Client ID and Client Secret are required!")
return None, None
env_vars["GOOGLE_OAUTH_CLIENT_ID"] = client_id
env_vars["GOOGLE_OAUTH_CLIENT_SECRET"] = client_secret
# Optional redirect URI
custom_redirect = input("Redirect URI (press Enter for default http://localhost:8000/oauth2callback): ").strip()
if custom_redirect:
env_vars["GOOGLE_OAUTH_REDIRECT_URI"] = custom_redirect
else:
print("\n📁 Client secrets file setup:")
default_path = "client_secret.json"
file_path = input(f"Path to client_secret.json file [{default_path}]: ").strip()
if not file_path:
file_path = default_path
# Check if file exists
if not Path(file_path).exists():
print(f"❌ File not found: {file_path}")
return None, None
client_secret_path = file_path
# Optional: Default user email
print("\n📧 Optional: Default user email (for single-user setups)")
user_email = input("Your Google email (press Enter to skip): ").strip()
if user_email:
env_vars["USER_GOOGLE_EMAIL"] = user_email
# Development mode
if prompt_yes_no("\n🔧 Enable development mode (OAUTHLIB_INSECURE_TRANSPORT)?", default=False):
env_vars["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
return env_vars, client_secret_path
def get_installation_options() -> Dict[str, any]:
"""Get installation options from user."""
options = {}
print("\n⚙️ Installation Options")
# Installation method
print("\nChoose installation method:")
print("1. uvx (recommended - auto-installs from PyPI)")
print("2. Development mode (requires local repository)")
method = input("Select method [1]: ").strip()
if method == "2":
options["dev_mode"] = True
cwd = input("Path to google_workspace_mcp repository [current directory]: ").strip()
options["cwd"] = cwd if cwd else os.getcwd()
else:
options["dev_mode"] = False
# Single-user mode
if prompt_yes_no("\n👤 Enable single-user mode (simplified authentication)?", default=False):
options["single_user"] = True
# Tool selection
print("\n🛠️ Tool Selection")
print("Available tools: gmail, drive, calendar, docs, sheets, forms, chat")
print("Leave empty to enable all tools")
tools = input("Enter tools to enable (comma-separated): ").strip()
if tools:
options["tools"] = [t.strip() for t in tools.split(",")]
# Transport mode
if prompt_yes_no("\n🌐 Use HTTP transport mode (for debugging)?", default=False):
options["http_mode"] = True
return options
def create_server_config(options: Dict, env_vars: Dict, client_secret_path: Optional[str]) -> Dict:
"""Create the server configuration."""
config = {}
if options.get("dev_mode"):
config["command"] = "uv"
config["args"] = ["run", "--directory", options["cwd"], "main.py"]
else:
config["command"] = "uvx"
config["args"] = ["workspace-mcp"]
# Add command line arguments
if options.get("single_user"):
config["args"].append("--single-user")
if options.get("tools"):
config["args"].extend(["--tools"] + options["tools"])
if options.get("http_mode"):
config["args"].extend(["--transport", "streamable-http"])
# Add environment variables
if env_vars or client_secret_path:
config["env"] = {}
if env_vars:
config["env"].update(env_vars)
if client_secret_path:
config["env"]["GOOGLE_CLIENT_SECRET_PATH"] = client_secret_path
return config
def main():
print("🚀 Google Workspace MCP Installer for Claude Desktop")
print("=" * 50)
try:
config_path = get_claude_config_path()
# Check if config already exists
existing_config = {}
if config_path.exists():
with open(config_path, 'r') as f:
existing_config = json.load(f)
if "mcpServers" in existing_config and "Google Workspace" in existing_config["mcpServers"]:
print(f"\n⚠️ Google Workspace MCP is already configured in {config_path}")
if not prompt_yes_no("Do you want to reconfigure it?", default=True):
print("Installation cancelled.")
return
# Get OAuth credentials
env_vars, client_secret_path = get_oauth_credentials()
if env_vars is None and client_secret_path is None:
print("\n❌ OAuth credentials are required. Installation cancelled.")
sys.exit(1)
# Get installation options
options = get_installation_options()
# Create server configuration
server_config = create_server_config(options, env_vars, client_secret_path)
# Prepare final config
if "mcpServers" not in existing_config:
existing_config["mcpServers"] = {}
existing_config["mcpServers"]["Google Workspace"] = server_config
# Create directory if needed
config_path.parent.mkdir(parents=True, exist_ok=True)
# Write configuration
with open(config_path, 'w') as f:
json.dump(existing_config, f, indent=2)
print("\n✅ Successfully configured Google Workspace MCP!")
print(f"📁 Config file: {config_path}")
print("\n📋 Configuration Summary:")
print(f" • Installation method: {'Development' if options.get('dev_mode') else 'uvx (PyPI)'}")
print(f" • Authentication: {'Environment variables' if env_vars else 'Client secrets file'}")
if options.get("single_user"):
print(" • Single-user mode: Enabled")
if options.get("tools"):
print(f" • Tools: {', '.join(options['tools'])}")
else:
print(" • Tools: All enabled")
if options.get("http_mode"):
print(" • Transport: HTTP mode")
else:
print(" • Transport: stdio (default)")
print("\n🚀 Next steps:")
print("1. Restart Claude Desktop")
print("2. The Google Workspace tools will be available in your chats!")
print("\n💡 The server will start automatically when Claude Desktop needs it.")
if options.get("http_mode"):
print("\n⚠️ Note: HTTP mode requires additional setup.")
print(" You may need to install and configure mcp-remote.")
print(" See the README for details.")
except KeyboardInterrupt:
print("\n\nInstallation cancelled by user.")
sys.exit(0)
except Exception as e:
print(f"\n❌ Error: {e}")
print("\n📋 Manual installation:")
print("1. Open Claude Desktop Settings → Developer → Edit Config")
print("2. Add the server configuration shown in the README")
sys.exit(1)
if __name__ == "__main__":
main()

18
main.py
View File

@@ -4,17 +4,19 @@ import os
import sys import sys
from importlib import metadata from importlib import metadata
from dotenv import load_dotenv from dotenv import load_dotenv
# Load environment variables from .env file BEFORE any other imports
dotenv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.env')
load_dotenv(dotenv_path=dotenv_path)
from core.server import server, set_transport_mode, configure_server_for_http from core.server import server, set_transport_mode, configure_server_for_http
from auth.oauth_config import reload_oauth_config
reload_oauth_config()
# Suppress googleapiclient discovery cache warning # Suppress googleapiclient discovery cache warning
logging.getLogger('googleapiclient.discovery_cache').setLevel(logging.ERROR) logging.getLogger('googleapiclient.discovery_cache').setLevel(logging.ERROR)
from core.utils import check_credentials_directory_permissions from core.utils import check_credentials_directory_permissions
# Load environment variables from .env file, specifying an explicit path
# This prevents accidentally loading a .env file from a different directory
dotenv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.env')
load_dotenv(dotenv_path=dotenv_path)
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
@@ -92,6 +94,7 @@ def main():
# Active Configuration # Active Configuration
safe_print("⚙️ Active Configuration:") safe_print("⚙️ Active Configuration:")
# Redact client secret for security # Redact client secret for security
client_secret = os.getenv('GOOGLE_OAUTH_CLIENT_SECRET', 'Not Set') client_secret = os.getenv('GOOGLE_OAUTH_CLIENT_SECRET', 'Not Set')
redacted_secret = f"{client_secret[:4]}...{client_secret[-4:]}" if len(client_secret) > 8 else "Invalid or too short" redacted_secret = f"{client_secret[:4]}...{client_secret[-4:]}" if len(client_secret) > 8 else "Invalid or too short"
@@ -153,7 +156,6 @@ def main():
safe_print("📊 Configuration Summary:") safe_print("📊 Configuration Summary:")
safe_print(f" 🔧 Tools Enabled: {len(tools_to_import)}/{len(tool_imports)}") safe_print(f" 🔧 Tools Enabled: {len(tools_to_import)}/{len(tool_imports)}")
safe_print(" 🔑 Auth Method: OAuth 2.0 with PKCE")
safe_print(f" 📝 Log Level: {logging.getLogger().getEffectiveLevel()}") safe_print(f" 📝 Log Level: {logging.getLogger().getEffectiveLevel()}")
safe_print("") safe_print("")
@@ -182,10 +184,10 @@ def main():
# Configure auth initialization for FastMCP lifecycle events # Configure auth initialization for FastMCP lifecycle events
if args.transport == 'streamable-http': if args.transport == 'streamable-http':
configure_server_for_http() configure_server_for_http()
safe_print(f"") safe_print("")
safe_print(f"🚀 Starting HTTP server on {base_uri}:{port}") safe_print(f"🚀 Starting HTTP server on {base_uri}:{port}")
else: else:
safe_print(f"") safe_print("")
safe_print("🚀 Starting STDIO server") safe_print("🚀 Starting STDIO server")
# Start minimal OAuth callback server for stdio mode # Start minimal OAuth callback server for stdio mode
from auth.oauth_callback_server import ensure_oauth_callback_available from auth.oauth_callback_server import ensure_oauth_callback_available

2202
uv.lock generated

File diff suppressed because it is too large Load Diff