Skip to content

Commit bebdd13

Browse files
committed
update authentication backend to respect cookies and bearer authorization
1 parent b26a55f commit bebdd13

File tree

2 files changed

+196
-78
lines changed

2 files changed

+196
-78
lines changed

backend/auth.py

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import requests
55
from config import CLIENT_ID, ISSUER_URI, LOGGER, PUBLIC_ISSUER_URI
6-
from fastapi import HTTPException, Request, status
6+
from fastapi import Request
77
from fastapi.responses import RedirectResponse
88
from jose import jwk, jwt
99
from starlette.requests import HTTPConnection
@@ -25,37 +25,17 @@ def delete_auth_cookie(response):
2525
)
2626

2727

28-
async def authenticate_connection(connection, token: str | None):
29-
user_payload = None
28+
def get_user_payload(connection: HTTPConnection) -> dict | None:
29+
token = get_token_source(connection)
3030

31-
if token:
32-
try:
33-
user_payload = verify_token(token)
34-
except Exception as e:
35-
logger.warning(f"Token validation failed: {e}")
31+
if not token:
32+
return None
3633

37-
if user_payload:
38-
return user_payload
39-
40-
if connection.scope["type"] == "http":
41-
response = RedirectResponse("/login", status_code=302)
42-
delete_auth_cookie(response)
43-
44-
accept_header = connection.headers.get("accept", "")
45-
46-
if "text/html" in accept_header:
47-
raise UnauthenticatedRedirect(response=response)
48-
49-
raise HTTPException(
50-
status_code=401,
51-
detail="Not authenticated",
52-
headers={"WWW-Authenticate": "Bearer"},
53-
)
54-
55-
if connection.scope["type"] == "websocket":
56-
raise HTTPException(status_code=403, detail="Not authenticated")
57-
58-
raise RuntimeError("Unsupported connection type")
34+
try:
35+
return verify_token(token)
36+
except Exception as e:
37+
logger.warning(f"Token validation failed: {e}")
38+
return None
5939

6040

6141
def get_public_key(token: str) -> Any:
@@ -71,7 +51,6 @@ def get_public_key(token: str) -> Any:
7151
if kid in jwks_cache:
7252
return jwks_cache[kid]
7353

74-
# TODO use openid config endpoint to obtain well-known endpoints in the future
7554
jwks_uri = f"{ISSUER_URI}/protocol/openid-connect/certs"
7655

7756
response = requests.get(jwks_uri, timeout=5)
@@ -87,12 +66,8 @@ def get_public_key(token: str) -> Any:
8766
raise Exception("Public key not found in JWKS")
8867

8968
except Exception as e:
90-
print(f"Auth Error: {e}")
91-
raise HTTPException(
92-
status_code=status.HTTP_401_UNAUTHORIZED,
93-
detail="Could not validate credentials",
94-
headers={"WWW-Authenticate": "Bearer"},
95-
)
69+
logger.error(f"Auth Error: {e}")
70+
raise e
9671

9772

9873
def verify_token(token: str) -> dict:
@@ -124,16 +99,14 @@ def verify_token(token: str) -> dict:
12499
f"Invalid audience/azp. Expected {CLIENT_ID}, got azp={azp}, aud={aud}",
125100
)
126101
except jwt.ExpiredSignatureError:
127-
raise HTTPException(status_code=401, detail="Token has expired")
102+
raise Exception("Token has expired")
128103
except jwt.JWTError as e:
129-
raise HTTPException(status_code=401, detail=f"Invalid token: {e!s}")
130-
except Exception:
131-
raise HTTPException(status_code=401, detail="Authentication failed")
104+
raise Exception(f"Invalid token: {e!s}")
105+
except Exception as e:
106+
raise Exception(f"Authentication failed: {e!s}")
132107

133108

134-
async def get_token_source(
135-
connection: HTTPConnection,
136-
) -> str | None:
109+
def get_token_source(connection: HTTPConnection) -> str | None:
137110
auth_header = connection.headers.get("authorization")
138111
if auth_header and auth_header.startswith("Bearer "):
139112
return auth_header.split(" ")[1]

backend/main.py

