Merge branch 'main' of github.com:taylorwilsdon/google_workspace_mcp into feat/555-auto-reply-headers
This commit is contained in:
@@ -338,6 +338,85 @@ def create_oauth_flow(
|
|||||||
return flow
|
return flow
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_oauth_prompt(
|
||||||
|
user_google_email: Optional[str],
|
||||||
|
required_scopes: List[str],
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Determine which OAuth prompt to use for a new authorization URL.
|
||||||
|
|
||||||
|
Uses `select_account` for re-auth when existing credentials already cover
|
||||||
|
required scopes. Uses `consent` for first-time auth and scope expansion.
|
||||||
|
"""
|
||||||
|
normalized_email = (
|
||||||
|
user_google_email.strip()
|
||||||
|
if user_google_email
|
||||||
|
and user_google_email.strip()
|
||||||
|
and user_google_email.lower() != "default"
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no explicit email was provided, attempt to resolve it from session mapping.
|
||||||
|
if not normalized_email and session_id:
|
||||||
|
try:
|
||||||
|
session_user = get_oauth21_session_store().get_user_by_mcp_session(
|
||||||
|
session_id
|
||||||
|
)
|
||||||
|
if session_user:
|
||||||
|
normalized_email = session_user
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not resolve user from session for prompt choice: {e}")
|
||||||
|
|
||||||
|
if not normalized_email:
|
||||||
|
logger.info(
|
||||||
|
"[start_auth_flow] Using prompt='consent' (no known user email for re-auth detection)."
|
||||||
|
)
|
||||||
|
return "consent"
|
||||||
|
|
||||||
|
existing_credentials: Optional[Credentials] = None
|
||||||
|
|
||||||
|
# Prefer credentials bound to the current session when available.
|
||||||
|
if session_id:
|
||||||
|
try:
|
||||||
|
session_store = get_oauth21_session_store()
|
||||||
|
mapped_user = session_store.get_user_by_mcp_session(session_id)
|
||||||
|
if mapped_user == normalized_email:
|
||||||
|
existing_credentials = session_store.get_credentials_by_mcp_session(
|
||||||
|
session_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
f"Could not read OAuth 2.1 session store for prompt choice: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fall back to credential file store in stateful mode.
|
||||||
|
if not existing_credentials and not is_stateless_mode():
|
||||||
|
try:
|
||||||
|
existing_credentials = get_credential_store().get_credential(
|
||||||
|
normalized_email
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not read credential store for prompt choice: {e}")
|
||||||
|
|
||||||
|
if not existing_credentials:
|
||||||
|
logger.info(
|
||||||
|
f"[start_auth_flow] Using prompt='consent' (no existing credentials for {normalized_email})."
|
||||||
|
)
|
||||||
|
return "consent"
|
||||||
|
|
||||||
|
if has_required_scopes(existing_credentials.scopes, required_scopes):
|
||||||
|
logger.info(
|
||||||
|
f"[start_auth_flow] Using prompt='select_account' for re-auth of {normalized_email}."
|
||||||
|
)
|
||||||
|
return "select_account"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[start_auth_flow] Using prompt='consent' (existing credentials for {normalized_email} are missing required scopes)."
|
||||||
|
)
|
||||||
|
return "consent"
|
||||||
|
|
||||||
|
|
||||||
# --- Core OAuth Logic ---
|
# --- Core OAuth Logic ---
|
||||||
|
|
||||||
|
|
||||||
@@ -387,15 +466,14 @@ async def start_auth_flow(
|
|||||||
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
|
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
|
||||||
|
|
||||||
oauth_state = os.urandom(16).hex()
|
oauth_state = os.urandom(16).hex()
|
||||||
|
current_scopes = get_current_scopes()
|
||||||
|
|
||||||
flow = create_oauth_flow(
|
flow = create_oauth_flow(
|
||||||
scopes=get_current_scopes(), # Use scopes for enabled tools only
|
scopes=current_scopes, # Use scopes for enabled tools only
|
||||||
redirect_uri=redirect_uri, # Use passed redirect_uri
|
redirect_uri=redirect_uri, # Use passed redirect_uri
|
||||||
state=oauth_state,
|
state=oauth_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
auth_url, _ = flow.authorization_url(access_type="offline", prompt="consent")
|
|
||||||
|
|
||||||
session_id = None
|
session_id = None
|
||||||
try:
|
try:
|
||||||
session_id = get_fastmcp_session_id()
|
session_id = get_fastmcp_session_id()
|
||||||
@@ -404,6 +482,13 @@ async def start_auth_flow(
|
|||||||
f"Could not retrieve FastMCP session ID for state binding: {e}"
|
f"Could not retrieve FastMCP session ID for state binding: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt_type = _determine_oauth_prompt(
|
||||||
|
user_google_email=user_google_email,
|
||||||
|
required_scopes=current_scopes,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
auth_url, _ = flow.authorization_url(access_type="offline", prompt=prompt_type)
|
||||||
|
|
||||||
store = get_oauth21_session_store()
|
store = get_oauth21_session_store()
|
||||||
store.store_oauth_state(
|
store.store_oauth_state(
|
||||||
oauth_state,
|
oauth_state,
|
||||||
@@ -568,12 +653,61 @@ def handle_auth_callback(
|
|||||||
user_google_email = user_info["email"]
|
user_google_email = user_info["email"]
|
||||||
logger.info(f"Identified user_google_email: {user_google_email}")
|
logger.info(f"Identified user_google_email: {user_google_email}")
|
||||||
|
|
||||||
# Save the credentials
|
|
||||||
credential_store = get_credential_store()
|
credential_store = get_credential_store()
|
||||||
|
if not credentials.refresh_token:
|
||||||
|
fallback_refresh_token = None
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
try:
|
||||||
|
session_credentials = store.get_credentials_by_mcp_session(
|
||||||
|
session_id
|
||||||
|
)
|
||||||
|
if session_credentials and session_credentials.refresh_token:
|
||||||
|
fallback_refresh_token = session_credentials.refresh_token
|
||||||
|
logger.info(
|
||||||
|
"OAuth callback response omitted refresh token; preserving existing refresh token from session store."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
f"Could not check session store for existing refresh token: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not fallback_refresh_token and not is_stateless_mode():
|
||||||
|
try:
|
||||||
|
existing_credentials = credential_store.get_credential(
|
||||||
|
user_google_email
|
||||||
|
)
|
||||||
|
if existing_credentials and existing_credentials.refresh_token:
|
||||||
|
fallback_refresh_token = existing_credentials.refresh_token
|
||||||
|
logger.info(
|
||||||
|
"OAuth callback response omitted refresh token; preserving existing refresh token from credential store."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
f"Could not check credential store for existing refresh token: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if fallback_refresh_token:
|
||||||
|
credentials = Credentials(
|
||||||
|
token=credentials.token,
|
||||||
|
refresh_token=fallback_refresh_token,
|
||||||
|
id_token=getattr(credentials, "id_token", None),
|
||||||
|
token_uri=credentials.token_uri,
|
||||||
|
client_id=credentials.client_id,
|
||||||
|
client_secret=credentials.client_secret,
|
||||||
|
scopes=credentials.scopes,
|
||||||
|
expiry=credentials.expiry,
|
||||||
|
quota_project_id=getattr(credentials, "quota_project_id", None),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"OAuth callback did not include a refresh token and no previous refresh token was available to preserve."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the credentials
|
||||||
credential_store.store_credential(user_google_email, credentials)
|
credential_store.store_credential(user_google_email, credentials)
|
||||||
|
|
||||||
# Always save to OAuth21SessionStore for centralized management
|
# Always save to OAuth21SessionStore for centralized management
|
||||||
store = get_oauth21_session_store()
|
|
||||||
store.store_session(
|
store.store_session(
|
||||||
user_email=user_google_email,
|
user_email=user_google_email,
|
||||||
access_token=credentials.token,
|
access_token=credentials.token,
|
||||||
|
|||||||
128
tests/auth/test_google_auth_callback_refresh_token.py
Normal file
128
tests/auth/test_google_auth_callback_refresh_token.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
from auth.google_auth import handle_auth_callback
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyFlow:
|
||||||
|
def __init__(self, credentials):
|
||||||
|
self.credentials = credentials
|
||||||
|
|
||||||
|
def fetch_token(self, authorization_response): # noqa: ARG002
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyOAuthStore:
|
||||||
|
def __init__(self, session_credentials=None):
|
||||||
|
self._session_credentials = session_credentials
|
||||||
|
self.stored_refresh_token = None
|
||||||
|
|
||||||
|
def validate_and_consume_oauth_state(self, state, session_id=None): # noqa: ARG002
|
||||||
|
return {"session_id": session_id, "code_verifier": "verifier"}
|
||||||
|
|
||||||
|
def get_credentials_by_mcp_session(self, mcp_session_id): # noqa: ARG002
|
||||||
|
return self._session_credentials
|
||||||
|
|
||||||
|
def store_session(self, **kwargs):
|
||||||
|
self.stored_refresh_token = kwargs.get("refresh_token")
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyCredentialStore:
|
||||||
|
def __init__(self, existing_credentials=None):
|
||||||
|
self._existing_credentials = existing_credentials
|
||||||
|
self.saved_credentials = None
|
||||||
|
|
||||||
|
def get_credential(self, user_email): # noqa: ARG002
|
||||||
|
return self._existing_credentials
|
||||||
|
|
||||||
|
def store_credential(self, user_email, credentials): # noqa: ARG002
|
||||||
|
self.saved_credentials = credentials
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _make_credentials(refresh_token):
|
||||||
|
return Credentials(
|
||||||
|
token="access-token",
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
token_uri="https://oauth2.googleapis.com/token",
|
||||||
|
client_id="client-id",
|
||||||
|
client_secret="client-secret",
|
||||||
|
scopes=["scope.a"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_preserves_refresh_token_from_credential_store(monkeypatch):
|
||||||
|
callback_credentials = _make_credentials(refresh_token=None)
|
||||||
|
oauth_store = _DummyOAuthStore(session_credentials=None)
|
||||||
|
credential_store = _DummyCredentialStore(
|
||||||
|
existing_credentials=_make_credentials(refresh_token="file-refresh-token")
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.create_oauth_flow",
|
||||||
|
lambda **kwargs: _DummyFlow(callback_credentials), # noqa: ARG005
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_oauth21_session_store", lambda: oauth_store
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_credential_store", lambda: credential_store
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_user_info",
|
||||||
|
lambda credentials: {"email": "user@gmail.com"}, # noqa: ARG005
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.save_credentials_to_session", lambda *args: None
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False)
|
||||||
|
|
||||||
|
_email, credentials = handle_auth_callback(
|
||||||
|
scopes=["scope.a"],
|
||||||
|
authorization_response="http://localhost/callback?state=abc123&code=code123",
|
||||||
|
redirect_uri="http://localhost/callback",
|
||||||
|
session_id="session-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert credentials.refresh_token == "file-refresh-token"
|
||||||
|
assert credential_store.saved_credentials.refresh_token == "file-refresh-token"
|
||||||
|
assert oauth_store.stored_refresh_token == "file-refresh-token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_prefers_session_refresh_token_over_credential_store(monkeypatch):
|
||||||
|
callback_credentials = _make_credentials(refresh_token=None)
|
||||||
|
oauth_store = _DummyOAuthStore(
|
||||||
|
session_credentials=_make_credentials(refresh_token="session-refresh-token")
|
||||||
|
)
|
||||||
|
credential_store = _DummyCredentialStore(
|
||||||
|
existing_credentials=_make_credentials(refresh_token="file-refresh-token")
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.create_oauth_flow",
|
||||||
|
lambda **kwargs: _DummyFlow(callback_credentials), # noqa: ARG005
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_oauth21_session_store", lambda: oauth_store
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_credential_store", lambda: credential_store
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_user_info",
|
||||||
|
lambda credentials: {"email": "user@gmail.com"}, # noqa: ARG005
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.save_credentials_to_session", lambda *args: None
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False)
|
||||||
|
|
||||||
|
_email, credentials = handle_auth_callback(
|
||||||
|
scopes=["scope.a"],
|
||||||
|
authorization_response="http://localhost/callback?state=abc123&code=code123",
|
||||||
|
redirect_uri="http://localhost/callback",
|
||||||
|
session_id="session-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert credentials.refresh_token == "session-refresh-token"
|
||||||
|
assert credential_store.saved_credentials.refresh_token == "session-refresh-token"
|
||||||
|
assert oauth_store.stored_refresh_token == "session-refresh-token"
|
||||||
119
tests/auth/test_google_auth_prompt_selection.py
Normal file
119
tests/auth/test_google_auth_prompt_selection.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from auth.google_auth import _determine_oauth_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyCredentialStore:
|
||||||
|
def __init__(self, credentials_by_email=None):
|
||||||
|
self._credentials_by_email = credentials_by_email or {}
|
||||||
|
|
||||||
|
def get_credential(self, user_email):
|
||||||
|
return self._credentials_by_email.get(user_email)
|
||||||
|
|
||||||
|
|
||||||
|
class _DummySessionStore:
|
||||||
|
def __init__(self, user_by_session=None, credentials_by_session=None):
|
||||||
|
self._user_by_session = user_by_session or {}
|
||||||
|
self._credentials_by_session = credentials_by_session or {}
|
||||||
|
|
||||||
|
def get_user_by_mcp_session(self, mcp_session_id):
|
||||||
|
return self._user_by_session.get(mcp_session_id)
|
||||||
|
|
||||||
|
def get_credentials_by_mcp_session(self, mcp_session_id):
|
||||||
|
return self._credentials_by_session.get(mcp_session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _credentials_with_scopes(scopes):
|
||||||
|
return SimpleNamespace(scopes=scopes)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_select_account_when_existing_credentials_cover_scopes(monkeypatch):
|
||||||
|
required_scopes = ["scope.a", "scope.b"]
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_oauth21_session_store",
|
||||||
|
lambda: _DummySessionStore(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_credential_store",
|
||||||
|
lambda: _DummyCredentialStore(
|
||||||
|
{"user@gmail.com": _credentials_with_scopes(required_scopes)}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False)
|
||||||
|
|
||||||
|
prompt = _determine_oauth_prompt(
|
||||||
|
user_google_email="user@gmail.com",
|
||||||
|
required_scopes=required_scopes,
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt == "select_account"
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_consent_when_existing_credentials_missing_scopes(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_oauth21_session_store",
|
||||||
|
lambda: _DummySessionStore(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_credential_store",
|
||||||
|
lambda: _DummyCredentialStore(
|
||||||
|
{"user@gmail.com": _credentials_with_scopes(["scope.a"])}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False)
|
||||||
|
|
||||||
|
prompt = _determine_oauth_prompt(
|
||||||
|
user_google_email="user@gmail.com",
|
||||||
|
required_scopes=["scope.a", "scope.b"],
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt == "consent"
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_consent_when_no_existing_credentials(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_oauth21_session_store",
|
||||||
|
lambda: _DummySessionStore(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_credential_store",
|
||||||
|
lambda: _DummyCredentialStore(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False)
|
||||||
|
|
||||||
|
prompt = _determine_oauth_prompt(
|
||||||
|
user_google_email="new_user@gmail.com",
|
||||||
|
required_scopes=["scope.a"],
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt == "consent"
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_uses_session_mapping_when_email_not_provided(monkeypatch):
|
||||||
|
session_id = "session-123"
|
||||||
|
required_scopes = ["scope.a"]
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_oauth21_session_store",
|
||||||
|
lambda: _DummySessionStore(
|
||||||
|
user_by_session={session_id: "mapped@gmail.com"},
|
||||||
|
credentials_by_session={
|
||||||
|
session_id: _credentials_with_scopes(required_scopes)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"auth.google_auth.get_credential_store",
|
||||||
|
lambda: _DummyCredentialStore(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("auth.google_auth.is_stateless_mode", lambda: False)
|
||||||
|
|
||||||
|
prompt = _determine_oauth_prompt(
|
||||||
|
user_google_email=None,
|
||||||
|
required_scopes=required_scopes,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt == "select_account"
|
||||||
Reference in New Issue
Block a user