Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 297 additions & 14 deletions crates/rmcp/src/transport/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -789,12 +789,23 @@ impl AuthorizationManager {
attempts < self.scope_upgrade_config.max_upgrade_attempts
}

/// select scopes to request from authorization server
pub fn select_scopes(
&self,
www_authenticate_scope: Option<&str>,
default_scopes: &[&str],
) -> Vec<String> {
let mut scopes = self.select_base_scopes(www_authenticate_scope, default_scopes);
self.add_offline_access_if_supported(&mut scopes);
scopes
}

/// select scopes based on SEP-835 priority:
/// 1. scope from WWW-Authenticate header (argument or stored from initial 401 probe)
/// 2. scopes_supported from protected resource metadata (RFC 9728)
/// 3. scopes_supported from authorization server metadata
/// 4. provided default scopes
pub fn select_scopes(
fn select_base_scopes(
&self,
www_authenticate_scope: Option<&str>,
default_scopes: &[&str],
Expand Down Expand Up @@ -829,6 +840,21 @@ impl AuthorizationManager {
default_scopes.iter().map(|s| s.to_string()).collect()
}

/// SEP-2207: when the AS advertises `offline_access` in `scopes_supported`, append
/// it so OIDC-flavored Authorization Servers will issue refresh tokens.
fn add_offline_access_if_supported(&self, scopes: &mut Vec<String>) {
if scopes.is_empty() || scopes.iter().any(|s| s == "offline_access") {
return;
}
if let Some(metadata) = &self.metadata {
if let Some(supported) = &metadata.scopes_supported {
if supported.iter().any(|s| s == "offline_access") {
scopes.push("offline_access".to_string());
}
}
}
}

/// attempt to upgrade scopes after receiving a 403 insufficient_scope error.
/// returns the authorization URL for re-authorization with upgraded scopes.
pub async fn request_scope_upgrade(&self, required_scope: &str) -> Result<String, AuthError> {
Expand Down Expand Up @@ -949,22 +975,32 @@ impl AuthorizationManager {
Ok(token_result)
}

/// get access token, if expired, refresh it automatically
/// get access token from local credential store.
/// if expired, refresh it automatically when a refresh token is available.
/// when the access token has expired and no refresh token is available, it returns
/// [`AuthError::AuthorizationRequired`] so the caller can re-authenticate.
pub async fn get_access_token(&self) -> Result<String, AuthError> {
// Load credentials from store
let stored = self.credential_store.load().await?;
let credentials = stored.and_then(|s| s.token_response);

if let Some(creds) = credentials.as_ref() {
// check token expiry if we have a refresh token or an expiry time
if creds.refresh_token().is_some() || creds.expires_in().is_some() {
let expires_in = creds.expires_in().unwrap_or(Duration::from_secs(0));
if let Some(expires_in) = creds.expires_in() {
if expires_in <= Duration::from_secs(0) {
tracing::info!("Access token expired, refreshing.");

let new_creds = self.refresh_token().await?;
tracing::info!("Refreshed access token.");
return Ok(new_creds.access_token().secret().to_string());
if creds.refresh_token().is_some() {
tracing::info!("Access token expired, attempting refresh.");
match self.refresh_token().await {
Ok(new_creds) => {
tracing::info!("Refreshed access token.");
return Ok(new_creds.access_token().secret().to_string());
}
Err(e) => {
tracing::warn!("Token refresh failed: {e}");
}
}
} else {
tracing::info!("Access token expired and no refresh token available.");
}
return Err(AuthError::AuthorizationRequired);
}
}

Expand All @@ -991,7 +1027,6 @@ impl AuthorizationManager {
let refresh_token = current_credentials.refresh_token().ok_or_else(|| {
AuthError::TokenRefreshFailed("No refresh token available".to_string())
})?;
debug!("refresh token: {:?}", refresh_token);

let token_result = oauth_client
.exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string()))
Expand Down Expand Up @@ -1664,7 +1699,9 @@ impl OAuthState {
let selected_scopes: Vec<String> = if scopes.is_empty() {
manager.select_scopes(None, &[])
} else {
scopes.iter().map(|s| s.to_string()).collect()
let mut s: Vec<String> = scopes.iter().map(|s| s.to_string()).collect();
manager.add_offline_access_if_supported(&mut s);
s
};
let scope_refs: Vec<&str> = selected_scopes.iter().map(|s| s.as_str()).collect();
debug!("start session");
Expand Down Expand Up @@ -1823,7 +1860,7 @@ impl OAuthState {

#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, sync::Arc, time::Duration};

use oauth2::{AuthType, CsrfToken, PkceCodeVerifier};
use url::Url;
Expand Down Expand Up @@ -2594,6 +2631,252 @@ mod tests {
assert_eq!(result.len(), 2);
}

// -- SEP-2207: get_access_token refresh behavior --

fn make_token_response(
access_token: &str,
expires_in: Option<Duration>,
refresh_token: Option<&str>,
) -> super::OAuthTokenResponse {
use oauth2::{AccessToken, basic::BasicTokenType};
let mut resp = super::OAuthTokenResponse::new(
AccessToken::new(access_token.to_string()),
BasicTokenType::Bearer,
oauth2::EmptyExtraTokenFields {},
);
resp.set_expires_in(expires_in.as_ref());
if let Some(rt) = refresh_token {
resp.set_refresh_token(Some(oauth2::RefreshToken::new(rt.to_string())));
}
resp
}

#[tokio::test]
async fn get_access_token_returns_token_when_not_expired() {
let mgr = AuthorizationManager::new("http://localhost").await.unwrap();
let creds = super::StoredCredentials {
client_id: "test".to_string(),
token_response: Some(make_token_response(
"valid-token",
Some(Duration::from_secs(3600)),
None,
)),
granted_scopes: vec![],
};
mgr.credential_store.save(creds).await.unwrap();

let token = mgr.get_access_token().await.unwrap();
assert_eq!(token, "valid-token");
}

#[tokio::test]
async fn get_access_token_returns_token_when_no_expires_in() {
let mgr = AuthorizationManager::new("http://localhost").await.unwrap();
let creds = super::StoredCredentials {
client_id: "test".to_string(),
token_response: Some(make_token_response(
"no-expiry-token",
None,
Some("refresh-tok"),
)),
granted_scopes: vec![],
};
mgr.credential_store.save(creds).await.unwrap();

let token = mgr.get_access_token().await.unwrap();
assert_eq!(token, "no-expiry-token");
}

#[tokio::test]
async fn get_access_token_requires_reauth_when_expired_without_refresh_token() {
let mgr = AuthorizationManager::new("http://localhost").await.unwrap();
let creds = super::StoredCredentials {
client_id: "test".to_string(),
token_response: Some(make_token_response(
"expired-token",
Some(Duration::from_secs(0)),
None,
)),
granted_scopes: vec![],
};
mgr.credential_store.save(creds).await.unwrap();

let err = mgr.get_access_token().await.unwrap_err();
assert!(
matches!(err, AuthError::AuthorizationRequired),
"expected AuthorizationRequired, got: {err:?}"
);
}

#[tokio::test]
async fn get_access_token_requires_reauth_when_refresh_fails() {
let mgr = AuthorizationManager::new("http://localhost").await.unwrap();
let creds = super::StoredCredentials {
client_id: "test".to_string(),
token_response: Some(make_token_response(
"expired-token",
Some(Duration::from_secs(0)),
Some("bad-refresh"),
)),
granted_scopes: vec![],
};
mgr.credential_store.save(creds).await.unwrap();

// No oauth_client configured, so refresh_token() will fail with InternalError.
// get_access_token should catch that and return AuthorizationRequired.
let err = mgr.get_access_token().await.unwrap_err();
assert!(
matches!(err, AuthError::AuthorizationRequired),
"expected AuthorizationRequired on refresh failure, got: {err:?}"
);
}

#[tokio::test]
async fn get_access_token_requires_reauth_when_no_credentials() {
let mgr = AuthorizationManager::new("http://localhost").await.unwrap();

let err = mgr.get_access_token().await.unwrap_err();
assert!(
matches!(err, AuthError::AuthorizationRequired),
"expected AuthorizationRequired with no credentials, got: {err:?}"
);
}

// -- SEP-2207: offline_access --

#[tokio::test]
async fn select_scopes_adds_offline_access_when_as_supports_it() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;
*mgr.resource_scopes.write().await = vec!["profile".to_string()];

let scopes = mgr.select_scopes(None, &[]);
assert!(
scopes.contains(&"offline_access".to_string()),
"offline_access should be added when AS supports it"
);
assert!(scopes.contains(&"profile".to_string()));
}

#[tokio::test]
async fn select_scopes_does_not_add_offline_access_when_as_does_not_support_it() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "email".to_string()]),
..Default::default()
}))
.await;
*mgr.resource_scopes.write().await = vec!["profile".to_string()];

let scopes = mgr.select_scopes(None, &[]);
assert!(
!scopes.contains(&"offline_access".to_string()),
"offline_access should not be added when AS does not support it"
);
}

