simplify cors middleware implementation
This commit is contained in:
@@ -42,13 +42,13 @@ class OAuthConfigurationError(OAuthError):
|
||||
super().__init__("server_error", description, 500)
|
||||
|
||||
|
||||
def create_oauth_error_response(error: OAuthError, cors_headers: Optional[Dict[str, str]] = None) -> JSONResponse:
|
||||
def create_oauth_error_response(error: OAuthError, origin: Optional[str] = None) -> JSONResponse:
|
||||
"""
|
||||
Create a standardized OAuth error response.
|
||||
|
||||
Args:
|
||||
error: The OAuth error to convert to a response
|
||||
cors_headers: Optional CORS headers to include
|
||||
origin: Optional origin for development CORS headers
|
||||
|
||||
Returns:
|
||||
JSONResponse with standardized error format
|
||||
@@ -58,8 +58,9 @@ def create_oauth_error_response(error: OAuthError, cors_headers: Optional[Dict[s
|
||||
"Cache-Control": "no-store"
|
||||
}
|
||||
|
||||
if cors_headers:
|
||||
headers.update(cors_headers)
|
||||
# Add development CORS headers if needed
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
headers.update(cors_headers)
|
||||
|
||||
content = {
|
||||
"error": error.error_code,
|
||||
@@ -296,46 +297,25 @@ def log_security_event(event_type: str, details: Dict[str, Any], request: Option
|
||||
logger.warning(f"Security event: {log_data}")
|
||||
|
||||
|
||||
def get_safe_cors_headers(origin: Optional[str] = None) -> Dict[str, str]:
|
||||
def get_development_cors_headers(origin: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Get safe CORS headers for error responses.
|
||||
Get minimal CORS headers for development scenarios only.
|
||||
|
||||
Only allows localhost origins for development tools and inspectors.
|
||||
|
||||
Args:
|
||||
origin: The request origin (will be validated)
|
||||
|
||||
Returns:
|
||||
Safe CORS headers
|
||||
CORS headers for localhost origins only, empty dict otherwise
|
||||
"""
|
||||
headers = {
|
||||
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, mcp-protocol-version, x-requested-with",
|
||||
"Access-Control-Max-Age": "3600"
|
||||
}
|
||||
# Only allow localhost origins for development
|
||||
if origin and (origin.startswith("http://localhost:") or origin.startswith("http://127.0.0.1:")):
|
||||
return {
|
||||
"Access-Control-Allow-Origin": origin,
|
||||
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
||||
"Access-Control-Max-Age": "3600"
|
||||
}
|
||||
|
||||
# Only allow specific origins, not wildcards
|
||||
if origin and _is_safe_origin(origin):
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def _is_safe_origin(origin: str) -> bool:
|
||||
"""
|
||||
Check if an origin is safe for CORS.
|
||||
|
||||
Args:
|
||||
origin: The origin to check
|
||||
|
||||
Returns:
|
||||
True if the origin is safe
|
||||
"""
|
||||
# Always allow localhost origins for development
|
||||
if origin.startswith("http://localhost:") or origin.startswith("http://127.0.0.1:"):
|
||||
return True
|
||||
|
||||
from auth.oauth_config import get_oauth_config
|
||||
config = get_oauth_config()
|
||||
allowed_origins = config.get_allowed_origins()
|
||||
|
||||
return origin in allowed_origins or origin.startswith("vscode-webview://") or origin == "null"
|
||||
return {}
|
||||
Reference in New Issue
Block a user