11import logging
2- from typing import Any
2+ from typing import Any , Optional
33
44import requests
5- from config import CLIENT_ID , ISSUER_URI , LOGGER , PUBLIC_ISSUER_URI
5+ from config import CLIENT_ID , FRONTEND_CLIENT_ID , ISSUER_URI , LOGGER , PUBLIC_ISSUER_URI
66from fastapi import Request
77from fastapi .responses import RedirectResponse
88from jose import jwk , jwt
@@ -21,11 +21,11 @@ def delete_auth_cookie(response):
2121 path = "/" ,
2222 secure = True ,
2323 httponly = True ,
24- samesite = "none " ,
24+ samesite = "lax " ,
2525 )
2626
2727
28- def get_user_payload (connection : HTTPConnection ) -> dict | None :
28+ def get_user_payload (connection : HTTPConnection ) -> Optional [ dict ] :
2929 token = get_token_source (connection )
3030
3131 if not token :
@@ -34,7 +34,7 @@ def get_user_payload(connection: HTTPConnection) -> dict | None:
3434 try :
3535 return verify_token (token )
3636 except Exception as e :
37- logger .warning (f"Token validation failed: { e } " )
37+ logger .warning (f"Auth failed for token : { e } " )
3838 return None
3939
4040
@@ -46,27 +46,33 @@ def get_public_key(token: str) -> Any:
4646 kid = header .get ("kid" )
4747
4848 if not kid :
49- raise Exception ("Token header missing kid" )
49+ raise Exception ("Token header missing ' kid' field " )
5050
5151 if kid in jwks_cache :
5252 return jwks_cache [kid ]
5353
5454 jwks_uri = f"{ ISSUER_URI } /protocol/openid-connect/certs"
5555
56- response = requests .get (jwks_uri , timeout = 5 )
57- response .raise_for_status ()
58- jwks = response .json ()
56+ try :
57+ response = requests .get (jwks_uri , timeout = 5 )
58+ response .raise_for_status ()
59+ jwks = response .json ()
60+ except Exception as net_err :
61+ logger .error (f"Failed to fetch JWKS from { jwks_uri } : { net_err } " )
62+ raise Exception (
63+ "Could not reach authentication server to verify token" ,
64+ )
5965
6066 for key_data in jwks .get ("keys" , []):
6167 if key_data .get ("kid" ) == kid :
6268 key = jwk .construct (key_data )
6369 jwks_cache [kid ] = key
6470 return key
6571
66- raise Exception ("Public key not found in JWKS" )
72+ raise Exception (f "Public key (kid= { kid } ) not found in JWKS" )
6773
6874 except Exception as e :
69- logger .error (f"Auth Error : { e } " )
75+ logger .error (f"Key retrieval error : { e } " )
7076 raise e
7177
7278
@@ -89,27 +95,34 @@ def verify_token(token: str) -> dict:
8995 elif aud is None :
9096 aud = []
9197
92- if azp and azp == CLIENT_ID :
98+ if azp and azp == CLIENT_ID or azp == FRONTEND_CLIENT_ID :
9399 return payload
94100
95- if CLIENT_ID in aud :
101+ if CLIENT_ID in aud or FRONTEND_CLIENT_ID in aud :
96102 return payload
97103
98- raise Exception (
99- f"Invalid audience/azp. Expected { CLIENT_ID } , got azp={ azp } , aud={ aud } " ,
104+ error_msg = (
105+ f"Audience/AZP mismatch. "
106+ f"Configured CLIENT_ID='{ CLIENT_ID } '. "
107+ f"Token azp='{ azp } ', aud='{ aud } '."
100108 )
109+ logger .warning (error_msg )
110+ raise Exception (error_msg )
111+
101112 except jwt .ExpiredSignatureError :
102113 raise Exception ("Token has expired" )
103114 except jwt .JWTError as e :
104- raise Exception (f"Invalid token: { e !s} " )
115+ raise Exception (f"Invalid token format or signature : { e !s} " )
105116 except Exception as e :
106- raise Exception (f"Authentication failed: { e !s} " )
117+ raise Exception (f"{ e !s} " )
107118
108119
109120def get_token_source (connection : HTTPConnection ) -> str | None :
110121 auth_header = connection .headers .get ("authorization" )
111- if auth_header and auth_header .startswith ("Bearer " ):
112- return auth_header .split (" " )[1 ]
122+ if auth_header :
123+ parts = auth_header .split ()
124+ if len (parts ) == 2 and parts [0 ].lower () == "bearer" :
125+ return parts [1 ]
113126
114127 return connection .cookies .get (AUTH_COOKIE_NAME )
115128
0 commit comments