#[tokio::test]
async fn select_scopes_falls_back_to_defaults() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: None,
..Default::default()
}))
.await;

let scopes = mgr.select_scopes(None, &["default_scope"]);
assert_eq!(scopes, vec!["default_scope".to_string()]);
}

#[tokio::test]
async fn select_scopes_does_not_duplicate_offline_access() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;

// When AS metadata is the scope source and already contains offline_access,
// it should appear exactly once.
let scopes = mgr.select_scopes(None, &[]);
let count = scopes.iter().filter(|s| *s == "offline_access").count();
assert_eq!(count, 1, "offline_access should not be duplicated");
}

#[tokio::test]
async fn select_scopes_adds_offline_access_to_www_authenticate_scopes() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;
*mgr.www_auth_scopes.write().await = vec!["profile".to_string()];

let scopes = mgr.select_scopes(None, &[]);
assert!(scopes.contains(&"offline_access".to_string()));
assert!(scopes.contains(&"profile".to_string()));
}

#[tokio::test]
async fn select_scopes_adds_offline_access_to_www_authenticate_argument() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;

let scopes = mgr.select_scopes(Some("profile email"), &[]);
assert!(scopes.contains(&"offline_access".to_string()));
assert!(scopes.contains(&"profile".to_string()));
assert!(scopes.contains(&"email".to_string()));
}

#[tokio::test]
async fn add_offline_access_if_supported_works_with_explicit_scopes() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;

let mut explicit = vec!["read".to_string(), "write".to_string()];
mgr.add_offline_access_if_supported(&mut explicit);
assert!(explicit.contains(&"offline_access".to_string()));
}

#[tokio::test]
async fn add_offline_access_if_supported_skips_empty_scopes() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;

let mut empty: Vec<String> = vec![];
mgr.add_offline_access_if_supported(&mut empty);
assert!(
empty.is_empty(),
"offline_access should not be the only scope"
);
}

#[test]
fn scope_upgrade_config_default_values() {
let config = ScopeUpgradeConfig::default();
Expand Down