From 66011fd8144d2ba61445f901c9745f2e04b7fbd4 Mon Sep 17 00:00:00 2001 From: Taylor Wilsdon Date: Tue, 17 Mar 2026 08:36:08 -0400 Subject: [PATCH] remove consent from reauth flow and add tests --- auth/google_auth.py | 144 +++++++++++++++++- ...test_google_auth_callback_refresh_token.py | 116 ++++++++++++++ .../auth/test_google_auth_prompt_selection.py | 119 +++++++++++++++ 3 files changed, 374 insertions(+), 5 deletions(-) create mode 100644 tests/auth/test_google_auth_callback_refresh_token.py create mode 100644 tests/auth/test_google_auth_prompt_selection.py diff --git a/auth/google_auth.py b/auth/google_auth.py index c915c86..fe70499 100644 --- a/auth/google_auth.py +++ b/auth/google_auth.py @@ -338,6 +338,85 @@ def create_oauth_flow( return flow +def _determine_oauth_prompt( + user_google_email: Optional[str], + required_scopes: List[str], + session_id: Optional[str] = None, +) -> str: + """ + Determine which OAuth prompt to use for a new authorization URL. + + Uses `select_account` for re-auth when existing credentials already cover + required scopes. Uses `consent` for first-time auth and scope expansion. + """ + normalized_email = ( + user_google_email.strip() + if user_google_email + and user_google_email.strip() + and user_google_email.lower() != "default" + else None + ) + + # If no explicit email was provided, attempt to resolve it from session mapping. + if not normalized_email and session_id: + try: + session_user = get_oauth21_session_store().get_user_by_mcp_session( + session_id + ) + if session_user: + normalized_email = session_user + except Exception as e: + logger.debug(f"Could not resolve user from session for prompt choice: {e}") + + if not normalized_email: + logger.info( + "[start_auth_flow] Using prompt='consent' (no known user email for re-auth detection)." + ) + return "consent" + + existing_credentials: Optional[Credentials] = None + + # Prefer credentials bound to the current session when available. + if session_id: + try: + session_store = get_oauth21_session_store() + mapped_user = session_store.get_user_by_mcp_session(session_id) + if mapped_user == normalized_email: + existing_credentials = session_store.get_credentials_by_mcp_session( + session_id + ) + except Exception as e: + logger.debug( + f"Could not read OAuth 2.1 session store for prompt choice: {e}" + ) + + # Fall back to credential file store in stateful mode. + if not existing_credentials and not is_stateless_mode(): + try: + existing_credentials = get_credential_store().get_credential( + normalized_email + ) + except Exception as e: + logger.debug(f"Could not read credential store for prompt choice: {e}") + + if not existing_credentials: + logger.info( + f"[start_auth_flow] Using prompt='consent' (no existing credentials for {normalized_email})." + ) + return "consent" + + if has_required_scopes(existing_credentials.scopes, required_scopes): + logger.info( + f"[start_auth_flow] Using prompt='select_account' for re-auth of {normalized_email}." + ) + return "select_account" + + logger.info( + f"[start_auth_flow] Using prompt='consent' (existing credentials for {normalized_email} are missing required scopes)." + ) + return "consent" + + # --- Core OAuth Logic --- @@ -387,15 +466,14 @@ async def start_auth_flow( os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" oauth_state = os.urandom(16).hex() + current_scopes = get_current_scopes() flow = create_oauth_flow( - scopes=get_current_scopes(), # Use scopes for enabled tools only + scopes=current_scopes, # Use scopes for enabled tools only redirect_uri=redirect_uri, # Use passed redirect_uri state=oauth_state, ) - auth_url, _ = flow.authorization_url(access_type="offline", prompt="consent") - session_id = None try: session_id = get_fastmcp_session_id() @@ -404,6 +482,13 @@ async def start_auth_flow( f"Could not retrieve FastMCP session ID for state binding: {e}" ) + prompt_type = _determine_oauth_prompt( + user_google_email=user_google_email, + required_scopes=current_scopes, + session_id=session_id, + ) + auth_url, _ = flow.authorization_url(access_type="offline", prompt=prompt_type) + store = get_oauth21_session_store() store.store_oauth_state( oauth_state, @@ -568,12 +653,61 @@ def handle_auth_callback( user_google_email = user_info["email"] logger.info(f"Identified user_google_email: {user_google_email}") - # Save the credentials credential_store = get_credential_store() + if not credentials.refresh_token: + fallback_refresh_token = None + + if session_id: + try: + session_credentials = store.get_credentials_by_mcp_session( + session_id + ) + if session_credentials and session_credentials.refresh_token: + fallback_refresh_token = session_credentials.refresh_token + logger.info( + "OAuth callback response omitted refresh token; preserving existing refresh token from session store." + ) + except Exception as e: + logger.debug( + f"Could not check session store for existing refresh token: {e}" + ) + + if not fallback_refresh_token and not is_stateless_mode(): + try: + existing_credentials = credential_store.get_credential( + user_google_email + ) + if existing_credentials and existing_credentials.refresh_token: + fallback_refresh_token = existing_credentials.refresh_token + logger.info( + "OAuth callback response omitted refresh token; preserving existing refresh token from credential store." + ) + except Exception as e: + logger.debug( + f"Could not check credential store for existing refresh token: {e}" + ) + + if fallback_refresh_token: + credentials = Credentials( + token=credentials.token, + refresh_token=fallback_refresh_token, + id_token=getattr(credentials, "id_token", None), + token_uri=credentials.token_uri, + client_id=credentials.client_id, + client_secret=credentials.client_secret, + scopes=credentials.scopes, + expiry=credentials.expiry, + quota_project_id=getattr(credentials, "quota_project_id", None), + ) + else: + logger.warning( + "OAuth callback did not include a refresh token and no previous refresh token was available to preserve." + ) + + # Save the credentials credential_store.store_credential(user_google_email, credentials) # Always save to OAuth21SessionStore for centralized management - store = get_oauth21_session_store() store.store_session( user_email=user_google_email, access_token=credentials.token, diff --git a/tests/auth/test_google_auth_callback_refresh_token.py b/tests/auth/test_google_auth_callback_refresh_token.py new file mode 100644 index 0000000..298e810 --- /dev/null +++ b/tests/auth/test_google_auth_callback_refresh_token.py @@ -0,0 +1,116 @@ +from google.oauth2.credentials import Credentials + +from auth.google_auth import handle_auth_callback + + +class _DummyFlow: + def __init__(self, credentials): + self.credentials = credentials + + def fetch_token(self, authorization_response): # noqa: ARG002 + return None + + +class _DummyOAuthStore: + def __init__(self, session_credentials=None): + self._session_credentials = session_credentials + self.stored_refresh_token = None + + def validate_and_consume_oauth_state(self, state, session_id=None): # noqa: ARG002 + return {"session_id": session_id, "code_verifier": "verifier"} + + def get_credentials_by_mcp_session(self, mcp_session_id): # noqa: ARG002 + return self._session_credentials + + def store_session(self, **kwargs): + self.stored_refresh_token = kwargs.get("refresh_token") + + +class _DummyCredentialStore: + def __init__(self, existing_credentials=None): + self._existing_credentials = existing_credentials + self.saved_credentials = None + + def get_credential(self, user_email): # noqa: ARG002 + return self._existing_credentials + + def store_credential(self, user_email, credentials): # noqa: ARG002 + self.saved_credentials = credentials + return True + + +def _make_credentials(refresh_token): + return Credentials( + token="access-token", + refresh_token=refresh_token, + token_uri="https://oauth2.googleapis.com/token", + client_id="client-id", + client_secret="client-secret", + scopes=["scope.a"], + ) + + +def test_callback_preserves_refresh_token_from_credential_store(monkeypatch): + callback_credentials = _make_credentials(refresh_token=None) + oauth_store = _DummyOAuthStore(session_credentials=None) + credential_store = _DummyCredentialStore( + existing_credentials=_make_credentials(refresh_token="file-refresh-token") + ) + + monkeypatch.setattr( + "auth.google_auth.create_oauth_flow", + lambda **kwargs: _DummyFlow(callback_credentials), # noqa: ARG005 + ) + monkeypatch.setattr("auth.google_auth.get_oauth21_session_store", lambda: oauth_store) + monkeypatch.setattr("auth.google_auth.get_credential_store", lambda: credential_store) + monkeypatch.setattr( + "auth.google_auth.get_user_info", + lambda credentials: {"email": "user@gmail.com"}, # noqa: ARG005 + ) + monkeypatch.setattr("auth.google_auth.save_credentials_to_session", lambda *args: None) + monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False) + + _email, credentials = handle_auth_callback( + scopes=["scope.a"], + authorization_response="http://localhost/callback?state=abc123&code=code123", + redirect_uri="http://localhost/callback", + session_id="session-1", + ) + + assert credentials.refresh_token == "file-refresh-token" + assert credential_store.saved_credentials.refresh_token == "file-refresh-token" + assert oauth_store.stored_refresh_token == "file-refresh-token" + + +def test_callback_prefers_session_refresh_token_over_credential_store(monkeypatch): + callback_credentials = _make_credentials(refresh_token=None) + oauth_store = _DummyOAuthStore( + session_credentials=_make_credentials(refresh_token="session-refresh-token") + ) + credential_store = _DummyCredentialStore( + existing_credentials=_make_credentials(refresh_token="file-refresh-token") + ) + + monkeypatch.setattr( + "auth.google_auth.create_oauth_flow", + lambda **kwargs: _DummyFlow(callback_credentials), # noqa: ARG005 + ) + monkeypatch.setattr("auth.google_auth.get_oauth21_session_store", lambda: oauth_store) + monkeypatch.setattr("auth.google_auth.get_credential_store", lambda: credential_store) + monkeypatch.setattr( + "auth.google_auth.get_user_info", + lambda credentials: {"email": "user@gmail.com"}, # noqa: ARG005 + ) + monkeypatch.setattr("auth.google_auth.save_credentials_to_session", lambda *args: None) + monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False) + + _email, credentials = handle_auth_callback( + scopes=["scope.a"], + authorization_response="http://localhost/callback?state=abc123&code=code123", + redirect_uri="http://localhost/callback", + session_id="session-1", + ) + + assert credentials.refresh_token == "session-refresh-token" + assert credential_store.saved_credentials.refresh_token == "session-refresh-token" + assert oauth_store.stored_refresh_token == "session-refresh-token" diff --git a/tests/auth/test_google_auth_prompt_selection.py b/tests/auth/test_google_auth_prompt_selection.py new file mode 100644 index 0000000..d5fb254 --- /dev/null +++ b/tests/auth/test_google_auth_prompt_selection.py @@ -0,0 +1,119 @@ +from types import SimpleNamespace + +from auth.google_auth import _determine_oauth_prompt + + +class _DummyCredentialStore: + def __init__(self, credentials_by_email=None): + self._credentials_by_email = credentials_by_email or {} + + def get_credential(self, user_email): + return self._credentials_by_email.get(user_email) + + +class _DummySessionStore: + def __init__(self, user_by_session=None, credentials_by_session=None): + self._user_by_session = user_by_session or {} + self._credentials_by_session = credentials_by_session or {} + + def get_user_by_mcp_session(self, mcp_session_id): + return self._user_by_session.get(mcp_session_id) + + def get_credentials_by_mcp_session(self, mcp_session_id): + return self._credentials_by_session.get(mcp_session_id) + + +def _credentials_with_scopes(scopes): + return SimpleNamespace(scopes=scopes) + + +def test_prompt_select_account_when_existing_credentials_cover_scopes(monkeypatch): + required_scopes = ["scope.a", "scope.b"] + monkeypatch.setattr( + "auth.google_auth.get_oauth21_session_store", + lambda: _DummySessionStore(), + ) + monkeypatch.setattr( + "auth.google_auth.get_credential_store", + lambda: _DummyCredentialStore( + {"user@gmail.com": _credentials_with_scopes(required_scopes)} + ), + ) + monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False) + + prompt = _determine_oauth_prompt( + user_google_email="user@gmail.com", + required_scopes=required_scopes, + session_id=None, + ) + + assert prompt == "select_account" + + +def test_prompt_consent_when_existing_credentials_missing_scopes(monkeypatch): + monkeypatch.setattr( + "auth.google_auth.get_oauth21_session_store", + lambda: _DummySessionStore(), + ) + monkeypatch.setattr( + "auth.google_auth.get_credential_store", + lambda: _DummyCredentialStore( + {"user@gmail.com": _credentials_with_scopes(["scope.a"])} + ), + ) + monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False) + + prompt = _determine_oauth_prompt( + user_google_email="user@gmail.com", + required_scopes=["scope.a", "scope.b"], + session_id=None, + ) + + assert prompt == "consent" + + +def test_prompt_consent_when_no_existing_credentials(monkeypatch): + monkeypatch.setattr( + "auth.google_auth.get_oauth21_session_store", + lambda: _DummySessionStore(), + ) + monkeypatch.setattr( + "auth.google_auth.get_credential_store", + lambda: _DummyCredentialStore(), + ) + monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False) + + prompt = _determine_oauth_prompt( + user_google_email="new_user@gmail.com", + required_scopes=["scope.a"], + session_id=None, + ) + + assert prompt == "consent" + + +def test_prompt_uses_session_mapping_when_email_not_provided(monkeypatch): + session_id = "session-123" + required_scopes = ["scope.a"] + monkeypatch.setattr( + "auth.google_auth.get_oauth21_session_store", + lambda: _DummySessionStore( + user_by_session={session_id: "mapped@gmail.com"}, + credentials_by_session={ + session_id: _credentials_with_scopes(required_scopes) + }, + ), + ) + monkeypatch.setattr( + "auth.google_auth.get_credential_store", + lambda: _DummyCredentialStore(), + ) + monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False) + + prompt = _determine_oauth_prompt( + user_google_email=None, + required_scopes=required_scopes, + session_id=session_id, + ) + + assert prompt == "select_account"