Lines changed: 179 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from api.resolvers import Mutation, Query, Subscription
66
from auth import (
77
UnauthenticatedRedirect,
8-
authenticate_connection,
9-
get_token_source,
8+
get_user_payload,
109
unauthenticated_redirect_handler,
1110
)
1211
from config import (
@@ -19,55 +18,87 @@
1918
from database.models.user import User
2019
from database.session import get_db_session
2120
from fastapi import Depends, FastAPI, HTTPException, Request
22-
from fastapi.responses import RedirectResponse
21+
from fastapi.responses import HTMLResponse, RedirectResponse
22+
from graphql import FieldNode
2323
from sqlalchemy import select
2424
from starlette.requests import HTTPConnection
25-
from starlette.status import HTTP_400_BAD_REQUEST
2625
from strawberry import Schema
26+
from strawberry.extensions import SchemaExtension
2727
from strawberry.fastapi import GraphQLRouter
2828

2929
logger = logging.getLogger(LOGGER)
3030

3131

32+
class GlobalAuthExtension(SchemaExtension):
33+
def on_execute(self):
34+
execution_context = self.execution_context
35+
user = execution_context.context.user
36+
37+
if user:
38+
yield
39+
return
40+
41+
document = execution_context.graphql_document
42+
if document:
43+
for definition in document.definitions:
44+
if definition.kind == "operation_definition":
45+
for selection in definition.selection_set.selections:
46+
if not isinstance(selection, FieldNode):
47+
raise HTTPException(
48+
status_code=401,
49+
detail="Not authenticated",
50+
)
51+
if not selection.name.value.startswith("__"):
52+
raise HTTPException(
53+
status_code=401,
54+
detail="Not authenticated",
55+
)
56+
57+
yield
58+
59+
3260
async def get_context(
3361
connection: HTTPConnection,
34-
token: str | None = Depends(get_token_source),
3562
session=Depends(get_db_session),
3663
) -> Context:
37-
user_payload = await authenticate_connection(connection, token)
64+
user_payload = get_user_payload(connection)
3865

39-
user_id = user_payload.get("sub")
40-
username = user_payload.get("preferred_username") or user_payload.get(
41-
"name",
42-
)
43-
firstname = user_payload.get("given_name")
44-
lastname = user_payload.get("family_name")
66+
db_user = None
4567

46-
if not (user_id and username and firstname and lastname):
47-
raise HTTPException(
48-
status_code=HTTP_400_BAD_REQUEST,
49-
detail="Missing required user details.",
68+
if user_payload:
69+
user_id = user_payload.get("sub")
70+
username = user_payload.get("preferred_username") or user_payload.get(
71+
"name",
5072
)
73+
firstname = user_payload.get("given_name")
74+
lastname = user_payload.get("family_name")
5175

52-
result = await session.execute(select(User).where(User.id == user_id))
53-
db_user = result.scalars().first()
76+
if user_id:
77+
result = await session.execute(
78+
select(User).where(User.id == user_id),
79+
)
80+
db_user = result.scalars().first()
5481

55-
if not db_user:
56-
db_user = User(
57-
id=user_id,
58-
name=username,
59-
firstname=firstname,
60-
lastname=lastname,
61-
title="User",
62-
)
63-
session.add(db_user)
64-
elif db_user.name != username or db_user.firstname != firstname:
65-
db_user.name = username
66-
db_user.firstname = firstname
67-
db_user.lastname = lastname
82+
if not db_user:
83+
db_user = User(
84+
id=user_id,
85+
name=username,
86+
firstname=firstname,
87+
lastname=lastname,
88+
title="User",
89+
)
90+
session.add(db_user)
91+
elif (
92+
db_user.name != username
93+
or db_user.firstname != firstname
94+
or db_user.lastname != lastname
95+
):
96+
db_user.name = username
97+
db_user.firstname = firstname
98+
db_user.lastname = lastname
6899

69-
await session.commit()
70-
await session.refresh(db_user)
100+
await session.commit()
101+
await session.refresh(db_user)
71102

72103
return Context(db=session, user=db_user)
73104

@@ -76,8 +107,107 @@ async def get_context(
76107
query=Query,
77108
mutation=Mutation,
78109
subscription=Subscription,
110+
extensions=[GlobalAuthExtension],
79111
)
80-
graphql_app = GraphQLRouter(
112+
113+
114+
class AuthedGraphQLRouter(GraphQLRouter):
115+
async def render_graphql_ide(self, request: Request) -> HTMLResponse:
116+
response = await super().render_graphql_ide(request)
117+
118+
redirect_uri = f"{request.base_url}callback"
119+
login_url = (
120+
f"{ISSUER_URI}/protocol/openid-connect/auth"
121+
f"?client_id={CLIENT_ID}"
122+
f"&response_type=code"
123+
f"&scope=openid profile email"
124+
f"&redirect_uri={redirect_uri}"
125+
)
126+
logout_url = f"{request.base_url}logout"
127+
128+
is_authenticated = (
129+
"true" if request.cookies.get("access_token") else "false"
130+
)
131+
132+
injection_script = f"""
133+
<script>
134+
(function() {{
135+
var loginUrl = "{login_url}";
136+
var logoutUrl = "{logout_url}";
137+
var isAuthenticated = {is_authenticated};
138+
139+
function injectAuthButton() {{
140+
var sidebars = document.querySelectorAll('.graphiql-sidebar-section');
141+
var sidebar = sidebars[0];
142+
143+
if (sidebar && !document.getElementById('custom-auth-button')) {{
144+
var button = document.createElement('button');
145+
button.id = 'custom-auth-button';
146+
button.className = 'graphiql-un-styled';
147+
button.type = 'button';
148+
149+
if (isAuthenticated) {{
150+
button.setAttribute('aria-label', 'Logout');
151+
button.title = 'Logout';
152+
button.innerHTML = `
153+
<svg height="1em" viewBox="0 0 24 24" fill="none"
154+
stroke="currentColor" stroke-width="1.5"
155+
stroke-linecap="round" stroke-linejoin="round"
156+
xmlns="http://www.w3.org/2000/svg">
157+
<path d="M9 21H5a2 2 0 0 1-2-2V5a2 2 0 0 1 2-2h4"></path>
158+
<polyline points="16 17 21 12 16 7"></polyline>
159+
<line x1="21" y1="12" x2="9" y2="12"></line>
160+
</svg>
161+
`;
162+
button.onclick = function() {{
163+
window.location.href = logoutUrl;
164+
}};
165+
}} else {{
166+
button.setAttribute('aria-label', 'Login with OIDC');
167+
button.title = 'Login with OIDC';
168+
button.innerHTML = `
169+
<svg height="1em" viewBox="0 0 24 24" fill="none"
170+
stroke="currentColor" stroke-width="1.5"
171+
stroke-linecap="round" stroke-linejoin="round"
172+
xmlns="http://www.w3.org/2000/svg">
173+
<path d="M15 3h4a2 2 0 0 1 2 2v14a2 2 0 0 1-2 2h-4"></path>
174+
<polyline points="10 17 15 12 10 7"></polyline>
175+
<line x1="15" y1="12" x2="3" y2="12"></line>
176+
</svg>
177+
`;
178+
button.onclick = function() {{
179+
window.location.href = loginUrl;
180+
}};
181+
}}
182+
183+
sidebar.appendChild(button);
184+
}}
185+
}}
186+
187+
var observer = new MutationObserver(function(mutations) {{
188+
injectAuthButton();
189+
}});
190+
191+
observer.observe(document.body, {{
192+
childList: true,
193+
subtree: true
194+
}});
195+
196+
injectAuthButton();
197+
}})();
198+
</script>
199+
"""
200+
201+
html_content = response.body.decode("utf-8")
202+
new_html = html_content.replace(
203+
"</body>",
204+
f"{injection_script}</body>",
205+
)
206+
207+
return HTMLResponse(new_html)
208+
209+
210+
graphql_app = AuthedGraphQLRouter(
81211
schema,
82212
context_getter=get_context,
83213
graphql_ide=IS_DEV,
@@ -97,6 +227,19 @@ async def get_context(
97227
)
98228

99229

230+
@app.get("/logout")
231+
def logout(_: Request):
232+
response = RedirectResponse(url="/graphql")
233+
response.delete_cookie(
234+
"access_token",
235+
path="/",
236+
secure=True,
237+
httponly=True,
238+
samesite="lax",
239+
)
240+
return response
241+
242+
100243
@app.get("/callback")
101244
def oauth_callback(code: str, request: Request):
102245
token_endpoint = f"{ISSUER_URI}/protocol/openid-connect/token"
@@ -123,6 +266,8 @@ def oauth_callback(code: str, request: Request):
123266
access_token,
124267
httponly=True,
125268
max_age=3600,
269+
samesite="lax",
270+
secure=True,
126271
)
127272
return resp
128273

0 commit comments

Comments
 (0)