remove consent from reauth flow and add tests
This commit is contained in:
@@ -338,6 +338,85 @@ def create_oauth_flow(
|
||||
return flow
|
||||
|
||||
|
||||
def _determine_oauth_prompt(
|
||||
user_google_email: Optional[str],
|
||||
required_scopes: List[str],
|
||||
session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Determine which OAuth prompt to use for a new authorization URL.
|
||||
|
||||
Uses `select_account` for re-auth when existing credentials already cover
|
||||
required scopes. Uses `consent` for first-time auth and scope expansion.
|
||||
"""
|
||||
normalized_email = (
|
||||
user_google_email.strip()
|
||||
if user_google_email
|
||||
and user_google_email.strip()
|
||||
and user_google_email.lower() != "default"
|
||||
else None
|
||||
)
|
||||
|
||||
# If no explicit email was provided, attempt to resolve it from session mapping.
|
||||
if not normalized_email and session_id:
|
||||
try:
|
||||
session_user = get_oauth21_session_store().get_user_by_mcp_session(
|
||||
session_id
|
||||
)
|
||||
if session_user:
|
||||
normalized_email = session_user
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not resolve user from session for prompt choice: {e}")
|
||||
|
||||
if not normalized_email:
|
||||
logger.info(
|
||||
"[start_auth_flow] Using prompt='consent' (no known user email for re-auth detection)."
|
||||
)
|
||||
return "consent"
|
||||
|
||||
existing_credentials: Optional[Credentials] = None
|
||||
|
||||
# Prefer credentials bound to the current session when available.
|
||||
if session_id:
|
||||
try:
|
||||
session_store = get_oauth21_session_store()
|
||||
mapped_user = session_store.get_user_by_mcp_session(session_id)
|
||||
if mapped_user == normalized_email:
|
||||
existing_credentials = session_store.get_credentials_by_mcp_session(
|
||||
session_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Could not read OAuth 2.1 session store for prompt choice: {e}"
|
||||
)
|
||||
|
||||
# Fall back to credential file store in stateful mode.
|
||||
if not existing_credentials and not is_stateless_mode():
|
||||
try:
|
||||
existing_credentials = get_credential_store().get_credential(
|
||||
normalized_email
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not read credential store for prompt choice: {e}")
|
||||
|
||||
if not existing_credentials:
|
||||
logger.info(
|
||||
f"[start_auth_flow] Using prompt='consent' (no existing credentials for {normalized_email})."
|
||||
)
|
||||
return "consent"
|
||||
|
||||
if has_required_scopes(existing_credentials.scopes, required_scopes):
|
||||
logger.info(
|
||||
f"[start_auth_flow] Using prompt='select_account' for re-auth of {normalized_email}."
|
||||
)
|
||||
return "select_account"
|
||||
|
||||
logger.info(
|
||||
f"[start_auth_flow] Using prompt='consent' (existing credentials for {normalized_email} are missing required scopes)."
|
||||
)
|
||||
return "consent"
|
||||
|
||||
|
||||
# --- Core OAuth Logic ---
|
||||
|
||||
|
||||
@@ -387,15 +466,14 @@ async def start_auth_flow(
|
||||
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
|
||||
|
||||
oauth_state = os.urandom(16).hex()
|
||||
current_scopes = get_current_scopes()
|
||||
|
||||
flow = create_oauth_flow(
|
||||
scopes=get_current_scopes(), # Use scopes for enabled tools only
|
||||
scopes=current_scopes, # Use scopes for enabled tools only
|
||||
redirect_uri=redirect_uri, # Use passed redirect_uri
|
||||
state=oauth_state,
|
||||
)
|
||||
|
||||
auth_url, _ = flow.authorization_url(access_type="offline", prompt="consent")
|
||||
|
||||
session_id = None
|
||||
try:
|
||||
session_id = get_fastmcp_session_id()
|
||||
@@ -404,6 +482,13 @@ async def start_auth_flow(
|
||||
f"Could not retrieve FastMCP session ID for state binding: {e}"
|
||||
)
|
||||
|
||||
prompt_type = _determine_oauth_prompt(
|
||||
user_google_email=user_google_email,
|
||||
required_scopes=current_scopes,
|
||||
session_id=session_id,
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(access_type="offline", prompt=prompt_type)
|
||||
|
||||
store = get_oauth21_session_store()
|
||||
store.store_oauth_state(
|
||||
oauth_state,
|
||||
@@ -568,12 +653,61 @@ def handle_auth_callback(
|
||||
user_google_email = user_info["email"]
|
||||
logger.info(f"Identified user_google_email: {user_google_email}")
|
||||
|
||||
# Save the credentials
|
||||
credential_store = get_credential_store()
|
||||
if not credentials.refresh_token:
|
||||
fallback_refresh_token = None
|
||||
|
||||
if session_id:
|
||||
try:
|
||||
session_credentials = store.get_credentials_by_mcp_session(
|
||||
session_id
|
||||
)
|
||||
if session_credentials and session_credentials.refresh_token:
|
||||
fallback_refresh_token = session_credentials.refresh_token
|
||||
logger.info(
|
||||
"OAuth callback response omitted refresh token; preserving existing refresh token from session store."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Could not check session store for existing refresh token: {e}"
|
||||
)
|
||||
|
||||
if not fallback_refresh_token and not is_stateless_mode():
|
||||
try:
|
||||
existing_credentials = credential_store.get_credential(
|
||||
user_google_email
|
||||
)
|
||||
if existing_credentials and existing_credentials.refresh_token:
|
||||
fallback_refresh_token = existing_credentials.refresh_token
|
||||
logger.info(
|
||||
"OAuth callback response omitted refresh token; preserving existing refresh token from credential store."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Could not check credential store for existing refresh token: {e}"
|
||||
)
|
||||
|
||||
if fallback_refresh_token:
|
||||
credentials = Credentials(
|
||||
token=credentials.token,
|
||||
refresh_token=fallback_refresh_token,
|
||||
id_token=getattr(credentials, "id_token", None),
|
||||
token_uri=credentials.token_uri,
|
||||
client_id=credentials.client_id,
|
||||
client_secret=credentials.client_secret,
|
||||
scopes=credentials.scopes,
|
||||
expiry=credentials.expiry,
|
||||
quota_project_id=getattr(credentials, "quota_project_id", None),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"OAuth callback did not include a refresh token and no previous refresh token was available to preserve."
|
||||
)
|
||||
|
||||
# Save the credentials
|
||||
credential_store.store_credential(user_google_email, credentials)
|
||||
|
||||
# Always save to OAuth21SessionStore for centralized management
|
||||
store = get_oauth21_session_store()
|
||||
store.store_session(
|
||||
user_email=user_google_email,
|
||||
access_token=credentials.token,
|
||||
|
||||
Reference in New Issue
Block a user