simplify cors middleware implementation

This commit is contained in:
Taylor Wilsdon
2025-08-09 11:44:12 -04:00
parent 42dcd54b18
commit 374dc9c3e7
6 changed files with 77 additions and 512 deletions

View File

@@ -42,13 +42,13 @@ class OAuthConfigurationError(OAuthError):
super().__init__("server_error", description, 500)
def create_oauth_error_response(error: OAuthError, cors_headers: Optional[Dict[str, str]] = None) -> JSONResponse:
def create_oauth_error_response(error: OAuthError, origin: Optional[str] = None) -> JSONResponse:
"""
Create a standardized OAuth error response.
Args:
error: The OAuth error to convert to a response
cors_headers: Optional CORS headers to include
origin: Optional origin for development CORS headers
Returns:
JSONResponse with standardized error format
@@ -58,8 +58,9 @@ def create_oauth_error_response(error: OAuthError, cors_headers: Optional[Dict[s
"Cache-Control": "no-store"
}
if cors_headers:
headers.update(cors_headers)
# Add development CORS headers if needed
cors_headers = get_development_cors_headers(origin)
headers.update(cors_headers)
content = {
"error": error.error_code,
@@ -296,46 +297,25 @@ def log_security_event(event_type: str, details: Dict[str, Any], request: Option
logger.warning(f"Security event: {log_data}")
def get_safe_cors_headers(origin: Optional[str] = None) -> Dict[str, str]:
def get_development_cors_headers(origin: Optional[str] = None) -> Dict[str, str]:
"""
Get safe CORS headers for error responses.
Get minimal CORS headers for development scenarios only.
Only allows localhost origins for development tools and inspectors.
Args:
origin: The request origin (will be validated)
Returns:
Safe CORS headers
CORS headers for localhost origins only, empty dict otherwise
"""
headers = {
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, mcp-protocol-version, x-requested-with",
"Access-Control-Max-Age": "3600"
}
# Only allow localhost origins for development
if origin and (origin.startswith("http://localhost:") or origin.startswith("http://127.0.0.1:")):
return {
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
"Access-Control-Max-Age": "3600"
}
# Only allow specific origins, not wildcards
if origin and _is_safe_origin(origin):
headers["Access-Control-Allow-Origin"] = origin
headers["Access-Control-Allow-Credentials"] = "true"
return headers
def _is_safe_origin(origin: str) -> bool:
"""
Check if an origin is safe for CORS.
Args:
origin: The origin to check
Returns:
True if the origin is safe
"""
# Always allow localhost origins for development
if origin.startswith("http://localhost:") or origin.startswith("http://127.0.0.1:"):
return True
from auth.oauth_config import get_oauth_config
config = get_oauth_config()
allowed_origins = config.get_allowed_origins()
return origin in allowed_origins or origin.startswith("vscode-webview://") or origin == "null"
return {}