timezone awareness handling improvements and tasks fix

This commit is contained in:
Taylor Wilsdon
2025-10-18 13:01:43 -04:00
parent a995fa4fde
commit 33b41a59d8
3 changed files with 79 additions and 12 deletions

View File

@@ -18,6 +18,37 @@ from google.oauth2.credentials import Credentials
logger = logging.getLogger(__name__)
def _normalize_expiry_to_naive_utc(expiry: Optional[Any]) -> Optional[datetime]:
"""
Convert expiry values to timezone-naive UTC datetimes for google-auth compatibility.
google-auth Credentials expect naive UTC datetimes for expiry comparison.
"""
if expiry is None:
return None
if isinstance(expiry, datetime):
if expiry.tzinfo is not None:
try:
return expiry.astimezone(timezone.utc).replace(tzinfo=None)
except Exception: # pragma: no cover - defensive
logger.debug("Failed to normalize aware expiry; returning without tzinfo")
return expiry.replace(tzinfo=None)
return expiry
if isinstance(expiry, str):
try:
parsed = datetime.fromisoformat(expiry.replace("Z", "+00:00"))
except ValueError:
logger.debug("Failed to parse expiry string '%s'", expiry)
return None
return _normalize_expiry_to_naive_utc(parsed)
logger.debug("Unsupported expiry type '%s' (%s)", expiry, type(expiry))
return None
# Context variable to store the current session information
_current_session_context: contextvars.ContextVar[Optional['SessionContext']] = contextvars.ContextVar(
'current_session_context',
@@ -279,6 +310,7 @@ class OAuth21SessionStore:
issuer: Token issuer (e.g., "https://accounts.google.com")
"""
with self._lock:
normalized_expiry = _normalize_expiry_to_naive_utc(expiry)
session_info = {
"access_token": access_token,
"refresh_token": refresh_token,
@@ -286,7 +318,7 @@ class OAuth21SessionStore:
"client_id": client_id,
"client_secret": client_secret,
"scopes": scopes or [],
"expiry": expiry,
"expiry": normalized_expiry,
"session_id": session_id,
"mcp_session_id": mcp_session_id,
"issuer": issuer,
@@ -339,7 +371,7 @@ class OAuth21SessionStore:
client_id=session_info.get("client_id"),
client_secret=session_info.get("client_secret"),
scopes=session_info.get("scopes", []),
expiry=session_info.get("expiry"),
expiry=_normalize_expiry_to_naive_utc(session_info.get("expiry")),
)
logger.debug(f"Retrieved OAuth 2.1 credentials for {user_email}")
@@ -616,7 +648,8 @@ def _build_credentials_from_provider(access_token: AccessToken) -> Optional[Cred
expires_at = getattr(access_entry, "expires_at", None)
if expires_at:
try:
expiry = datetime.utcfromtimestamp(expires_at)
expiry_candidate = datetime.fromtimestamp(expires_at, tz=timezone.utc)
expiry = _normalize_expiry_to_naive_utc(expiry_candidate)
except Exception: # pragma: no cover - defensive
expiry = None
@@ -655,7 +688,7 @@ def ensure_session_from_access_token(
expires_at = getattr(access_token, "expires_at", None)
if expires_at:
try:
expiry = datetime.utcfromtimestamp(expires_at)
expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
except Exception: # pragma: no cover - defensive
expiry = None
@@ -666,7 +699,7 @@ def ensure_session_from_access_token(
client_id=client_id,
client_secret=client_secret,
scopes=getattr(access_token, "scopes", None),
expiry=expiry,
expiry=_normalize_expiry_to_naive_utc(expiry),
)
if email:
@@ -721,7 +754,7 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
# Otherwise, create minimal credentials with just the access token
# Assume token is valid for 1 hour (typical for Google tokens)
expiry = datetime.now(timezone.utc) + timedelta(hours=1)
expiry = _normalize_expiry_to_naive_utc(datetime.now(timezone.utc) + timedelta(hours=1))
client_id, client_secret = _resolve_client_credentials()
credentials = Credentials(
@@ -776,7 +809,9 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
client_id, client_secret = _resolve_client_credentials()
scopes = token_response.get("scope", "")
scopes_list = scopes.split() if scopes else None
expiry = datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600))
expiry = _normalize_expiry_to_naive_utc(
datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600))
)
store.store_session(
user_email=user_email,