diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 1ff2ddd7..c129a239 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -789,12 +789,23 @@ impl AuthorizationManager { attempts < self.scope_upgrade_config.max_upgrade_attempts } + /// select scopes to request from authorization server + pub fn select_scopes( + &self, + www_authenticate_scope: Option<&str>, + default_scopes: &[&str], + ) -> Vec { + let mut scopes = self.select_base_scopes(www_authenticate_scope, default_scopes); + self.add_offline_access_if_supported(&mut scopes); + scopes + } + /// select scopes based on SEP-835 priority: /// 1. scope from WWW-Authenticate header (argument or stored from initial 401 probe) /// 2. scopes_supported from protected resource metadata (RFC 9728) /// 3. scopes_supported from authorization server metadata /// 4. provided default scopes - pub fn select_scopes( + fn select_base_scopes( &self, www_authenticate_scope: Option<&str>, default_scopes: &[&str], @@ -829,6 +840,21 @@ impl AuthorizationManager { default_scopes.iter().map(|s| s.to_string()).collect() } + /// SEP-2207: when the AS advertises `offline_access` in `scopes_supported`, append + /// it so OIDC-flavored Authorization Servers will issue refresh tokens. + fn add_offline_access_if_supported(&self, scopes: &mut Vec) { + if scopes.is_empty() || scopes.iter().any(|s| s == "offline_access") { + return; + } + if let Some(metadata) = &self.metadata { + if let Some(supported) = &metadata.scopes_supported { + if supported.iter().any(|s| s == "offline_access") { + scopes.push("offline_access".to_string()); + } + } + } + } + /// attempt to upgrade scopes after receiving a 403 insufficient_scope error. /// returns the authorization URL for re-authorization with upgraded scopes. pub async fn request_scope_upgrade(&self, required_scope: &str) -> Result { @@ -949,22 +975,32 @@ impl AuthorizationManager { Ok(token_result) } - /// get access token, if expired, refresh it automatically + /// get access token from local credential store. + /// if expired, refresh it automatically when a refresh token is available. + /// when the access token has expired and no refresh token is available, it returns + /// [`AuthError::AuthorizationRequired`] so the caller can re-authenticate. pub async fn get_access_token(&self) -> Result { - // Load credentials from store let stored = self.credential_store.load().await?; let credentials = stored.and_then(|s| s.token_response); if let Some(creds) = credentials.as_ref() { - // check token expiry if we have a refresh token or an expiry time - if creds.refresh_token().is_some() || creds.expires_in().is_some() { - let expires_in = creds.expires_in().unwrap_or(Duration::from_secs(0)); + if let Some(expires_in) = creds.expires_in() { if expires_in <= Duration::from_secs(0) { - tracing::info!("Access token expired, refreshing."); - - let new_creds = self.refresh_token().await?; - tracing::info!("Refreshed access token."); - return Ok(new_creds.access_token().secret().to_string()); + if creds.refresh_token().is_some() { + tracing::info!("Access token expired, attempting refresh."); + match self.refresh_token().await { + Ok(new_creds) => { + tracing::info!("Refreshed access token."); + return Ok(new_creds.access_token().secret().to_string()); + } + Err(e) => { + tracing::warn!("Token refresh failed: {e}"); + } + } + } else { + tracing::info!("Access token expired and no refresh token available."); + } + return Err(AuthError::AuthorizationRequired); } } @@ -991,7 +1027,6 @@ impl AuthorizationManager { let refresh_token = current_credentials.refresh_token().ok_or_else(|| { AuthError::TokenRefreshFailed("No refresh token available".to_string()) })?; - debug!("refresh token: {:?}", refresh_token); let token_result = oauth_client .exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string())) @@ -1664,7 +1699,9 @@ impl OAuthState { let selected_scopes: Vec = if scopes.is_empty() { manager.select_scopes(None, &[]) } else { - scopes.iter().map(|s| s.to_string()).collect() + let mut s: Vec = scopes.iter().map(|s| s.to_string()).collect(); + manager.add_offline_access_if_supported(&mut s); + s }; let scope_refs: Vec<&str> = selected_scopes.iter().map(|s| s.as_str()).collect(); debug!("start session"); @@ -1823,7 +1860,7 @@ impl OAuthState { #[cfg(test)] mod tests { - use std::{collections::HashMap, sync::Arc}; + use std::{collections::HashMap, sync::Arc, time::Duration}; use oauth2::{AuthType, CsrfToken, PkceCodeVerifier}; use url::Url; @@ -2594,6 +2631,252 @@ mod tests { assert_eq!(result.len(), 2); } + // -- SEP-2207: get_access_token refresh behavior -- + + fn make_token_response( + access_token: &str, + expires_in: Option, + refresh_token: Option<&str>, + ) -> super::OAuthTokenResponse { + use oauth2::{AccessToken, basic::BasicTokenType}; + let mut resp = super::OAuthTokenResponse::new( + AccessToken::new(access_token.to_string()), + BasicTokenType::Bearer, + oauth2::EmptyExtraTokenFields {}, + ); + resp.set_expires_in(expires_in.as_ref()); + if let Some(rt) = refresh_token { + resp.set_refresh_token(Some(oauth2::RefreshToken::new(rt.to_string()))); + } + resp + } + + #[tokio::test] + async fn get_access_token_returns_token_when_not_expired() { + let mgr = AuthorizationManager::new("http://localhost").await.unwrap(); + let creds = super::StoredCredentials { + client_id: "test".to_string(), + token_response: Some(make_token_response( + "valid-token", + Some(Duration::from_secs(3600)), + None, + )), + granted_scopes: vec![], + }; + mgr.credential_store.save(creds).await.unwrap(); + + let token = mgr.get_access_token().await.unwrap(); + assert_eq!(token, "valid-token"); + } + + #[tokio::test] + async fn get_access_token_returns_token_when_no_expires_in() { + let mgr = AuthorizationManager::new("http://localhost").await.unwrap(); + let creds = super::StoredCredentials { + client_id: "test".to_string(), + token_response: Some(make_token_response( + "no-expiry-token", + None, + Some("refresh-tok"), + )), + granted_scopes: vec![], + }; + mgr.credential_store.save(creds).await.unwrap(); + + let token = mgr.get_access_token().await.unwrap(); + assert_eq!(token, "no-expiry-token"); + } + + #[tokio::test] + async fn get_access_token_requires_reauth_when_expired_without_refresh_token() { + let mgr = AuthorizationManager::new("http://localhost").await.unwrap(); + let creds = super::StoredCredentials { + client_id: "test".to_string(), + token_response: Some(make_token_response( + "expired-token", + Some(Duration::from_secs(0)), + None, + )), + granted_scopes: vec![], + }; + mgr.credential_store.save(creds).await.unwrap(); + + let err = mgr.get_access_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::AuthorizationRequired), + "expected AuthorizationRequired, got: {err:?}" + ); + } + + #[tokio::test] + async fn get_access_token_requires_reauth_when_refresh_fails() { + let mgr = AuthorizationManager::new("http://localhost").await.unwrap(); + let creds = super::StoredCredentials { + client_id: "test".to_string(), + token_response: Some(make_token_response( + "expired-token", + Some(Duration::from_secs(0)), + Some("bad-refresh"), + )), + granted_scopes: vec![], + }; + mgr.credential_store.save(creds).await.unwrap(); + + // No oauth_client configured, so refresh_token() will fail with InternalError. + // get_access_token should catch that and return AuthorizationRequired. + let err = mgr.get_access_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::AuthorizationRequired), + "expected AuthorizationRequired on refresh failure, got: {err:?}" + ); + } + + #[tokio::test] + async fn get_access_token_requires_reauth_when_no_credentials() { + let mgr = AuthorizationManager::new("http://localhost").await.unwrap(); + + let err = mgr.get_access_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::AuthorizationRequired), + "expected AuthorizationRequired with no credentials, got: {err:?}" + ); + } + + // -- SEP-2207: offline_access -- + + #[tokio::test] + async fn select_scopes_adds_offline_access_when_as_supports_it() { + let mgr = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]), + ..Default::default() + })) + .await; + *mgr.resource_scopes.write().await = vec!["profile".to_string()]; + + let scopes = mgr.select_scopes(None, &[]); + assert!( + scopes.contains(&"offline_access".to_string()), + "offline_access should be added when AS supports it" + ); + assert!(scopes.contains(&"profile".to_string())); + } + + #[tokio::test] + async fn select_scopes_does_not_add_offline_access_when_as_does_not_support_it() { + let mgr = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + scopes_supported: Some(vec!["profile".to_string(), "email".to_string()]), + ..Default::default() + })) + .await; + *mgr.resource_scopes.write().await = vec!["profile".to_string()]; + + let scopes = mgr.select_scopes(None, &[]); + assert!( + !scopes.contains(&"offline_access".to_string()), + "offline_access should not be added when AS does not support it" + ); + } + + #[tokio::test] + async fn select_scopes_falls_back_to_defaults() { + let mgr = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + scopes_supported: None, + ..Default::default() + })) + .await; + + let scopes = mgr.select_scopes(None, &["default_scope"]); + assert_eq!(scopes, vec!["default_scope".to_string()]); + } + + #[tokio::test] + async fn select_scopes_does_not_duplicate_offline_access() { + let mgr = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]), + ..Default::default() + })) + .await; + + // When AS metadata is the scope source and already contains offline_access, + // it should appear exactly once. + let scopes = mgr.select_scopes(None, &[]); + let count = scopes.iter().filter(|s| *s == "offline_access").count(); + assert_eq!(count, 1, "offline_access should not be duplicated"); + } + + #[tokio::test] + async fn select_scopes_adds_offline_access_to_www_authenticate_scopes() { + let mgr = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]), + ..Default::default() + })) + .await; + *mgr.www_auth_scopes.write().await = vec!["profile".to_string()]; + + let scopes = mgr.select_scopes(None, &[]); + assert!(scopes.contains(&"offline_access".to_string())); + assert!(scopes.contains(&"profile".to_string())); + } + + #[tokio::test] + async fn select_scopes_adds_offline_access_to_www_authenticate_argument() { + let mgr = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]), + ..Default::default() + })) + .await; + + let scopes = mgr.select_scopes(Some("profile email"), &[]); + assert!(scopes.contains(&"offline_access".to_string())); + assert!(scopes.contains(&"profile".to_string())); + assert!(scopes.contains(&"email".to_string())); + } + + #[tokio::test] + async fn add_offline_access_if_supported_works_with_explicit_scopes() { + let mgr = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]), + ..Default::default() + })) + .await; + + let mut explicit = vec!["read".to_string(), "write".to_string()]; + mgr.add_offline_access_if_supported(&mut explicit); + assert!(explicit.contains(&"offline_access".to_string())); + } + + #[tokio::test] + async fn add_offline_access_if_supported_skips_empty_scopes() { + let mgr = manager_with_metadata(Some(AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]), + ..Default::default() + })) + .await; + + let mut empty: Vec = vec![]; + mgr.add_offline_access_if_supported(&mut empty); + assert!( + empty.is_empty(), + "offline_access should not be the only scope" + ); + } + #[test] fn scope_upgrade_config_default_values() { let config = ScopeUpgradeConfig::default();