feat: abstract credential store

This commit is contained in:
Shawn Zhu
2025-08-10 18:11:27 -04:00
parent c719b2573d
commit c505a4782b
3 changed files with 301 additions and 135 deletions

View File

@@ -17,6 +17,7 @@ from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from auth.scopes import SCOPES
from auth.oauth21_session_store import get_oauth21_session_store
from auth.credential_store import get_credential_store
from core.config import (
WORKSPACE_MCP_PORT,
WORKSPACE_MCP_BASE_URI,
@@ -81,72 +82,35 @@ def _find_any_credentials(
Returns:
First valid Credentials object found, or None if none exist.
"""
if not os.path.exists(base_dir):
logger.info(f"[single-user] Credentials directory not found: {base_dir}")
return None
# Scan for any .json credential files
for filename in os.listdir(base_dir):
if filename.endswith(".json"):
filepath = os.path.join(base_dir, filename)
try:
with open(filepath, "r") as f:
creds_data = json.load(f)
credentials = Credentials(
token=creds_data.get("token"),
refresh_token=creds_data.get("refresh_token"),
token_uri=creds_data.get("token_uri"),
client_id=creds_data.get("client_id"),
client_secret=creds_data.get("client_secret"),
scopes=creds_data.get("scopes"),
)
logger.info(f"[single-user] Found credentials in {filepath}")
return credentials
except (IOError, json.JSONDecodeError, KeyError) as e:
logger.warning(
f"[single-user] Error loading credentials from {filepath}: {e}"
)
continue
logger.info(f"[single-user] No valid credentials found in {base_dir}")
return None
def _get_user_credential_path(
user_google_email: str, base_dir: str = DEFAULT_CREDENTIALS_DIR
) -> str:
"""Constructs the path to a user's credential file."""
if not os.path.exists(base_dir):
os.makedirs(base_dir)
logger.info(f"Created credentials directory: {base_dir}")
return os.path.join(base_dir, f"{user_google_email}.json")
def save_credentials_to_file(
user_google_email: str,
credentials: Credentials,
base_dir: str = DEFAULT_CREDENTIALS_DIR,
):
"""Saves user credentials to a file."""
creds_path = _get_user_credential_path(user_google_email, base_dir)
creds_data = {
"token": credentials.token,
"refresh_token": credentials.refresh_token,
"token_uri": credentials.token_uri,
"client_id": credentials.client_id,
"client_secret": credentials.client_secret,
"scopes": credentials.scopes,
"expiry": credentials.expiry.isoformat() if credentials.expiry else None,
}
try:
with open(creds_path, "w") as f:
json.dump(creds_data, f)
logger.info(f"Credentials saved for user {user_google_email} to {creds_path}")
except IOError as e:
store = get_credential_store()
users = store.list_users()
if not users:
logger.info(
"[single-user] No users found with credentials via credential store"
)
return None
# Return credentials for the first user found
first_user = users[0]
credentials = store.get_credential(first_user)
if credentials:
logger.info(
f"[single-user] Found credentials for {first_user} via credential store"
)
return credentials
else:
logger.warning(
f"[single-user] Could not load credentials for {first_user} via credential store"
)
except Exception as e:
logger.error(
f"Error saving credentials for user {user_google_email} to {creds_path}: {e}"
f"[single-user] Error finding credentials via credential store: {e}"
)
raise
logger.info("[single-user] No valid credentials found via credential store")
return None
def save_credentials_to_session(session_id: str, credentials: Credentials):
@@ -161,7 +125,7 @@ def save_credentials_to_session(session_id: str, credentials: Credentials):
user_email = decoded_token.get("email")
except Exception as e:
logger.debug(f"Could not decode id_token to get email: {e}")
if user_email:
store = get_oauth21_session_store()
store.store_session(
@@ -180,54 +144,6 @@ def save_credentials_to_session(session_id: str, credentials: Credentials):
logger.warning(f"Could not save credentials to session store - no user email found for session: {session_id}")
def load_credentials_from_file(
user_google_email: str, base_dir: str = DEFAULT_CREDENTIALS_DIR
) -> Optional[Credentials]:
"""Loads user credentials from a file."""
creds_path = _get_user_credential_path(user_google_email, base_dir)
if not os.path.exists(creds_path):
logger.info(
f"No credentials file found for user {user_google_email} at {creds_path}"
)
return None
try:
with open(creds_path, "r") as f:
creds_data = json.load(f)
# Parse expiry if present
expiry = None
if creds_data.get("expiry"):
try:
expiry = datetime.fromisoformat(creds_data["expiry"])
# Ensure timezone-naive datetime for Google auth library compatibility
if expiry.tzinfo is not None:
expiry = expiry.replace(tzinfo=None)
except (ValueError, TypeError) as e:
logger.warning(
f"Could not parse expiry time for {user_google_email}: {e}"
)
credentials = Credentials(
token=creds_data.get("token"),
refresh_token=creds_data.get("refresh_token"),
token_uri=creds_data.get("token_uri"),
client_id=creds_data.get("client_id"),
client_secret=creds_data.get("client_secret"),
scopes=creds_data.get("scopes"),
expiry=expiry,
)
logger.debug(
f"Credentials loaded for user {user_google_email} from {creds_path}"
)
return credentials
except (IOError, json.JSONDecodeError, KeyError) as e:
logger.error(
f"Error loading or parsing credentials for user {user_google_email} from {creds_path}: {e}"
)
return None
def load_credentials_from_session(session_id: str) -> Optional[Credentials]:
"""Loads user credentials from OAuth21SessionStore."""
store = get_oauth21_session_store()
@@ -547,8 +463,9 @@ def handle_auth_callback(
user_google_email = user_info["email"]
logger.info(f"Identified user_google_email: {user_google_email}")
# Save the credentials to file
save_credentials_to_file(user_google_email, credentials, credentials_base_dir)
# Save the credentials
credential_store = get_credential_store()
credential_store.store_credential(user_google_email, credentials)
# Always save to OAuth21SessionStore for centralized management
store = get_oauth21_session_store()
@@ -689,11 +606,11 @@ def get_credentials(
if not credentials and user_google_email:
logger.debug(
f"[get_credentials] No session credentials, trying file for user_google_email '{user_google_email}'."
)
credentials = load_credentials_from_file(
user_google_email, credentials_base_dir
f"[get_credentials] No session credentials, trying credential store for user_google_email '{user_google_email}'."
)
store = get_credential_store()
credentials = store.get_credential(user_google_email)
if credentials and session_id:
logger.debug(
f"[get_credentials] Loaded from file for user '{user_google_email}', caching to session '{session_id}'."
@@ -747,11 +664,10 @@ def get_credentials(
)
# Save refreshed credentials
if user_google_email: # Always save to file if email is known
save_credentials_to_file(
user_google_email, credentials, credentials_base_dir
)
if user_google_email: # Always save to credential store if email is known
credential_store = get_credential_store()
credential_store.store_credential(user_google_email, credentials)
# Also update OAuth21SessionStore
store = get_oauth21_session_store()
store.store_session(
@@ -766,7 +682,7 @@ def get_credentials(
mcp_session_id=session_id,
issuer="https://accounts.google.com" # Add issuer for Google tokens
)
if session_id: # Update session cache if it was the source or is active
save_credentials_to_session(session_id, credentials)
return credentials
@@ -857,7 +773,9 @@ async def get_authenticated_google_service(
else:
logger.debug(f"[{tool_name}] Context variable returned None/empty session ID")
except Exception as e:
logger.debug(f"[{tool_name}] Could not get FastMCP session from context: {e}")
logger.debug(
f"[{tool_name}] Could not get FastMCP session from context: {e}"
)
# Fallback to direct FastMCP context if context variable not set
if not session_id and get_fastmcp_context:
@@ -894,20 +812,21 @@ async def get_authenticated_google_service(
)
if not credentials or not credentials.valid:
logger.warning(
f"[{tool_name}] No valid credentials. Email: '{user_google_email}'."
)
logger.info(
f"[{tool_name}] Valid email '{user_google_email}' provided, initiating auth flow."
)
logger.warning(f"[{tool_name}] No valid credentials. Email: '{user_google_email}'.")
logger.info(f"[{tool_name}] Valid email '{user_google_email}' provided, initiating auth flow.")
# Ensure OAuth callback is available
from auth.oauth_callback_server import ensure_oauth_callback_available
redirect_uri = get_oauth_redirect_uri()
success, error_msg = ensure_oauth_callback_available(get_transport_mode(), WORKSPACE_MCP_PORT, WORKSPACE_MCP_BASE_URI)
success, error_msg = ensure_oauth_callback_available(
get_transport_mode(), WORKSPACE_MCP_PORT, WORKSPACE_MCP_BASE_URI
)
if not success:
error_detail = f" ({error_msg})" if error_msg else ""
raise GoogleAuthenticationError(f"Cannot initiate OAuth flow - callback server unavailable{error_detail}")
raise GoogleAuthenticationError(
f"Cannot initiate OAuth flow - callback server unavailable{error_detail}"
)
# Generate auth URL and raise exception with it
auth_response = await start_auth_flow(