Merge pull request #153 from taylorwilsdon/vscode_oauth_support
feat: Streamlined VSCode OAuth Support
This commit is contained in:
46
README.md
46
README.md
@@ -51,6 +51,8 @@
|
|||||||
|
|
||||||
A production-ready MCP server that integrates all major Google Workspace services with AI assistants. It supports both single-user operation and multi-user authentication via OAuth 2.1, making it a powerful backend for custom applications. Built with FastMCP for optimal performance, featuring advanced authentication handling, service caching, and streamlined development patterns.
|
A production-ready MCP server that integrates all major Google Workspace services with AI assistants. It supports both single-user operation and multi-user authentication via OAuth 2.1, making it a powerful backend for custom applications. Built with FastMCP for optimal performance, featuring advanced authentication handling, service caching, and streamlined development patterns.
|
||||||
|
|
||||||
|
**🎉 Simplified Setup**: Now uses Google Desktop OAuth clients - no redirect URIs or port configuration needed!
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **🔐 Advanced OAuth 2.0 & OAuth 2.1**: Secure authentication with automatic token refresh, transport-aware callback handling, session management, centralized scope management, and OAuth 2.1 bearer token support for multi-user environments with innovative CORS proxy architecture
|
- **🔐 Advanced OAuth 2.0 & OAuth 2.1**: Secure authentication with automatic token refresh, transport-aware callback handling, session management, centralized scope management, and OAuth 2.1 bearer token support for multi-user environments with innovative CORS proxy architecture
|
||||||
@@ -95,6 +97,8 @@ A production-ready MCP server that integrates all major Google Workspace service
|
|||||||
| `GOOGLE_PSE_API_KEY` *(optional)* | API key for Google Custom Search - see [Custom Search Setup](#google-custom-search-setup) |
|
| `GOOGLE_PSE_API_KEY` *(optional)* | API key for Google Custom Search - see [Custom Search Setup](#google-custom-search-setup) |
|
||||||
| `GOOGLE_PSE_ENGINE_ID` *(optional)* | Programmable Search Engine ID for Custom Search |
|
| `GOOGLE_PSE_ENGINE_ID` *(optional)* | Programmable Search Engine ID for Custom Search |
|
||||||
| `MCP_ENABLE_OAUTH21` *(optional)* | Set to `true` to enable OAuth 2.1 support (requires streamable-http transport) |
|
| `MCP_ENABLE_OAUTH21` *(optional)* | Set to `true` to enable OAuth 2.1 support (requires streamable-http transport) |
|
||||||
|
| `OAUTH_CUSTOM_REDIRECT_URIS` *(optional)* | Comma-separated list of additional redirect URIs |
|
||||||
|
| `OAUTH_ALLOWED_ORIGINS` *(optional)* | Comma-separated list of additional CORS origins |
|
||||||
| `OAUTHLIB_INSECURE_TRANSPORT=1` | Development only (allows `http://` redirect) |
|
| `OAUTHLIB_INSECURE_TRANSPORT=1` | Development only (allows `http://` redirect) |
|
||||||
|
|
||||||
Claude Desktop stores these securely in the OS keychain; set them once in the extension pane.
|
Claude Desktop stores these securely in the OS keychain; set them once in the extension pane.
|
||||||
@@ -114,12 +118,12 @@ Claude Desktop stores these securely in the OS keychain; set them once in the ex
|
|||||||
### Configuration
|
### Configuration
|
||||||
|
|
||||||
1. **Google Cloud Setup**:
|
1. **Google Cloud Setup**:
|
||||||
- Create OAuth 2.0 credentials (web application) in [Google Cloud Console](https://console.cloud.google.com/)
|
- Create OAuth 2.0 credentials in [Google Cloud Console](https://console.cloud.google.com/)
|
||||||
- Create a new project (or use an existing one) for your MCP server.
|
- Create a new project (or use an existing one) for your MCP server.
|
||||||
- Navigate to APIs & Services → Credentials.
|
- Navigate to APIs & Services → Credentials.
|
||||||
- Click Create Credentials → OAuth Client ID.
|
- Click Create Credentials → OAuth Client ID.
|
||||||
- Choose Web Application as the application type.
|
- **Choose Desktop Application as the application type** (simpler setup, no redirect URIs needed!)
|
||||||
- Add redirect URI: `http://localhost:8000/oauth2callback`
|
- Download your credentials and note the Client ID and Client Secret
|
||||||
|
|
||||||
- **Enable APIs**:
|
- **Enable APIs**:
|
||||||
- In the Google Cloud Console, go to APIs & Services → Library.
|
- In the Google Cloud Console, go to APIs & Services → Library.
|
||||||
@@ -278,6 +282,29 @@ This architecture enables any OAuth 2.1 compliant client to authenticate users t
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
**MCP Inspector**: No additional configuration needed with desktop OAuth client.
|
||||||
|
|
||||||
|
**Claude Code Inspector**: No additional configuration needed with desktop OAuth client.
|
||||||
|
|
||||||
|
### VS Code MCP Client Support
|
||||||
|
|
||||||
|
The server includes native support for VS Code's MCP client:
|
||||||
|
|
||||||
|
- **No Configuration Required**: Works out-of-the-box with VS Code's MCP extension
|
||||||
|
- **Standards Compliant**: Full OAuth 2.1 compliance with desktop OAuth clients
|
||||||
|
|
||||||
|
**VS Code mcp.json Configuration Example**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"servers": {
|
||||||
|
"google-workspace": {
|
||||||
|
"url": "http://localhost:8000/mcp/",
|
||||||
|
"type": "http"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### Connect to Claude Desktop
|
### Connect to Claude Desktop
|
||||||
|
|
||||||
The server supports two transport modes:
|
The server supports two transport modes:
|
||||||
@@ -422,17 +449,18 @@ If you need to use HTTP mode with Claude Desktop:
|
|||||||
|
|
||||||
### First-Time Authentication
|
### First-Time Authentication
|
||||||
|
|
||||||
The server features **transport-aware OAuth callback handling**:
|
The server uses **Google Desktop OAuth** for simplified authentication:
|
||||||
|
|
||||||
- **Stdio Mode**: Automatically starts a minimal HTTP server on port 8000 for OAuth callbacks
|
- **No redirect URIs needed**: Desktop OAuth clients handle authentication without complex callback URLs
|
||||||
- **HTTP Mode**: Uses the existing FastAPI server for OAuth callbacks
|
- **Automatic flow**: The server manages the entire OAuth process transparently
|
||||||
- **Same OAuth Flow**: Both modes use `http://localhost:8000/oauth2callback` for consistency
|
- **Transport-agnostic**: Works seamlessly in both stdio and HTTP modes
|
||||||
|
|
||||||
When calling a tool:
|
When calling a tool:
|
||||||
1. Server returns authorization URL
|
1. Server returns authorization URL
|
||||||
2. Open URL in browser and authorize
|
2. Open URL in browser and authorize
|
||||||
3. Server handles OAuth callback automatically (on port 8000 in both modes)
|
3. Google provides an authorization code
|
||||||
4. Retry the original request
|
4. Paste the code when prompted (or it's handled automatically)
|
||||||
|
5. Server completes authentication and retries your request
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from pydantic import AnyHttpUrl
|
|||||||
try:
|
try:
|
||||||
from fastmcp.server.auth import RemoteAuthProvider
|
from fastmcp.server.auth import RemoteAuthProvider
|
||||||
from fastmcp.server.auth.providers.jwt import JWTVerifier
|
from fastmcp.server.auth.providers.jwt import JWTVerifier
|
||||||
|
|
||||||
REMOTEAUTHPROVIDER_AVAILABLE = True
|
REMOTEAUTHPROVIDER_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
REMOTEAUTHPROVIDER_AVAILABLE = False
|
REMOTEAUTHPROVIDER_AVAILABLE = False
|
||||||
@@ -35,9 +36,10 @@ except ImportError:
|
|||||||
from auth.oauth_common_handlers import (
|
from auth.oauth_common_handlers import (
|
||||||
handle_oauth_authorize,
|
handle_oauth_authorize,
|
||||||
handle_proxy_token_exchange,
|
handle_proxy_token_exchange,
|
||||||
|
handle_oauth_protected_resource,
|
||||||
handle_oauth_authorization_server,
|
handle_oauth_authorization_server,
|
||||||
handle_oauth_client_config,
|
handle_oauth_client_config,
|
||||||
handle_oauth_register
|
handle_oauth_register,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -45,12 +47,12 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class GoogleRemoteAuthProvider(RemoteAuthProvider):
|
class GoogleRemoteAuthProvider(RemoteAuthProvider):
|
||||||
"""
|
"""
|
||||||
RemoteAuthProvider implementation for Google Workspace using FastMCP v2.11.1+.
|
RemoteAuthProvider implementation for Google Workspace.
|
||||||
|
|
||||||
This provider extends RemoteAuthProvider to add:
|
This provider extends RemoteAuthProvider to add:
|
||||||
- OAuth proxy endpoints for CORS workaround
|
- OAuth proxy endpoints for CORS workaround
|
||||||
- Dynamic client registration support
|
- Dynamic client registration support
|
||||||
- Enhanced session management with issuer tracking
|
- Session management with issuer tracking
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -65,51 +67,90 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
|
|||||||
self.port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
|
self.port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
|
||||||
|
|
||||||
if not self.client_id:
|
if not self.client_id:
|
||||||
logger.error("GOOGLE_OAUTH_CLIENT_ID not set - OAuth 2.1 authentication will not work")
|
logger.error(
|
||||||
raise ValueError("GOOGLE_OAUTH_CLIENT_ID environment variable is required for OAuth 2.1 authentication")
|
"GOOGLE_OAUTH_CLIENT_ID not set - OAuth 2.1 authentication will not work"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"GOOGLE_OAUTH_CLIENT_ID environment variable is required for OAuth 2.1 authentication"
|
||||||
|
)
|
||||||
|
|
||||||
# Configure JWT verifier for Google tokens
|
# Configure JWT verifier for Google tokens
|
||||||
token_verifier = JWTVerifier(
|
token_verifier = JWTVerifier(
|
||||||
jwks_uri="https://www.googleapis.com/oauth2/v3/certs",
|
jwks_uri="https://www.googleapis.com/oauth2/v3/certs",
|
||||||
issuer="https://accounts.google.com",
|
issuer="https://accounts.google.com",
|
||||||
audience=self.client_id, # Always use actual client_id
|
audience=self.client_id, # Always use actual client_id
|
||||||
algorithm="RS256"
|
algorithm="RS256",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize RemoteAuthProvider with local server as the authorization server
|
# Initialize RemoteAuthProvider with base URL (no /mcp/ suffix)
|
||||||
# This ensures OAuth discovery points to our proxy endpoints instead of Google directly
|
# The /mcp/ resource URL is handled in the protected resource metadata endpoint
|
||||||
super().__init__(
|
super().__init__(
|
||||||
token_verifier=token_verifier,
|
token_verifier=token_verifier,
|
||||||
authorization_servers=[AnyHttpUrl(f"{self.base_url}:{self.port}")],
|
authorization_servers=[AnyHttpUrl(f"{self.base_url}:{self.port}")],
|
||||||
resource_server_url=f"{self.base_url}:{self.port}"
|
resource_server_url=f"{self.base_url}:{self.port}",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("GoogleRemoteAuthProvider initialized")
|
logger.debug("GoogleRemoteAuthProvider")
|
||||||
|
|
||||||
def get_routes(self) -> List[Route]:
|
def get_routes(self) -> List[Route]:
|
||||||
"""
|
"""
|
||||||
Add custom OAuth proxy endpoints to the standard protected resource routes.
|
Add OAuth routes at canonical locations.
|
||||||
|
|
||||||
These endpoints work around Google's CORS restrictions and provide
|
|
||||||
dynamic client registration support.
|
|
||||||
"""
|
"""
|
||||||
# Get the standard OAuth protected resource routes from RemoteAuthProvider
|
# Get the standard OAuth protected resource routes from RemoteAuthProvider
|
||||||
routes = super().get_routes()
|
parent_routes = super().get_routes()
|
||||||
|
|
||||||
# Log what routes we're getting from the parent
|
# Filter out the parent's oauth-protected-resource route since we're replacing it
|
||||||
logger.debug(f"Registered {len(routes)} OAuth routes from parent")
|
routes = [
|
||||||
|
r
|
||||||
|
for r in parent_routes
|
||||||
|
if r.path != "/.well-known/oauth-protected-resource"
|
||||||
|
]
|
||||||
|
|
||||||
# Add our custom proxy endpoints using common handlers
|
# Add our custom OAuth discovery endpoint that returns /mcp/ as the resource
|
||||||
routes.append(Route("/oauth2/authorize", handle_oauth_authorize, methods=["GET", "OPTIONS"]))
|
routes.append(
|
||||||
|
Route(
|
||||||
|
"/.well-known/oauth-protected-resource",
|
||||||
|
handle_oauth_protected_resource,
|
||||||
|
methods=["GET", "OPTIONS"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
routes.append(Route("/oauth2/token", handle_proxy_token_exchange, methods=["POST", "OPTIONS"]))
|
routes.append(
|
||||||
|
Route(
|
||||||
|
"/.well-known/oauth-authorization-server",
|
||||||
|
handle_oauth_authorization_server,
|
||||||
|
methods=["GET", "OPTIONS"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
routes.append(Route("/oauth2/register", handle_oauth_register, methods=["POST", "OPTIONS"]))
|
routes.append(
|
||||||
|
Route(
|
||||||
|
"/.well-known/oauth-client",
|
||||||
|
handle_oauth_client_config,
|
||||||
|
methods=["GET", "OPTIONS"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
routes.append(Route("/.well-known/oauth-authorization-server", handle_oauth_authorization_server, methods=["GET", "OPTIONS"]))
|
# Add OAuth flow endpoints
|
||||||
|
routes.append(
|
||||||
routes.append(Route("/.well-known/oauth-client", handle_oauth_client_config, methods=["GET", "OPTIONS"]))
|
Route(
|
||||||
|
"/oauth2/authorize", handle_oauth_authorize, methods=["GET", "OPTIONS"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
routes.append(
|
||||||
|
Route(
|
||||||
|
"/oauth2/token",
|
||||||
|
handle_proxy_token_exchange,
|
||||||
|
methods=["POST", "OPTIONS"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
routes.append(
|
||||||
|
Route(
|
||||||
|
"/oauth2/register", handle_oauth_register, methods=["POST", "OPTIONS"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Registered {len(routes)} OAuth routes")
|
||||||
return routes
|
return routes
|
||||||
|
|
||||||
async def verify_token(self, token: str) -> Optional[object]:
|
async def verify_token(self, token: str) -> Optional[object]:
|
||||||
@@ -121,22 +162,30 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
|
|||||||
"""
|
"""
|
||||||
# Check if this is a Google OAuth access token (starts with ya29.)
|
# Check if this is a Google OAuth access token (starts with ya29.)
|
||||||
if token.startswith("ya29."):
|
if token.startswith("ya29."):
|
||||||
logger.debug("Detected Google OAuth access token, using tokeninfo verification")
|
logger.debug(
|
||||||
|
"Detected Google OAuth access token, using tokeninfo verification"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify the access token using Google's tokeninfo endpoint
|
# Verify the access token using Google's tokeninfo endpoint
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={token}"
|
url = (
|
||||||
|
f"https://oauth2.googleapis.com/tokeninfo?access_token={token}"
|
||||||
|
)
|
||||||
async with session.get(url) as response:
|
async with session.get(url) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
logger.error(f"Token verification failed: {response.status}")
|
logger.error(
|
||||||
|
f"Token verification failed: {response.status}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
token_info = await response.json()
|
token_info = await response.json()
|
||||||
|
|
||||||
# Verify the token is for our client
|
# Verify the token is for our client
|
||||||
if token_info.get("aud") != self.client_id:
|
if token_info.get("aud") != self.client_id:
|
||||||
logger.error(f"Token audience mismatch: expected {self.client_id}, got {token_info.get('aud')}")
|
logger.error(
|
||||||
|
f"Token audience mismatch: expected {self.client_id}, got {token_info.get('aud')}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if token is expired
|
# Check if token is expired
|
||||||
@@ -151,7 +200,9 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
|
|||||||
|
|
||||||
# Calculate expires_at timestamp
|
# Calculate expires_at timestamp
|
||||||
expires_in = int(token_info.get("expires_in", 0))
|
expires_in = int(token_info.get("expires_in", 0))
|
||||||
expires_at = int(time.time()) + expires_in if expires_in > 0 else 0
|
expires_at = (
|
||||||
|
int(time.time()) + expires_in if expires_in > 0 else 0
|
||||||
|
)
|
||||||
|
|
||||||
access_token = SimpleNamespace(
|
access_token = SimpleNamespace(
|
||||||
claims={
|
claims={
|
||||||
@@ -166,12 +217,15 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
|
|||||||
client_id=self.client_id, # Add client_id at top level
|
client_id=self.client_id, # Add client_id at top level
|
||||||
# Add other required fields
|
# Add other required fields
|
||||||
sub=token_info.get("sub", ""),
|
sub=token_info.get("sub", ""),
|
||||||
email=token_info.get("email", "")
|
email=token_info.get("email", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
user_email = token_info.get("email")
|
user_email = token_info.get("email")
|
||||||
if user_email:
|
if user_email:
|
||||||
from auth.oauth21_session_store import get_oauth21_session_store
|
from auth.oauth21_session_store import (
|
||||||
|
get_oauth21_session_store,
|
||||||
|
)
|
||||||
|
|
||||||
store = get_oauth21_session_store()
|
store = get_oauth21_session_store()
|
||||||
session_id = f"google_{token_info.get('sub', 'unknown')}"
|
session_id = f"google_{token_info.get('sub', 'unknown')}"
|
||||||
|
|
||||||
@@ -179,10 +233,13 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
|
|||||||
mcp_session_id = None
|
mcp_session_id = None
|
||||||
try:
|
try:
|
||||||
from fastmcp.server.dependencies import get_context
|
from fastmcp.server.dependencies import get_context
|
||||||
|
|
||||||
ctx = get_context()
|
ctx = get_context()
|
||||||
if ctx and hasattr(ctx, 'session_id'):
|
if ctx and hasattr(ctx, "session_id"):
|
||||||
mcp_session_id = ctx.session_id
|
mcp_session_id = ctx.session_id
|
||||||
logger.debug(f"Binding MCP session {mcp_session_id} to user {user_email}")
|
logger.debug(
|
||||||
|
f"Binding MCP session {mcp_session_id} to user {user_email}"
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -193,7 +250,7 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
|
|||||||
scopes=access_token.scopes,
|
scopes=access_token.scopes,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
mcp_session_id=mcp_session_id,
|
mcp_session_id=mcp_session_id,
|
||||||
issuer="https://accounts.google.com"
|
issuer="https://accounts.google.com",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Verified OAuth token: {user_email}")
|
logger.info(f"Verified OAuth token: {user_email}")
|
||||||
@@ -214,6 +271,7 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
|
|||||||
user_email = access_token.claims.get("email")
|
user_email = access_token.claims.get("email")
|
||||||
if user_email:
|
if user_email:
|
||||||
from auth.oauth21_session_store import get_oauth21_session_store
|
from auth.oauth21_session_store import get_oauth21_session_store
|
||||||
|
|
||||||
store = get_oauth21_session_store()
|
store = get_oauth21_session_store()
|
||||||
session_id = f"google_{access_token.claims.get('sub', 'unknown')}"
|
session_id = f"google_{access_token.claims.get('sub', 'unknown')}"
|
||||||
|
|
||||||
@@ -223,9 +281,11 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
|
|||||||
access_token=token,
|
access_token=token,
|
||||||
scopes=access_token.scopes or [],
|
scopes=access_token.scopes or [],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
issuer="https://accounts.google.com"
|
issuer="https://accounts.google.com",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Successfully verified JWT token for user: {user_email}")
|
logger.debug(
|
||||||
|
f"Successfully verified JWT token for user: {user_email}"
|
||||||
|
)
|
||||||
|
|
||||||
return access_token
|
return access_token
|
||||||
@@ -160,23 +160,6 @@ def set_auth_layer(auth_layer):
|
|||||||
logger.info("set_auth_layer called - OAuth is now handled by FastMCP")
|
logger.info("set_auth_layer called - OAuth is now handled by FastMCP")
|
||||||
|
|
||||||
|
|
||||||
_oauth21_enabled = False
|
|
||||||
|
|
||||||
def is_oauth21_enabled() -> bool:
|
|
||||||
"""
|
|
||||||
Check if the OAuth 2.1 authentication layer is active.
|
|
||||||
"""
|
|
||||||
global _oauth21_enabled
|
|
||||||
return _oauth21_enabled
|
|
||||||
|
|
||||||
|
|
||||||
def enable_oauth21():
|
|
||||||
"""
|
|
||||||
Enable the OAuth 2.1 authentication layer.
|
|
||||||
"""
|
|
||||||
global _oauth21_enabled
|
|
||||||
_oauth21_enabled = True
|
|
||||||
logger.debug("OAuth 2.1 authentication enabled")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_legacy_auth_service(
|
async def get_legacy_auth_service(
|
||||||
@@ -206,13 +189,16 @@ async def get_authenticated_google_service_oauth21(
|
|||||||
tool_name: str,
|
tool_name: str,
|
||||||
user_google_email: str,
|
user_google_email: str,
|
||||||
required_scopes: list[str],
|
required_scopes: list[str],
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
auth_token_email: Optional[str] = None,
|
||||||
|
allow_recent_auth: bool = False,
|
||||||
context: Optional[Dict[str, Any]] = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[Any, str]:
|
) -> Tuple[Any, str]:
|
||||||
"""
|
"""
|
||||||
Enhanced version of get_authenticated_google_service that supports OAuth 2.1.
|
Enhanced version of get_authenticated_google_service that supports OAuth 2.1.
|
||||||
|
|
||||||
This function checks for OAuth 2.1 session context and uses it if available,
|
This function checks for OAuth 2.1 session context and uses it if available,
|
||||||
otherwise falls back to legacy authentication.
|
otherwise falls back to legacy authentication based on configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_name: Google service name
|
service_name: Google service name
|
||||||
@@ -220,20 +206,32 @@ async def get_authenticated_google_service_oauth21(
|
|||||||
tool_name: Tool name for logging
|
tool_name: Tool name for logging
|
||||||
user_google_email: User's Google email
|
user_google_email: User's Google email
|
||||||
required_scopes: Required OAuth scopes
|
required_scopes: Required OAuth scopes
|
||||||
|
session_id: Optional OAuth session ID
|
||||||
|
auth_token_email: Optional authenticated user email from token
|
||||||
|
allow_recent_auth: Whether to allow recently authenticated sessions
|
||||||
context: Optional context containing session information
|
context: Optional context containing session information
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (service instance, actual user email)
|
Tuple of (service instance, actual user email)
|
||||||
"""
|
"""
|
||||||
|
# Check if OAuth 2.1 is truly enabled
|
||||||
|
if not is_oauth21_enabled():
|
||||||
|
logger.debug(f"[{tool_name}] OAuth 2.1 disabled, using legacy authentication")
|
||||||
|
return await get_legacy_auth_service(
|
||||||
|
service_name=service_name,
|
||||||
|
version=version,
|
||||||
|
tool_name=tool_name,
|
||||||
|
user_google_email=user_google_email,
|
||||||
|
required_scopes=required_scopes,
|
||||||
|
)
|
||||||
|
|
||||||
builder = get_oauth21_service_builder()
|
builder = get_oauth21_service_builder()
|
||||||
|
|
||||||
# FastMCP handles context now - extract any session info
|
# FastMCP handles context now - extract any session info
|
||||||
session_id = None
|
if not session_id and context:
|
||||||
auth_context = None
|
|
||||||
|
|
||||||
if context:
|
|
||||||
session_id = builder.extract_session_from_context(context)
|
session_id = builder.extract_session_from_context(context)
|
||||||
auth_context = context.get("auth_context")
|
|
||||||
|
auth_context = context.get("auth_context") if context else None
|
||||||
|
|
||||||
return await builder.get_authenticated_service_with_session(
|
return await builder.get_authenticated_service_with_session(
|
||||||
service_name=service_name,
|
service_name=service_name,
|
||||||
@@ -244,3 +242,34 @@ async def get_authenticated_google_service_oauth21(
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
auth_context=auth_context,
|
auth_context=auth_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_authenticated_google_service_oauth21_v2(
|
||||||
|
request: "OAuth21ServiceRequest",
|
||||||
|
) -> Tuple[Any, str]:
|
||||||
|
"""
|
||||||
|
Enhanced version of get_authenticated_google_service that supports OAuth 2.1.
|
||||||
|
|
||||||
|
This version uses a parameter object to reduce function complexity and
|
||||||
|
improve maintainability. It's the recommended approach for new code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: OAuth21ServiceRequest object containing all parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (service instance, actual user email)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Delegate to the original function for now
|
||||||
|
# This provides a migration path while maintaining backward compatibility
|
||||||
|
return await get_authenticated_google_service_oauth21(
|
||||||
|
service_name=request.service_name,
|
||||||
|
version=request.version,
|
||||||
|
tool_name=request.tool_name,
|
||||||
|
user_google_email=request.user_google_email,
|
||||||
|
required_scopes=request.required_scopes,
|
||||||
|
session_id=request.session_id,
|
||||||
|
auth_token_email=request.auth_token_email,
|
||||||
|
allow_recent_auth=request.allow_recent_auth,
|
||||||
|
context=request.context,
|
||||||
|
)
|
||||||
@@ -5,7 +5,6 @@ In streamable-http mode: Uses the existing FastAPI server
|
|||||||
In stdio mode: Starts a minimal HTTP server just for OAuth callbacks
|
In stdio mode: Starts a minimal HTTP server just for OAuth callbacks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
@@ -20,7 +19,7 @@ from urllib.parse import urlparse
|
|||||||
from auth.scopes import SCOPES
|
from auth.scopes import SCOPES
|
||||||
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
|
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
|
||||||
from auth.google_auth import handle_auth_callback, check_client_secrets
|
from auth.google_auth import handle_auth_callback, check_client_secrets
|
||||||
from core.config import get_oauth_redirect_uri
|
from auth.oauth_config import get_oauth_redirect_uri
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -16,22 +16,24 @@ from google.oauth2.credentials import Credentials
|
|||||||
from auth.oauth21_session_store import store_token_session
|
from auth.oauth21_session_store import store_token_session
|
||||||
from auth.google_auth import save_credentials_to_file
|
from auth.google_auth import save_credentials_to_file
|
||||||
from auth.scopes import get_current_scopes
|
from auth.scopes import get_current_scopes
|
||||||
from core.config import WORKSPACE_MCP_BASE_URI, WORKSPACE_MCP_PORT, get_oauth_base_url
|
from auth.oauth_config import get_oauth_config
|
||||||
|
from auth.oauth_error_handling import (
|
||||||
|
OAuthError, OAuthValidationError, OAuthConfigurationError,
|
||||||
|
create_oauth_error_response, validate_token_request,
|
||||||
|
validate_registration_request, get_development_cors_headers,
|
||||||
|
log_security_event
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def handle_oauth_authorize(request: Request):
|
async def handle_oauth_authorize(request: Request):
|
||||||
"""Common handler for OAuth authorization proxy."""
|
"""Common handler for OAuth authorization proxy."""
|
||||||
|
origin = request.headers.get("origin")
|
||||||
|
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
return JSONResponse(
|
cors_headers = get_development_cors_headers(origin)
|
||||||
content={},
|
return JSONResponse(content={}, headers=cors_headers)
|
||||||
headers={
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get query parameters
|
# Get query parameters
|
||||||
params = dict(request.query_params)
|
params = dict(request.query_params)
|
||||||
@@ -55,35 +57,40 @@ async def handle_oauth_authorize(request: Request):
|
|||||||
# Build Google authorization URL
|
# Build Google authorization URL
|
||||||
google_auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(params)
|
google_auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(params)
|
||||||
|
|
||||||
# Return redirect
|
# Return redirect with development CORS headers if needed
|
||||||
|
cors_headers = get_development_cors_headers(origin)
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=google_auth_url,
|
url=google_auth_url,
|
||||||
status_code=302,
|
status_code=302,
|
||||||
headers={
|
headers=cors_headers
|
||||||
"Access-Control-Allow-Origin": "*"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_proxy_token_exchange(request: Request):
|
async def handle_proxy_token_exchange(request: Request):
|
||||||
"""Common handler for OAuth token exchange proxy."""
|
"""Common handler for OAuth token exchange proxy with comprehensive error handling."""
|
||||||
|
origin = request.headers.get("origin")
|
||||||
|
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
return JSONResponse(
|
cors_headers = get_development_cors_headers(origin)
|
||||||
content={},
|
return JSONResponse(content={}, headers=cors_headers)
|
||||||
headers={
|
try:
|
||||||
"Access-Control-Allow-Origin": "*",
|
# Get form data with validation
|
||||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type, Authorization"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
# Get form data
|
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
content_type = request.headers.get("content-type", "application/x-www-form-urlencoded")
|
content_type = request.headers.get("content-type", "application/x-www-form-urlencoded")
|
||||||
|
except Exception as e:
|
||||||
|
raise OAuthValidationError(f"Failed to read request body: {e}")
|
||||||
|
|
||||||
# Parse form data to add missing client credentials
|
# Parse and validate form data
|
||||||
if content_type and "application/x-www-form-urlencoded" in content_type:
|
if content_type and "application/x-www-form-urlencoded" in content_type:
|
||||||
|
try:
|
||||||
form_data = parse_qs(body.decode('utf-8'))
|
form_data = parse_qs(body.decode('utf-8'))
|
||||||
|
except Exception as e:
|
||||||
|
raise OAuthValidationError(f"Invalid form data: {e}")
|
||||||
|
|
||||||
|
# Convert to single values and validate
|
||||||
|
request_data = {k: v[0] if v else '' for k, v in form_data.items()}
|
||||||
|
validate_token_request(request_data)
|
||||||
|
|
||||||
# Check if client_id is missing (public client)
|
# Check if client_id is missing (public client)
|
||||||
if 'client_id' not in form_data or not form_data['client_id'][0]:
|
if 'client_id' not in form_data or not form_data['client_id'][0]:
|
||||||
@@ -186,43 +193,57 @@ async def handle_proxy_token_exchange(request: Request):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to store OAuth session: {e}")
|
logger.error(f"Failed to store OAuth session: {e}")
|
||||||
|
|
||||||
|
# Add development CORS headers
|
||||||
|
cors_headers = get_development_cors_headers(origin)
|
||||||
|
response_headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Cache-Control": "no-store"
|
||||||
|
}
|
||||||
|
response_headers.update(cors_headers)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=response.status,
|
status_code=response.status,
|
||||||
content=response_data,
|
content=response_data,
|
||||||
headers={
|
headers=response_headers
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Cache-Control": "no-store"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except OAuthError as e:
|
||||||
|
log_security_event("oauth_token_exchange_error", {
|
||||||
|
"error_code": e.error_code,
|
||||||
|
"description": e.description
|
||||||
|
}, request)
|
||||||
|
return create_oauth_error_response(e, origin)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in token proxy: {e}")
|
logger.error(f"Unexpected error in token proxy: {e}", exc_info=True)
|
||||||
return JSONResponse(
|
log_security_event("oauth_token_exchange_unexpected_error", {
|
||||||
status_code=500,
|
"error": str(e)
|
||||||
content={"error": "server_error", "error_description": str(e)},
|
}, request)
|
||||||
headers={"Access-Control-Allow-Origin": "*"}
|
error = OAuthConfigurationError("Internal server error")
|
||||||
)
|
return create_oauth_error_response(error, origin)
|
||||||
|
|
||||||
|
|
||||||
async def handle_oauth_protected_resource(request: Request):
|
async def handle_oauth_protected_resource(request: Request):
|
||||||
"""Common handler for OAuth protected resource metadata."""
|
"""
|
||||||
if request.method == "OPTIONS":
|
Handle OAuth protected resource metadata requests.
|
||||||
return JSONResponse(
|
"""
|
||||||
content={},
|
origin = request.headers.get("origin")
|
||||||
headers={
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
base_url = get_oauth_base_url()
|
# Handle preflight
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
cors_headers = get_development_cors_headers(origin)
|
||||||
|
return JSONResponse(content={}, headers=cors_headers)
|
||||||
|
|
||||||
|
config = get_oauth_config()
|
||||||
|
base_url = config.get_oauth_base_url()
|
||||||
|
|
||||||
|
# For streamable-http transport, the MCP server runs at /mcp
|
||||||
|
# This is the actual resource being protected
|
||||||
|
resource_url = f"{base_url}/mcp"
|
||||||
|
|
||||||
|
# Build metadata response per RFC 9449
|
||||||
metadata = {
|
metadata = {
|
||||||
"resource": base_url,
|
"resource": resource_url, # The MCP server endpoint that needs protection
|
||||||
"authorization_servers": [
|
"authorization_servers": [base_url], # Our proxy acts as the auth server
|
||||||
base_url
|
|
||||||
],
|
|
||||||
"bearer_methods_supported": ["header"],
|
"bearer_methods_supported": ["header"],
|
||||||
"scopes_supported": get_current_scopes(),
|
"scopes_supported": get_current_scopes(),
|
||||||
"resource_documentation": "https://developers.google.com/workspace",
|
"resource_documentation": "https://developers.google.com/workspace",
|
||||||
@@ -230,179 +251,143 @@ async def handle_oauth_protected_resource(request: Request):
|
|||||||
"client_configuration_endpoint": f"{base_url}/.well-known/oauth-client",
|
"client_configuration_endpoint": f"{base_url}/.well-known/oauth-client",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Log the response for debugging
|
||||||
|
logger.debug(f"Returning protected resource metadata: {metadata}")
|
||||||
|
|
||||||
|
# Add development CORS headers
|
||||||
|
cors_headers = get_development_cors_headers(origin)
|
||||||
|
response_headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Cache-Control": "public, max-age=3600"
|
||||||
|
}
|
||||||
|
response_headers.update(cors_headers)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content=metadata,
|
content=metadata,
|
||||||
headers={
|
headers=response_headers
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Access-Control-Allow-Origin": "*"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_oauth_authorization_server(request: Request):
|
async def handle_oauth_authorization_server(request: Request):
|
||||||
"""Common handler for OAuth authorization server metadata."""
|
"""
|
||||||
|
Handle OAuth authorization server metadata.
|
||||||
|
"""
|
||||||
|
origin = request.headers.get("origin")
|
||||||
|
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
return JSONResponse(
|
cors_headers = get_development_cors_headers(origin)
|
||||||
content={},
|
return JSONResponse(content={}, headers=cors_headers)
|
||||||
headers={
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
config = get_oauth_config()
|
||||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type"
|
# Get authorization server metadata from centralized config
|
||||||
|
# Pass scopes directly to keep all metadata generation in one place
|
||||||
|
metadata = config.get_authorization_server_metadata(scopes=get_current_scopes())
|
||||||
|
|
||||||
|
logger.debug(f"Returning authorization server metadata: {metadata}")
|
||||||
|
|
||||||
|
# Add development CORS headers
|
||||||
|
cors_headers = get_development_cors_headers(origin)
|
||||||
|
response_headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Cache-Control": "public, max-age=3600"
|
||||||
}
|
}
|
||||||
)
|
response_headers.update(cors_headers)
|
||||||
|
|
||||||
# Get base URL once and reuse
|
|
||||||
base_url = get_oauth_base_url()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Fetch metadata from Google
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
url = "https://accounts.google.com/.well-known/openid-configuration"
|
|
||||||
async with session.get(url) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
metadata = await response.json()
|
|
||||||
|
|
||||||
# Add OAuth 2.1 required fields
|
|
||||||
metadata.setdefault("code_challenge_methods_supported", ["S256"])
|
|
||||||
metadata.setdefault("pkce_required", True)
|
|
||||||
|
|
||||||
# Override endpoints to use our proxies
|
|
||||||
metadata["token_endpoint"] = f"{base_url}/oauth2/token"
|
|
||||||
metadata["authorization_endpoint"] = f"{base_url}/oauth2/authorize"
|
|
||||||
metadata["enable_dynamic_registration"] = True
|
|
||||||
metadata["registration_endpoint"] = f"{base_url}/oauth2/register"
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content=metadata,
|
content=metadata,
|
||||||
headers={
|
headers=response_headers
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Access-Control-Allow-Origin": "*"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback metadata
|
|
||||||
return JSONResponse(
|
|
||||||
content={
|
|
||||||
"issuer": "https://accounts.google.com",
|
|
||||||
"authorization_endpoint": f"{base_url}/oauth2/authorize",
|
|
||||||
"token_endpoint": f"{base_url}/oauth2/token",
|
|
||||||
"userinfo_endpoint": "https://www.googleapis.com/oauth2/v2/userinfo",
|
|
||||||
"revocation_endpoint": "https://oauth2.googleapis.com/revoke",
|
|
||||||
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
|
|
||||||
"response_types_supported": ["code"],
|
|
||||||
"code_challenge_methods_supported": ["S256"],
|
|
||||||
"pkce_required": True,
|
|
||||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
|
||||||
"scopes_supported": get_current_scopes(),
|
|
||||||
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"]
|
|
||||||
},
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Access-Control-Allow-Origin": "*"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching auth server metadata: {e}")
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=500,
|
|
||||||
content={"error": "Failed to fetch authorization server metadata"},
|
|
||||||
headers={"Access-Control-Allow-Origin": "*"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_oauth_client_config(request: Request):
|
async def handle_oauth_client_config(request: Request):
|
||||||
"""Common handler for OAuth client configuration."""
|
"""Common handler for OAuth client configuration."""
|
||||||
|
origin = request.headers.get("origin")
|
||||||
|
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
return JSONResponse(
|
cors_headers = get_development_cors_headers(origin)
|
||||||
content={},
|
return JSONResponse(content={}, headers=cors_headers)
|
||||||
headers={
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||||
if not client_id:
|
if not client_id:
|
||||||
|
cors_headers = get_development_cors_headers(origin)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
content={"error": "OAuth not configured"},
|
content={"error": "OAuth not configured"},
|
||||||
headers={"Access-Control-Allow-Origin": "*"}
|
headers=cors_headers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get OAuth configuration
|
||||||
|
config = get_oauth_config()
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={
|
content={
|
||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"client_name": "Google Workspace MCP Server",
|
"client_name": "Google Workspace MCP Server",
|
||||||
"client_uri": f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}",
|
"client_uri": config.base_url,
|
||||||
"redirect_uris": [
|
"redirect_uris": [
|
||||||
f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/oauth2callback",
|
f"{config.base_url}/oauth2callback",
|
||||||
"http://localhost:5173/auth/callback"
|
"http://localhost:5173/auth/callback"
|
||||||
],
|
],
|
||||||
"grant_types": ["authorization_code", "refresh_token"],
|
"grant_types": ["authorization_code", "refresh_token"],
|
||||||
"response_types": ["code"],
|
"response_types": ["code"],
|
||||||
"scope": " ".join(get_current_scopes()),
|
"scope": " ".join(get_current_scopes()),
|
||||||
"token_endpoint_auth_method": "client_secret_basic",
|
"token_endpoint_auth_method": "client_secret_basic",
|
||||||
"code_challenge_methods": ["S256"]
|
"code_challenge_methods": config.supported_code_challenge_methods[:1] # Primary method only
|
||||||
},
|
},
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
"Access-Control-Allow-Origin": "*"
|
"Cache-Control": "public, max-age=3600",
|
||||||
|
**get_development_cors_headers(origin)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def handle_oauth_register(request: Request):
|
async def handle_oauth_register(request: Request):
|
||||||
"""Common handler for OAuth dynamic client registration."""
|
"""Common handler for OAuth dynamic client registration with comprehensive error handling."""
|
||||||
|
origin = request.headers.get("origin")
|
||||||
|
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
return JSONResponse(
|
cors_headers = get_development_cors_headers(origin)
|
||||||
content={},
|
return JSONResponse(content={}, headers=cors_headers)
|
||||||
headers={
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "Content-Type, Authorization"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
config = get_oauth_config()
|
||||||
client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
|
||||||
|
|
||||||
if not client_id or not client_secret:
|
if not config.is_configured():
|
||||||
return JSONResponse(
|
error = OAuthConfigurationError("OAuth client credentials not configured")
|
||||||
status_code=400,
|
return create_oauth_error_response(error, origin)
|
||||||
content={"error": "invalid_request", "error_description": "OAuth not configured"},
|
|
||||||
headers={"Access-Control-Allow-Origin": "*"}
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Parse the registration request
|
# Parse and validate the registration request
|
||||||
|
try:
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
logger.info(f"Dynamic client registration request received: {body}")
|
except Exception as e:
|
||||||
|
raise OAuthValidationError(f"Invalid JSON in registration request: {e}")
|
||||||
|
|
||||||
|
validate_registration_request(body)
|
||||||
|
logger.info("Dynamic client registration request received")
|
||||||
|
|
||||||
# Extract redirect URIs from the request or use defaults
|
# Extract redirect URIs from the request or use defaults
|
||||||
redirect_uris = body.get("redirect_uris", [])
|
redirect_uris = body.get("redirect_uris", [])
|
||||||
if not redirect_uris:
|
if not redirect_uris:
|
||||||
redirect_uris = [
|
redirect_uris = config.get_redirect_uris()
|
||||||
f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/oauth2callback",
|
|
||||||
"http://localhost:5173/auth/callback"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Build the registration response with our pre-configured credentials
|
# Build the registration response with our pre-configured credentials
|
||||||
response_data = {
|
response_data = {
|
||||||
"client_id": client_id,
|
"client_id": config.client_id,
|
||||||
"client_secret": client_secret,
|
"client_secret": config.client_secret,
|
||||||
"client_name": body.get("client_name", "Google Workspace MCP Server"),
|
"client_name": body.get("client_name", "Google Workspace MCP Server"),
|
||||||
"client_uri": body.get("client_uri", f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}"),
|
"client_uri": body.get("client_uri", config.base_url),
|
||||||
"redirect_uris": redirect_uris,
|
"redirect_uris": redirect_uris,
|
||||||
"grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]),
|
"grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]),
|
||||||
"response_types": body.get("response_types", ["code"]),
|
"response_types": body.get("response_types", ["code"]),
|
||||||
"scope": body.get("scope", " ".join(get_current_scopes())),
|
"scope": body.get("scope", " ".join(get_current_scopes())),
|
||||||
"token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_basic"),
|
"token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_basic"),
|
||||||
"code_challenge_methods": ["S256"],
|
"code_challenge_methods": config.supported_code_challenge_methods,
|
||||||
# Additional OAuth 2.1 fields
|
# Additional OAuth 2.1 fields
|
||||||
"client_id_issued_at": int(time.time()),
|
"client_id_issued_at": int(time.time()),
|
||||||
"registration_access_token": "not-required", # We don't implement client management
|
"registration_access_token": "not-required", # We don't implement client management
|
||||||
"registration_client_uri": f"{get_oauth_base_url()}/oauth2/register/{client_id}"
|
"registration_client_uri": f"{config.get_oauth_base_url()}/oauth2/register/{config.client_id}"
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("Dynamic client registration successful - returning pre-configured Google credentials")
|
logger.info("Dynamic client registration successful - returning pre-configured Google credentials")
|
||||||
@@ -412,15 +397,21 @@ async def handle_oauth_register(request: Request):
|
|||||||
content=response_data,
|
content=response_data,
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Access-Control-Allow-Origin": "*",
|
"Cache-Control": "no-store",
|
||||||
"Cache-Control": "no-store"
|
**get_development_cors_headers(origin)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except OAuthError as e:
|
||||||
|
log_security_event("oauth_registration_error", {
|
||||||
|
"error_code": e.error_code,
|
||||||
|
"description": e.description
|
||||||
|
}, request)
|
||||||
|
return create_oauth_error_response(e, origin)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in dynamic client registration: {e}")
|
logger.error(f"Unexpected error in client registration: {e}", exc_info=True)
|
||||||
return JSONResponse(
|
log_security_event("oauth_registration_unexpected_error", {
|
||||||
status_code=400,
|
"error": str(e)
|
||||||
content={"error": "invalid_request", "error_description": str(e)},
|
}, request)
|
||||||
headers={"Access-Control-Allow-Origin": "*"}
|
error = OAuthConfigurationError("Internal server error")
|
||||||
)
|
return create_oauth_error_response(error, origin)
|
||||||
319
auth/oauth_config.py
Normal file
319
auth/oauth_config.py
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
"""
|
||||||
|
OAuth Configuration Management
|
||||||
|
|
||||||
|
This module centralizes OAuth-related configuration to eliminate hardcoded values
|
||||||
|
scattered throughout the codebase. It provides environment variable support and
|
||||||
|
sensible defaults for all OAuth-related settings.
|
||||||
|
|
||||||
|
Supports both OAuth 2.0 and OAuth 2.1 with automatic client capability detection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthConfig:
|
||||||
|
"""
|
||||||
|
Centralized OAuth configuration management.
|
||||||
|
|
||||||
|
This class eliminates the hardcoded configuration anti-pattern identified
|
||||||
|
in the challenge review by providing a single source of truth for all
|
||||||
|
OAuth-related configuration values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Base server configuration
|
||||||
|
self.base_uri = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
|
||||||
|
self.port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", "8000")))
|
||||||
|
self.base_url = f"{self.base_uri}:{self.port}"
|
||||||
|
|
||||||
|
# OAuth client configuration
|
||||||
|
self.client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||||
|
self.client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||||
|
|
||||||
|
# OAuth 2.1 configuration
|
||||||
|
self.oauth21_enabled = os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true"
|
||||||
|
self.pkce_required = self.oauth21_enabled # PKCE is mandatory in OAuth 2.1
|
||||||
|
self.supported_code_challenge_methods = ["S256", "plain"] if not self.oauth21_enabled else ["S256"]
|
||||||
|
|
||||||
|
# Transport mode (will be set at runtime)
|
||||||
|
self._transport_mode = "stdio" # Default
|
||||||
|
|
||||||
|
# Redirect URI configuration
|
||||||
|
self.redirect_uri = self._get_redirect_uri()
|
||||||
|
|
||||||
|
def _get_redirect_uri(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the OAuth redirect URI, supporting reverse proxy configurations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configured redirect URI
|
||||||
|
"""
|
||||||
|
explicit_uri = os.getenv("GOOGLE_OAUTH_REDIRECT_URI")
|
||||||
|
if explicit_uri:
|
||||||
|
return explicit_uri
|
||||||
|
return f"{self.base_url}/oauth2callback"
|
||||||
|
|
||||||
|
def get_redirect_uris(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get all valid OAuth redirect URIs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all supported redirect URIs
|
||||||
|
"""
|
||||||
|
uris = []
|
||||||
|
|
||||||
|
# Primary redirect URI
|
||||||
|
uris.append(self.redirect_uri)
|
||||||
|
|
||||||
|
# Custom redirect URIs from environment
|
||||||
|
custom_uris = os.getenv("OAUTH_CUSTOM_REDIRECT_URIS")
|
||||||
|
if custom_uris:
|
||||||
|
uris.extend([uri.strip() for uri in custom_uris.split(",")])
|
||||||
|
|
||||||
|
# Remove duplicates while preserving order
|
||||||
|
return list(dict.fromkeys(uris))
|
||||||
|
|
||||||
|
def get_allowed_origins(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get allowed CORS origins for OAuth endpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of allowed origins for CORS
|
||||||
|
"""
|
||||||
|
origins = []
|
||||||
|
|
||||||
|
# Server's own origin
|
||||||
|
origins.append(self.base_url)
|
||||||
|
|
||||||
|
# VS Code and development origins
|
||||||
|
origins.extend([
|
||||||
|
"vscode-webview://",
|
||||||
|
"https://vscode.dev",
|
||||||
|
"https://github.dev",
|
||||||
|
])
|
||||||
|
|
||||||
|
# Custom origins from environment
|
||||||
|
custom_origins = os.getenv("OAUTH_ALLOWED_ORIGINS")
|
||||||
|
if custom_origins:
|
||||||
|
origins.extend([origin.strip() for origin in custom_origins.split(",")])
|
||||||
|
|
||||||
|
return list(dict.fromkeys(origins))
|
||||||
|
|
||||||
|
def is_configured(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if OAuth is properly configured.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if OAuth client credentials are available
|
||||||
|
"""
|
||||||
|
return bool(self.client_id and self.client_secret)
|
||||||
|
|
||||||
|
def get_oauth_base_url(self) -> str:
|
||||||
|
"""
|
||||||
|
Get OAuth base URL for constructing OAuth endpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base URL for OAuth endpoints
|
||||||
|
"""
|
||||||
|
return self.base_url
|
||||||
|
|
||||||
|
def validate_redirect_uri(self, uri: str) -> bool:
|
||||||
|
"""
|
||||||
|
Validate if a redirect URI is allowed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri: The redirect URI to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the URI is allowed, False otherwise
|
||||||
|
"""
|
||||||
|
allowed_uris = self.get_redirect_uris()
|
||||||
|
return uri in allowed_uris
|
||||||
|
|
||||||
|
def get_environment_summary(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get a summary of the current OAuth configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with configuration summary (excluding secrets)
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"base_url": self.base_url,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"client_configured": bool(self.client_id),
|
||||||
|
"oauth21_enabled": self.oauth21_enabled,
|
||||||
|
"pkce_required": self.pkce_required,
|
||||||
|
"transport_mode": self._transport_mode,
|
||||||
|
"total_redirect_uris": len(self.get_redirect_uris()),
|
||||||
|
"total_allowed_origins": len(self.get_allowed_origins()),
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_transport_mode(self, mode: str) -> None:
|
||||||
|
"""
|
||||||
|
Set the current transport mode for OAuth callback handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: Transport mode ("stdio", "streamable-http", etc.)
|
||||||
|
"""
|
||||||
|
self._transport_mode = mode
|
||||||
|
|
||||||
|
def get_transport_mode(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the current transport mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current transport mode
|
||||||
|
"""
|
||||||
|
return self._transport_mode
|
||||||
|
|
||||||
|
def is_oauth21_enabled(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if OAuth 2.1 mode is enabled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if OAuth 2.1 is enabled
|
||||||
|
"""
|
||||||
|
return self.oauth21_enabled
|
||||||
|
|
||||||
|
def detect_oauth_version(self, request_params: Dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Detect OAuth version based on request parameters.
|
||||||
|
|
||||||
|
This method implements a conservative detection strategy:
|
||||||
|
- Only returns "oauth21" when we have clear indicators
|
||||||
|
- Defaults to "oauth20" for backward compatibility
|
||||||
|
- Respects the global oauth21_enabled flag
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_params: Request parameters from authorization or token request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"oauth21" or "oauth20" based on detection
|
||||||
|
"""
|
||||||
|
# If OAuth 2.1 is not enabled globally, always return OAuth 2.0
|
||||||
|
if not self.oauth21_enabled:
|
||||||
|
return "oauth20"
|
||||||
|
|
||||||
|
# Use the structured type for cleaner detection logic
|
||||||
|
from auth.oauth_types import OAuthVersionDetectionParams
|
||||||
|
params = OAuthVersionDetectionParams.from_request(request_params)
|
||||||
|
|
||||||
|
# Clear OAuth 2.1 indicator: PKCE is present
|
||||||
|
if params.has_pkce:
|
||||||
|
return "oauth21"
|
||||||
|
|
||||||
|
# For public clients in OAuth 2.1 mode, we require PKCE
|
||||||
|
# But since they didn't send PKCE, fall back to OAuth 2.0
|
||||||
|
# This ensures backward compatibility
|
||||||
|
|
||||||
|
# Default to OAuth 2.0 for maximum compatibility
|
||||||
|
return "oauth20"
|
||||||
|
|
||||||
|
def get_authorization_server_metadata(self, scopes: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get OAuth authorization server metadata per RFC 8414.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scopes: Optional list of supported scopes to include in metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Authorization server metadata dictionary
|
||||||
|
"""
|
||||||
|
metadata = {
|
||||||
|
"issuer": self.base_url,
|
||||||
|
"authorization_endpoint": f"{self.base_url}/oauth2/authorize",
|
||||||
|
"token_endpoint": f"{self.base_url}/oauth2/token",
|
||||||
|
"registration_endpoint": f"{self.base_url}/oauth2/register",
|
||||||
|
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
|
||||||
|
"response_types_supported": ["code", "token"],
|
||||||
|
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||||
|
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
||||||
|
"code_challenge_methods_supported": self.supported_code_challenge_methods,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Include scopes if provided
|
||||||
|
if scopes is not None:
|
||||||
|
metadata["scopes_supported"] = scopes
|
||||||
|
|
||||||
|
# Add OAuth 2.1 specific metadata
|
||||||
|
if self.oauth21_enabled:
|
||||||
|
metadata["pkce_required"] = True
|
||||||
|
# OAuth 2.1 deprecates implicit flow
|
||||||
|
metadata["response_types_supported"] = ["code"]
|
||||||
|
# OAuth 2.1 requires exact redirect URI matching
|
||||||
|
metadata["require_exact_redirect_uri"] = True
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
# Global configuration instance
|
||||||
|
_oauth_config = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_oauth_config() -> OAuthConfig:
|
||||||
|
"""
|
||||||
|
Get the global OAuth configuration instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The singleton OAuth configuration instance
|
||||||
|
"""
|
||||||
|
global _oauth_config
|
||||||
|
if _oauth_config is None:
|
||||||
|
_oauth_config = OAuthConfig()
|
||||||
|
return _oauth_config
|
||||||
|
|
||||||
|
|
||||||
|
def reload_oauth_config() -> OAuthConfig:
|
||||||
|
"""
|
||||||
|
Reload the OAuth configuration from environment variables.
|
||||||
|
|
||||||
|
This is useful for testing or when environment variables change.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reloaded OAuth configuration instance
|
||||||
|
"""
|
||||||
|
global _oauth_config
|
||||||
|
_oauth_config = OAuthConfig()
|
||||||
|
return _oauth_config
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions for backward compatibility
|
||||||
|
def get_oauth_base_url() -> str:
|
||||||
|
"""Get OAuth base URL."""
|
||||||
|
return get_oauth_config().get_oauth_base_url()
|
||||||
|
|
||||||
|
|
||||||
|
def get_redirect_uris() -> List[str]:
|
||||||
|
"""Get all valid OAuth redirect URIs."""
|
||||||
|
return get_oauth_config().get_redirect_uris()
|
||||||
|
|
||||||
|
|
||||||
|
def get_allowed_origins() -> List[str]:
|
||||||
|
"""Get allowed CORS origins."""
|
||||||
|
return get_oauth_config().get_allowed_origins()
|
||||||
|
|
||||||
|
|
||||||
|
def is_oauth_configured() -> bool:
|
||||||
|
"""Check if OAuth is properly configured."""
|
||||||
|
return get_oauth_config().is_configured()
|
||||||
|
|
||||||
|
|
||||||
|
def set_transport_mode(mode: str) -> None:
|
||||||
|
"""Set the current transport mode."""
|
||||||
|
get_oauth_config().set_transport_mode(mode)
|
||||||
|
|
||||||
|
|
||||||
|
def get_transport_mode() -> str:
|
||||||
|
"""Get the current transport mode."""
|
||||||
|
return get_oauth_config().get_transport_mode()
|
||||||
|
|
||||||
|
|
||||||
|
def is_oauth21_enabled() -> bool:
|
||||||
|
"""Check if OAuth 2.1 is enabled."""
|
||||||
|
return get_oauth_config().is_oauth21_enabled()
|
||||||
|
|
||||||
|
|
||||||
|
def get_oauth_redirect_uri() -> str:
|
||||||
|
"""Get the primary OAuth redirect URI."""
|
||||||
|
return get_oauth_config().redirect_uri
|
||||||
321
auth/oauth_error_handling.py
Normal file
321
auth/oauth_error_handling.py
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
"""
|
||||||
|
OAuth Error Handling and Validation
|
||||||
|
|
||||||
|
This module provides comprehensive error handling and input validation for OAuth
|
||||||
|
endpoints, addressing the inconsistent error handling identified in the challenge review.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.requests import Request
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import re
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthError(Exception):
|
||||||
|
"""Base exception for OAuth-related errors."""
|
||||||
|
|
||||||
|
def __init__(self, error_code: str, description: str, status_code: int = 400):
|
||||||
|
self.error_code = error_code
|
||||||
|
self.description = description
|
||||||
|
self.status_code = status_code
|
||||||
|
super().__init__(f"{error_code}: {description}")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthValidationError(OAuthError):
|
||||||
|
"""Exception for OAuth validation errors."""
|
||||||
|
|
||||||
|
def __init__(self, description: str, field: Optional[str] = None):
|
||||||
|
error_code = "invalid_request"
|
||||||
|
if field:
|
||||||
|
description = f"Invalid {field}: {description}"
|
||||||
|
super().__init__(error_code, description, 400)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthConfigurationError(OAuthError):
|
||||||
|
"""Exception for OAuth configuration errors."""
|
||||||
|
|
||||||
|
def __init__(self, description: str):
|
||||||
|
super().__init__("server_error", description, 500)
|
||||||
|
|
||||||
|
|
||||||
|
def create_oauth_error_response(error: OAuthError, origin: Optional[str] = None) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
Create a standardized OAuth error response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: The OAuth error to convert to a response
|
||||||
|
origin: Optional origin for development CORS headers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSONResponse with standardized error format
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Cache-Control": "no-store"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add development CORS headers if needed
|
||||||
|
cors_headers = get_development_cors_headers(origin)
|
||||||
|
headers.update(cors_headers)
|
||||||
|
|
||||||
|
content = {
|
||||||
|
"error": error.error_code,
|
||||||
|
"error_description": error.description
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.warning(f"OAuth error response: {error.error_code} - {error.description}")
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=error.status_code,
|
||||||
|
content=content,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_redirect_uri(uri: str) -> None:
|
||||||
|
"""
|
||||||
|
Validate an OAuth redirect URI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri: The redirect URI to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OAuthValidationError: If the URI is invalid
|
||||||
|
"""
|
||||||
|
if not uri:
|
||||||
|
raise OAuthValidationError("Redirect URI is required", "redirect_uri")
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(uri)
|
||||||
|
except Exception:
|
||||||
|
raise OAuthValidationError("Malformed redirect URI", "redirect_uri")
|
||||||
|
|
||||||
|
# Basic URI validation
|
||||||
|
if not parsed.scheme or not parsed.netloc:
|
||||||
|
raise OAuthValidationError("Redirect URI must be absolute", "redirect_uri")
|
||||||
|
|
||||||
|
# Security checks
|
||||||
|
if parsed.scheme not in ["http", "https"]:
|
||||||
|
raise OAuthValidationError("Redirect URI must use HTTP or HTTPS", "redirect_uri")
|
||||||
|
|
||||||
|
# Additional security for production
|
||||||
|
if parsed.scheme == "http" and parsed.hostname not in ["localhost", "127.0.0.1"]:
|
||||||
|
logger.warning(f"Insecure redirect URI: {uri}")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_client_id(client_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Validate an OAuth client ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_id: The client ID to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OAuthValidationError: If the client ID is invalid
|
||||||
|
"""
|
||||||
|
if not client_id:
|
||||||
|
raise OAuthValidationError("Client ID is required", "client_id")
|
||||||
|
|
||||||
|
if len(client_id) < 10:
|
||||||
|
raise OAuthValidationError("Client ID is too short", "client_id")
|
||||||
|
|
||||||
|
# Basic format validation for Google client IDs
|
||||||
|
if not re.match(r'^[a-zA-Z0-9\-_.]+$', client_id):
|
||||||
|
raise OAuthValidationError("Client ID contains invalid characters", "client_id")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_authorization_code(code: str) -> None:
|
||||||
|
"""
|
||||||
|
Validate an OAuth authorization code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: The authorization code to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OAuthValidationError: If the code is invalid
|
||||||
|
"""
|
||||||
|
if not code:
|
||||||
|
raise OAuthValidationError("Authorization code is required", "code")
|
||||||
|
|
||||||
|
if len(code) < 10:
|
||||||
|
raise OAuthValidationError("Authorization code is too short", "code")
|
||||||
|
|
||||||
|
# Check for suspicious patterns
|
||||||
|
if any(char in code for char in [' ', '\n', '\t', '<', '>']):
|
||||||
|
raise OAuthValidationError("Authorization code contains invalid characters", "code")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_scopes(scopes: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
Validate OAuth scopes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scopes: List of scopes to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OAuthValidationError: If the scopes are invalid
|
||||||
|
"""
|
||||||
|
if not scopes:
|
||||||
|
return # Empty scopes list is acceptable
|
||||||
|
|
||||||
|
for scope in scopes:
|
||||||
|
if not scope:
|
||||||
|
raise OAuthValidationError("Empty scope is not allowed", "scope")
|
||||||
|
|
||||||
|
if len(scope) > 200:
|
||||||
|
raise OAuthValidationError("Scope is too long", "scope")
|
||||||
|
|
||||||
|
# Basic scope format validation
|
||||||
|
if not re.match(r'^[a-zA-Z0-9\-_.:/]+$', scope):
|
||||||
|
raise OAuthValidationError(f"Invalid scope format: {scope}", "scope")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_token_request(request_data: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Validate an OAuth token exchange request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_data: The token request data to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OAuthValidationError: If the request is invalid
|
||||||
|
"""
|
||||||
|
grant_type = request_data.get("grant_type")
|
||||||
|
if not grant_type:
|
||||||
|
raise OAuthValidationError("Grant type is required", "grant_type")
|
||||||
|
|
||||||
|
if grant_type not in ["authorization_code", "refresh_token"]:
|
||||||
|
raise OAuthValidationError(f"Unsupported grant type: {grant_type}", "grant_type")
|
||||||
|
|
||||||
|
if grant_type == "authorization_code":
|
||||||
|
code = request_data.get("code")
|
||||||
|
validate_authorization_code(code)
|
||||||
|
|
||||||
|
redirect_uri = request_data.get("redirect_uri")
|
||||||
|
if redirect_uri:
|
||||||
|
validate_redirect_uri(redirect_uri)
|
||||||
|
|
||||||
|
client_id = request_data.get("client_id")
|
||||||
|
if client_id:
|
||||||
|
validate_client_id(client_id)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_registration_request(request_data: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Validate an OAuth client registration request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_data: The registration request data to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OAuthValidationError: If the request is invalid
|
||||||
|
"""
|
||||||
|
# Validate redirect URIs if provided
|
||||||
|
redirect_uris = request_data.get("redirect_uris", [])
|
||||||
|
if redirect_uris:
|
||||||
|
if not isinstance(redirect_uris, list):
|
||||||
|
raise OAuthValidationError("redirect_uris must be an array", "redirect_uris")
|
||||||
|
|
||||||
|
for uri in redirect_uris:
|
||||||
|
validate_redirect_uri(uri)
|
||||||
|
|
||||||
|
# Validate grant types if provided
|
||||||
|
grant_types = request_data.get("grant_types", [])
|
||||||
|
if grant_types:
|
||||||
|
if not isinstance(grant_types, list):
|
||||||
|
raise OAuthValidationError("grant_types must be an array", "grant_types")
|
||||||
|
|
||||||
|
allowed_grant_types = ["authorization_code", "refresh_token"]
|
||||||
|
for grant_type in grant_types:
|
||||||
|
if grant_type not in allowed_grant_types:
|
||||||
|
raise OAuthValidationError(f"Unsupported grant type: {grant_type}", "grant_types")
|
||||||
|
|
||||||
|
# Validate response types if provided
|
||||||
|
response_types = request_data.get("response_types", [])
|
||||||
|
if response_types:
|
||||||
|
if not isinstance(response_types, list):
|
||||||
|
raise OAuthValidationError("response_types must be an array", "response_types")
|
||||||
|
|
||||||
|
allowed_response_types = ["code"]
|
||||||
|
for response_type in response_types:
|
||||||
|
if response_type not in allowed_response_types:
|
||||||
|
raise OAuthValidationError(f"Unsupported response type: {response_type}", "response_types")
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_user_input(value: str, max_length: int = 1000) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize user input to prevent injection attacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The input value to sanitize
|
||||||
|
max_length: Maximum allowed length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized input value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OAuthValidationError: If the input is invalid
|
||||||
|
"""
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise OAuthValidationError("Input must be a string")
|
||||||
|
|
||||||
|
if len(value) > max_length:
|
||||||
|
raise OAuthValidationError(f"Input is too long (max {max_length} characters)")
|
||||||
|
|
||||||
|
# Remove potentially dangerous characters
|
||||||
|
sanitized = re.sub(r'[<>"\'\0\n\r\t]', '', value)
|
||||||
|
|
||||||
|
return sanitized.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def log_security_event(event_type: str, details: Dict[str, Any], request: Optional[Request] = None) -> None:
|
||||||
|
"""
|
||||||
|
Log security-related events for monitoring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: Type of security event
|
||||||
|
details: Event details
|
||||||
|
request: Optional request object for context
|
||||||
|
"""
|
||||||
|
log_data = {
|
||||||
|
"event_type": event_type,
|
||||||
|
"details": details
|
||||||
|
}
|
||||||
|
|
||||||
|
if request:
|
||||||
|
log_data["request"] = {
|
||||||
|
"method": request.method,
|
||||||
|
"path": request.url.path,
|
||||||
|
"user_agent": request.headers.get("user-agent", "unknown"),
|
||||||
|
"origin": request.headers.get("origin", "unknown")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.warning(f"Security event: {log_data}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_development_cors_headers(origin: Optional[str] = None) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Get minimal CORS headers for development scenarios only.
|
||||||
|
|
||||||
|
Only allows localhost origins for development tools and inspectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin: The request origin (will be validated)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CORS headers for localhost origins only, empty dict otherwise
|
||||||
|
"""
|
||||||
|
# Only allow localhost origins for development
|
||||||
|
if origin and (origin.startswith("http://localhost:") or origin.startswith("http://127.0.0.1:")):
|
||||||
|
return {
|
||||||
|
"Access-Control-Allow-Origin": origin,
|
||||||
|
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
||||||
|
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
||||||
|
"Access-Control-Max-Age": "3600"
|
||||||
|
}
|
||||||
|
|
||||||
|
return {}
|
||||||
78
auth/oauth_types.py
Normal file
78
auth/oauth_types.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
Type definitions for OAuth authentication.
|
||||||
|
|
||||||
|
This module provides structured types for OAuth-related parameters,
|
||||||
|
improving code maintainability and type safety.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuth21ServiceRequest:
|
||||||
|
"""
|
||||||
|
Encapsulates parameters for OAuth 2.1 service authentication requests.
|
||||||
|
|
||||||
|
This parameter object pattern reduces function complexity and makes
|
||||||
|
it easier to extend authentication parameters in the future.
|
||||||
|
"""
|
||||||
|
service_name: str
|
||||||
|
version: str
|
||||||
|
tool_name: str
|
||||||
|
user_google_email: str
|
||||||
|
required_scopes: List[str]
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
auth_token_email: Optional[str] = None
|
||||||
|
allow_recent_auth: bool = False
|
||||||
|
context: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def to_legacy_params(self) -> dict:
|
||||||
|
"""Convert to legacy parameter format for backward compatibility."""
|
||||||
|
return {
|
||||||
|
"service_name": self.service_name,
|
||||||
|
"version": self.version,
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
"user_google_email": self.user_google_email,
|
||||||
|
"required_scopes": self.required_scopes,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthVersionDetectionParams:
|
||||||
|
"""
|
||||||
|
Parameters used for OAuth version detection.
|
||||||
|
|
||||||
|
Encapsulates the various signals we use to determine
|
||||||
|
whether a client supports OAuth 2.1 or needs OAuth 2.0.
|
||||||
|
"""
|
||||||
|
client_id: Optional[str] = None
|
||||||
|
client_secret: Optional[str] = None
|
||||||
|
code_challenge: Optional[str] = None
|
||||||
|
code_challenge_method: Optional[str] = None
|
||||||
|
code_verifier: Optional[str] = None
|
||||||
|
authenticated_user: Optional[str] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_request(cls, request_params: Dict[str, Any]) -> "OAuthVersionDetectionParams":
|
||||||
|
"""Create from raw request parameters."""
|
||||||
|
return cls(
|
||||||
|
client_id=request_params.get("client_id"),
|
||||||
|
client_secret=request_params.get("client_secret"),
|
||||||
|
code_challenge=request_params.get("code_challenge"),
|
||||||
|
code_challenge_method=request_params.get("code_challenge_method"),
|
||||||
|
code_verifier=request_params.get("code_verifier"),
|
||||||
|
authenticated_user=request_params.get("authenticated_user"),
|
||||||
|
session_id=request_params.get("session_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_pkce(self) -> bool:
|
||||||
|
"""Check if PKCE parameters are present."""
|
||||||
|
return bool(self.code_challenge or self.code_verifier)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_public_client(self) -> bool:
|
||||||
|
"""Check if this appears to be a public client (no secret)."""
|
||||||
|
return bool(self.client_id and not self.client_secret)
|
||||||
@@ -334,9 +334,31 @@ def require_google_service(
|
|||||||
# Log authentication status
|
# Log authentication status
|
||||||
logger.debug(f"[{tool_name}] Auth: {authenticated_user or 'none'} via {auth_method or 'none'} (session: {mcp_session_id[:8] if mcp_session_id else 'none'})")
|
logger.debug(f"[{tool_name}] Auth: {authenticated_user or 'none'} via {auth_method or 'none'} (session: {mcp_session_id[:8] if mcp_session_id else 'none'})")
|
||||||
|
|
||||||
from auth.oauth21_integration import is_oauth21_enabled
|
from auth.oauth_config import is_oauth21_enabled, get_oauth_config
|
||||||
|
|
||||||
|
# Smart OAuth version detection and fallback
|
||||||
|
use_oauth21 = False
|
||||||
|
oauth_version = "oauth20" # Default
|
||||||
|
|
||||||
if is_oauth21_enabled():
|
if is_oauth21_enabled():
|
||||||
|
# OAuth 2.1 is enabled globally, check client capabilities
|
||||||
|
# Try to detect from context if this is an OAuth 2.1 capable client
|
||||||
|
config = get_oauth_config()
|
||||||
|
|
||||||
|
# Build request params from context for version detection
|
||||||
|
request_params = {}
|
||||||
|
if authenticated_user:
|
||||||
|
request_params["authenticated_user"] = authenticated_user
|
||||||
|
if mcp_session_id:
|
||||||
|
request_params["session_id"] = mcp_session_id
|
||||||
|
|
||||||
|
# Detect OAuth version based on client capabilities
|
||||||
|
oauth_version = config.detect_oauth_version(request_params)
|
||||||
|
use_oauth21 = (oauth_version == "oauth21")
|
||||||
|
|
||||||
|
logger.debug(f"[{tool_name}] OAuth version detected: {oauth_version}, will use OAuth 2.1: {use_oauth21}")
|
||||||
|
|
||||||
|
if use_oauth21:
|
||||||
logger.debug(f"[{tool_name}] Using OAuth 2.1 flow")
|
logger.debug(f"[{tool_name}] Using OAuth 2.1 flow")
|
||||||
# The downstream get_authenticated_google_service_oauth21 will handle
|
# The downstream get_authenticated_google_service_oauth21 will handle
|
||||||
# whether the user's token is valid for the requested resource.
|
# whether the user's token is valid for the requested resource.
|
||||||
@@ -352,8 +374,8 @@ def require_google_service(
|
|||||||
allow_recent_auth=False,
|
allow_recent_auth=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# If OAuth 2.1 is not enabled, always use the legacy authentication method.
|
# Use legacy OAuth 2.0 authentication
|
||||||
logger.debug(f"[{tool_name}] Using legacy OAuth flow")
|
logger.debug(f"[{tool_name}] Using legacy OAuth 2.0 flow")
|
||||||
service, actual_user_email = await get_authenticated_google_service(
|
service, actual_user_email = await get_authenticated_google_service(
|
||||||
service_name=service_name,
|
service_name=service_name,
|
||||||
version=service_version,
|
version=service_version,
|
||||||
@@ -464,7 +486,7 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
|||||||
logger.debug(f"[{tool_name}] Could not get FastMCP context: {e}")
|
logger.debug(f"[{tool_name}] Could not get FastMCP context: {e}")
|
||||||
|
|
||||||
# Use the same logic as single service decorator
|
# Use the same logic as single service decorator
|
||||||
from auth.oauth21_integration import is_oauth21_enabled
|
from auth.oauth_config import is_oauth21_enabled
|
||||||
|
|
||||||
if is_oauth21_enabled():
|
if is_oauth21_enabled():
|
||||||
logger.debug(f"[{tool_name}] Attempting OAuth 2.1 authentication flow for {service_type}.")
|
logger.debug(f"[{tool_name}] Attempting OAuth 2.1 authentication flow for {service_type}.")
|
||||||
|
|||||||
@@ -2,57 +2,34 @@
|
|||||||
Shared configuration for Google Workspace MCP server.
|
Shared configuration for Google Workspace MCP server.
|
||||||
This module holds configuration values that need to be shared across modules
|
This module holds configuration values that need to be shared across modules
|
||||||
to avoid circular imports.
|
to avoid circular imports.
|
||||||
|
|
||||||
|
NOTE: OAuth configuration has been moved to auth.oauth_config for centralization.
|
||||||
|
This module now imports from there for backward compatibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from auth.oauth_config import (
|
||||||
|
get_oauth_base_url,
|
||||||
|
get_oauth_redirect_uri,
|
||||||
|
set_transport_mode,
|
||||||
|
get_transport_mode,
|
||||||
|
is_oauth21_enabled
|
||||||
|
)
|
||||||
|
|
||||||
# Server configuration
|
# Server configuration
|
||||||
WORKSPACE_MCP_PORT = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
|
WORKSPACE_MCP_PORT = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
|
||||||
WORKSPACE_MCP_BASE_URI = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
|
WORKSPACE_MCP_BASE_URI = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
|
||||||
|
|
||||||
# Disable USER_GOOGLE_EMAIL in OAuth 2.1 multi-user mode
|
# Disable USER_GOOGLE_EMAIL in OAuth 2.1 multi-user mode
|
||||||
_oauth21_enabled = os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true"
|
USER_GOOGLE_EMAIL = None if is_oauth21_enabled() else os.getenv("USER_GOOGLE_EMAIL", None)
|
||||||
USER_GOOGLE_EMAIL = None if _oauth21_enabled else os.getenv("USER_GOOGLE_EMAIL", None)
|
|
||||||
|
|
||||||
# Transport mode (will be set by main.py)
|
# Re-export OAuth functions for backward compatibility
|
||||||
_current_transport_mode = "stdio" # Default to stdio
|
__all__ = [
|
||||||
|
'WORKSPACE_MCP_PORT',
|
||||||
|
'WORKSPACE_MCP_BASE_URI',
|
||||||
def set_transport_mode(mode: str):
|
'USER_GOOGLE_EMAIL',
|
||||||
"""Set the current transport mode for OAuth callback handling."""
|
'get_oauth_base_url',
|
||||||
global _current_transport_mode
|
'get_oauth_redirect_uri',
|
||||||
_current_transport_mode = mode
|
'set_transport_mode',
|
||||||
|
'get_transport_mode'
|
||||||
|
]
|
||||||
def get_transport_mode() -> str:
|
|
||||||
"""Get the current transport mode."""
|
|
||||||
return _current_transport_mode
|
|
||||||
|
|
||||||
|
|
||||||
# OAuth Configuration
|
|
||||||
# Determine base URL and redirect URI once at startup
|
|
||||||
_OAUTH_REDIRECT_URI = os.getenv("GOOGLE_OAUTH_REDIRECT_URI")
|
|
||||||
if _OAUTH_REDIRECT_URI:
|
|
||||||
# Extract base URL from the redirect URI (remove the /oauth2callback path)
|
|
||||||
_OAUTH_BASE_URL = _OAUTH_REDIRECT_URI.removesuffix("/oauth2callback")
|
|
||||||
else:
|
|
||||||
# Construct from base URI and port if not explicitly set
|
|
||||||
_OAUTH_BASE_URL = f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}"
|
|
||||||
_OAUTH_REDIRECT_URI = f"{_OAUTH_BASE_URL}/oauth2callback"
|
|
||||||
|
|
||||||
def get_oauth_base_url() -> str:
|
|
||||||
"""Get OAuth base URL for constructing OAuth endpoints.
|
|
||||||
|
|
||||||
Returns the base URL (without paths) for OAuth endpoints,
|
|
||||||
respecting GOOGLE_OAUTH_REDIRECT_URI for reverse proxy scenarios.
|
|
||||||
"""
|
|
||||||
return _OAUTH_BASE_URL
|
|
||||||
|
|
||||||
def get_oauth_redirect_uri() -> str:
|
|
||||||
"""Get OAuth redirect URI based on current configuration.
|
|
||||||
|
|
||||||
Returns the redirect URI configured at startup, either from
|
|
||||||
GOOGLE_OAUTH_REDIRECT_URI environment variable or constructed
|
|
||||||
from WORKSPACE_MCP_BASE_URI and WORKSPACE_MCP_PORT.
|
|
||||||
"""
|
|
||||||
return _OAUTH_REDIRECT_URI
|
|
||||||
@@ -7,7 +7,6 @@ from fastapi.responses import HTMLResponse, JSONResponse
|
|||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.middleware import Middleware
|
from starlette.middleware import Middleware
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
@@ -19,8 +18,6 @@ from auth.auth_info_middleware import AuthInfoMiddleware
|
|||||||
from auth.fastmcp_google_auth import GoogleWorkspaceAuthProvider
|
from auth.fastmcp_google_auth import GoogleWorkspaceAuthProvider
|
||||||
from auth.scopes import SCOPES
|
from auth.scopes import SCOPES
|
||||||
from core.config import (
|
from core.config import (
|
||||||
WORKSPACE_MCP_PORT,
|
|
||||||
WORKSPACE_MCP_BASE_URI,
|
|
||||||
USER_GOOGLE_EMAIL,
|
USER_GOOGLE_EMAIL,
|
||||||
get_transport_mode,
|
get_transport_mode,
|
||||||
set_transport_mode as _set_transport_mode,
|
set_transport_mode as _set_transport_mode,
|
||||||
@@ -41,31 +38,25 @@ logger = logging.getLogger(__name__)
|
|||||||
_auth_provider: Optional[Union[GoogleWorkspaceAuthProvider, GoogleRemoteAuthProvider]] = None
|
_auth_provider: Optional[Union[GoogleWorkspaceAuthProvider, GoogleRemoteAuthProvider]] = None
|
||||||
|
|
||||||
# --- Middleware Definitions ---
|
# --- Middleware Definitions ---
|
||||||
cors_middleware = Middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
session_middleware = Middleware(MCPSessionMiddleware)
|
session_middleware = Middleware(MCPSessionMiddleware)
|
||||||
|
|
||||||
# Custom FastMCP that adds CORS to streamable HTTP
|
# Custom FastMCP that adds secure middleware stack for OAuth 2.1
|
||||||
class CORSEnabledFastMCP(FastMCP):
|
class SecureFastMCP(FastMCP):
|
||||||
def streamable_http_app(self) -> "Starlette":
|
def streamable_http_app(self) -> "Starlette":
|
||||||
"""Override to add CORS and session middleware to the app."""
|
"""Override to add secure middleware stack for OAuth 2.1."""
|
||||||
app = super().streamable_http_app()
|
app = super().streamable_http_app()
|
||||||
# Add session middleware first (to set context before other middleware)
|
|
||||||
|
# Add middleware in order (first added = outermost layer)
|
||||||
|
# Session Management - extracts session info for MCP context
|
||||||
app.user_middleware.insert(0, session_middleware)
|
app.user_middleware.insert(0, session_middleware)
|
||||||
# Add CORS as the second middleware
|
|
||||||
app.user_middleware.insert(1, cors_middleware)
|
|
||||||
# Rebuild middleware stack
|
# Rebuild middleware stack
|
||||||
app.middleware_stack = app.build_middleware_stack()
|
app.middleware_stack = app.build_middleware_stack()
|
||||||
logger.info("Added session and CORS middleware to streamable HTTP app")
|
logger.info("Added middleware stack: Session Management")
|
||||||
return app
|
return app
|
||||||
|
|
||||||
# --- Server Instance ---
|
# --- Server Instance ---
|
||||||
server = CORSEnabledFastMCP(
|
server = SecureFastMCP(
|
||||||
name="google_workspace",
|
name="google_workspace",
|
||||||
auth=None,
|
auth=None,
|
||||||
)
|
)
|
||||||
@@ -86,32 +77,37 @@ def configure_server_for_http():
|
|||||||
This must be called BEFORE server.run().
|
This must be called BEFORE server.run().
|
||||||
"""
|
"""
|
||||||
global _auth_provider
|
global _auth_provider
|
||||||
|
|
||||||
transport_mode = get_transport_mode()
|
transport_mode = get_transport_mode()
|
||||||
|
|
||||||
if transport_mode != "streamable-http":
|
if transport_mode != "streamable-http":
|
||||||
return
|
return
|
||||||
|
|
||||||
oauth21_enabled = os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true"
|
# Use centralized OAuth configuration
|
||||||
|
from auth.oauth_config import get_oauth_config
|
||||||
|
config = get_oauth_config()
|
||||||
|
|
||||||
|
# Check if OAuth 2.1 is enabled via centralized config
|
||||||
|
oauth21_enabled = config.is_oauth21_enabled()
|
||||||
|
|
||||||
if oauth21_enabled:
|
if oauth21_enabled:
|
||||||
if not os.getenv("GOOGLE_OAUTH_CLIENT_ID"):
|
if not config.is_configured():
|
||||||
logger.warning("⚠️ OAuth 2.1 enabled but GOOGLE_OAUTH_CLIENT_ID not set")
|
logger.warning("⚠️ OAuth 2.1 enabled but OAuth credentials not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
if GOOGLE_REMOTE_AUTH_AVAILABLE:
|
if GOOGLE_REMOTE_AUTH_AVAILABLE:
|
||||||
logger.info("🔐 OAuth 2.1 enabled")
|
logger.info("🔐 OAuth 2.1 enabled with automatic OAuth 2.0 fallback for legacy clients")
|
||||||
try:
|
try:
|
||||||
_auth_provider = GoogleRemoteAuthProvider()
|
_auth_provider = GoogleRemoteAuthProvider()
|
||||||
server.auth = _auth_provider
|
server.auth = _auth_provider
|
||||||
set_auth_provider(_auth_provider)
|
set_auth_provider(_auth_provider)
|
||||||
from auth.oauth21_integration import enable_oauth21
|
logger.debug("OAuth 2.1 authentication enabled")
|
||||||
enable_oauth21()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize GoogleRemoteAuthProvider: {e}", exc_info=True)
|
logger.error(f"Failed to initialize GoogleRemoteAuthProvider: {e}", exc_info=True)
|
||||||
else:
|
else:
|
||||||
logger.error("OAuth 2.1 is enabled, but GoogleRemoteAuthProvider is not available.")
|
logger.error("OAuth 2.1 is enabled, but GoogleRemoteAuthProvider is not available.")
|
||||||
else:
|
else:
|
||||||
logger.info("OAuth 2.1 is DISABLED. Server will use legacy tool-based authentication.")
|
logger.info("OAuth 2.0 mode - Server will use legacy authentication.")
|
||||||
server.auth = None
|
server.auth = None
|
||||||
|
|
||||||
def get_auth_provider() -> Optional[Union[GoogleWorkspaceAuthProvider, GoogleRemoteAuthProvider]]:
|
def get_auth_provider() -> Optional[Union[GoogleWorkspaceAuthProvider, GoogleRemoteAuthProvider]]:
|
||||||
|
|||||||
@@ -1,256 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Auto-installer for Google Workspace MCP in Claude Desktop
|
|
||||||
Enhanced version with OAuth configuration and installation options
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
def get_claude_config_path() -> Path:
|
|
||||||
"""Get the Claude Desktop config file path for the current platform."""
|
|
||||||
system = platform.system()
|
|
||||||
if system == "Darwin": # macOS
|
|
||||||
return Path.home() / "Library/Application Support/Claude/claude_desktop_config.json"
|
|
||||||
elif system == "Windows":
|
|
||||||
appdata = os.environ.get("APPDATA")
|
|
||||||
if not appdata:
|
|
||||||
raise RuntimeError("APPDATA environment variable not found")
|
|
||||||
return Path(appdata) / "Claude/claude_desktop_config.json"
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported platform: {system}")
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_yes_no(question: str, default: bool = True) -> bool:
|
|
||||||
"""Prompt user for yes/no question."""
|
|
||||||
default_str = "Y/n" if default else "y/N"
|
|
||||||
while True:
|
|
||||||
response = input(f"{question} [{default_str}]: ").strip().lower()
|
|
||||||
if not response:
|
|
||||||
return default
|
|
||||||
if response in ['y', 'yes']:
|
|
||||||
return True
|
|
||||||
if response in ['n', 'no']:
|
|
||||||
return False
|
|
||||||
print("Please answer 'y' or 'n'")
|
|
||||||
|
|
||||||
|
|
||||||
def get_oauth_credentials() -> Tuple[Optional[Dict[str, str]], Optional[str]]:
|
|
||||||
"""Get OAuth credentials from user."""
|
|
||||||
print("\n🔑 OAuth Credentials Setup")
|
|
||||||
print("You need Google OAuth 2.0 credentials to use this server.")
|
|
||||||
print("\nYou can provide credentials in two ways:")
|
|
||||||
print("1. Environment variables (recommended for production)")
|
|
||||||
print("2. Client secrets JSON file")
|
|
||||||
|
|
||||||
use_env = prompt_yes_no("\nDo you want to use environment variables?", default=True)
|
|
||||||
|
|
||||||
env_vars = {}
|
|
||||||
client_secret_path = None
|
|
||||||
|
|
||||||
if use_env:
|
|
||||||
print("\n📝 Enter your OAuth credentials:")
|
|
||||||
client_id = input("Client ID (ends with .apps.googleusercontent.com): ").strip()
|
|
||||||
client_secret = input("Client Secret: ").strip()
|
|
||||||
|
|
||||||
if not client_id or not client_secret:
|
|
||||||
print("❌ Both Client ID and Client Secret are required!")
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
env_vars["GOOGLE_OAUTH_CLIENT_ID"] = client_id
|
|
||||||
env_vars["GOOGLE_OAUTH_CLIENT_SECRET"] = client_secret
|
|
||||||
|
|
||||||
# Optional redirect URI
|
|
||||||
custom_redirect = input("Redirect URI (press Enter for default http://localhost:8000/oauth2callback): ").strip()
|
|
||||||
if custom_redirect:
|
|
||||||
env_vars["GOOGLE_OAUTH_REDIRECT_URI"] = custom_redirect
|
|
||||||
|
|
||||||
else:
|
|
||||||
print("\n📁 Client secrets file setup:")
|
|
||||||
default_path = "client_secret.json"
|
|
||||||
file_path = input(f"Path to client_secret.json file [{default_path}]: ").strip()
|
|
||||||
|
|
||||||
if not file_path:
|
|
||||||
file_path = default_path
|
|
||||||
|
|
||||||
# Check if file exists
|
|
||||||
if not Path(file_path).exists():
|
|
||||||
print(f"❌ File not found: {file_path}")
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
client_secret_path = file_path
|
|
||||||
|
|
||||||
# Optional: Default user email
|
|
||||||
print("\n📧 Optional: Default user email (for single-user setups)")
|
|
||||||
user_email = input("Your Google email (press Enter to skip): ").strip()
|
|
||||||
if user_email:
|
|
||||||
env_vars["USER_GOOGLE_EMAIL"] = user_email
|
|
||||||
|
|
||||||
# Development mode
|
|
||||||
if prompt_yes_no("\n🔧 Enable development mode (OAUTHLIB_INSECURE_TRANSPORT)?", default=False):
|
|
||||||
env_vars["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
|
|
||||||
|
|
||||||
return env_vars, client_secret_path
|
|
||||||
|
|
||||||
|
|
||||||
def get_installation_options() -> Dict[str, any]:
|
|
||||||
"""Get installation options from user."""
|
|
||||||
options = {}
|
|
||||||
|
|
||||||
print("\n⚙️ Installation Options")
|
|
||||||
|
|
||||||
# Installation method
|
|
||||||
print("\nChoose installation method:")
|
|
||||||
print("1. uvx (recommended - auto-installs from PyPI)")
|
|
||||||
print("2. Development mode (requires local repository)")
|
|
||||||
|
|
||||||
method = input("Select method [1]: ").strip()
|
|
||||||
if method == "2":
|
|
||||||
options["dev_mode"] = True
|
|
||||||
cwd = input("Path to google_workspace_mcp repository [current directory]: ").strip()
|
|
||||||
options["cwd"] = cwd if cwd else os.getcwd()
|
|
||||||
else:
|
|
||||||
options["dev_mode"] = False
|
|
||||||
|
|
||||||
# Single-user mode
|
|
||||||
if prompt_yes_no("\n👤 Enable single-user mode (simplified authentication)?", default=False):
|
|
||||||
options["single_user"] = True
|
|
||||||
|
|
||||||
# Tool selection
|
|
||||||
print("\n🛠️ Tool Selection")
|
|
||||||
print("Available tools: gmail, drive, calendar, docs, sheets, forms, chat")
|
|
||||||
print("Leave empty to enable all tools")
|
|
||||||
tools = input("Enter tools to enable (comma-separated): ").strip()
|
|
||||||
if tools:
|
|
||||||
options["tools"] = [t.strip() for t in tools.split(",")]
|
|
||||||
|
|
||||||
# Transport mode
|
|
||||||
if prompt_yes_no("\n🌐 Use HTTP transport mode (for debugging)?", default=False):
|
|
||||||
options["http_mode"] = True
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
|
|
||||||
def create_server_config(options: Dict, env_vars: Dict, client_secret_path: Optional[str]) -> Dict:
|
|
||||||
"""Create the server configuration."""
|
|
||||||
config = {}
|
|
||||||
|
|
||||||
if options.get("dev_mode"):
|
|
||||||
config["command"] = "uv"
|
|
||||||
config["args"] = ["run", "--directory", options["cwd"], "main.py"]
|
|
||||||
else:
|
|
||||||
config["command"] = "uvx"
|
|
||||||
config["args"] = ["workspace-mcp"]
|
|
||||||
|
|
||||||
# Add command line arguments
|
|
||||||
if options.get("single_user"):
|
|
||||||
config["args"].append("--single-user")
|
|
||||||
|
|
||||||
if options.get("tools"):
|
|
||||||
config["args"].extend(["--tools"] + options["tools"])
|
|
||||||
|
|
||||||
if options.get("http_mode"):
|
|
||||||
config["args"].extend(["--transport", "streamable-http"])
|
|
||||||
|
|
||||||
# Add environment variables
|
|
||||||
if env_vars or client_secret_path:
|
|
||||||
config["env"] = {}
|
|
||||||
|
|
||||||
if env_vars:
|
|
||||||
config["env"].update(env_vars)
|
|
||||||
|
|
||||||
if client_secret_path:
|
|
||||||
config["env"]["GOOGLE_CLIENT_SECRET_PATH"] = client_secret_path
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
print("🚀 Google Workspace MCP Installer for Claude Desktop")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
config_path = get_claude_config_path()
|
|
||||||
|
|
||||||
# Check if config already exists
|
|
||||||
existing_config = {}
|
|
||||||
if config_path.exists():
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
existing_config = json.load(f)
|
|
||||||
|
|
||||||
if "mcpServers" in existing_config and "Google Workspace" in existing_config["mcpServers"]:
|
|
||||||
print(f"\n⚠️ Google Workspace MCP is already configured in {config_path}")
|
|
||||||
if not prompt_yes_no("Do you want to reconfigure it?", default=True):
|
|
||||||
print("Installation cancelled.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get OAuth credentials
|
|
||||||
env_vars, client_secret_path = get_oauth_credentials()
|
|
||||||
if env_vars is None and client_secret_path is None:
|
|
||||||
print("\n❌ OAuth credentials are required. Installation cancelled.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Get installation options
|
|
||||||
options = get_installation_options()
|
|
||||||
|
|
||||||
# Create server configuration
|
|
||||||
server_config = create_server_config(options, env_vars, client_secret_path)
|
|
||||||
|
|
||||||
# Prepare final config
|
|
||||||
if "mcpServers" not in existing_config:
|
|
||||||
existing_config["mcpServers"] = {}
|
|
||||||
|
|
||||||
existing_config["mcpServers"]["Google Workspace"] = server_config
|
|
||||||
|
|
||||||
# Create directory if needed
|
|
||||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Write configuration
|
|
||||||
with open(config_path, 'w') as f:
|
|
||||||
json.dump(existing_config, f, indent=2)
|
|
||||||
|
|
||||||
print("\n✅ Successfully configured Google Workspace MCP!")
|
|
||||||
print(f"📁 Config file: {config_path}")
|
|
||||||
|
|
||||||
print("\n📋 Configuration Summary:")
|
|
||||||
print(f" • Installation method: {'Development' if options.get('dev_mode') else 'uvx (PyPI)'}")
|
|
||||||
print(f" • Authentication: {'Environment variables' if env_vars else 'Client secrets file'}")
|
|
||||||
if options.get("single_user"):
|
|
||||||
print(" • Single-user mode: Enabled")
|
|
||||||
if options.get("tools"):
|
|
||||||
print(f" • Tools: {', '.join(options['tools'])}")
|
|
||||||
else:
|
|
||||||
print(" • Tools: All enabled")
|
|
||||||
if options.get("http_mode"):
|
|
||||||
print(" • Transport: HTTP mode")
|
|
||||||
else:
|
|
||||||
print(" • Transport: stdio (default)")
|
|
||||||
|
|
||||||
print("\n🚀 Next steps:")
|
|
||||||
print("1. Restart Claude Desktop")
|
|
||||||
print("2. The Google Workspace tools will be available in your chats!")
|
|
||||||
print("\n💡 The server will start automatically when Claude Desktop needs it.")
|
|
||||||
|
|
||||||
if options.get("http_mode"):
|
|
||||||
print("\n⚠️ Note: HTTP mode requires additional setup.")
|
|
||||||
print(" You may need to install and configure mcp-remote.")
|
|
||||||
print(" See the README for details.")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n\nInstallation cancelled by user.")
|
|
||||||
sys.exit(0)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Error: {e}")
|
|
||||||
print("\n📋 Manual installation:")
|
|
||||||
print("1. Open Claude Desktop Settings → Developer → Edit Config")
|
|
||||||
print("2. Add the server configuration shown in the README")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
18
main.py
18
main.py
@@ -4,17 +4,19 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from importlib import metadata
|
from importlib import metadata
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables from .env file BEFORE any other imports
|
||||||
|
dotenv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.env')
|
||||||
|
load_dotenv(dotenv_path=dotenv_path)
|
||||||
|
|
||||||
from core.server import server, set_transport_mode, configure_server_for_http
|
from core.server import server, set_transport_mode, configure_server_for_http
|
||||||
|
from auth.oauth_config import reload_oauth_config
|
||||||
|
reload_oauth_config()
|
||||||
|
|
||||||
# Suppress googleapiclient discovery cache warning
|
# Suppress googleapiclient discovery cache warning
|
||||||
logging.getLogger('googleapiclient.discovery_cache').setLevel(logging.ERROR)
|
logging.getLogger('googleapiclient.discovery_cache').setLevel(logging.ERROR)
|
||||||
from core.utils import check_credentials_directory_permissions
|
from core.utils import check_credentials_directory_permissions
|
||||||
|
|
||||||
# Load environment variables from .env file, specifying an explicit path
|
|
||||||
# This prevents accidentally loading a .env file from a different directory
|
|
||||||
dotenv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.env')
|
|
||||||
load_dotenv(dotenv_path=dotenv_path)
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
@@ -92,6 +94,7 @@ def main():
|
|||||||
# Active Configuration
|
# Active Configuration
|
||||||
safe_print("⚙️ Active Configuration:")
|
safe_print("⚙️ Active Configuration:")
|
||||||
|
|
||||||
|
|
||||||
# Redact client secret for security
|
# Redact client secret for security
|
||||||
client_secret = os.getenv('GOOGLE_OAUTH_CLIENT_SECRET', 'Not Set')
|
client_secret = os.getenv('GOOGLE_OAUTH_CLIENT_SECRET', 'Not Set')
|
||||||
redacted_secret = f"{client_secret[:4]}...{client_secret[-4:]}" if len(client_secret) > 8 else "Invalid or too short"
|
redacted_secret = f"{client_secret[:4]}...{client_secret[-4:]}" if len(client_secret) > 8 else "Invalid or too short"
|
||||||
@@ -153,7 +156,6 @@ def main():
|
|||||||
|
|
||||||
safe_print("📊 Configuration Summary:")
|
safe_print("📊 Configuration Summary:")
|
||||||
safe_print(f" 🔧 Tools Enabled: {len(tools_to_import)}/{len(tool_imports)}")
|
safe_print(f" 🔧 Tools Enabled: {len(tools_to_import)}/{len(tool_imports)}")
|
||||||
safe_print(" 🔑 Auth Method: OAuth 2.0 with PKCE")
|
|
||||||
safe_print(f" 📝 Log Level: {logging.getLogger().getEffectiveLevel()}")
|
safe_print(f" 📝 Log Level: {logging.getLogger().getEffectiveLevel()}")
|
||||||
safe_print("")
|
safe_print("")
|
||||||
|
|
||||||
@@ -182,10 +184,10 @@ def main():
|
|||||||
# Configure auth initialization for FastMCP lifecycle events
|
# Configure auth initialization for FastMCP lifecycle events
|
||||||
if args.transport == 'streamable-http':
|
if args.transport == 'streamable-http':
|
||||||
configure_server_for_http()
|
configure_server_for_http()
|
||||||
safe_print(f"")
|
safe_print("")
|
||||||
safe_print(f"🚀 Starting HTTP server on {base_uri}:{port}")
|
safe_print(f"🚀 Starting HTTP server on {base_uri}:{port}")
|
||||||
else:
|
else:
|
||||||
safe_print(f"")
|
safe_print("")
|
||||||
safe_print("🚀 Starting STDIO server")
|
safe_print("🚀 Starting STDIO server")
|
||||||
# Start minimal OAuth callback server for stdio mode
|
# Start minimal OAuth callback server for stdio mode
|
||||||
from auth.oauth_callback_server import ensure_oauth_callback_available
|
from auth.oauth_callback_server import ensure_oauth_callback_available
|
||||||
|
|||||||
Reference in New Issue
Block a user