apply ruff formatting

This commit is contained in:
Taylor Wilsdon
2025-12-13 13:49:28 -08:00
parent 1d80a24ca4
commit 6b8352a354
50 changed files with 4010 additions and 2842 deletions

View File

@@ -17,13 +17,18 @@ from fastapi.responses import FileResponse, JSONResponse
from typing import Optional
from urllib.parse import urlparse
from auth.scopes import SCOPES, get_current_scopes # noqa
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
from auth.scopes import SCOPES, get_current_scopes # noqa
from auth.oauth_responses import (
create_error_response,
create_success_response,
create_server_error_response,
)
from auth.google_auth import handle_auth_callback, check_client_secrets
from auth.oauth_config import get_oauth_redirect_uri
logger = logging.getLogger(__name__)
class MinimalOAuthServer:
"""
Minimal HTTP server for OAuth callbacks in stdio mode.
@@ -59,7 +64,9 @@ class MinimalOAuthServer:
return create_error_response(error_message)
if not code:
error_message = "Authentication failed: No authorization code received from Google."
error_message = (
"Authentication failed: No authorization code received from Google."
)
logger.error(error_message)
return create_error_response(error_message)
@@ -69,7 +76,9 @@ class MinimalOAuthServer:
if error_message:
return create_server_error_response(error_message)
logger.info(f"OAuth callback: Received code (state: {state}). Attempting to exchange for tokens.")
logger.info(
f"OAuth callback: Received code (state: {state}). Attempting to exchange for tokens."
)
# Session ID tracking removed - not needed
@@ -79,16 +88,20 @@ class MinimalOAuthServer:
scopes=get_current_scopes(),
authorization_response=str(request.url),
redirect_uri=redirect_uri,
session_id=None
session_id=None,
)
logger.info(f"OAuth callback: Successfully authenticated user: {verified_user_id} (state: {state}).")
logger.info(
f"OAuth callback: Successfully authenticated user: {verified_user_id} (state: {state})."
)
# Return success page using shared template
return create_success_response(verified_user_id)
except Exception as e:
error_message_detail = f"Error processing OAuth callback (state: {state}): {str(e)}"
error_message_detail = (
f"Error processing OAuth callback (state: {state}): {str(e)}"
)
logger.error(error_message_detail, exc_info=True)
return create_server_error_response(str(e))
@@ -101,24 +114,22 @@ class MinimalOAuthServer:
"""Serve a stored attachment file."""
storage = get_attachment_storage()
metadata = storage.get_attachment_metadata(file_id)
if not metadata:
return JSONResponse(
{"error": "Attachment not found or expired"},
status_code=404
{"error": "Attachment not found or expired"}, status_code=404
)
file_path = storage.get_attachment_path(file_id)
if not file_path:
return JSONResponse(
{"error": "Attachment file not found"},
status_code=404
{"error": "Attachment file not found"}, status_code=404
)
return FileResponse(
path=str(file_path),
filename=metadata["filename"],
media_type=metadata["mime_type"]
media_type=metadata["mime_type"],
)
def start(self) -> tuple[bool, str]:
@@ -136,9 +147,9 @@ class MinimalOAuthServer:
# Extract hostname from base_uri (e.g., "http://localhost" -> "localhost")
try:
parsed_uri = urlparse(self.base_uri)
hostname = parsed_uri.hostname or 'localhost'
hostname = parsed_uri.hostname or "localhost"
except Exception:
hostname = 'localhost'
hostname = "localhost"
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -156,7 +167,7 @@ class MinimalOAuthServer:
host=hostname,
port=self.port,
log_level="warning",
access_log=False
access_log=False,
)
self.server = uvicorn.Server(config)
asyncio.run(self.server.serve())
@@ -178,7 +189,9 @@ class MinimalOAuthServer:
result = s.connect_ex((hostname, self.port))
if result == 0:
self.is_running = True
logger.info(f"Minimal OAuth server started on {hostname}:{self.port}")
logger.info(
f"Minimal OAuth server started on {hostname}:{self.port}"
)
return True, ""
except Exception:
pass
@@ -195,7 +208,7 @@ class MinimalOAuthServer:
try:
if self.server:
if hasattr(self.server, 'should_exit'):
if hasattr(self.server, "should_exit"):
self.server.should_exit = True
if self.server_thread and self.server_thread.is_alive():
@@ -211,7 +224,10 @@ class MinimalOAuthServer:
# Global instance for stdio mode
_minimal_oauth_server: Optional[MinimalOAuthServer] = None
def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8000, base_uri: str = "http://localhost") -> tuple[bool, str]:
def ensure_oauth_callback_available(
transport_mode: str = "stdio", port: int = 8000, base_uri: str = "http://localhost"
) -> tuple[bool, str]:
"""
Ensure OAuth callback endpoint is available for the given transport mode.
@@ -230,7 +246,9 @@ def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8
if transport_mode == "streamable-http":
# In streamable-http mode, the main FastAPI server should handle callbacks
logger.debug("Using existing FastAPI server for OAuth callbacks (streamable-http mode)")
logger.debug(
"Using existing FastAPI server for OAuth callbacks (streamable-http mode)"
)
return True, ""
elif transport_mode == "stdio":
@@ -243,10 +261,14 @@ def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8
logger.info("Starting minimal OAuth server for stdio mode")
success, error_msg = _minimal_oauth_server.start()
if success:
logger.info(f"Minimal OAuth server successfully started on {base_uri}:{port}")
logger.info(
f"Minimal OAuth server successfully started on {base_uri}:{port}"
)
return True, ""
else:
logger.error(f"Failed to start minimal OAuth server on {base_uri}:{port}: {error_msg}")
logger.error(
f"Failed to start minimal OAuth server on {base_uri}:{port}: {error_msg}"
)
return False, error_msg
else:
logger.info("Minimal OAuth server is already running")
@@ -257,6 +279,7 @@ def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8
logger.error(error_msg)
return False, error_msg
def cleanup_oauth_callback_server():
"""Clean up the minimal OAuth server if it was started."""
global _minimal_oauth_server