@@ -179,6 +179,28 @@ class JTIMapping(BaseModel):
179179 created_at : float # Unix timestamp
180180
181181
182+ class RefreshTokenMetadata (BaseModel ):
183+ """Metadata for a refresh token, stored keyed by token hash.
184+
185+ We store only metadata (not the token itself) for security - if storage
186+ is compromised, attackers get hashes they can't reverse into usable tokens.
187+ """
188+
189+ client_id : str
190+ scopes : list [str ]
191+ expires_at : int | None = None
192+ created_at : float
193+
194+
195+ def _hash_token (token : str ) -> str :
196+ """Hash a token for secure storage lookup.
197+
198+ Uses SHA-256 to create a one-way hash. The original token cannot be
199+ recovered from the hash, providing defense in depth if storage is compromised.
200+ """
201+ return hashlib .sha256 (token .encode ()).hexdigest ()
202+
203+
182204class ProxyDCRClient (OAuthClientInformationFull ):
183205 """Client for DCR proxy with configurable redirect URI validation.
184206
@@ -624,14 +646,18 @@ class OAuthProxy(OAuthProvider):
624646
625647 State Management
626648 ---------------
627- The proxy maintains minimal but crucial state:
649+ The proxy maintains minimal but crucial state via pluggable storage (client_storage) :
628650 - _oauth_transactions: Active authorization flows with client context
629651 - _client_codes: Authorization codes with PKCE challenges and upstream tokens
630- - _access_tokens, _refresh_tokens: Token storage for revocation
631- - Token relationship mappings for cleanup and rotation
652+ - _jti_mapping_store: Maps FastMCP token JTIs to upstream token IDs
653+ - _refresh_token_store: Refresh token metadata (keyed by token hash)
654+
655+ All state is stored in the configured client_storage backend (Redis, disk, etc.)
656+ enabling horizontal scaling across multiple instances.
632657
633658 Security Considerations
634659 ----------------------
660+ - Refresh tokens stored by hash only (defense in depth if storage compromised)
635661 - PKCE enforced end-to-end (client to proxy, proxy to upstream)
636662 - Authorization codes are single-use with short expiry
637663 - Transaction IDs are cryptographically random
@@ -895,13 +921,17 @@ def __init__(
895921 raise_on_validation_error = True ,
896922 )
897923
898- # Local state for token bookkeeping only (no client caching)
899- self ._access_tokens : dict [str , AccessToken ] = {}
900- self ._refresh_tokens : dict [str , RefreshToken ] = {}
901-
902- # Token relation mappings for cleanup
903- self ._access_to_refresh : dict [str , str ] = {}
904- self ._refresh_to_access : dict [str , str ] = {}
924+ # Refresh token metadata storage, keyed by token hash for security.
925+ # We only store metadata (not the token itself) - if storage is compromised,
926+ # attackers get hashes they can't reverse into usable tokens.
927+ self ._refresh_token_store : PydanticAdapter [RefreshTokenMetadata ] = (
928+ PydanticAdapter [RefreshTokenMetadata ](
929+ key_value = self ._client_storage ,
930+ pydantic_model = RefreshTokenMetadata ,
931+ default_collection = "mcp-refresh-tokens" ,
932+ raise_on_validation_error = True ,
933+ )
934+ )
905935
906936 # Use the provided token validator
907937 self ._token_validator : TokenVerifier = token_verifier
@@ -1254,25 +1284,18 @@ async def exchange_authorization_code(
12541284 ttl = 60 * 60 * 24 * 30 , # Auto-expire with refresh token (30 days)
12551285 )
12561286
1257- # Store FastMCP access token for MCP framework validation
1258- self ._access_tokens [fastmcp_access_token ] = AccessToken (
1259- token = fastmcp_access_token ,
1260- client_id = client .client_id ,
1261- scopes = authorization_code .scopes ,
1262- expires_at = int (time .time () + expires_in ),
1263- )
1264-
1265- # Store FastMCP refresh token if provided
1287+ # Store refresh token metadata (keyed by hash for security)
12661288 if fastmcp_refresh_token :
1267- self ._refresh_tokens [fastmcp_refresh_token ] = RefreshToken (
1268- token = fastmcp_refresh_token ,
1269- client_id = client .client_id ,
1270- scopes = authorization_code .scopes ,
1271- expires_at = None ,
1289+ await self ._refresh_token_store .put (
1290+ key = _hash_token (fastmcp_refresh_token ),
1291+ value = RefreshTokenMetadata (
1292+ client_id = client .client_id ,
1293+ scopes = authorization_code .scopes ,
1294+ expires_at = None ,
1295+ created_at = time .time (),
1296+ ),
1297+ ttl = 60 * 60 * 24 * 30 , # 30 days
12721298 )
1273- # Maintain token relationships for cleanup
1274- self ._access_to_refresh [fastmcp_access_token ] = fastmcp_refresh_token
1275- self ._refresh_to_access [fastmcp_refresh_token ] = fastmcp_access_token
12761299
12771300 logger .debug (
12781301 "Issued FastMCP tokens for client=%s (access_jti=%s, refresh_jti=%s)" ,
@@ -1316,8 +1339,20 @@ async def load_refresh_token(
13161339 client : OAuthClientInformationFull ,
13171340 refresh_token : str ,
13181341 ) -> RefreshToken | None :
1319- """Load refresh token from local storage."""
1320- return self ._refresh_tokens .get (refresh_token )
1342+ """Load refresh token metadata from distributed storage.
1343+
1344+ Looks up by token hash and reconstructs the RefreshToken object.
1345+ """
1346+ token_hash = _hash_token (refresh_token )
1347+ metadata = await self ._refresh_token_store .get (key = token_hash )
1348+ if not metadata :
1349+ return None
1350+ return RefreshToken (
1351+ token = refresh_token ,
1352+ client_id = metadata .client_id ,
1353+ scopes = metadata .scopes ,
1354+ expires_at = metadata .expires_at ,
1355+ )
13211356
13221357 async def exchange_refresh_token (
13231358 self ,
@@ -1488,30 +1523,20 @@ async def exchange_refresh_token(
14881523 "Rotated refresh token (old JTI invalidated - one-time use enforced)"
14891524 )
14901525
1491- # Update local token tracking
1492- self ._access_tokens [new_fastmcp_access ] = AccessToken (
1493- token = new_fastmcp_access ,
1494- client_id = client .client_id ,
1495- scopes = scopes ,
1496- expires_at = int (time .time () + new_expires_in ),
1497- )
1498- self ._refresh_tokens [new_fastmcp_refresh ] = RefreshToken (
1499- token = new_fastmcp_refresh ,
1500- client_id = client .client_id ,
1501- scopes = scopes ,
1502- expires_at = None ,
1526+ # Store new refresh token metadata (keyed by hash)
1527+ await self ._refresh_token_store .put (
1528+ key = _hash_token (new_fastmcp_refresh ),
1529+ value = RefreshTokenMetadata (
1530+ client_id = client .client_id ,
1531+ scopes = scopes ,
1532+ expires_at = None ,
1533+ created_at = time .time (),
1534+ ),
1535+ ttl = refresh_ttl ,
15031536 )
15041537
1505- # Update token relationship mappings
1506- self ._access_to_refresh [new_fastmcp_access ] = new_fastmcp_refresh
1507- self ._refresh_to_access [new_fastmcp_refresh ] = new_fastmcp_access
1508-
1509- # Clean up old token from in-memory tracking
1510- self ._refresh_tokens .pop (refresh_token .token , None )
1511- old_access = self ._refresh_to_access .pop (refresh_token .token , None )
1512- if old_access :
1513- self ._access_tokens .pop (old_access , None )
1514- self ._access_to_refresh .pop (old_access , None )
1538+ # Delete old refresh token (by hash)
1539+ await self ._refresh_token_store .delete (key = _hash_token (refresh_token .token ))
15151540
15161541 logger .info (
15171542 "Issued new FastMCP tokens (rotated refresh) for client=%s (access_jti=%s, refresh_jti=%s)" ,
@@ -1592,24 +1617,13 @@ async def load_access_token(self, token: str) -> AccessToken | None:
15921617 async def revoke_token (self , token : AccessToken | RefreshToken ) -> None :
15931618 """Revoke token locally and with upstream server if supported.
15941619
1595- Removes tokens from local storage and attempts to revoke them with
1596- the upstream server if a revocation endpoint is configured.
1620+ For refresh tokens, removes from local storage by hash.
1621+ For all tokens, attempts upstream revocation if endpoint is configured.
1622+ Access token JTI mappings expire via TTL.
15971623 """
1598- # Clean up local token storage
1599- if isinstance (token , AccessToken ):
1600- self ._access_tokens .pop (token .token , None )
1601- # Also remove associated refresh token
1602- paired_refresh = self ._access_to_refresh .pop (token .token , None )
1603- if paired_refresh :
1604- self ._refresh_tokens .pop (paired_refresh , None )
1605- self ._refresh_to_access .pop (paired_refresh , None )
1606- else : # RefreshToken
1607- self ._refresh_tokens .pop (token .token , None )
1608- # Also remove associated access token
1609- paired_access = self ._refresh_to_access .pop (token .token , None )
1610- if paired_access :
1611- self ._access_tokens .pop (paired_access , None )
1612- self ._access_to_refresh .pop (paired_access , None )
1624+ # For refresh tokens, delete from local storage by hash
1625+ if isinstance (token , RefreshToken ):
1626+ await self ._refresh_token_store .delete (key = _hash_token (token .token ))
16131627
16141628 # Attempt upstream revocation if endpoint is configured
16151629 if self ._upstream_revocation_endpoint :
0 commit comments