Skip to content

Commit db6988b

Browse files
committed
Merge non-merged fix for oauth refresh tokens jlowin#2483
1 parent 10f355c commit db6988b

File tree

1 file changed

+82
-68
lines changed

1 file changed

+82
-68
lines changed

src/fastmcp/server/auth/oauth_proxy.py

Lines changed: 82 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,28 @@ class JTIMapping(BaseModel):
179179
created_at: float # Unix timestamp
180180

181181

182+
class RefreshTokenMetadata(BaseModel):
183+
"""Metadata for a refresh token, stored keyed by token hash.
184+
185+
We store only metadata (not the token itself) for security - if storage
186+
is compromised, attackers get hashes they can't reverse into usable tokens.
187+
"""
188+
189+
client_id: str
190+
scopes: list[str]
191+
expires_at: int | None = None
192+
created_at: float
193+
194+
195+
def _hash_token(token: str) -> str:
196+
"""Hash a token for secure storage lookup.
197+
198+
Uses SHA-256 to create a one-way hash. The original token cannot be
199+
recovered from the hash, providing defense in depth if storage is compromised.
200+
"""
201+
return hashlib.sha256(token.encode()).hexdigest()
202+
203+
182204
class ProxyDCRClient(OAuthClientInformationFull):
183205
"""Client for DCR proxy with configurable redirect URI validation.
184206
@@ -624,14 +646,18 @@ class OAuthProxy(OAuthProvider):
624646
625647
State Management
626648
---------------
627-
The proxy maintains minimal but crucial state:
649+
The proxy maintains minimal but crucial state via pluggable storage (client_storage):
628650
- _oauth_transactions: Active authorization flows with client context
629651
- _client_codes: Authorization codes with PKCE challenges and upstream tokens
630-
- _access_tokens, _refresh_tokens: Token storage for revocation
631-
- Token relationship mappings for cleanup and rotation
652+
- _jti_mapping_store: Maps FastMCP token JTIs to upstream token IDs
653+
- _refresh_token_store: Refresh token metadata (keyed by token hash)
654+
655+
All state is stored in the configured client_storage backend (Redis, disk, etc.)
656+
enabling horizontal scaling across multiple instances.
632657
633658
Security Considerations
634659
----------------------
660+
- Refresh tokens stored by hash only (defense in depth if storage compromised)
635661
- PKCE enforced end-to-end (client to proxy, proxy to upstream)
636662
- Authorization codes are single-use with short expiry
637663
- Transaction IDs are cryptographically random
@@ -895,13 +921,17 @@ def __init__(
895921
raise_on_validation_error=True,
896922
)
897923

898-
# Local state for token bookkeeping only (no client caching)
899-
self._access_tokens: dict[str, AccessToken] = {}
900-
self._refresh_tokens: dict[str, RefreshToken] = {}
901-
902-
# Token relation mappings for cleanup
903-
self._access_to_refresh: dict[str, str] = {}
904-
self._refresh_to_access: dict[str, str] = {}
924+
# Refresh token metadata storage, keyed by token hash for security.
925+
# We only store metadata (not the token itself) - if storage is compromised,
926+
# attackers get hashes they can't reverse into usable tokens.
927+
self._refresh_token_store: PydanticAdapter[RefreshTokenMetadata] = (
928+
PydanticAdapter[RefreshTokenMetadata](
929+
key_value=self._client_storage,
930+
pydantic_model=RefreshTokenMetadata,
931+
default_collection="mcp-refresh-tokens",
932+
raise_on_validation_error=True,
933+
)
934+
)
905935

906936
# Use the provided token validator
907937
self._token_validator: TokenVerifier = token_verifier
@@ -1254,25 +1284,18 @@ async def exchange_authorization_code(
12541284
ttl=60 * 60 * 24 * 30, # Auto-expire with refresh token (30 days)
12551285
)
12561286

1257-
# Store FastMCP access token for MCP framework validation
1258-
self._access_tokens[fastmcp_access_token] = AccessToken(
1259-
token=fastmcp_access_token,
1260-
client_id=client.client_id,
1261-
scopes=authorization_code.scopes,
1262-
expires_at=int(time.time() + expires_in),
1263-
)
1264-
1265-
# Store FastMCP refresh token if provided
1287+
# Store refresh token metadata (keyed by hash for security)
12661288
if fastmcp_refresh_token:
1267-
self._refresh_tokens[fastmcp_refresh_token] = RefreshToken(
1268-
token=fastmcp_refresh_token,
1269-
client_id=client.client_id,
1270-
scopes=authorization_code.scopes,
1271-
expires_at=None,
1289+
await self._refresh_token_store.put(
1290+
key=_hash_token(fastmcp_refresh_token),
1291+
value=RefreshTokenMetadata(
1292+
client_id=client.client_id,
1293+
scopes=authorization_code.scopes,
1294+
expires_at=None,
1295+
created_at=time.time(),
1296+
),
1297+
ttl=60 * 60 * 24 * 30, # 30 days
12721298
)
1273-
# Maintain token relationships for cleanup
1274-
self._access_to_refresh[fastmcp_access_token] = fastmcp_refresh_token
1275-
self._refresh_to_access[fastmcp_refresh_token] = fastmcp_access_token
12761299

12771300
logger.debug(
12781301
"Issued FastMCP tokens for client=%s (access_jti=%s, refresh_jti=%s)",
@@ -1316,8 +1339,20 @@ async def load_refresh_token(
13161339
client: OAuthClientInformationFull,
13171340
refresh_token: str,
13181341
) -> RefreshToken | None:
1319-
"""Load refresh token from local storage."""
1320-
return self._refresh_tokens.get(refresh_token)
1342+
"""Load refresh token metadata from distributed storage.
1343+
1344+
Looks up by token hash and reconstructs the RefreshToken object.
1345+
"""
1346+
token_hash = _hash_token(refresh_token)
1347+
metadata = await self._refresh_token_store.get(key=token_hash)
1348+
if not metadata:
1349+
return None
1350+
return RefreshToken(
1351+
token=refresh_token,
1352+
client_id=metadata.client_id,
1353+
scopes=metadata.scopes,
1354+
expires_at=metadata.expires_at,
1355+
)
13211356

13221357
async def exchange_refresh_token(
13231358
self,
@@ -1488,30 +1523,20 @@ async def exchange_refresh_token(
14881523
"Rotated refresh token (old JTI invalidated - one-time use enforced)"
14891524
)
14901525

1491-
# Update local token tracking
1492-
self._access_tokens[new_fastmcp_access] = AccessToken(
1493-
token=new_fastmcp_access,
1494-
client_id=client.client_id,
1495-
scopes=scopes,
1496-
expires_at=int(time.time() + new_expires_in),
1497-
)
1498-
self._refresh_tokens[new_fastmcp_refresh] = RefreshToken(
1499-
token=new_fastmcp_refresh,
1500-
client_id=client.client_id,
1501-
scopes=scopes,
1502-
expires_at=None,
1526+
# Store new refresh token metadata (keyed by hash)
1527+
await self._refresh_token_store.put(
1528+
key=_hash_token(new_fastmcp_refresh),
1529+
value=RefreshTokenMetadata(
1530+
client_id=client.client_id,
1531+
scopes=scopes,
1532+
expires_at=None,
1533+
created_at=time.time(),
1534+
),
1535+
ttl=refresh_ttl,
15031536
)
15041537

1505-
# Update token relationship mappings
1506-
self._access_to_refresh[new_fastmcp_access] = new_fastmcp_refresh
1507-
self._refresh_to_access[new_fastmcp_refresh] = new_fastmcp_access
1508-
1509-
# Clean up old token from in-memory tracking
1510-
self._refresh_tokens.pop(refresh_token.token, None)
1511-
old_access = self._refresh_to_access.pop(refresh_token.token, None)
1512-
if old_access:
1513-
self._access_tokens.pop(old_access, None)
1514-
self._access_to_refresh.pop(old_access, None)
1538+
# Delete old refresh token (by hash)
1539+
await self._refresh_token_store.delete(key=_hash_token(refresh_token.token))
15151540

15161541
logger.info(
15171542
"Issued new FastMCP tokens (rotated refresh) for client=%s (access_jti=%s, refresh_jti=%s)",
@@ -1592,24 +1617,13 @@ async def load_access_token(self, token: str) -> AccessToken | None:
15921617
async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
15931618
"""Revoke token locally and with upstream server if supported.
15941619
1595-
Removes tokens from local storage and attempts to revoke them with
1596-
the upstream server if a revocation endpoint is configured.
1620+
For refresh tokens, removes from local storage by hash.
1621+
For all tokens, attempts upstream revocation if endpoint is configured.
1622+
Access token JTI mappings expire via TTL.
15971623
"""
1598-
# Clean up local token storage
1599-
if isinstance(token, AccessToken):
1600-
self._access_tokens.pop(token.token, None)
1601-
# Also remove associated refresh token
1602-
paired_refresh = self._access_to_refresh.pop(token.token, None)
1603-
if paired_refresh:
1604-
self._refresh_tokens.pop(paired_refresh, None)
1605-
self._refresh_to_access.pop(paired_refresh, None)
1606-
else: # RefreshToken
1607-
self._refresh_tokens.pop(token.token, None)
1608-
# Also remove associated access token
1609-
paired_access = self._refresh_to_access.pop(token.token, None)
1610-
if paired_access:
1611-
self._access_tokens.pop(paired_access, None)
1612-
self._access_to_refresh.pop(paired_access, None)
1624+
# For refresh tokens, delete from local storage by hash
1625+
if isinstance(token, RefreshToken):
1626+
await self._refresh_token_store.delete(key=_hash_token(token.token))
16131627

16141628
# Attempt upstream revocation if endpoint is configured
16151629
if self._upstream_revocation_endpoint:

0 commit comments

Comments
 (0)