refac decorator, add configure_logt_formatting helper, fixed variable scope & pep8
This commit is contained in:
@@ -17,7 +17,7 @@ from googleapiclient.errors import HttpError
|
|||||||
from auth.scopes import SCOPES
|
from auth.scopes import SCOPES
|
||||||
from auth.oauth21_session_store import get_oauth21_session_store
|
from auth.oauth21_session_store import get_oauth21_session_store
|
||||||
from auth.credential_store import get_credential_store
|
from auth.credential_store import get_credential_store
|
||||||
from auth.oauth_config import get_oauth_config
|
from auth.oauth_config import get_oauth_config, is_stateless_mode
|
||||||
from core.config import (
|
from core.config import (
|
||||||
get_transport_mode,
|
get_transport_mode,
|
||||||
get_oauth_redirect_uri,
|
get_oauth_redirect_uri,
|
||||||
@@ -603,7 +603,6 @@ def get_credentials(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not credentials and user_google_email:
|
if not credentials and user_google_email:
|
||||||
from auth.oauth_config import is_stateless_mode
|
|
||||||
if not is_stateless_mode():
|
if not is_stateless_mode():
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[get_credentials] No session credentials, trying credential store for user_google_email '{user_google_email}'."
|
f"[get_credentials] No session credentials, trying credential store for user_google_email '{user_google_email}'."
|
||||||
@@ -669,7 +668,6 @@ def get_credentials(
|
|||||||
|
|
||||||
# Save refreshed credentials (skip file save in stateless mode)
|
# Save refreshed credentials (skip file save in stateless mode)
|
||||||
if user_google_email: # Always save to credential store if email is known
|
if user_google_email: # Always save to credential store if email is known
|
||||||
from auth.oauth_config import is_stateless_mode
|
|
||||||
if not is_stateless_mode():
|
if not is_stateless_mode():
|
||||||
credential_store = get_credential_store()
|
credential_store = get_credential_store()
|
||||||
credential_store.store_credential(user_google_email, credentials)
|
credential_store.store_credential(user_google_email, credentials)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from google.oauth2.credentials import Credentials
|
|||||||
from auth.oauth21_session_store import store_token_session
|
from auth.oauth21_session_store import store_token_session
|
||||||
from auth.google_auth import get_credential_store
|
from auth.google_auth import get_credential_store
|
||||||
from auth.scopes import get_current_scopes
|
from auth.scopes import get_current_scopes
|
||||||
from auth.oauth_config import get_oauth_config
|
from auth.oauth_config import get_oauth_config, is_stateless_mode
|
||||||
from auth.oauth_error_handling import (
|
from auth.oauth_error_handling import (
|
||||||
OAuthError, OAuthValidationError, OAuthConfigurationError,
|
OAuthError, OAuthValidationError, OAuthConfigurationError,
|
||||||
create_oauth_error_response, validate_token_request,
|
create_oauth_error_response, validate_token_request,
|
||||||
@@ -181,7 +181,6 @@ async def handle_proxy_token_exchange(request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save credentials to file for legacy auth (skip in stateless mode)
|
# Save credentials to file for legacy auth (skip in stateless mode)
|
||||||
from auth.oauth_config import is_stateless_mode
|
|
||||||
if not is_stateless_mode():
|
if not is_stateless_mode():
|
||||||
store = get_credential_store()
|
store = get_credential_store()
|
||||||
if not store.store_credential(user_email, credentials):
|
if not store.store_credential(user_email, credentials):
|
||||||
|
|||||||
@@ -235,6 +235,95 @@ async def get_authenticated_google_service_oauth21(
|
|||||||
return service, user_google_email
|
return service, user_google_email
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_oauth21_user_email(authenticated_user: Optional[str], func_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract user email for OAuth 2.1 mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
authenticated_user: The authenticated user from context
|
||||||
|
func_name: Name of the function being decorated (for error messages)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User email string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If no authenticated user found in OAuth 2.1 mode
|
||||||
|
"""
|
||||||
|
if not authenticated_user:
|
||||||
|
raise Exception(
|
||||||
|
f"OAuth 2.1 mode requires an authenticated user for {func_name}, but none was found."
|
||||||
|
)
|
||||||
|
return authenticated_user
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_oauth20_user_email(
|
||||||
|
args: tuple,
|
||||||
|
kwargs: dict,
|
||||||
|
wrapper_sig: inspect.Signature
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Extract user email for OAuth 2.0 mode from function arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Positional arguments passed to wrapper
|
||||||
|
kwargs: Keyword arguments passed to wrapper
|
||||||
|
wrapper_sig: Function signature for parameter binding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User email string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If user_google_email parameter not found
|
||||||
|
"""
|
||||||
|
bound_args = wrapper_sig.bind(*args, **kwargs)
|
||||||
|
bound_args.apply_defaults()
|
||||||
|
|
||||||
|
user_google_email = bound_args.arguments.get("user_google_email")
|
||||||
|
if not user_google_email:
|
||||||
|
raise Exception(
|
||||||
|
"'user_google_email' parameter is required but was not found."
|
||||||
|
)
|
||||||
|
return user_google_email
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_oauth20_user_email_multiple_services(
|
||||||
|
args: tuple,
|
||||||
|
kwargs: dict,
|
||||||
|
original_sig: inspect.Signature
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Extract user email for OAuth 2.0 mode from function arguments (multiple services version).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Positional arguments passed to wrapper
|
||||||
|
kwargs: Keyword arguments passed to wrapper
|
||||||
|
original_sig: Original function signature for parameter extraction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User email string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If user_google_email parameter not found
|
||||||
|
"""
|
||||||
|
param_names = list(original_sig.parameters.keys())
|
||||||
|
user_google_email = None
|
||||||
|
|
||||||
|
if "user_google_email" in kwargs:
|
||||||
|
user_google_email = kwargs["user_google_email"]
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
user_email_index = param_names.index("user_google_email")
|
||||||
|
if user_email_index < len(args):
|
||||||
|
user_google_email = args[user_email_index]
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not user_google_email:
|
||||||
|
raise Exception("user_google_email parameter is required but not found")
|
||||||
|
|
||||||
|
return user_google_email
|
||||||
|
|
||||||
|
|
||||||
def _remove_user_email_arg_from_docstring(docstring: str) -> str:
|
def _remove_user_email_arg_from_docstring(docstring: str) -> str:
|
||||||
"""
|
"""
|
||||||
Remove user_google_email parameter documentation from docstring.
|
Remove user_google_email parameter documentation from docstring.
|
||||||
@@ -437,30 +526,16 @@ def require_google_service(
|
|||||||
# Note: `args` and `kwargs` are now the arguments for the *wrapper*,
|
# Note: `args` and `kwargs` are now the arguments for the *wrapper*,
|
||||||
# which does not include 'service'.
|
# which does not include 'service'.
|
||||||
|
|
||||||
# Extract user_google_email from the arguments passed to the wrapper
|
|
||||||
bound_args = wrapper_sig.bind(*args, **kwargs)
|
|
||||||
bound_args.apply_defaults()
|
|
||||||
|
|
||||||
# Get authentication context early to determine OAuth mode
|
# Get authentication context early to determine OAuth mode
|
||||||
authenticated_user, auth_method, mcp_session_id = _get_auth_context(
|
authenticated_user, auth_method, mcp_session_id = _get_auth_context(
|
||||||
func.__name__
|
func.__name__
|
||||||
)
|
)
|
||||||
|
|
||||||
# In OAuth 2.1 mode, user_google_email is not in the signature
|
# Extract user_google_email based on OAuth mode
|
||||||
# and we use the authenticated user directly
|
|
||||||
if is_oauth21_enabled():
|
if is_oauth21_enabled():
|
||||||
if not authenticated_user:
|
user_google_email = _extract_oauth21_user_email(authenticated_user, func.__name__)
|
||||||
raise Exception(
|
|
||||||
"OAuth 2.1 mode requires an authenticated user, but none was found."
|
|
||||||
)
|
|
||||||
user_google_email = authenticated_user
|
|
||||||
else:
|
else:
|
||||||
# OAuth 2.0 mode: extract from arguments
|
user_google_email = _extract_oauth20_user_email(args, kwargs, wrapper_sig)
|
||||||
user_google_email = bound_args.arguments.get("user_google_email")
|
|
||||||
if not user_google_email:
|
|
||||||
raise Exception(
|
|
||||||
"'user_google_email' parameter is required but was not found."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get service configuration from the decorator's arguments
|
# Get service configuration from the decorator's arguments
|
||||||
if service_type not in SERVICE_CONFIGS:
|
if service_type not in SERVICE_CONFIGS:
|
||||||
@@ -589,27 +664,9 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
|||||||
|
|
||||||
# Extract user_google_email based on OAuth mode
|
# Extract user_google_email based on OAuth mode
|
||||||
if is_oauth21_enabled():
|
if is_oauth21_enabled():
|
||||||
if not authenticated_user:
|
user_google_email = _extract_oauth21_user_email(authenticated_user, tool_name)
|
||||||
raise Exception(
|
|
||||||
"OAuth 2.1 mode requires an authenticated user, but none was found."
|
|
||||||
)
|
|
||||||
user_google_email = authenticated_user
|
|
||||||
else:
|
else:
|
||||||
# OAuth 2.0 mode: extract from arguments
|
user_google_email = _extract_oauth20_user_email_multiple_services(args, kwargs, original_sig)
|
||||||
param_names = list(original_sig.parameters.keys())
|
|
||||||
user_google_email = None
|
|
||||||
if "user_google_email" in kwargs:
|
|
||||||
user_google_email = kwargs["user_google_email"]
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
user_email_index = param_names.index("user_google_email")
|
|
||||||
if user_email_index < len(args):
|
|
||||||
user_google_email = args[user_email_index]
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not user_google_email:
|
|
||||||
raise Exception("user_google_email parameter is required but not found")
|
|
||||||
|
|
||||||
# Authenticate all services
|
# Authenticate all services
|
||||||
for config in service_configs:
|
for config in service_configs:
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ Provides visually appealing log formatting with emojis and consistent styling
|
|||||||
to match the safe_print output format.
|
to match the safe_print output format.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
class EnhancedLogFormatter(logging.Formatter):
|
class EnhancedLogFormatter(logging.Formatter):
|
||||||
@@ -139,4 +141,52 @@ def setup_enhanced_logging(log_level: int = logging.INFO, use_colors: bool = Tru
|
|||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
console_handler.setLevel(log_level)
|
console_handler.setLevel(log_level)
|
||||||
root_logger.addHandler(console_handler)
|
root_logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_file_logging(logger_name: str = None) -> bool:
|
||||||
|
"""
|
||||||
|
Configure file logging based on stateless mode setting.
|
||||||
|
|
||||||
|
In stateless mode, file logging is completely disabled to avoid filesystem writes.
|
||||||
|
In normal mode, sets up detailed file logging to 'mcp_server_debug.log'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger_name: Optional name for the logger (defaults to root logger)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if file logging was configured, False if skipped (stateless mode)
|
||||||
|
"""
|
||||||
|
# Check if stateless mode is enabled
|
||||||
|
stateless_mode = os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
|
||||||
|
|
||||||
|
if stateless_mode:
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.debug("File logging disabled in stateless mode")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Configure file logging for normal mode
|
||||||
|
try:
|
||||||
|
target_logger = logging.getLogger(logger_name)
|
||||||
|
log_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
# Go up one level since we're in core/ subdirectory
|
||||||
|
log_file_dir = os.path.dirname(log_file_dir)
|
||||||
|
log_file_path = os.path.join(log_file_dir, 'mcp_server_debug.log')
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(log_file_path, mode='a')
|
||||||
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
file_formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(process)d - %(threadName)s '
|
||||||
|
'[%(module)s.%(funcName)s:%(lineno)d] - %(message)s'
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(file_formatter)
|
||||||
|
target_logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.debug(f"Detailed file logging configured to: {log_file_path}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
sys.stderr.write(f"CRITICAL: Failed to set up file logging to '{log_file_path}': {e}\n")
|
||||||
|
return False
|
||||||
@@ -33,30 +33,9 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Skip file logging in stateless mode
|
# Configure file logging based on stateless mode
|
||||||
stateless_mode = os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
|
from core.log_formatter import configure_file_logging
|
||||||
if not stateless_mode:
|
configure_file_logging()
|
||||||
# Configure file logging
|
|
||||||
try:
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
log_file_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
log_file_path = os.path.join(log_file_dir, 'mcp_server_debug.log')
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(log_file_path, mode='a')
|
|
||||||
file_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
file_formatter = logging.Formatter(
|
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(process)d - %(threadName)s '
|
|
||||||
'[%(module)s.%(funcName)s:%(lineno)d] - %(message)s'
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(file_formatter)
|
|
||||||
root_logger.addHandler(file_handler)
|
|
||||||
|
|
||||||
logger.debug(f"Detailed file logging configured to: {log_file_path}")
|
|
||||||
except Exception as e:
|
|
||||||
sys.stderr.write(f"CRITICAL: Failed to set up file logging to '{log_file_path}': {e}\n")
|
|
||||||
else:
|
|
||||||
logger.debug("File logging disabled in stateless mode")
|
|
||||||
|
|
||||||
def configure_safe_logging():
|
def configure_safe_logging():
|
||||||
"""Configure safe Unicode handling for logging."""
|
"""Configure safe Unicode handling for logging."""
|
||||||
|
|||||||
27
main.py
27
main.py
@@ -27,29 +27,12 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Skip file logging in stateless mode
|
# Configure file logging based on stateless mode
|
||||||
|
from core.log_formatter import configure_file_logging
|
||||||
|
configure_file_logging()
|
||||||
|
|
||||||
|
# Define stateless_mode for use in main() function
|
||||||
stateless_mode = os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
|
stateless_mode = os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
|
||||||
if not stateless_mode:
|
|
||||||
try:
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
log_file_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
log_file_path = os.path.join(log_file_dir, 'mcp_server_debug.log')
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(log_file_path, mode='a')
|
|
||||||
file_handler.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
file_formatter = logging.Formatter(
|
|
||||||
'%(asctime)s - %(name)s - %(levelname)s - %(process)d - %(threadName)s '
|
|
||||||
'[%(module)s.%(funcName)s:%(lineno)d] - %(message)s'
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(file_formatter)
|
|
||||||
root_logger.addHandler(file_handler)
|
|
||||||
|
|
||||||
logger.debug(f"Detailed file logging configured to: {log_file_path}")
|
|
||||||
except Exception as e:
|
|
||||||
sys.stderr.write(f"CRITICAL: Failed to set up file logging to '{log_file_path}': {e}\n")
|
|
||||||
else:
|
|
||||||
logger.debug("File logging disabled in stateless mode")
|
|
||||||
|
|
||||||
def safe_print(text):
|
def safe_print(text):
|
||||||
# Don't print to stderr when running as MCP server via uvx to avoid JSON parsing errors
|
# Don't print to stderr when running as MCP server via uvx to avoid JSON parsing errors
|
||||||
|
|||||||
Reference in New Issue
Block a user